统计每个热码出现的次数

count the number of occurance of each one hot code

我有一个 numpy 数组列表(one-hot 表示),如下例所示,我想计算每个 one-hot 代码的出现次数。

[0 0 1 0 0 0 0 0 0 0]
[0 0 1 0 0 0 0 0 0 0]
[0 1 0 0 0 0 0 0 0 0]
[0 0 0 0 0 1 0 0 0 0]
[0 1 0 0 0 0 0 0 0 0]
[0 0 0 0 1 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 1]
[0 0 0 0 1 0 0 0 0 0]
[1 0 0 0 0 0 0 0 0 0]
[0 0 0 1 0 0 0 0 0 0]
[0 1 0 0 0 0 0 0 0 0]
[0 0 0 0 0 0 0 0 0 1]

编辑: 预期输出:

[1 0 0 0 0 0 0 0 0 0] ==> 1 occurrence
[0 0 1 0 0 0 0 0 0 0] ==> 2 occurrences
[0 1 0 0 0 0 0 0 0 0] ==> 3 occurrences
[0 0 0 0 0 1 0 0 0 0] ==> 1 occurrence
[0 0 0 0 1 0 0 0 0 0] ==> 2 occurrences
[0 0 0 0 0 0 0 0 0 1] ==> 2 occurrences

我想你可以得到你想要的结果:

[1 3 2 1 2 1 0 0 0 2]

使用 ndarray.sum():

通过一个简单的 column-wise 求和来表示该位置上一个热点的出现次数
import numpy
data = numpy.array([
    [0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
    [0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
    [0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
    [0, 0, 0, 0, 0, 1, 0, 0, 0, 0],
    [0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
    [0, 0, 0, 0, 1, 0, 0, 0, 0, 0],
    [0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
    [0, 0, 0, 0, 1, 0, 0, 0, 0, 0],
    [1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    [0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
    [0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
    [0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
])
print(numpy.ndarray.sum(data, axis=0))

或更简洁地表示为:

print(data.sum(axis=0))

两者都应该给你:

[1 3 2 1 2 1 0 0 0 2]

利用每行1个热的面,可以做如下操作:

temp = np.array([[0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
[0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 1, 0, 0, 0, 0],
[0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 1, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 1, 0, 0, 0, 0, 0],
[1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0 ,0 ,0 ,1 ,0 ,0 ,0 ,0 ,0, 0],
[0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 1]])

将 one-hot 转换为索引可以按如下方式完成:

temp2 = np.argmax(temp, axis=1)  # array([2, 2, 1, 5, 1, 4, 9, 4, 0, 3, 1, 9])

然后可以使用 np.histogram 计算出现次数。我们知道您有 10 个可能的值,因此我们使用 10 个 bin,如下所示:

temp3 = np.histogram(temp2, bins=10, range=(-0.5,9.5))

np.histogram returns 一个双元组,其中索引 [0] 保存直方图值,索引 [1] 保存箱子。在你的情况下:

(array([1, 3, 2, 1, 2, 1, 0, 0, 0, 2]),
 array([-0.5,  0.5,  1.5,  2.5,  3.5,  4.5,  5.5,  6.5,  7.5,  8.5,  9.5]))