jax中的高阶多元导数
Higher-order multivariate derivatives in jax
我对如何在 jax 中计算高阶多元导数感到困惑。
例如,您如何计算
的 d^2f / dx dy
def f(x, y):
return jnp.sin(jnp.dot(x, y.T))
其中 R^n 中的 x、y,n >= 1?
我一直在尝试 jax.jvp
和 jax.partial
,但我没有取得任何成功。
由于 x
和 y
是矢量值而 f(x, y)
是标量,我相信您可以通过组合 jax.jacfwd
and jax.jacrev
函数来计算您想要的结果使用适当的参数:
import jax.numpy as jnp
from jax import jacfwd, jacrev
def f(x, y):
return jnp.sin(jnp.dot(x, y.T))
d2f_dxdy = jacfwd(jacrev(f, argnums=1), argnums=0)
x = jnp.arange(4.0)
y = jnp.ones(4)
print(d2f_dxdy(x, y))
# DeviceArray([[0.96017027, 0. , 0. , 0. ],
# [0.2794155 , 1.2395858 , 0.2794155 , 0.2794155 ],
# [0.558831 , 0.558831 , 1.5190012 , 0.558831 ],
# [0.83824646, 0.83824646, 0.83824646, 1.7984167 ]],
# dtype=float32)
我对如何在 jax 中计算高阶多元导数感到困惑。
例如,您如何计算
的 d^2f / dx dydef f(x, y):
return jnp.sin(jnp.dot(x, y.T))
其中 R^n 中的 x、y,n >= 1?
我一直在尝试 jax.jvp
和 jax.partial
,但我没有取得任何成功。
由于 x
和 y
是矢量值而 f(x, y)
是标量,我相信您可以通过组合 jax.jacfwd
and jax.jacrev
函数来计算您想要的结果使用适当的参数:
import jax.numpy as jnp
from jax import jacfwd, jacrev
def f(x, y):
return jnp.sin(jnp.dot(x, y.T))
d2f_dxdy = jacfwd(jacrev(f, argnums=1), argnums=0)
x = jnp.arange(4.0)
y = jnp.ones(4)
print(d2f_dxdy(x, y))
# DeviceArray([[0.96017027, 0. , 0. , 0. ],
# [0.2794155 , 1.2395858 , 0.2794155 , 0.2794155 ],
# [0.558831 , 0.558831 , 1.5190012 , 0.558831 ],
# [0.83824646, 0.83824646, 0.83824646, 1.7984167 ]],
# dtype=float32)