为什么这个布尔值在这个贝叶斯分类器中? (Python 个问题?)

Why this boolean is in this bayes classifier? (Python question?)

我正在研究GANs(并且我是python的初学者),我发现之前练习中的这部分代码我看不懂。具体来说,我不明白为什么使用第 9 行的布尔值 (Xk = X[Y == k]),原因我在下面写下

class BayesClassifier:
  def fit(self, X, Y):
    # assume classes are numbered 0...K-1
    self.K = len(set(Y))

    self.gaussians = []
    self.p_y = np.zeros(self.K)
    for k in range(self.K):
      Xk = X[Y == k]
      self.p_y[k] = len(Xk)
      mean = Xk.mean(axis=0)
      cov = np.cov(Xk.T)
      g = {'m': mean, 'c': cov}
      self.gaussians.append(g)
    # normalize p(y)
    self.p_y /= self.p_y.sum()
  1. 布尔值 return 0 或 1 取决于 Y == 的真实性 k,因此 Xk 总是 X 列表的第一个或第二个值。 Y 没有找到它的用处。
  2. 在第 10 行中,len(Xk) 始终为 1,为什么它使用该参数而不是单个 1?
  3. 下一行的均值和协方差每次只计算一个值。

我觉得我不太了解一些非常基本的东西。

您应该考虑到 X, Y, k 是 NumPy 数组,而不是标量,并且某些运算符已为它们重载。特别是 == 和基于布尔的索引。 == 将是逐元素比较,而不是整个数组比较。

看看它是如何工作的:

In [9]: Y = np.array([0,1,2])                                                                                        
In [10]: k = np.array([0,1,3])                                                                                       
In [11]: Y==k                                                                                                        

Out[11]: array([ True,  True, False])

所以,==的结果是一个布尔数组。

In [12]: X=np.array([0,2,4])                                                                                         
In [13]: X[Y==k]                                                                                                     

Out[13]: array([0, 2])

当条件为True

时,结果是从X中选择元素的数组

因此 len(Xk) 将是 Xk 之间的匹配元素数。

谢谢,阿尔乔姆,

你是对的。我在另一个渠道找到了另一个答案,这里是:

It's a Numpy array - it's a special feature of NumPy arrays called boolean indexing that lets you filter out only the values in the array where the filter returns True:

https://docs.scipy.org/doc/numpy-1.13.0/user/basics.indexing.html?fbclid=IwAR3sGlgSwhv3i7IETsIxp4ROu9oZvNaaaBxZS01DrM5ShjWWRz22ShP2rIg#boolean-or-mask-index-arrays

import numpy as np

a = np.array([1, 2, 3, 4, 5]) filter = a > 3

print(filter)

[False, False, False, True, True]

打印(a[过滤器])

[4, 5]