创建一个零的 3D 张量,在 numpy/jax 中的每个切片上随机放置一个“1”

Create a 3D tensor of zeros with exactly one '1' randomly placed on every slice in numpy/jax

我需要创建一个像这样的 3D 张量 (5,3,2) 例如

array([[[0, 0],
        [0, 1],
        [0, 0]],

       [[1, 0],
        [0, 0],
        [0, 0]],

       [[0, 0],
        [1, 0],
        [0, 0]],

       [[0, 0],
        [0, 0],
        [1, 0]],

       [[0, 0],
        [0, 1],
        [0, 0]]])

每个切片中应该恰好有一个 'one' 随机放置(如果您将张量视为一条面包)。这可以使用循环来完成,但我想将这部分矢量化。

尝试生成一个随机数组,然后找到 max:

a = np.random.rand(5,3,2)
out = (a == a.max(axis=(1,2))[:,None,None]).astype(int)

最直接的方法可能是创建一个零数组,并将随机索引设置为 1。在 NumPy 中,它可能如下所示:

import numpy as np

K, M, N = 5, 3, 2
i = np.random.randint(0, M, K)
j = np.random.randint(0, N, K)
x = np.zeros((K, M, N))
x[np.arange(K), i, j] = 1

在 JAX 中,它可能看起来像这样:

import jax.numpy as jnp
from jax import random

K, M, N = 5, 3, 2
key1, key2 = random.split(random.PRNGKey(0))
i = random.randint(key1, (K,), 0, M)
j = random.randint(key2, (K,), 0, N)
x = jnp.zeros((K, M, N)).at[jnp.arange(K), i, j].set(1)

一个更简洁的选项也保证每个切片一个 1 将使用具有适当构造范围的随机整数的广播相等性:

r = random.randint(random.PRNGKey(0), (K, 1, 1), 0, M * N)
x = (r == jnp.arange(M * N).reshape(M, N)).astype(int)

您可以创建一个零数组,其中每个子数组的第一个元素为 1,然后 permute 它横跨最后两个轴:

x = np.zeros((5,3,2)); x[:,0,0] = 1

rng = np.random.default_rng()
x = rng.permuted(rng.permuted(x, axis=-1), axis=-2)