用向量索引 3d 矩阵的最佳方法是什么?

What is the best way to index 3d matrix with vectors?

import jax.numpy as jnp

向量和数组是jnp.array(dtype=jnp.int32)

我有一个形状为 [x, d, y] (3x3x3) 的数组

[[[0 0 0],
 [0 0 0],
 [0 0 0]],

[[0 0 0],
 [0 0 0],
 [0 0 0]],

[[0 0 0],
 [0 0 0],
 [0 0 0]]]

和向量x = [2 0 3], y = [ 2 0 1], d = [0 0 1]

我想通过索引得到这样的东西,但我试过了,但我真的不知道怎么做,jax.numpy。

[[[0 0 2],
 [0 0 0],
 [0 0 0]],

[[0 0 0],
 [0 0 0],
 [0 0 0]],

[[0 0 0],
 [0 3 0],
 [0 0 0]]]

编辑:我想说明我想将 x 中的数字及其索引放入数组,但仅当 x > 0 时。我尝试使用布尔掩码。 像这样

mask = x > 0
array = array.at[mask, d, y].set(array[mask, d, y] + x)

您有一个 three-dimensional 数组,因此您可以使用三个索引数组对其进行索引。由于您希望 dy 与第二个和第三个维度相关联,因此您需要为第一个维度创建另一个索引数组:

import jax.numpy as jnp

arr = jnp.zeros((3, 3, 3), dtype='int32')
x = jnp.array([2, 0, 3])
y = jnp.array([2, 0, 1])
d = jnp.array([0, 0, 1])

i = jnp.arange(len(x))
mask = x > 0

out = arr.at[i[mask], d[mask], y[mask]].set(x[mask])
print(out)
# [[[0 0 2]
#   [0 0 0]
#   [0 0 0]]

#  [[0 0 0]
#   [0 0 0]
#   [0 0 0]]

#  [[0 0 0]
#   [0 3 0]
#   [0 0 0]]]

在这种情况下,无论您是否使用掩码,结果都是相同的(即 arr.at[i, d, y].set(x) 将给出相同的结果)但是因为您的问题明确指定您只想使用值 x > 0 我把它包括在内了。