为 Flux 自定义渐变而不是使用 Zygote A.D

Custom gradients for Flux rather than using Zygote A.D

我有一个机器学习模型,其中模型参数的梯度是解析的,不需要自动微分。但是,我仍然希望能够在 Flux 中利用不同的优化器,而不必依赖 Zygote 进行差异化。这是我的一些代码片段。

W = rand(Nh, N)
U = rand(N, Nh)
b = rand(N)
c = rand(Nh)

θ = Flux.Params([b, c, U, W])

opt = ADAM(0.01)

然后我有一个函数可以计算我的模型参数的解析梯度,θ

function gradients(x) # x = one input data point or a batch of input data points
    # stuff to calculate gradients of each parameter
    # returns gradients of each parameter

然后我希望能够执行以下操作。

grads = gradients(x)
update!(opt, θ, grads)

我的问题是:我的 gradient(x) 函数需要 return 什么 form/type 才能完成 update!(opt, θ, grads),我该怎么做?

如果您不使用 Params,那么 grads 只需要作为渐变即可。唯一的要求是 θgrads 大小相同。

例如,map((x, g) -> update!(opt, x, g), θ, grads) 其中 θ == [b, c, U, W]grads = [gradients(b), gradients(c), gradients(U), gradients(W)](不太确定 gradients 期望您输入什么)。

更新:但要回答您原来的问题,gradients 需要 return 一个 Grads 对象在这里找到:https://github.com/FluxML/Zygote.jl/blob/359e586766129878ca0e56121037ed80afda6289/src/compiler/interface.jl#L88

所以像

# within gradient function body assuming gb is the gradient w.r.t b
g = Zygote.Grads(IdDict())
g.grads[θ[1]] = gb # assuming θ[1] == b

但不使用 Params 可能更容易调试。唯一的问题是没有 update! 可以处理一组参数,但您可以轻松定义自己的:

function Flux.Optimise.update!(opt, xs::Tuple, gs)
    for (x, g) in zip(xs, gs)
        update!(opt, x, g)
    end
end

# use it like this
W = rand(Nh, N)
U = rand(N, Nh)
b = rand(N)
c = rand(Nh)

θ = (b, c, U, W)

opt = ADAM(0.01)
x = # generate input to gradients
grads = gradients(x) # return tuple (gb, gc, gU, gW)
update!(opt, θ, grads)

更新 2:

另一种选择是仍然使用 Zygote 获取梯度,以便它自动为您设置 Grads 对象,但使用自定义伴随,以便它使用您的分析函数来计算伴随。假设您的 ML 模型被定义为名为 f 的函数,因此 f(x) return 是输入 x 的模型输出。我们还假设 gradients(x) return 是分析梯度 w.r.t。 x 就像您在问题中提到的那样。那么下面的代码仍然会使用 Zygote 的 AD,它会正确地填充 Grads 对象,但是它会使用你定义的函数计算梯度 f:

W = rand(Nh, N)
U = rand(N, Nh)
b = rand(N)
c = rand(Nh)

θ = Flux.Params([b, c, U, W])

f(x) = # define your model
gradients(x) = # define your analytical gradient

# set up the custom adjoint
Zygote.@adjoint f(x) = f(x), Δ -> (gradients(x),)

opt = ADAM(0.01)
x = # generate input to model
y = # output of model
grads = Zygote.gradient(() -> Flux.mse(f(x), y), θ)
update!(opt, θ, grads)

请注意,我在上面使用 Flux.mse 作为示例损失。这种方法的一个缺点是 Zygote 的 gradient 函数需要标量输出。如果您的模型被传递到一些会输出标量误差值的损失中,那么 @adjoint 是最好的方法。这适用于您正在执行标准 ML 的情况,唯一的变化是您希望 Zygote 使用您的函数分析计算 f 的梯度。

如果您正在做一些更复杂的事情并且不能使用 Zygote.gradient,那么第一种方法(不使用 Params)是最合适的。 Params 确实只是为了向后兼容 Flux 的旧 AD,因此最好尽可能避免使用它。