在给定值之后屏蔽一个 numpy 数组

Mask a numpy array after a given value

我有两个像这样的 numpy 数组:

a = [False, False, False, False, False, True, False, False]

b = [1, 2, 3, 4, 5, 6, 7, 8]

我需要对 b 求和,而不是对整个数组求和,但仅当 a 中具有等效索引的元素为 True

换句话说,我想做 1+2+3+4+5=15 而不是 1+2+3+4+5+6+7+8=36

我需要一个高效的解决方案,我想我需要屏蔽b中第一个True之后a中的所有元素] 并使它们成为 0.

旁注:我的代码在 jax.numpy 中,而不是原始的 numpy,但我想这并不重要。

你可以做一个累加和

np.sum(b[np.cumsum(a)==0])

我建议用 .tolist() 将数组转换为列表,然后应用 .index() 获取第一个 True 的索引:i = a.tolist().index(True)。 然后简单的切片和求和:total = numpy.sum(b[:i])

我可以想到两种方法:您可以通过使用 cumsum 构造一个掩码来实现(这也适用于常规 numpy):

a = jnp.array([False, False, False, False, False, True, False, False])
b = jnp.array([1, 2, 3, 4, 5, 6, 7, 8])

mask = a.cumsum() == 0
b.sum(where=mask) # 15

或者您可以使用 jnp.where 找到第一个 True 索引(请注意 size 参数仅存在于 JAX 的 jnp.where 版本中,而不存在于 numpy 的版本中:

idx = jnp.where(a, size=1)[0][0]
b[:idx].sum() # 15

您可以做一些微基准测试来确定对于您所关注的数组大小哪个更有效。