JIT Jax 中的最小二乘损失函数
JIT a least squares loss function in Jax
我有一个简单的损失函数,看起来像这样
def loss(r, x, y):
resid = f(r, x) - y
return jnp.mean(jnp.square(resid))
我想优化参数 r
并使用一些静态参数 x
和 y
来计算残差。所有有问题的参数都是 DeviceArrays
.
为了做到这一点,我尝试了以下操作
@partial(jax.jit, static_argnums=(1, 2))
def loss(r, x, y):
resid = f(r, x) - y
return jnp.mean(jnp.square(resid))
但是我得到这个错误
jax._src.traceback_util.UnfilteredStackTrace: ValueError: Non-hashable static arguments are not supported, as this can lead to unexpected cache-misses. Static argument (index 1) of type <class 'numpy.ndarray'> for function loss is non-hashable.
我从 #6233 了解到这是设计使然,但我想知道这里的解决方法是什么,因为这似乎是一个非常常见的用例,您有一些固定的(输入、输出)训练数据对和一些自由变量。
感谢任何提示!
编辑:这是我尝试使用 jax.jit
时遇到的错误
jax._src.traceback_util.UnfilteredStackTrace: jax._src.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[])>with<DynamicJaxprTrace(level=0/1)>
The problem arose with the `bool` function.
While tracing the function loss at /path/to/my/script:9 for jit, this concrete value was not available in Python because it depends on the value of the argument 'r'.`
这听起来像是您将静态参数视为“在计算之间不发生变化的值”。在 JAX 的 JIT 中,最好将静态参数视为“可散列编译时常量”。在您的情况下,您没有可哈希的编译时常量;你有数组,所以你可以在没有静态参数的情况下进行 JIT 编译:
@jit
def loss(r, x, y):
resid = f(r, x) - y
return jnp.mean(jnp.square(resid))
如果您真的想让 JAX 机器知道您的数组是常量,您可以通过闭包或部分传递它们来实现;例如:
from functools import partial
def loss(r, x, y):
resid = f(r, x) - y
return jnp.mean(jnp.square(resid))
loss = jit(partial(loss, x=x, y=y))
但是,对于您正在进行的计算类型,其中常量是由 JAX 数组函数操作的数组,这两种方法导致基本相同的低 XLA 代码,因此您不妨使用更简单的方法。
我有一个简单的损失函数,看起来像这样
def loss(r, x, y):
resid = f(r, x) - y
return jnp.mean(jnp.square(resid))
我想优化参数 r
并使用一些静态参数 x
和 y
来计算残差。所有有问题的参数都是 DeviceArrays
.
为了做到这一点,我尝试了以下操作
@partial(jax.jit, static_argnums=(1, 2))
def loss(r, x, y):
resid = f(r, x) - y
return jnp.mean(jnp.square(resid))
但是我得到这个错误
jax._src.traceback_util.UnfilteredStackTrace: ValueError: Non-hashable static arguments are not supported, as this can lead to unexpected cache-misses. Static argument (index 1) of type <class 'numpy.ndarray'> for function loss is non-hashable.
我从 #6233 了解到这是设计使然,但我想知道这里的解决方法是什么,因为这似乎是一个非常常见的用例,您有一些固定的(输入、输出)训练数据对和一些自由变量。
感谢任何提示!
编辑:这是我尝试使用 jax.jit
jax._src.traceback_util.UnfilteredStackTrace: jax._src.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[])>with<DynamicJaxprTrace(level=0/1)>
The problem arose with the `bool` function.
While tracing the function loss at /path/to/my/script:9 for jit, this concrete value was not available in Python because it depends on the value of the argument 'r'.`
这听起来像是您将静态参数视为“在计算之间不发生变化的值”。在 JAX 的 JIT 中,最好将静态参数视为“可散列编译时常量”。在您的情况下,您没有可哈希的编译时常量;你有数组,所以你可以在没有静态参数的情况下进行 JIT 编译:
@jit
def loss(r, x, y):
resid = f(r, x) - y
return jnp.mean(jnp.square(resid))
如果您真的想让 JAX 机器知道您的数组是常量,您可以通过闭包或部分传递它们来实现;例如:
from functools import partial
def loss(r, x, y):
resid = f(r, x) - y
return jnp.mean(jnp.square(resid))
loss = jit(partial(loss, x=x, y=y))
但是,对于您正在进行的计算类型,其中常量是由 JAX 数组函数操作的数组,这两种方法导致基本相同的低 XLA 代码,因此您不妨使用更简单的方法。