重复数组中的行
Repeating rows from array
我有一个问题,因为我想在不使用循环的情况下重复 n 次 array(X, Y) 中的所有行并获取 array(n*X, Y)
import jax.numpy as jnp
arr = jnp.array([[12, 14, 12, 0, 1],
[0, 14, 12, 0, 1],
[0, 0, 12, 0, 1]])
n = 3
result = jnp.array([[12 14 12 0 1],
[12 14 12 0 1],
[12 14 12 0 1],
[0 14 12 0 1],
[0 14 12 0 1],
[0 14 12 0 1],
[0 0 12 0 1],
[0 0 12 0 1],
[0 0 12 0 1]])
我还没有找到任何内置方法来执行此操作,尝试使用 jnp.tile、jnp.repeat。
jnp.repeat
arr_r = jnp.repeat(arr, n, axis=1)
Output:
[[12 12 12 14 14 14 12 12 12 0 0 0 1 1 1]
[ 0 0 0 14 14 14 12 12 12 0 0 0 1 1 1]
[ 0 0 0 0 0 0 12 12 12 0 0 0 1 1 1]]
arr_t = jnp.tile(arr, n)
Output:
[[12 14 12 0 1 12 14 12 0 1 12 14 12 0 1]
[ 0 14 12 0 1 0 14 12 0 1 0 14 12 0 1]
[ 0 0 12 0 1 0 0 12 0 1 0 0 12 0 1]]
也许我可以从 array_t...
构建结果数组
你说你试过 jnp.repeat
但没有解释为什么它没有达到你想要的效果。我猜你忽略了 axis
参数:
jnp.repeat(arr, n, axis=0)
我有一个问题,因为我想在不使用循环的情况下重复 n 次 array(X, Y) 中的所有行并获取 array(n*X, Y)
import jax.numpy as jnp
arr = jnp.array([[12, 14, 12, 0, 1],
[0, 14, 12, 0, 1],
[0, 0, 12, 0, 1]])
n = 3
result = jnp.array([[12 14 12 0 1],
[12 14 12 0 1],
[12 14 12 0 1],
[0 14 12 0 1],
[0 14 12 0 1],
[0 14 12 0 1],
[0 0 12 0 1],
[0 0 12 0 1],
[0 0 12 0 1]])
我还没有找到任何内置方法来执行此操作,尝试使用 jnp.tile、jnp.repeat。
jnp.repeat
arr_r = jnp.repeat(arr, n, axis=1)
Output:
[[12 12 12 14 14 14 12 12 12 0 0 0 1 1 1]
[ 0 0 0 14 14 14 12 12 12 0 0 0 1 1 1]
[ 0 0 0 0 0 0 12 12 12 0 0 0 1 1 1]]
arr_t = jnp.tile(arr, n)
Output:
[[12 14 12 0 1 12 14 12 0 1 12 14 12 0 1]
[ 0 14 12 0 1 0 14 12 0 1 0 14 12 0 1]
[ 0 0 12 0 1 0 0 12 0 1 0 0 12 0 1]]
也许我可以从 array_t...
构建结果数组你说你试过 jnp.repeat
但没有解释为什么它没有达到你想要的效果。我猜你忽略了 axis
参数:
jnp.repeat(arr, n, axis=0)