是否可以 jit 使用 jax.numpy.unique 的函数?
is it possible to jit a function which uses jax.numpy.unique?
以下代码不起作用:
def get_unique(arr):
return jnp.unique(arr)
get_unique = jit(get_unique)
get_unique(jnp.ones((10,)))
关于jnp.unique
的使用的错误信息:
FilteredStackTrace: jax._src.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(float32[10])>with<DynamicJaxprTrace(level=0/1)>
The error arose in jnp.unique()
documentation on sharp bits 解释了如果内部数组的形状取决于参数值,则 jit 不起作用。这正是这里的情况。
根据文档,一个可能的解决方法是指定静态参数。但这不适用于我的情况。几乎每个函数调用的参数都会改变。我已将我的代码拆分为一个预处理步骤,该步骤执行诸如 jnp.unique
之类的计算,以及一个可以 jitted 的计算步骤。
但我还是想问一下,是否有一些我不知道的解决方法?
不,由于您提到的原因,目前无法对非静态值使用 jnp.unique
。
在类似的情况下,JAX 有时会添加额外的参数,这些参数可用于指定输出的静态大小(例如,jax.numpy.nonzero
) but nothing like that is currently implemented for jnp.unique
. If that is something you'd like, it would be worth filing a feature request 中的 size
参数。
以下代码不起作用:
def get_unique(arr):
return jnp.unique(arr)
get_unique = jit(get_unique)
get_unique(jnp.ones((10,)))
关于jnp.unique
的使用的错误信息:
FilteredStackTrace: jax._src.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(float32[10])>with<DynamicJaxprTrace(level=0/1)>
The error arose in jnp.unique()
documentation on sharp bits 解释了如果内部数组的形状取决于参数值,则 jit 不起作用。这正是这里的情况。
根据文档,一个可能的解决方法是指定静态参数。但这不适用于我的情况。几乎每个函数调用的参数都会改变。我已将我的代码拆分为一个预处理步骤,该步骤执行诸如 jnp.unique
之类的计算,以及一个可以 jitted 的计算步骤。
但我还是想问一下,是否有一些我不知道的解决方法?
不,由于您提到的原因,目前无法对非静态值使用 jnp.unique
。
在类似的情况下,JAX 有时会添加额外的参数,这些参数可用于指定输出的静态大小(例如,jax.numpy.nonzero
) but nothing like that is currently implemented for jnp.unique
. If that is something you'd like, it would be worth filing a feature request 中的 size
参数。