在 JAX 中高效计算 Hessian 矩阵
Compute efficiently Hessian matrices in JAX
在 JAX 的快速入门教程中,我发现使用以下代码行可以有效地计算可微函数的 Hessian 矩阵 fun
:
from jax import jacfwd, jacrev
def hessian(fun):
return jit(jacfwd(jacrev(fun)))
但是,也可以通过计算以下内容来计算 Hessian:
def hessian(fun):
return jit(jacrev(jacfwd(fun)))
def hessian(fun):
return jit(jacfwd(jacfwd(fun)))
def hessian(fun):
return jit(jacrev(jacrev(fun)))
这是一个最小的工作示例:
import jax.numpy as jnp
from jax import jit
from jax import jacfwd, jacrev
def comp_hessian():
x = jnp.arange(1.0, 4.0)
def sum_logistics(x):
return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))
def hessian_1(fun):
return jit(jacfwd(jacrev(fun)))
def hessian_2(fun):
return jit(jacrev(jacfwd(fun)))
def hessian_3(fun):
return jit(jacrev(jacrev(fun)))
def hessian_4(fun):
return jit(jacfwd(jacfwd(fun)))
hessian_fn = hessian_1(sum_logistics)
print(hessian_fn(x))
hessian_fn = hessian_2(sum_logistics)
print(hessian_fn(x))
hessian_fn = hessian_3(sum_logistics)
print(hessian_fn(x))
hessian_fn = hessian_4(sum_logistics)
print(hessian_fn(x))
def main():
comp_hessian()
if __name__ == "__main__":
main()
我想知道哪种方法最适合使用以及何时使用?我也想知道是否可以使用 grad()
来计算 Hessian? grad()
与 jacfwd
和 jacrev
有何不同?
您问题的答案在 JAX 文档中;例如,参见本节:https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html#jacobians-and-hessians-using-jacfwd-and-jacrev
引用其对 jacrev
和 jacfwd
的讨论:
These two functions compute the same values (up to machine numerics), but differ in their implementation: jacfwd
uses forward-mode automatic differentiation, which is more efficient for “tall” Jacobian matrices, while jacrev
uses reverse-mode, which is more efficient for “wide” Jacobian matrices. For matrices that are near-square, jacfwd
probably has an edge over jacrev
.
再往下,
To implement hessian, we could have used jacfwd(jacrev(f))
or jacrev(jacfwd(f))
or any other composition of the two. But forward-over-reverse is typically the most efficient. That’s because in the inner Jacobian computation we’re often differentiating a function wide Jacobian (maybe like a loss function :ℝⁿ→ℝ), while in the outer Jacobian computation we’re differentiating a function with a square Jacobian (since ∇:ℝⁿ→ℝⁿ), which is where forward-mode wins out.
因为你的函数看起来像 :ℝⁿ→ℝ,那么 jit(jacfwd(jacrev(fun)))
可能是最有效的方法。
至于为什么不能用 grad
实现 hessian,这是因为 grad
仅设计用于具有标量输出的函数的导数。根据定义,hessian 矩阵是向量值雅可比矩阵的组合,而不是标量梯度的组合。
在 JAX 的快速入门教程中,我发现使用以下代码行可以有效地计算可微函数的 Hessian 矩阵 fun
:
from jax import jacfwd, jacrev
def hessian(fun):
return jit(jacfwd(jacrev(fun)))
但是,也可以通过计算以下内容来计算 Hessian:
def hessian(fun):
return jit(jacrev(jacfwd(fun)))
def hessian(fun):
return jit(jacfwd(jacfwd(fun)))
def hessian(fun):
return jit(jacrev(jacrev(fun)))
这是一个最小的工作示例:
import jax.numpy as jnp
from jax import jit
from jax import jacfwd, jacrev
def comp_hessian():
x = jnp.arange(1.0, 4.0)
def sum_logistics(x):
return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))
def hessian_1(fun):
return jit(jacfwd(jacrev(fun)))
def hessian_2(fun):
return jit(jacrev(jacfwd(fun)))
def hessian_3(fun):
return jit(jacrev(jacrev(fun)))
def hessian_4(fun):
return jit(jacfwd(jacfwd(fun)))
hessian_fn = hessian_1(sum_logistics)
print(hessian_fn(x))
hessian_fn = hessian_2(sum_logistics)
print(hessian_fn(x))
hessian_fn = hessian_3(sum_logistics)
print(hessian_fn(x))
hessian_fn = hessian_4(sum_logistics)
print(hessian_fn(x))
def main():
comp_hessian()
if __name__ == "__main__":
main()
我想知道哪种方法最适合使用以及何时使用?我也想知道是否可以使用 grad()
来计算 Hessian? grad()
与 jacfwd
和 jacrev
有何不同?
您问题的答案在 JAX 文档中;例如,参见本节:https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html#jacobians-and-hessians-using-jacfwd-and-jacrev
引用其对 jacrev
和 jacfwd
的讨论:
These two functions compute the same values (up to machine numerics), but differ in their implementation:
jacfwd
uses forward-mode automatic differentiation, which is more efficient for “tall” Jacobian matrices, whilejacrev
uses reverse-mode, which is more efficient for “wide” Jacobian matrices. For matrices that are near-square,jacfwd
probably has an edge overjacrev
.
再往下,
To implement hessian, we could have used
jacfwd(jacrev(f))
orjacrev(jacfwd(f))
or any other composition of the two. But forward-over-reverse is typically the most efficient. That’s because in the inner Jacobian computation we’re often differentiating a function wide Jacobian (maybe like a loss function :ℝⁿ→ℝ), while in the outer Jacobian computation we’re differentiating a function with a square Jacobian (since ∇:ℝⁿ→ℝⁿ), which is where forward-mode wins out.
因为你的函数看起来像 :ℝⁿ→ℝ,那么 jit(jacfwd(jacrev(fun)))
可能是最有效的方法。
至于为什么不能用 grad
实现 hessian,这是因为 grad
仅设计用于具有标量输出的函数的导数。根据定义,hessian 矩阵是向量值雅可比矩阵的组合,而不是标量梯度的组合。