numpy.where 函数有什么问题?
What is wrong with numpy.where function?
我有一个 numpy 数组(2 个元素列表的列表)a
在下面给出,我有一个我想要查找的 2 个元素列表 [30.94, 0.]
。
当我执行以下操作时,我没有得到想要的结果。为什么?
import numpy as np
a = np.array([[ 5.73, 0. ],
[ 57.73, 10. ],
[ 57.73, 20. ],
[ 30.94, 0. ],
[ 30.94, 10. ],
[ 30.94, 20. ],
[ 4.14, 0. ],
[ 4.14, 10. ]])
np.where(a==np.array([30.94, 0.]))
但是我明白了
(array([0, 3, 3, 4, 5, 6]), array([1, 0, 1, 0, 0, 1]))
这不是真的。
正如 Divakar 所暗示的,a == np.array([30.94, 0.])
不是您所期望的。广播数组,并按元素进行比较。这是结果:
array([[False, True],
[False, False],
[False, False],
[ True, True],
[ True, False],
[ True, False],
[False, True],
[False, False]], dtype=bool)
然而,我们可以通过np.all
得到我们想要的:
>>> np.all(a==np.array([30.94, 0.]), axis=-1)
array([False, False, False, True, False, False, False, False], dtype=bool)
>>> np.where(_)
(array([3]),)
因此您可以看到第 3 行符合预期。请注意,将 ==
与浮点数一起使用的常见注意事项将适用于此处。
另一种解决方案,但请注意,这会比 慢一点,尤其是对于大型数组。
In [1]: cond = np.array([30.94, 0.])
In [2]: arr = np.array([[ 5.73, 0. ],
[ 57.73, 10. ],
[ 57.73, 20. ],
[ 30.94, 0. ],
[ 30.94, 10. ],
[ 30.94, 20. ],
[ 4.14, 0. ],
[ 4.14, 10. ]])
In [3]: [idx for idx, el in enumerate(arr) if np.array_equal(el, cond)]
Out[3]: [3]
我有一个 numpy 数组(2 个元素列表的列表)a
在下面给出,我有一个我想要查找的 2 个元素列表 [30.94, 0.]
。
当我执行以下操作时,我没有得到想要的结果。为什么?
import numpy as np
a = np.array([[ 5.73, 0. ],
[ 57.73, 10. ],
[ 57.73, 20. ],
[ 30.94, 0. ],
[ 30.94, 10. ],
[ 30.94, 20. ],
[ 4.14, 0. ],
[ 4.14, 10. ]])
np.where(a==np.array([30.94, 0.]))
但是我明白了
(array([0, 3, 3, 4, 5, 6]), array([1, 0, 1, 0, 0, 1]))
这不是真的。
正如 Divakar 所暗示的,a == np.array([30.94, 0.])
不是您所期望的。广播数组,并按元素进行比较。这是结果:
array([[False, True],
[False, False],
[False, False],
[ True, True],
[ True, False],
[ True, False],
[False, True],
[False, False]], dtype=bool)
然而,我们可以通过np.all
得到我们想要的:
>>> np.all(a==np.array([30.94, 0.]), axis=-1)
array([False, False, False, True, False, False, False, False], dtype=bool)
>>> np.where(_)
(array([3]),)
因此您可以看到第 3 行符合预期。请注意,将 ==
与浮点数一起使用的常见注意事项将适用于此处。
另一种解决方案,但请注意,这会比
In [1]: cond = np.array([30.94, 0.])
In [2]: arr = np.array([[ 5.73, 0. ],
[ 57.73, 10. ],
[ 57.73, 20. ],
[ 30.94, 0. ],
[ 30.94, 10. ],
[ 30.94, 20. ],
[ 4.14, 0. ],
[ 4.14, 10. ]])
In [3]: [idx for idx, el in enumerate(arr) if np.array_equal(el, cond)]
Out[3]: [3]