重新编译 JIT 函数时的 JAX 挂钩/信息/警告

JAX hook / information / warning when a JIT function is re-compiled

在 JAX 中是否有可能在 JAX 即时编译器必须重新编译函数时获得通知(因为输入已更改并且无法评估缓存的编译版本)?

现在,我使用一个 hacky 解决方法来了解重新编译。目前的实现中,tracer只在需要编译的时候执行一次函数,允许有sideeffects,只有在函数重新编译时才执行:


import jax
recompilation_count: int = 0

@jax.jit
def func(z):
    global recompilation_count
    recompilation_count += 1
    return z * z + 100 / z


func(1)
print(recompilation_count)
func(2)
print(recompilation_count)
func(jax.numpy.arange(10))
print(recompilation_count)
func(jax.numpy.arange(10, 20))
print(recompilation_count)
func(jax.numpy.arange(10) ** 2)
print(recompilation_count)

assert recompilation_count == 2

但是,这是 JAX 实现的内部实现,因此不能以可靠的方式使用。如果函数经常发生,是否有另一种方法可以通知并可能阻止函数的重新编译?

我不相信有任何内置的 API 可以满足您的要求。但类似的功能目前正在积极讨论中(参见 https://github.com/google/jax/issues/8655

但是请注意,如果您愿意,有一种内置的方法可以跟踪编译计数:

import jax

@jax.jit
def f(x):
  return x

print(f._cache_size())
# 0

_ = f(jnp.arange(3))
print(f._cache_size())
# 1

_ = f(jnp.arange(3))  # should not trigger a recompilation
print(f._cache_size())
# 1

_ = f(jnp.arange(100))  # should trigger a recompilation
print(f._cache_size())
# 2