jax:从 random.choice 中抽取许多观察值并在它们之间进行替换
jax: sample many observations from random.choice with replacement between them
我想从数组中选取两个索引。这些索引不能相同。
一个这样的样本可以通过以下方式获得:
random.choice(next(key), num_items, (2,), replace=False)
出于性能原因,我想对采样进行批处理:
num_samples = 100
samples = random.choice(next(key), num_items, (num_samples, 2), replace=False)
由于 replace=False
,这不起作用。它引发错误:
ValueError: Cannot take a larger sample than population when 'replace=False'
对于每个新样本,我想要 replace=True
。在一个示例中,我想要 replace=False
。
有办法吗?
我的随机抽样中的next(key)
是句法糖。为方便起见,我使用此代码段:
def reset_key(seed=42):
key = random.PRNGKey(seed)
while True:
key, subkey = random.split(key)
yield subkey
key = reset_key()
最好的方法是使用 jax.vmap
映射各个样本。例如:
from jax import random, vmap
def sample_two(key, num_items):
return random.choice(key , num_items, (2,), replace=False)
key = random.PRNGKey(0)
num_samples = 10
num_items = 5
key_array = random.split(key, num_samples)
print(vmap(sample_two, in_axes=(0, None))(key_array, num_items))
# [[2 0]
# [1 4]
# [2 1]
# [3 4]
# [4 2]
# [2 0]
# [1 3]
# [2 1]
# [1 0]
# [2 4]]
有关 jax.vmap
的更多信息,请参阅 Automatic Vectorization in JAX。
我想从数组中选取两个索引。这些索引不能相同。 一个这样的样本可以通过以下方式获得:
random.choice(next(key), num_items, (2,), replace=False)
出于性能原因,我想对采样进行批处理:
num_samples = 100
samples = random.choice(next(key), num_items, (num_samples, 2), replace=False)
由于 replace=False
,这不起作用。它引发错误:
ValueError: Cannot take a larger sample than population when 'replace=False'
对于每个新样本,我想要 replace=True
。在一个示例中,我想要 replace=False
。
有办法吗?
我的随机抽样中的next(key)
是句法糖。为方便起见,我使用此代码段:
def reset_key(seed=42):
key = random.PRNGKey(seed)
while True:
key, subkey = random.split(key)
yield subkey
key = reset_key()
最好的方法是使用 jax.vmap
映射各个样本。例如:
from jax import random, vmap
def sample_two(key, num_items):
return random.choice(key , num_items, (2,), replace=False)
key = random.PRNGKey(0)
num_samples = 10
num_items = 5
key_array = random.split(key, num_samples)
print(vmap(sample_two, in_axes=(0, None))(key_array, num_items))
# [[2 0]
# [1 4]
# [2 1]
# [3 4]
# [4 2]
# [2 0]
# [1 3]
# [2 1]
# [1 0]
# [2 4]]
有关 jax.vmap
的更多信息,请参阅 Automatic Vectorization in JAX。