numpy.where() 方法如何处理数组元素和目标不是同一数据类型时的相等条件

how does numpy.where() method handle the equality condition when the array element and target are not the same data type

我有一个很长的列表,它的元素类型是 int。我想找到等于某个数字的元素的索引,我使用 np.where 来实现这一点。

以下是我的原代码,

# suppose x is [1, 1, 2, 3]
y = np.array(x, dtype=np.float32)
idx = list(np.where(y==1)[0])
# output is [0, 1]

一段时间后检查代码后,我意识到我不应该使用 dtype=np.float32,因为它会将 y 的数据类型更改为 float。正确的代码应该是下面的,

# suppose x is [1, 1, 2, 3]
y = np.array(x)
idx = list(np.where(y==1)[0])
# output is also [0, 1]

令人惊讶的是,这两个代码片段产生了完全相同的结果。

我的问题

当数组和目标的数据类型不兼容(例如 int 与 float,例如)时,我是否在 numpy.where 中处理了相等性测试条件?

NumPy where (source code here) 不关心数据类型的比较:它的第一个参数是一个 bool 类型的数组。当你写 y == 1 时,这是一个数组比较操作,returns 一个布尔数组,然后作为参数传递给 where

相关方法是 equal,您可以通过编写 y == 1 隐式调用它。它的文档说:

What is compared are values, not types.

例如,

x, y, z = np.float64(0.25), np.float32(0.25), 0.25

这些都是不同的类型,(numpy.float64, numpy.float32, float) 但是 x == y 和 y == z 以及 x == z 是正确的。这里重要的是 0.25 在二进制系统中精确表示 (1/4)。

x, y, z = np.float64(0.2), np.float32(0.2), 0.2

我们看到 x == y 是 False,y == z 是 False,但是 x == z 是 True,因为 Python 浮点数和 np.float64 一样是 64 位的。由于 1/5 不能精确地用二进制表示,使用 32 位和 64 位会导致两个不同的 1/5 近似值,这就是相等性失败的原因:不是因为类型,而是因为 np.float64(0.2)np.float32(0.2) 实际上是不同的值(它们的差异大约是 3e-9)。