Flux.jl 中是否有 `zero_grad()` 函数
Is there a `zero_grad()` function in Flux.jl
在 PyTorch 中,您通常必须在进行反向传播之前将梯度归零。在 Flux 中是这样吗?如果是这样,执行此操作的编程方式是什么?
tl;博士
不,不需要。
说明
Flux 曾经使用 Tracker,这是一个微分系统,其中每个跟踪的数组都可能包含一个梯度。我认为这是类似于pytorch的设计。反向传播两次可能导致归零旨在避免的问题(尽管默认值试图保护您):
julia> using Tracker
julia> x_tr = Tracker.param([1 2 3])
Tracked 1×3 Matrix{Float64}:
1.0 2.0 3.0
julia> y_tr = sum(abs2, x_tr)
14.0 (tracked)
julia> Tracker.back!(y_tr, 1; once=false)
julia> x_tr.grad
1×3 Matrix{Float64}:
2.0 4.0 6.0
julia> Tracker.back!(y_tr, 1; once=false) # by default (i.e. with once=true) this would be an error
julia> x_tr.grad
1×3 Matrix{Float64}:
4.0 8.0 12.0
现在使用Zygote,不使用tracked数组类型。相反,要跟踪的评估必须在调用 Zygote.gradient
时发生,然后它可以查看和操作源代码以为梯度编写新代码。重复调用每次都会生成相同的梯度;没有需要清理的存储状态。
julia> using Zygote
julia> x = [1 2 3] # an ordinary Array
1×3 Matrix{Int64}:
1 2 3
julia> Zygote.gradient(x -> sum(abs2, x), x)
([2 4 6],)
julia> Zygote.gradient(x -> sum(abs2, x), x)
([2 4 6],)
julia> y, bk = Zygote.pullback(x -> sum(abs2, x), x);
julia> bk(1.0)
([2.0 4.0 6.0],)
julia> bk(1.0)
([2.0 4.0 6.0],)
Tracker 也可以这样使用,而不是自己处理 param
和 back!
:
julia> Tracker.gradient(x -> sum(abs2, x), [1, 2, 3])
([2.0, 4.0, 6.0] (tracked),)
在 PyTorch 中,您通常必须在进行反向传播之前将梯度归零。在 Flux 中是这样吗?如果是这样,执行此操作的编程方式是什么?
tl;博士
不,不需要。
说明
Flux 曾经使用 Tracker,这是一个微分系统,其中每个跟踪的数组都可能包含一个梯度。我认为这是类似于pytorch的设计。反向传播两次可能导致归零旨在避免的问题(尽管默认值试图保护您):
julia> using Tracker
julia> x_tr = Tracker.param([1 2 3])
Tracked 1×3 Matrix{Float64}:
1.0 2.0 3.0
julia> y_tr = sum(abs2, x_tr)
14.0 (tracked)
julia> Tracker.back!(y_tr, 1; once=false)
julia> x_tr.grad
1×3 Matrix{Float64}:
2.0 4.0 6.0
julia> Tracker.back!(y_tr, 1; once=false) # by default (i.e. with once=true) this would be an error
julia> x_tr.grad
1×3 Matrix{Float64}:
4.0 8.0 12.0
现在使用Zygote,不使用tracked数组类型。相反,要跟踪的评估必须在调用 Zygote.gradient
时发生,然后它可以查看和操作源代码以为梯度编写新代码。重复调用每次都会生成相同的梯度;没有需要清理的存储状态。
julia> using Zygote
julia> x = [1 2 3] # an ordinary Array
1×3 Matrix{Int64}:
1 2 3
julia> Zygote.gradient(x -> sum(abs2, x), x)
([2 4 6],)
julia> Zygote.gradient(x -> sum(abs2, x), x)
([2 4 6],)
julia> y, bk = Zygote.pullback(x -> sum(abs2, x), x);
julia> bk(1.0)
([2.0 4.0 6.0],)
julia> bk(1.0)
([2.0 4.0 6.0],)
Tracker 也可以这样使用,而不是自己处理 param
和 back!
:
julia> Tracker.gradient(x -> sum(abs2, x), [1, 2, 3])
([2.0, 4.0, 6.0] (tracked),)