朱莉娅:Zygote.@adjoint 来自 Enzyme.autodiff

Julia: Zygote.@adjoint from Enzyme.autodiff

给定下面的函数 f! :

function f!(s::Vector, a::Vector, b::Vector)
  
  s .= a .+ b
  return nothing

end # f!

如何根据

Zygote定义伴随物

Enzyme.autodiff(f!, Const, Duplicated(s, dz_ds). Duplicated(a, zero(a)), Duplicated(b, zero(b))) ?

Zygote.@adjoint f!(s, a, b) = f!(s, a, b), # What would come here ?

想出了办法,在这里分享。

对于给定函数 fooZygote.pullback(foo, args...) returns foo(args...) 和向后传递(允许梯度计算)。

我的目标是告诉 Zygote 使用 Enzyme 进行反向传递。

这可以通过 Zygote.@adjoint 来完成(查看更多 here)。

在 array-valued 函数的情况下,Enzyme 需要 returns nothing 的变异版本及其结果在 args 中(查看更多here).

问题 post 中的函数 f! 是两个数组求和的 Enzyme 兼容版本。

因为 f! returns nothing, Zygote 只会 return nothing 当对传递给我们的某些梯度调用反向传递时。

一个解决方案是将 f! 放在一个包装器(比如 f)中,return 是数组 s

并为 f 定义 Zygote.@adjoint,而不是 f!

因此,

function f(a::Vector, b::Vector)

  s = zero(a)
  f!(s, a, b)
  return s

end
function enzyme_back(dzds, a, b)

  s    = zero(a)
  dzda = zero(dzds)
  dzdb = zero(dzds)
  Enzyme.autodiff(
    f!,
    Const,
    Duplicated(s, dzds),
    Duplicated(a, dzda),
    Duplicated(b, dzdb)
  )
  return (dzda, dzdb)

end

Zygote.@adjoint f(a, b) = f(a, b), dzds -> enzyme_back(dzds, a, b)

通知Zygote在反向传播中使用Enzyme


最后,您可以检查是否在

上调用 Zygote.gradient
g1(a::Vector, b::Vector) = sum(abs2, a + b)

g2(a::Vector, b::Vector) = sum(abs2, f(a, b))

产生相同的结果。