在 JAX 中高效计算 Hessian 矩阵

Compute efficiently Hessian matrices in JAX

在 JAX 的快速入门教程中,我发现使用以下代码行可以有效地计算可微函数的 Hessian 矩阵 fun

from jax import jacfwd, jacrev

def hessian(fun):
  return jit(jacfwd(jacrev(fun)))

但是,也可以通过计算以下内容来计算 Hessian:

def hessian(fun):
  return jit(jacrev(jacfwd(fun)))

def hessian(fun):
  return jit(jacfwd(jacfwd(fun)))

def hessian(fun):
  return jit(jacrev(jacrev(fun)))

这是一个最小的工作示例:

import jax.numpy as jnp
from jax import jit
from jax import jacfwd, jacrev

def comp_hessian():

    x = jnp.arange(1.0, 4.0)

    def sum_logistics(x):
        return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))

    def hessian_1(fun):
        return jit(jacfwd(jacrev(fun)))

    def hessian_2(fun):
        return jit(jacrev(jacfwd(fun)))

    def hessian_3(fun):
        return jit(jacrev(jacrev(fun)))

    def hessian_4(fun):
        return jit(jacfwd(jacfwd(fun)))

    hessian_fn = hessian_1(sum_logistics)
    print(hessian_fn(x))

    hessian_fn = hessian_2(sum_logistics)
    print(hessian_fn(x))

    hessian_fn = hessian_3(sum_logistics)
    print(hessian_fn(x))

    hessian_fn = hessian_4(sum_logistics)
    print(hessian_fn(x))


def main():
    comp_hessian()


if __name__ == "__main__":
    main()

我想知道哪种方法最适合使用以及何时使用?我也想知道是否可以使用 grad() 来计算 Hessian? grad()jacfwdjacrev 有何不同?

您问题的答案在 JAX 文档中;例如,参见本节:https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html#jacobians-and-hessians-using-jacfwd-and-jacrev

引用其对 jacrevjacfwd 的讨论:

These two functions compute the same values (up to machine numerics), but differ in their implementation: jacfwd uses forward-mode automatic differentiation, which is more efficient for “tall” Jacobian matrices, while jacrev uses reverse-mode, which is more efficient for “wide” Jacobian matrices. For matrices that are near-square, jacfwd probably has an edge over jacrev.

再往下,

To implement hessian, we could have used jacfwd(jacrev(f)) or jacrev(jacfwd(f)) or any other composition of the two. But forward-over-reverse is typically the most efficient. That’s because in the inner Jacobian computation we’re often differentiating a function wide Jacobian (maybe like a loss function :ℝⁿ→ℝ), while in the outer Jacobian computation we’re differentiating a function with a square Jacobian (since ∇:ℝⁿ→ℝⁿ), which is where forward-mode wins out.

因为你的函数看起来像 :ℝⁿ→ℝ,那么 jit(jacfwd(jacrev(fun))) 可能是最有效的方法。

至于为什么不能用 grad 实现 hessian,这是因为 grad 仅设计用于具有标量输出的函数的导数。根据定义,hessian 矩阵是向量值雅可比矩阵的组合,而不是标量梯度的组合。