如何从 Jax 中的联合累积密度函数计算联合概率密度函数?
How to compute the joint probability density function from a joint cumulative density function in Jax?
我在 python 中定义了一个联合累积密度函数作为 jax 数组的函数并返回单个值。
类似于:
def cumulative(inputs: array) -> float:
...
要获得梯度,我知道我可以做 grad(cumulative)
,但这只会给我关于输入变量的累积偏导数 first-order。
相反,我想做的是计算这个,假设 F 是我的函数,f 是联合概率密度函数:
偏导的顺序无关紧要。
所以,我有几个问题:
- 如何在 Jax 中有效地计算这个?我想我不能只打电话给 grad n 次
- 一旦计算出结果函数,结果函数的调用复杂度是否会比原始函数高(是增加了O(n),还是常数,还是其他)?
- 或者,我如何计算仅关于输入数组变量之一的单个偏导数,而不是整个数组? (我将重复 n 次,每个变量一次)
JAX 通常将渐变视为相对于单个参数,而不是参数中的元素。在此上下文中,一个类似于您想要执行的操作(但不完全相同)的内置函数是 jax.hessian
,它计算二阶导数的 hessian 矩阵;例如:
import jax
import jax.numpy as jnp
def f(x):
return jnp.prod(x ** 2)
x = jnp.arange(1.0, 4.0)
print(jax.hessian(f)(x))
# [[72. 72. 48.]
# [72. 18. 24.]
# [48. 24. 8.]]
对于数组的各个元素的高阶导数,我认为您必须手动嵌套梯度。您可以使用如下所示的辅助函数来执行此操作:
def grad_all(f):
def gradfun(x):
args = tuple(x)
f_args = lambda *args: f(jnp.array(args))
for i in range(len(args)):
f_args = jax.grad(f_args, argnums=i)
return f_args(*args)
return gradfun
print(grad_all(f)(x))
# 48.0
我在 python 中定义了一个联合累积密度函数作为 jax 数组的函数并返回单个值。 类似于:
def cumulative(inputs: array) -> float:
...
要获得梯度,我知道我可以做 grad(cumulative)
,但这只会给我关于输入变量的累积偏导数 first-order。
相反,我想做的是计算这个,假设 F 是我的函数,f 是联合概率密度函数:
偏导的顺序无关紧要。
所以,我有几个问题:
- 如何在 Jax 中有效地计算这个?我想我不能只打电话给 grad n 次
- 一旦计算出结果函数,结果函数的调用复杂度是否会比原始函数高(是增加了O(n),还是常数,还是其他)?
- 或者,我如何计算仅关于输入数组变量之一的单个偏导数,而不是整个数组? (我将重复 n 次,每个变量一次)
JAX 通常将渐变视为相对于单个参数,而不是参数中的元素。在此上下文中,一个类似于您想要执行的操作(但不完全相同)的内置函数是 jax.hessian
,它计算二阶导数的 hessian 矩阵;例如:
import jax
import jax.numpy as jnp
def f(x):
return jnp.prod(x ** 2)
x = jnp.arange(1.0, 4.0)
print(jax.hessian(f)(x))
# [[72. 72. 48.]
# [72. 18. 24.]
# [48. 24. 8.]]
对于数组的各个元素的高阶导数,我认为您必须手动嵌套梯度。您可以使用如下所示的辅助函数来执行此操作:
def grad_all(f):
def gradfun(x):
args = tuple(x)
f_args = lambda *args: f(jnp.array(args))
for i in range(len(args)):
f_args = jax.grad(f_args, argnums=i)
return f_args(*args)
return gradfun
print(grad_all(f)(x))
# 48.0