JAX 仅在 jit 下的数组切片上应用函数

JAX Apply function only on slice of array under jit

我正在使用 JAX,我想执行类似

的操作
@jax.jit
def fun(x, index):
    x[:index] = other_fun(x[:index])
    return x

这无法在 jit 下执行。有没有办法用 jax.opsjax.lax 来做到这一点? 我想过使用 jax.ops.index_update(x, idx, y) 但我找不到一种计算 y 的方法而不再次出现同样的问题。

您的实施过程中似乎存在两个问题。首先,切片正在生成动态形状的数组(在 jitted 代码中不允许)。其次,与 numpy 数组不同,JAX 数组是不可变的(即数组的内容不能更改)。

您可以通过组合 static_argnumsjax.lax.dynamic_update_slice 来解决这两个问题。这是一个例子:

def other_fun(x):
    return x + 1

@jax.partial(jax.jit, static_argnums=(1,))
def fun(x, index):
    update = other_fun(x[:index])
    return jax.lax.dynamic_update_slice(x, update, (0,))

x = jnp.arange(5)
print(fun(x, 3))  # prints [1 2 3 3 4]

本质上,上面的示例使用 static_argnums 来指示函数应该针对不同的 index 值重新编译,并且 jax.lax.dynamic_update_slice 使用更新的值创建 x 的副本在 :len(update).

如果您的索引是静态的,@rvinas 使用 dynamic_slice 效果很好,但您也可以使用 jnp.where 使用动态索引来完成此操作。例如:

import jax
import jax.numpy as jnp

def other_fun(x):
    return x + 1

@jax.jit
def fun(x, index):
  mask = jnp.arange(x.shape[0]) < index
  return jnp.where(mask, other_fun(x), x)

x = jnp.arange(5)
print(fun(x, 3))
# [1 2 3 3 4]