按索引过滤并在 numpy 中展平,例如 tf.sequence_mask
Filter by index and flattened in numpy, like tf.sequence_mask
我想用一个索引过滤我的二维数组,然后只用过滤器中的值来平整这个数组。这几乎就是 tf.sequence_mask 会做的,但我需要在 numpy 或其他灯库中使用它。
谢谢!
PD:
这是一个例子:
array_2d = [[0,1,2,3,4,5],[8,9,10,11,12,0],[21,22,21,0,0,0]] # this is a numpy array
array_len = [6,5,3]
expected_output = [0,1,2,3,4,5,8,9,10,11,12,21,22,21]
这是一种使用布尔掩码并将其应用于扁平化 array_2d
的方法
array_2d = np.array([[0,1,2,3,4,5],[8,9,10,11,12,0],[21,22,21,0,0,0]])
array_len = [6,5,3]
# Create a boolean mask
mask = np.zeros((array_2d.shape), dtype=bool)
# Change to True for elements to be kept
for i, j in enumerate(array_len):
mask[i][0:j] = True
expected_output = array_2d.flatten()[mask.flatten()]
输出
array([ 0, 1, 2, 3, 4, 5, 8, 9, 10, 11, 12, 21, 22, 21])
这是一个 vectorized
解决方案,使用布尔掩码索引 array_2d
:
array_2d = np.array([[0,1,2,3,4,5],[8,9,10,11,12,0],[21,22,21,0,0,0]])
array_len = [6,5,3]
m = ~(np.ones(array_2d.shape).cumsum(axis=1).T > array_len).T
array_2d[m]
array([ 0, 1, 2, 3, 4, 5, 8, 9, 10, 11, 12, 21, 22, 21])
详情
使用与 array_2d
相同形状的 cumsum
over an ndarray of ones
创建掩码,并执行逐行比较以查看哪些元素大于 array_len
。
所以第一步是创建以下 ndarray
:
np.ones(array_2d.shape).cumsum(axis=1)
array([[1., 2., 3., 4., 5., 6.],
[1., 2., 3., 4., 5., 6.],
[1., 2., 3., 4., 5., 6.]])
并与array_len
进行逐行比较:
~(np.ones(array_2d.shape).cumsum(axis=1).T > array_len).T
array([[ True, True, True, True, True, True],
[ True, True, True, True, True, False],
[ True, True, True, False, False, False]])
然后您只需使用以下内容过滤数组:
array_2d[m]
array([ 0, 1, 2, 3, 4, 5, 8, 9, 10, 11, 12, 21, 22, 21])
我想用一个索引过滤我的二维数组,然后只用过滤器中的值来平整这个数组。这几乎就是 tf.sequence_mask 会做的,但我需要在 numpy 或其他灯库中使用它。
谢谢!
PD: 这是一个例子:
array_2d = [[0,1,2,3,4,5],[8,9,10,11,12,0],[21,22,21,0,0,0]] # this is a numpy array
array_len = [6,5,3]
expected_output = [0,1,2,3,4,5,8,9,10,11,12,21,22,21]
这是一种使用布尔掩码并将其应用于扁平化 array_2d
array_2d = np.array([[0,1,2,3,4,5],[8,9,10,11,12,0],[21,22,21,0,0,0]])
array_len = [6,5,3]
# Create a boolean mask
mask = np.zeros((array_2d.shape), dtype=bool)
# Change to True for elements to be kept
for i, j in enumerate(array_len):
mask[i][0:j] = True
expected_output = array_2d.flatten()[mask.flatten()]
输出
array([ 0, 1, 2, 3, 4, 5, 8, 9, 10, 11, 12, 21, 22, 21])
这是一个 vectorized
解决方案,使用布尔掩码索引 array_2d
:
array_2d = np.array([[0,1,2,3,4,5],[8,9,10,11,12,0],[21,22,21,0,0,0]])
array_len = [6,5,3]
m = ~(np.ones(array_2d.shape).cumsum(axis=1).T > array_len).T
array_2d[m]
array([ 0, 1, 2, 3, 4, 5, 8, 9, 10, 11, 12, 21, 22, 21])
详情
使用与 array_2d
相同形状的 cumsum
over an ndarray of ones
创建掩码,并执行逐行比较以查看哪些元素大于 array_len
。
所以第一步是创建以下 ndarray
:
np.ones(array_2d.shape).cumsum(axis=1)
array([[1., 2., 3., 4., 5., 6.],
[1., 2., 3., 4., 5., 6.],
[1., 2., 3., 4., 5., 6.]])
并与array_len
进行逐行比较:
~(np.ones(array_2d.shape).cumsum(axis=1).T > array_len).T
array([[ True, True, True, True, True, True],
[ True, True, True, True, True, False],
[ True, True, True, False, False, False]])
然后您只需使用以下内容过滤数组:
array_2d[m]
array([ 0, 1, 2, 3, 4, 5, 8, 9, 10, 11, 12, 21, 22, 21])