如何在使用 jax 时检查值是否在数组中

How to check if a value is in an array while using jax

我有一个负采样函数,我想使用 JAX 的 @jit 但我所做的一切都使它停止工作。

参数为:

我的函数如下所示,其目标是从评分中抽取 100 个样本,并为每个样本检索该用户尚未消费的项目。

BATCH_SIZE = 100

@jit
def sample(key, ratings, user_positives, num_items):
    new_key, subkey = jax.random.split(key)
    
    sampled_ratings = jax.random.choice(subkey, ratings, shape=(BATCH_SIZE,))
    sampled_users = jnp.zeros(BATCH_SIZE)
    sampled_positives = jnp.zeros(BATCH_SIZE)
    sampled_negatives = jnp.zeros(BATCH_SIZE)
    idx = 0
    
    for u, i, r in sampled_ratings:
        negative = user_positives[u][0]
        new_key, subkey = jax.random.split(key)
        while jnp.isin(jnp.array([negative]), user_positives[u])[0]:
            negative = jax.random.randint(current_subkey, (1,), 0, num_items)
            current_subkey = jax.random.split(subkey)
        
        sampled_users.at[idx].set(u)
        sampled_positives.at[idx].set(i)
        sampled_negatives.at[idx].set(negative)
        idx += 1
    
    return new_key, sampled_users, sampled_positives, sampled_negatives

但是,每当我 运行 并尝试更改它时,都会产生新的错误,并且我陷入了以下错误。谁能帮我做这个?

ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[])>with<DynamicJaxprTrace(level=0/1)>
The problem arose with the `bool` function. 
While tracing the function sample at /tmp/ipykernel_11557/2294038851.py:1 for jit, this concrete value was not available in Python because it depends on the values of the arguments 'key', 'ratings', and 'user_positives'.

编辑 1:输入示例为:

rng_key = 
rng_key, su, sp, sn = sample(
    rng_key,
    np.array([(0, 0, 1), (0, 1, 1), (1, 2, 1)]),
    np.array([np.array([0, 1]), np.array([2])]),
    15
)

一般来说,如果您想 jit-compile 一个条件取决于 non-static 数量的 while 循环,您必须用 jax.lax.while_loop. For more information see JAX Sharp Bits: Structured control flow primitives 来表达它。

如果您可以添加预期输入的示例,我将尝试使用基于您的代码的示例来编辑我的答案。