停止跟踪 Flux 中的数组 (Julia)
Stop tracking Arrays in Flux (Julia)
我目前正在尝试在 Flux for Julia 中实现批量更新。
在我的计算过程中,我通过反复做得到了一批标量
δ = Gt - model(St)[1]
push!(deltas,δ)
其中模型是神经网络
global model= Chain(
Dense(statesize,10, leakyrelu),
Dense(10,10,leakyrelu),
Dense(10,1))
我最终得到数组增量,我想在第二个神经网络上执行批量梯度更新(批量大小 = 19),其中每个梯度都由适当的增量加权。我写的更新函数是
function vupdate2!(S_batch,model,α,deltas)
function v_loss_total(x)
return sum(reshape(deltas,(1,19)) .* model(x))
end
local ps = Flux.params(model)
local gs = Flux.Tracker.gradient(() -> v_loss_total(S_batch), ps)
for p in ps
Flux.Tracker.update!( p, α.* gs[p])
end
end
问题是,计算梯度的行抛出错误:MethodError: no method matching Float32(::Tracker.TrackedReal{Float64})
我认为问题是我的增量数组被跟踪了。查看随机输入的 v_loss_total 函数的输出,我得到:
julia> v_loss_total(S_batch)
-6752.433690476287 (tracked) (tracked)
有趣的是,这个数字被跟踪了两次(?),我猜这是因为将两个跟踪的数字相乘(即增量和模型(S_batch)的条目)。有没有办法首先取消跟踪增量数组?如果有任何帮助,我将不胜感激。
好的,事实证明,有一个函数
Flux.Tracker.data()
这正是我所需要的。它需要一个跟踪号码和 returns Float 本身。另见:https://github.com/FluxML/Flux.jl/issues/640
在 julia 1.2 中对我有用的是使用 .data
将浮点数作为字段访问
以上功能仅由 GreenLogic 建议 returns 另一个 Tracker。
我目前正在尝试在 Flux for Julia 中实现批量更新。
在我的计算过程中,我通过反复做得到了一批标量
δ = Gt - model(St)[1]
push!(deltas,δ)
其中模型是神经网络
global model= Chain(
Dense(statesize,10, leakyrelu),
Dense(10,10,leakyrelu),
Dense(10,1))
我最终得到数组增量,我想在第二个神经网络上执行批量梯度更新(批量大小 = 19),其中每个梯度都由适当的增量加权。我写的更新函数是
function vupdate2!(S_batch,model,α,deltas)
function v_loss_total(x)
return sum(reshape(deltas,(1,19)) .* model(x))
end
local ps = Flux.params(model)
local gs = Flux.Tracker.gradient(() -> v_loss_total(S_batch), ps)
for p in ps
Flux.Tracker.update!( p, α.* gs[p])
end
end
问题是,计算梯度的行抛出错误:MethodError: no method matching Float32(::Tracker.TrackedReal{Float64})
我认为问题是我的增量数组被跟踪了。查看随机输入的 v_loss_total 函数的输出,我得到:
julia> v_loss_total(S_batch)
-6752.433690476287 (tracked) (tracked)
有趣的是,这个数字被跟踪了两次(?),我猜这是因为将两个跟踪的数字相乘(即增量和模型(S_batch)的条目)。有没有办法首先取消跟踪增量数组?如果有任何帮助,我将不胜感激。
好的,事实证明,有一个函数
Flux.Tracker.data()
这正是我所需要的。它需要一个跟踪号码和 returns Float 本身。另见:https://github.com/FluxML/Flux.jl/issues/640
在 julia 1.2 中对我有用的是使用 .data
以上功能仅由 GreenLogic 建议 returns 另一个 Tracker。