jax:从 random.choice 中抽取许多观察值并在它们之间进行替换

jax: sample many observations from random.choice with replacement between them

我想从数组中选取两个索引。这些索引不能相同。 一个这样的样本可以通过以下方式获得:

random.choice(next(key), num_items, (2,), replace=False)

出于性能原因,我想对采样进行批处理:

num_samples = 100
samples = random.choice(next(key), num_items, (num_samples, 2), replace=False)

由于 replace=False,这不起作用。它引发错误:

ValueError: Cannot take a larger sample than population when 'replace=False'

对于每个新样本,我想要 replace=True。在一个示例中,我想要 replace=False。 有办法吗?

我的随机抽样中的next(key)是句法糖。为方便起见,我使用此代码段:

def reset_key(seed=42):
    key = random.PRNGKey(seed)
    while True:
        key, subkey = random.split(key)
        yield subkey
        
key = reset_key()

最好的方法是使用 jax.vmap 映射各个样本。例如:

from jax import random, vmap

def sample_two(key, num_items):
  return random.choice(key , num_items, (2,), replace=False)

key = random.PRNGKey(0)
num_samples = 10
num_items = 5

key_array = random.split(key, num_samples)
print(vmap(sample_two, in_axes=(0, None))(key_array, num_items))
# [[2 0]
#  [1 4]
#  [2 1]
#  [3 4]
#  [4 2]
#  [2 0]
#  [1 3]
#  [2 1]
#  [1 0]
#  [2 4]]

有关 jax.vmap 的更多信息,请参阅 Automatic Vectorization in JAX