如何从 Jax 中的联合累积密度函数计算联合概率密度函数?

How to compute the joint probability density function from a joint cumulative density function in Jax?

我在 python 中定义了一个联合累积密度函数作为 jax 数组的函数并返回单个值。 类似于:

def cumulative(inputs: array) -> float:
    ...

要获得梯度,我知道我可以做 grad(cumulative),但这只会给我关于输入变量的累积偏导数 first-order。 相反,我想做的是计算这个,假设 F 是我的函数,f 是联合概率密度函数:

formula

偏导的顺序无关紧要。

所以,我有几个问题:

JAX 通常将渐变视为相对于单个参数,而不是参数中的元素。在此上下文中,一个类似于您想要执行的操作(但不完全相同)的内置函数是 jax.hessian,它计算二阶导数的 hessian 矩阵;例如:

import jax
import jax.numpy as jnp

def f(x):
  return jnp.prod(x ** 2)

x = jnp.arange(1.0, 4.0)
print(jax.hessian(f)(x))
# [[72. 72. 48.]
#  [72. 18. 24.]
#  [48. 24.  8.]]

对于数组的各个元素的高阶导数,我认为您必须手动嵌套梯度。您可以使用如下所示的辅助函数来执行此操作:

def grad_all(f):
  def gradfun(x):
    args = tuple(x)
    f_args = lambda *args: f(jnp.array(args))
    for i in range(len(args)):
      f_args = jax.grad(f_args, argnums=i)
    return f_args(*args)
  return gradfun

print(grad_all(f)(x))
# 48.0