JAX 仅在 jit 下的数组切片上应用函数
JAX Apply function only on slice of array under jit
我正在使用 JAX,我想执行类似
的操作
@jax.jit
def fun(x, index):
x[:index] = other_fun(x[:index])
return x
这无法在 jit
下执行。有没有办法用 jax.ops
或 jax.lax
来做到这一点?
我想过使用 jax.ops.index_update(x, idx, y)
但我找不到一种计算 y
的方法而不再次出现同样的问题。
您的实施过程中似乎存在两个问题。首先,切片正在生成动态形状的数组(在 jitted 代码中不允许)。其次,与 numpy 数组不同,JAX 数组是不可变的(即数组的内容不能更改)。
您可以通过组合 static_argnums
和 jax.lax.dynamic_update_slice
来解决这两个问题。这是一个例子:
def other_fun(x):
return x + 1
@jax.partial(jax.jit, static_argnums=(1,))
def fun(x, index):
update = other_fun(x[:index])
return jax.lax.dynamic_update_slice(x, update, (0,))
x = jnp.arange(5)
print(fun(x, 3)) # prints [1 2 3 3 4]
本质上,上面的示例使用 static_argnums
来指示函数应该针对不同的 index
值重新编译,并且 jax.lax.dynamic_update_slice
使用更新的值创建 x
的副本在 :len(update)
.
如果您的索引是静态的,@rvinas 使用 dynamic_slice
的 效果很好,但您也可以使用 jnp.where
使用动态索引来完成此操作。例如:
import jax
import jax.numpy as jnp
def other_fun(x):
return x + 1
@jax.jit
def fun(x, index):
mask = jnp.arange(x.shape[0]) < index
return jnp.where(mask, other_fun(x), x)
x = jnp.arange(5)
print(fun(x, 3))
# [1 2 3 3 4]
我正在使用 JAX,我想执行类似
的操作@jax.jit
def fun(x, index):
x[:index] = other_fun(x[:index])
return x
这无法在 jit
下执行。有没有办法用 jax.ops
或 jax.lax
来做到这一点?
我想过使用 jax.ops.index_update(x, idx, y)
但我找不到一种计算 y
的方法而不再次出现同样的问题。
您的实施过程中似乎存在两个问题。首先,切片正在生成动态形状的数组(在 jitted 代码中不允许)。其次,与 numpy 数组不同,JAX 数组是不可变的(即数组的内容不能更改)。
您可以通过组合 static_argnums
和 jax.lax.dynamic_update_slice
来解决这两个问题。这是一个例子:
def other_fun(x):
return x + 1
@jax.partial(jax.jit, static_argnums=(1,))
def fun(x, index):
update = other_fun(x[:index])
return jax.lax.dynamic_update_slice(x, update, (0,))
x = jnp.arange(5)
print(fun(x, 3)) # prints [1 2 3 3 4]
本质上,上面的示例使用 static_argnums
来指示函数应该针对不同的 index
值重新编译,并且 jax.lax.dynamic_update_slice
使用更新的值创建 x
的副本在 :len(update)
.
如果您的索引是静态的,@rvinas 使用 dynamic_slice
的 jnp.where
使用动态索引来完成此操作。例如:
import jax
import jax.numpy as jnp
def other_fun(x):
return x + 1
@jax.jit
def fun(x, index):
mask = jnp.arange(x.shape[0]) < index
return jnp.where(mask, other_fun(x), x)
x = jnp.arange(5)
print(fun(x, 3))
# [1 2 3 3 4]