jax中的高阶多元导数

Higher-order multivariate derivatives in jax

我对如何在 jax 中计算高阶多元导数感到困惑。

例如,您如何计算

的 d^2f / dx dy
def f(x, y):
     return jnp.sin(jnp.dot(x, y.T))

其中 R^n 中的 x、y,n >= 1?

我一直在尝试 jax.jvpjax.partial,但我没有取得任何成功。

由于 xy 是矢量值而 f(x, y) 是标量,我相信您可以通过组合 jax.jacfwd and jax.jacrev 函数来计算您想要的结果使用适当的参数:

import jax.numpy as jnp
from jax import jacfwd, jacrev

def f(x, y):
     return jnp.sin(jnp.dot(x, y.T))

d2f_dxdy = jacfwd(jacrev(f, argnums=1), argnums=0)
  
x = jnp.arange(4.0)
y = jnp.ones(4)

print(d2f_dxdy(x, y))

# DeviceArray([[0.96017027, 0.        , 0.        , 0.        ],
#              [0.2794155 , 1.2395858 , 0.2794155 , 0.2794155 ],
#              [0.558831  , 0.558831  , 1.5190012 , 0.558831  ],
#              [0.83824646, 0.83824646, 0.83824646, 1.7984167 ]],
#             dtype=float32)