使用 jax.jit 时编译出错

Error upon compilation while using jax.jit

对不起,我仍然是 Jax 内部工作的菜鸟,并试图找到绕过它的方法。我有这段代码,没有 jit 也能很好地工作。但是当我尝试 jit 时,它会抛出一个错误。我最初在代码中使用了 if else 语句,但它也不起作用,因此不得不在没有 if else 语句的情况下以这种方式重写代码。我该如何解决这个问题? MWE在下面。


import jax
import jax.numpy as jnp
jax.config.update("jax_enable_x64", True)

num_rows = 5
num_cols = 20
smf = jnp.array([jnp.inf, 0.1, 0.1, 0.1, 0.1])
par_init = jnp.array([1.0,2.0,3.0,4.0,5.0])
lb = jnp.array([0.1, 0.1, 0.1, 0.1, 0.1])
ub = jnp.array([10.0, 10.0, 10.0, 10.0, 10.0])
par = jnp.broadcast_to(par_init[:,None],(num_rows,num_cols))

kvals = jnp.where(jnp.isinf(smf), 1, num_cols)
kvals = jnp.insert(kvals, 0, 0)
kvals = jnp.cumsum(kvals)

par0_col = jnp.zeros(num_rows*num_cols - (num_cols-1) * jnp.sum(jnp.isinf(smf)))
lb_col = jnp.zeros(num_rows*num_cols - (num_cols-1) * jnp.sum(jnp.isinf(smf)))
ub_col = jnp.zeros(num_rows*num_cols- (num_cols-1) * jnp.sum(jnp.isinf(smf)))



for i in range(num_rows):
    par0_col = par0_col.at[kvals[i]:kvals[i+1]].set(par[i, :kvals[i+1]-kvals[i]])
    lb_col = lb_col.at[kvals[i]:kvals[i+1]].set(lb[i])
    ub_col = ub_col.at[kvals[i]:kvals[i+1]].set(ub[i])




par_log = jnp.log10((par0_col - lb_col) / (1 - par0_col / ub_col))



@jax.jit  
def compute(p):
    arr_1 = jnp.zeros(shape = (num_rows, num_cols))
    arr_2 = jnp.zeros(shape = (num_rows, num_cols))
    for i in range(num_rows):
 
        arr_1 = arr_1.at[i, :].set((par_log[kvals[i]:kvals[i+1]]))
        arr_2 = arr_2.at[i, :].set(10**par_log[kvals[i]:kvals[i+1]])

    return arr_1

arr = compute(par_log)
print(arr)


# WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
# Traceback (most recent call last):
#   File "test_7.py", line 47, in <module>
#     arr = compute(par_log)
#   File "/home/richinex/anaconda3/envs/numerical/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
#     return fun(*args, **kwargs)
#   File "/home/richinex/anaconda3/envs/numerical/lib/python3.8/site-packages/jax/_src/api.py", line 424, in cache_miss
#     out_flat = xla.xla_call(
#   File "/home/richinex/anaconda3/envs/numerical/lib/python3.8/site-packages/jax/core.py", line 1661, in bind
#     return call_bind(self, fun, *args, **params)
#   File "/home/richinex/anaconda3/envs/numerical/lib/python3.8/site-packages/jax/core.py", line 1652, in call_bind
#     outs = primitive.process(top_trace, fun, tracers, params)
#   File "/home/richinex/anaconda3/envs/numerical/lib/python3.8/site-packages/jax/core.py", line 1664, in process
#     return trace.process_call(self, fun, tracers, params)
#   File "/home/richinex/anaconda3/envs/numerical/lib/python3.8/site-packages/jax/core.py", line 633, in process_call
#     return primitive.impl(f, *tracers, **params)
#   File "/home/richinex/anaconda3/envs/numerical/lib/python3.8/site-packages/jax/_src/dispatch.py", line 128, in _xla_call_impl
#     compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
#   File "/home/richinex/anaconda3/envs/numerical/lib/python3.8/site-packages/jax/linear_util.py", line 263, in memoized_fun
#     ans = call(fun, *args)
#   File "/home/richinex/anaconda3/envs/numerical/lib/python3.8/site-packages/jax/_src/dispatch.py", line 155, in _xla_callable_uncached
#     return lower_xla_callable(fun, device, backend, name, donated_invars,
#   File "/home/richinex/anaconda3/envs/numerical/lib/python3.8/site-packages/jax/_src/profiler.py", line 206, in wrapper
#     return func(*args, **kwargs)
#   File "/home/richinex/anaconda3/envs/numerical/lib/python3.8/site-packages/jax/_src/dispatch.py", line 169, in lower_xla_callable
#     jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(
#   File "/home/richinex/anaconda3/envs/numerical/lib/python3.8/site-packages/jax/_src/profiler.py", line 206, in wrapper
#     return func(*args, **kwargs)
#   File "/home/richinex/anaconda3/envs/numerical/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1566, in trace_to_jaxpr_final
#     jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
#   File "/home/richinex/anaconda3/envs/numerical/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1543, in trace_to_subjaxpr_dynamic
#     ans = fun.call_wrapped(*in_tracers)
#   File "/home/richinex/anaconda3/envs/numerical/lib/python3.8/site-packages/jax/linear_util.py", line 166, in call_wrapped
#     ans = self.f(*args, **dict(self.params, **kwargs))
#   File "test_7.py", line 42, in compute
#     arr_1 = arr_1.at[i, :].set((par_log[kvals[i]:kvals[i+1]]))
#   File "/home/richinex/anaconda3/envs/numerical/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 5704, in _rewriting_take
#     return _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted,
#   File "/home/richinex/anaconda3/envs/numerical/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 5713, in _gather
#     indexer = _index_to_gather(shape(arr), idx)  # shared with _scatter_update
#   File "/home/richinex/anaconda3/envs/numerical/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 5956, in _index_to_gather
#     raise IndexError(msg)
# jax._src.traceback_util.UnfilteredStackTrace: IndexError: Array slice indices must have static start/stop/step to be used with NumPy indexing syntax. Found slice(Traced<ShapedArray(int64[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>, Traced<ShapedArray(int64[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>, None). To index a statically sized array at a dynamic position, try lax.dynamic_slice/dynamic_update_slice (JAX does not support dynamically sized arrays within JIT compiled functions).

# The stack trace below excludes JAX-internal frames.
# The preceding is the original exception that occurred, unmodified.

# --------------------

# The above exception was the direct cause of the following exception:

问题是 JAX 中的索引必须使用静态值完成,而在 JIT 中 kvals[i] 不是静态值(因为它是从 JAX 数组计算的)。

在您的情况下解决此问题的一种简单方法是将 kvals 设为 non-jax 数组;例如,当您定义它时,执行此操作;

kvals = list(jnp.cumsum(kvals))

这在这里有效,因为 kvals 是在 jit 表达式之外创建的。通常,如果您的索引也在 JIT 表达式中创建,您可以使用 lax.dynamic_slice 动态计算切片,它支持动态启动索引。

有关静态值与跟踪值的更多背景信息,请阅读 How To Think In JAX