如何使用 JIT 处理 JAX 重塑
How to handle JAX reshape with JIT
我正在尝试实现 here 中描述的 entmax-alpha。
这是代码。
import jax
import jax.numpy as jnp
from jax import custom_jvp
from jax import jit
from jax import lax
from jax import vmap
@jax.partial(jit, static_argnums=(2,))
def p_tau(z, tau, alpha=1.5):
return jnp.clip((alpha - 1) * z - tau, a_min=0) ** (1 / (alpha - 1))
@jit
def get_tau(tau, tau_max, tau_min, z_value):
return lax.cond(z_value < 1,
lambda _: (tau, tau_min),
lambda _: (tau_max, tau),
operand=None
)
@jit
def body(kwargs, x):
tau_min = kwargs['tau_min']
tau_max = kwargs['tau_max']
z = kwargs['z']
alpha = kwargs['alpha']
tau = (tau_min + tau_max) / 2
z_value = p_tau(z, tau, alpha).sum()
taus = get_tau(tau, tau_max, tau_min, z_value)
tau_max, tau_min = taus[0], taus[1]
return {'tau_min': tau_min, 'tau_max': tau_max, 'z': z, 'alpha': alpha}, None
@jax.partial(jit, static_argnums=(1, 2,))
def map_row(z_input, alpha, T):
z = (alpha - 1) * z_input
tau_min, tau_max = jnp.min(z) - 1, jnp.max(z) - z.shape[0] ** (1 - alpha)
result, _ = lax.scan(body, {'tau_min': tau_min, 'tau_max': tau_max, 'z': z, 'alpha': alpha}, xs=None,
length=T)
tau = (result['tau_max'] + result['tau_min']) / 2
result = p_tau(z, tau, alpha)
return result / result.sum()
@jax.partial(custom_jvp, nondiff_argnums=(1, 2, 3,))
def entmax(input, axis=-1, alpha=1.5, T=10):
reduce_length = input.shape[axis]
input = jnp.swapaxes(input, -1, axis)
input = input.reshape(input.size / reduce_length, reduce_length)
result = vmap(jax.partial(map_row, alpha=alpha, T=T), 0)(input)
return jnp.swapaxes(result, -1, axis)
@jax.partial(jit, static_argnums=(1, 2,))
def _entmax_jvp_impl(axis, alpha, T, primals, tangents):
input = primals[0]
Y = entmax(input, axis, alpha, T)
gppr = Y ** (2 - alpha)
grad_output = tangents[0]
dX = grad_output * gppr
q = dX.sum(axis=axis) / gppr.sum(axis=axis)
q = jnp.expand_dims(q, axis=axis)
dX -= q * gppr
return Y, dX
@entmax.defjvp
def entmax_jvp(axis, alpha, T, primals, tangents):
return _entmax_jvp_impl(axis, alpha, T, primals, tangents)
当我用下面的代码调用它时:
import numpy as np
from jax import value_and_grad
input = jnp.array(np.random.randn(64, 10))
weight = jnp.array(np.random.randn(64, 10))
def toy(input, weight):
return (weight*entmax(input, axis=-1, alpha=1.5, T=20)).sum()
value_and_grad(toy)(input, weight)
我收到以下错误。
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-3-3a62e54c67d2> in <module>()
7 return (weight*entmax(input, axis=-1, alpha=1.5, T=20)).sum()
8
----> 9 value_and_grad(toy)(input, weight)
35 frames
<ipython-input-1-d85b1daec668> in entmax(input, axis, alpha, T)
49 @jax.partial(custom_jvp, nondiff_argnums=(1, 2, 3,))
50 def entmax(input, axis=-1, alpha=1.5, T=10):
---> 51 reduce_length = input.shape[axis]
52 input = jnp.swapaxes(input, -1, axis)
53 input = input.reshape(input.size / reduce_length, reduce_length)
TypeError: tuple indices must be integers or slices, not DynamicJaxprTracer
它似乎总是与重塑操作有关。我不确定为什么会这样,非常感谢您的帮助。
要重现问题,这里是 colab notebook
非常感谢。
错误是由于您试图用跟踪数量 axis
索引 Python 元组。您可以通过将 axis
设为静态参数来修复此错误:
@jax.partial(jit, static_argnums=(0, 1, 2,))
def _entmax_jvp_impl(axis, alpha, T, primals, tangents):
...
不幸的是,这揭示了另一个问题:p_tau
声明 alpha
参数是静态的,但 body()
调用它时带有跟踪数量。在 body
中不能轻易将此数量标记为静态,因为它是在包含正在跟踪的输入的参数字典中传递的。
要解决此问题,您必须重写函数签名,在每个函数签名中仔细标记哪些输入是静态的,哪些不是,并确保两者不会混合在函数调用层中。
我正在尝试实现 here 中描述的 entmax-alpha。
这是代码。
import jax
import jax.numpy as jnp
from jax import custom_jvp
from jax import jit
from jax import lax
from jax import vmap
@jax.partial(jit, static_argnums=(2,))
def p_tau(z, tau, alpha=1.5):
return jnp.clip((alpha - 1) * z - tau, a_min=0) ** (1 / (alpha - 1))
@jit
def get_tau(tau, tau_max, tau_min, z_value):
return lax.cond(z_value < 1,
lambda _: (tau, tau_min),
lambda _: (tau_max, tau),
operand=None
)
@jit
def body(kwargs, x):
tau_min = kwargs['tau_min']
tau_max = kwargs['tau_max']
z = kwargs['z']
alpha = kwargs['alpha']
tau = (tau_min + tau_max) / 2
z_value = p_tau(z, tau, alpha).sum()
taus = get_tau(tau, tau_max, tau_min, z_value)
tau_max, tau_min = taus[0], taus[1]
return {'tau_min': tau_min, 'tau_max': tau_max, 'z': z, 'alpha': alpha}, None
@jax.partial(jit, static_argnums=(1, 2,))
def map_row(z_input, alpha, T):
z = (alpha - 1) * z_input
tau_min, tau_max = jnp.min(z) - 1, jnp.max(z) - z.shape[0] ** (1 - alpha)
result, _ = lax.scan(body, {'tau_min': tau_min, 'tau_max': tau_max, 'z': z, 'alpha': alpha}, xs=None,
length=T)
tau = (result['tau_max'] + result['tau_min']) / 2
result = p_tau(z, tau, alpha)
return result / result.sum()
@jax.partial(custom_jvp, nondiff_argnums=(1, 2, 3,))
def entmax(input, axis=-1, alpha=1.5, T=10):
reduce_length = input.shape[axis]
input = jnp.swapaxes(input, -1, axis)
input = input.reshape(input.size / reduce_length, reduce_length)
result = vmap(jax.partial(map_row, alpha=alpha, T=T), 0)(input)
return jnp.swapaxes(result, -1, axis)
@jax.partial(jit, static_argnums=(1, 2,))
def _entmax_jvp_impl(axis, alpha, T, primals, tangents):
input = primals[0]
Y = entmax(input, axis, alpha, T)
gppr = Y ** (2 - alpha)
grad_output = tangents[0]
dX = grad_output * gppr
q = dX.sum(axis=axis) / gppr.sum(axis=axis)
q = jnp.expand_dims(q, axis=axis)
dX -= q * gppr
return Y, dX
@entmax.defjvp
def entmax_jvp(axis, alpha, T, primals, tangents):
return _entmax_jvp_impl(axis, alpha, T, primals, tangents)
当我用下面的代码调用它时:
import numpy as np
from jax import value_and_grad
input = jnp.array(np.random.randn(64, 10))
weight = jnp.array(np.random.randn(64, 10))
def toy(input, weight):
return (weight*entmax(input, axis=-1, alpha=1.5, T=20)).sum()
value_and_grad(toy)(input, weight)
我收到以下错误。
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-3-3a62e54c67d2> in <module>()
7 return (weight*entmax(input, axis=-1, alpha=1.5, T=20)).sum()
8
----> 9 value_and_grad(toy)(input, weight)
35 frames
<ipython-input-1-d85b1daec668> in entmax(input, axis, alpha, T)
49 @jax.partial(custom_jvp, nondiff_argnums=(1, 2, 3,))
50 def entmax(input, axis=-1, alpha=1.5, T=10):
---> 51 reduce_length = input.shape[axis]
52 input = jnp.swapaxes(input, -1, axis)
53 input = input.reshape(input.size / reduce_length, reduce_length)
TypeError: tuple indices must be integers or slices, not DynamicJaxprTracer
它似乎总是与重塑操作有关。我不确定为什么会这样,非常感谢您的帮助。
要重现问题,这里是 colab notebook
非常感谢。
错误是由于您试图用跟踪数量 axis
索引 Python 元组。您可以通过将 axis
设为静态参数来修复此错误:
@jax.partial(jit, static_argnums=(0, 1, 2,))
def _entmax_jvp_impl(axis, alpha, T, primals, tangents):
...
不幸的是,这揭示了另一个问题:p_tau
声明 alpha
参数是静态的,但 body()
调用它时带有跟踪数量。在 body
中不能轻易将此数量标记为静态,因为它是在包含正在跟踪的输入的参数字典中传递的。
要解决此问题,您必须重写函数签名,在每个函数签名中仔细标记哪些输入是静态的,哪些不是,并确保两者不会混合在函数调用层中。