在numpy中过滤一个ndarray

Filtering a ndarray in numpy

我有一个 ndarray,我想过滤掉它的特定值。我的数组是:

arr = np.array([
   [1., 6., 1.],
   [1., 7., 0.],
   [1., 8., 0.],
   [3., 5., 1.],
   [5., 1., 1.],
   [5., 2., 2.],
   [6., 1., 1.],
   [6., 2., 2.],
   [6., 7., 3.],
   [6., 8., 0.]
])

我要过滤掉[6., 1., 1.]。所以我试过了:

arr[arr != [6., 1., 1.]]

我得到了:

array([1., 6., 1., 7., 0., 1., 8., 0., 3., 5., 5., 5., 2., 2., 2., 2., 7.,
   3., 8., 0.])

这不是我想要的(而且还破坏了数组之前的结构)。我也试过:

arr[arr[:] != [6., 1., 1.]]

但我得到了和以前一样的输出。

P.S.: 我知道我可以通过索引删除一个元素,但我不想那样做。我想检查特定元素。

P.P.S.: 对于一维数组,我的方法有效。

你非常接近。您获得的布尔数组告诉您每行中有多少个元素匹配。您需要确保一行中的所有元素都匹配才能删除它,或者任何元素不匹配才能保留它:

arr[(arr != [6, 1, 1]).any(axis=1)]

也可以写成

arr[~(arr == [6, 1, 1]).all(axis=1)]