重新编译 JIT 函数时的 JAX 挂钩/信息/警告
JAX hook / information / warning when a JIT function is re-compiled
在 JAX 中是否有可能在 JAX 即时编译器必须重新编译函数时获得通知(因为输入已更改并且无法评估缓存的编译版本)?
现在,我使用一个 hacky 解决方法来了解重新编译。目前的实现中,tracer只在需要编译的时候执行一次函数,允许有sideeffects,只有在函数重新编译时才执行:
import jax
recompilation_count: int = 0
@jax.jit
def func(z):
global recompilation_count
recompilation_count += 1
return z * z + 100 / z
func(1)
print(recompilation_count)
func(2)
print(recompilation_count)
func(jax.numpy.arange(10))
print(recompilation_count)
func(jax.numpy.arange(10, 20))
print(recompilation_count)
func(jax.numpy.arange(10) ** 2)
print(recompilation_count)
assert recompilation_count == 2
但是,这是 JAX 实现的内部实现,因此不能以可靠的方式使用。如果函数经常发生,是否有另一种方法可以通知并可能阻止函数的重新编译?
我不相信有任何内置的 API 可以满足您的要求。但类似的功能目前正在积极讨论中(参见 https://github.com/google/jax/issues/8655)
但是请注意,如果您愿意,有一种内置的方法可以跟踪编译计数:
import jax
@jax.jit
def f(x):
return x
print(f._cache_size())
# 0
_ = f(jnp.arange(3))
print(f._cache_size())
# 1
_ = f(jnp.arange(3)) # should not trigger a recompilation
print(f._cache_size())
# 1
_ = f(jnp.arange(100)) # should trigger a recompilation
print(f._cache_size())
# 2
在 JAX 中是否有可能在 JAX 即时编译器必须重新编译函数时获得通知(因为输入已更改并且无法评估缓存的编译版本)?
现在,我使用一个 hacky 解决方法来了解重新编译。目前的实现中,tracer只在需要编译的时候执行一次函数,允许有sideeffects,只有在函数重新编译时才执行:
import jax
recompilation_count: int = 0
@jax.jit
def func(z):
global recompilation_count
recompilation_count += 1
return z * z + 100 / z
func(1)
print(recompilation_count)
func(2)
print(recompilation_count)
func(jax.numpy.arange(10))
print(recompilation_count)
func(jax.numpy.arange(10, 20))
print(recompilation_count)
func(jax.numpy.arange(10) ** 2)
print(recompilation_count)
assert recompilation_count == 2
但是,这是 JAX 实现的内部实现,因此不能以可靠的方式使用。如果函数经常发生,是否有另一种方法可以通知并可能阻止函数的重新编译?
我不相信有任何内置的 API 可以满足您的要求。但类似的功能目前正在积极讨论中(参见 https://github.com/google/jax/issues/8655)
但是请注意,如果您愿意,有一种内置的方法可以跟踪编译计数:
import jax
@jax.jit
def f(x):
return x
print(f._cache_size())
# 0
_ = f(jnp.arange(3))
print(f._cache_size())
# 1
_ = f(jnp.arange(3)) # should not trigger a recompilation
print(f._cache_size())
# 1
_ = f(jnp.arange(100)) # should trigger a recompilation
print(f._cache_size())
# 2