重复数组中的行

Repeating rows from array

我有一个问题,因为我想在不使用循环的情况下重复 n 次 array(X, Y) 中的所有行并获取 array(n*X, Y)

import jax.numpy as jnp

arr = jnp.array([[12, 14, 12, 0, 1],
                [0, 14, 12, 0, 1],
                [0, 0, 12, 0, 1]])
n = 3

result = jnp.array([[12 14 12 0 1],
                    [12 14 12 0 1],
                    [12 14 12 0 1],
                    [0 14 12 0 1],
                    [0 14 12 0 1],
                    [0 14 12 0 1],
                    [0 0 12 0 1],
                    [0 0 12 0 1],
                    [0 0 12 0 1]])

我还没有找到任何内置方法来执行此操作,尝试使用 jnp.tile、jnp.repeat。

jnp.repeat

arr_r = jnp.repeat(arr, n, axis=1)

Output:
[[12 12 12 14 14 14 12 12 12  0  0  0  1  1  1]
 [ 0  0  0 14 14 14 12 12 12  0  0  0  1  1  1]
 [ 0  0  0  0  0  0 12 12 12  0  0  0  1  1  1]]

arr_t = jnp.tile(arr, n)

Output:
[[12 14 12  0  1 12 14 12  0  1 12 14 12  0  1]
 [ 0 14 12  0  1  0 14 12  0  1  0 14 12  0  1]
 [ 0  0 12  0  1  0  0 12  0  1  0  0 12  0  1]]

也许我可以从 array_t...

构建结果数组

你说你试过 jnp.repeat 但没有解释为什么它没有达到你想要的效果。我猜你忽略了 axis 参数:

jnp.repeat(arr, n, axis=0)