创建一个零的 3D 张量,在 numpy/jax 中的每个切片上随机放置一个“1”
Create a 3D tensor of zeros with exactly one '1' randomly placed on every slice in numpy/jax
我需要创建一个像这样的 3D 张量 (5,3,2) 例如
array([[[0, 0],
[0, 1],
[0, 0]],
[[1, 0],
[0, 0],
[0, 0]],
[[0, 0],
[1, 0],
[0, 0]],
[[0, 0],
[0, 0],
[1, 0]],
[[0, 0],
[0, 1],
[0, 0]]])
每个切片中应该恰好有一个 'one' 随机放置(如果您将张量视为一条面包)。这可以使用循环来完成,但我想将这部分矢量化。
尝试生成一个随机数组,然后找到 max
:
a = np.random.rand(5,3,2)
out = (a == a.max(axis=(1,2))[:,None,None]).astype(int)
最直接的方法可能是创建一个零数组,并将随机索引设置为 1。在 NumPy 中,它可能如下所示:
import numpy as np
K, M, N = 5, 3, 2
i = np.random.randint(0, M, K)
j = np.random.randint(0, N, K)
x = np.zeros((K, M, N))
x[np.arange(K), i, j] = 1
在 JAX 中,它可能看起来像这样:
import jax.numpy as jnp
from jax import random
K, M, N = 5, 3, 2
key1, key2 = random.split(random.PRNGKey(0))
i = random.randint(key1, (K,), 0, M)
j = random.randint(key2, (K,), 0, N)
x = jnp.zeros((K, M, N)).at[jnp.arange(K), i, j].set(1)
一个更简洁的选项也保证每个切片一个 1
将使用具有适当构造范围的随机整数的广播相等性:
r = random.randint(random.PRNGKey(0), (K, 1, 1), 0, M * N)
x = (r == jnp.arange(M * N).reshape(M, N)).astype(int)
您可以创建一个零数组,其中每个子数组的第一个元素为 1,然后 permute
它横跨最后两个轴:
x = np.zeros((5,3,2)); x[:,0,0] = 1
rng = np.random.default_rng()
x = rng.permuted(rng.permuted(x, axis=-1), axis=-2)
我需要创建一个像这样的 3D 张量 (5,3,2) 例如
array([[[0, 0],
[0, 1],
[0, 0]],
[[1, 0],
[0, 0],
[0, 0]],
[[0, 0],
[1, 0],
[0, 0]],
[[0, 0],
[0, 0],
[1, 0]],
[[0, 0],
[0, 1],
[0, 0]]])
每个切片中应该恰好有一个 'one' 随机放置(如果您将张量视为一条面包)。这可以使用循环来完成,但我想将这部分矢量化。
尝试生成一个随机数组,然后找到 max
:
a = np.random.rand(5,3,2)
out = (a == a.max(axis=(1,2))[:,None,None]).astype(int)
最直接的方法可能是创建一个零数组,并将随机索引设置为 1。在 NumPy 中,它可能如下所示:
import numpy as np
K, M, N = 5, 3, 2
i = np.random.randint(0, M, K)
j = np.random.randint(0, N, K)
x = np.zeros((K, M, N))
x[np.arange(K), i, j] = 1
在 JAX 中,它可能看起来像这样:
import jax.numpy as jnp
from jax import random
K, M, N = 5, 3, 2
key1, key2 = random.split(random.PRNGKey(0))
i = random.randint(key1, (K,), 0, M)
j = random.randint(key2, (K,), 0, N)
x = jnp.zeros((K, M, N)).at[jnp.arange(K), i, j].set(1)
一个更简洁的选项也保证每个切片一个 1
将使用具有适当构造范围的随机整数的广播相等性:
r = random.randint(random.PRNGKey(0), (K, 1, 1), 0, M * N)
x = (r == jnp.arange(M * N).reshape(M, N)).astype(int)
您可以创建一个零数组,其中每个子数组的第一个元素为 1,然后 permute
它横跨最后两个轴:
x = np.zeros((5,3,2)); x[:,0,0] = 1
rng = np.random.default_rng()
x = rng.permuted(rng.permuted(x, axis=-1), axis=-2)