用向量索引 3d 矩阵的最佳方法是什么?
What is the best way to index 3d matrix with vectors?
import jax.numpy as jnp
向量和数组是jnp.array(dtype=jnp.int32)
我有一个形状为 [x, d, y] (3x3x3) 的数组
[[[0 0 0],
[0 0 0],
[0 0 0]],
[[0 0 0],
[0 0 0],
[0 0 0]],
[[0 0 0],
[0 0 0],
[0 0 0]]]
和向量x = [2 0 3], y = [ 2 0 1], d = [0 0 1]
我想通过索引得到这样的东西,但我试过了,但我真的不知道怎么做,jax.numpy。
[[[0 0 2],
[0 0 0],
[0 0 0]],
[[0 0 0],
[0 0 0],
[0 0 0]],
[[0 0 0],
[0 3 0],
[0 0 0]]]
编辑:我想说明我想将 x 中的数字及其索引放入数组,但仅当 x > 0 时。我尝试使用布尔掩码。
像这样
mask = x > 0
array = array.at[mask, d, y].set(array[mask, d, y] + x)
您有一个 three-dimensional 数组,因此您可以使用三个索引数组对其进行索引。由于您希望 d
和 y
与第二个和第三个维度相关联,因此您需要为第一个维度创建另一个索引数组:
import jax.numpy as jnp
arr = jnp.zeros((3, 3, 3), dtype='int32')
x = jnp.array([2, 0, 3])
y = jnp.array([2, 0, 1])
d = jnp.array([0, 0, 1])
i = jnp.arange(len(x))
mask = x > 0
out = arr.at[i[mask], d[mask], y[mask]].set(x[mask])
print(out)
# [[[0 0 2]
# [0 0 0]
# [0 0 0]]
# [[0 0 0]
# [0 0 0]
# [0 0 0]]
# [[0 0 0]
# [0 3 0]
# [0 0 0]]]
在这种情况下,无论您是否使用掩码,结果都是相同的(即 arr.at[i, d, y].set(x)
将给出相同的结果)但是因为您的问题明确指定您只想使用值 x > 0
我把它包括在内了。
import jax.numpy as jnp
向量和数组是jnp.array(dtype=jnp.int32)
我有一个形状为 [x, d, y] (3x3x3) 的数组
[[[0 0 0],
[0 0 0],
[0 0 0]],
[[0 0 0],
[0 0 0],
[0 0 0]],
[[0 0 0],
[0 0 0],
[0 0 0]]]
和向量x = [2 0 3], y = [ 2 0 1], d = [0 0 1]
我想通过索引得到这样的东西,但我试过了,但我真的不知道怎么做,jax.numpy。
[[[0 0 2],
[0 0 0],
[0 0 0]],
[[0 0 0],
[0 0 0],
[0 0 0]],
[[0 0 0],
[0 3 0],
[0 0 0]]]
编辑:我想说明我想将 x 中的数字及其索引放入数组,但仅当 x > 0 时。我尝试使用布尔掩码。 像这样
mask = x > 0
array = array.at[mask, d, y].set(array[mask, d, y] + x)
您有一个 three-dimensional 数组,因此您可以使用三个索引数组对其进行索引。由于您希望 d
和 y
与第二个和第三个维度相关联,因此您需要为第一个维度创建另一个索引数组:
import jax.numpy as jnp
arr = jnp.zeros((3, 3, 3), dtype='int32')
x = jnp.array([2, 0, 3])
y = jnp.array([2, 0, 1])
d = jnp.array([0, 0, 1])
i = jnp.arange(len(x))
mask = x > 0
out = arr.at[i[mask], d[mask], y[mask]].set(x[mask])
print(out)
# [[[0 0 2]
# [0 0 0]
# [0 0 0]]
# [[0 0 0]
# [0 0 0]
# [0 0 0]]
# [[0 0 0]
# [0 3 0]
# [0 0 0]]]
在这种情况下,无论您是否使用掩码,结果都是相同的(即 arr.at[i, d, y].set(x)
将给出相同的结果)但是因为您的问题明确指定您只想使用值 x > 0
我把它包括在内了。