如何通过一维列表中的值屏蔽二维 numpy 矩阵的行?

How to mask rows of a 2D numpy matrix by values in 1D list?

我有一个 2D numpy 数组,如下所示:

a = np.array([[0, 1, 2, 3, 4],
              [0, 1, 2, 3, 4],
              [0, 1, 2, 3, 4],
              [0, 1, 2, 3, 4],
              [0, 1, 2, 3, 4]])

还有一个如下所示的一维列表:

b = [4, 3, 2, 3, 4]

我想根据给定行中的哪些值小于我的一维列表 (b) 中的相应值来屏蔽我的二维数组 (a)。例如,a[0] 行将根据该行中的哪些值小于 b[0] 处的值进行屏蔽;与行 a[1] 和 b[1] 处的值相同,依此类推...

我希望得到的是一个二维的布尔数组:

mask_bools = [[True, True, True, True, False],
              [True, True, True, False, False],
              [True, True, False, False, False],
              [True, True, True, False, False],
              [True, True, True, True, False]]

我有一个愚蠢的方法来实现这个循环:

mask_bools = []
for i in range(len(b)):
    mask_bools.append(np.ma.masked_less(a[i], b[i]).mask)
mask_bools = np.array(mask_bools)

但我觉得 必须 是一种 better/faster 更好地利用 numpy 功能的方法。有任何想法吗?谢谢!

您可以使用 np.vectorize 来实现这一点。创建一个函数,比较适当索引处的值和 returns 布尔值。

它基本上会做与您现在所做的相同的事情,但正如您所说,这只是 NumPy 实现。

你可以试试这个:

import numpy as np
a = np.array([[0, 1, 2, 3, 4],
              [0, 1, 2, 3, 4],
              [0, 1, 2, 3, 4],
              [0, 1, 2, 3, 4],
              [0, 1, 2, 3, 4]])


b = np.array([[4, 3, 2, 3, 4],]*len(a)).T

a < b

array([[ True,  True,  True,  True, False],
       [ True,  True,  True, False, False],
       [ True,  True, False, False, False],
       [ True,  True,  True, False, False],
       [ True,  True,  True,  True, False]])

为了提高使用效率:

b = np.tile(np.array([[4, 3, 2, 3, 4]]).transpose(), (1, len(a)))

然而,这更难阅读。

尝试broadcasting小于:

a < b[:, None]
[[ True  True  True  True False]
 [ True  True  True False False]
 [ True  True False False False]
 [ True  True  True False False]
 [ True  True  True  True False]]
import numpy as np

a = np.array([[0, 1, 2, 3, 4],
              [0, 1, 2, 3, 4],
              [0, 1, 2, 3, 4],
              [0, 1, 2, 3, 4],
              [0, 1, 2, 3, 4]])

b = np.array([4, 3, 2, 3, 4])

c = a < b[:, None]

# Test equality with expected output
mask_bools = np.array([[True, True, True, True, False],
                       [True, True, True, False, False],
                       [True, True, False, False, False],
                       [True, True, True, False, False],
                       [True, True, True, True, False]])

print((c == mask_bools).all().all())  # True