如何在使用 jax 时检查值是否在数组中
How to check if a value is in an array while using jax
我有一个负采样函数,我想使用 JAX 的 @jit
但我所做的一切都使它停止工作。
参数为:
key
:jax.random
的关键
ratings
:三元组列表(user_id, item_id, 1)
;
user_positives
:列表的列表,其中第 i 个内部列表包含第 i 个项目用户已消费;
num_items
: 项目总数
我的函数如下所示,其目标是从评分中抽取 100 个样本,并为每个样本检索该用户尚未消费的项目。
BATCH_SIZE = 100
@jit
def sample(key, ratings, user_positives, num_items):
new_key, subkey = jax.random.split(key)
sampled_ratings = jax.random.choice(subkey, ratings, shape=(BATCH_SIZE,))
sampled_users = jnp.zeros(BATCH_SIZE)
sampled_positives = jnp.zeros(BATCH_SIZE)
sampled_negatives = jnp.zeros(BATCH_SIZE)
idx = 0
for u, i, r in sampled_ratings:
negative = user_positives[u][0]
new_key, subkey = jax.random.split(key)
while jnp.isin(jnp.array([negative]), user_positives[u])[0]:
negative = jax.random.randint(current_subkey, (1,), 0, num_items)
current_subkey = jax.random.split(subkey)
sampled_users.at[idx].set(u)
sampled_positives.at[idx].set(i)
sampled_negatives.at[idx].set(negative)
idx += 1
return new_key, sampled_users, sampled_positives, sampled_negatives
但是,每当我 运行 并尝试更改它时,都会产生新的错误,并且我陷入了以下错误。谁能帮我做这个?
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[])>with<DynamicJaxprTrace(level=0/1)>
The problem arose with the `bool` function.
While tracing the function sample at /tmp/ipykernel_11557/2294038851.py:1 for jit, this concrete value was not available in Python because it depends on the values of the arguments 'key', 'ratings', and 'user_positives'.
编辑 1:输入示例为:
rng_key =
rng_key, su, sp, sn = sample(
rng_key,
np.array([(0, 0, 1), (0, 1, 1), (1, 2, 1)]),
np.array([np.array([0, 1]), np.array([2])]),
15
)
一般来说,如果您想 jit-compile 一个条件取决于 non-static 数量的 while
循环,您必须用 jax.lax.while_loop
. For more information see JAX Sharp Bits: Structured control flow primitives 来表达它。
如果您可以添加预期输入的示例,我将尝试使用基于您的代码的示例来编辑我的答案。
我有一个负采样函数,我想使用 JAX 的 @jit
但我所做的一切都使它停止工作。
参数为:
key
:jax.random
的关键
ratings
:三元组列表(user_id, item_id, 1)
;user_positives
:列表的列表,其中第 i 个内部列表包含第 i 个项目用户已消费;num_items
: 项目总数
我的函数如下所示,其目标是从评分中抽取 100 个样本,并为每个样本检索该用户尚未消费的项目。
BATCH_SIZE = 100
@jit
def sample(key, ratings, user_positives, num_items):
new_key, subkey = jax.random.split(key)
sampled_ratings = jax.random.choice(subkey, ratings, shape=(BATCH_SIZE,))
sampled_users = jnp.zeros(BATCH_SIZE)
sampled_positives = jnp.zeros(BATCH_SIZE)
sampled_negatives = jnp.zeros(BATCH_SIZE)
idx = 0
for u, i, r in sampled_ratings:
negative = user_positives[u][0]
new_key, subkey = jax.random.split(key)
while jnp.isin(jnp.array([negative]), user_positives[u])[0]:
negative = jax.random.randint(current_subkey, (1,), 0, num_items)
current_subkey = jax.random.split(subkey)
sampled_users.at[idx].set(u)
sampled_positives.at[idx].set(i)
sampled_negatives.at[idx].set(negative)
idx += 1
return new_key, sampled_users, sampled_positives, sampled_negatives
但是,每当我 运行 并尝试更改它时,都会产生新的错误,并且我陷入了以下错误。谁能帮我做这个?
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[])>with<DynamicJaxprTrace(level=0/1)>
The problem arose with the `bool` function.
While tracing the function sample at /tmp/ipykernel_11557/2294038851.py:1 for jit, this concrete value was not available in Python because it depends on the values of the arguments 'key', 'ratings', and 'user_positives'.
编辑 1:输入示例为:
rng_key =
rng_key, su, sp, sn = sample(
rng_key,
np.array([(0, 0, 1), (0, 1, 1), (1, 2, 1)]),
np.array([np.array([0, 1]), np.array([2])]),
15
)
一般来说,如果您想 jit-compile 一个条件取决于 non-static 数量的 while
循环,您必须用 jax.lax.while_loop
. For more information see JAX Sharp Bits: Structured control flow primitives 来表达它。
如果您可以添加预期输入的示例,我将尝试使用基于您的代码的示例来编辑我的答案。