在损失函数中使用 Flux (Julia) 中的分位数
Using quantile in Flux (Julia) in loss function
我正在尝试使用损失函数中的分位数进行训练! (为了某些稳健性,比如最小修剪平方),但它会改变数组并且 Zygote 会抛出一个错误 Mutating arrays is not supported
,来自 sort!
。下面是一个简单的例子(当然内容没有意义):
using Flux, StatsBase
xdata = randn(2, 100)
ydata = randn(100)
model = Chain(Dense(2,10), Dense(10, 1))
function trimmedLoss(x,y; trimFrac=0.f05)
yhat = model(x)
absRes = abs.(yhat .- y) |> vec
trimVal = quantile(absRes, 1.f0-trimFrac)
s = sum(ifelse.(absRes .> trimVal, 0.f0 , absRes ))/(length(absRes)*(1.f0-trimFrac))
#s = sum(absRes)/length(absRes) # using this and commenting out the two above works (no surprise)
end
println(trimmedLoss(xdata, ydata)) #works ok
Flux.train!(trimmedLoss, params(model), zip([xdata], [ydata]), ADAM())
println(trimmedLoss(xdata, ydata)) #changed loss?
这一切都在 Flux 0.10 和 Julia 1.2
提前感谢您提供任何提示或解决方法!
理想情况下,我们会定义一个 custom adjoint for quantile
so that this works out of the box. (Feel free to open an issue 来提醒我们这样做。)
与此同时,有一个快速解决方法。实际上,这里的排序会造成麻烦,因此如果您执行 quantile(xs, p, sorted=true)
它就会起作用。显然这需要对 xs
进行排序以获得正确的结果,因此您可能需要使用 quantile(sort(xs), ...)
.
根据您的 Zygote 版本,您可能还需要 sort
的伴随物。这个很简单:
julia> using Zygote: @adjoint
julia> @adjoint function sort(x)
p = sortperm(x)
x[p], x̄ -> (x̄[invperm(p)],)
end
julia> gradient(x -> quantile(sort(x), 0.5, sorted=true), [1, 2, 3, 3])
([0.0, 0.5, 0.5, 0.0],)
我们将在下一个 Zygote 版本中内置它,但现在如果您将它添加到您的脚本中,它将使您的代码正常工作。
我正在尝试使用损失函数中的分位数进行训练! (为了某些稳健性,比如最小修剪平方),但它会改变数组并且 Zygote 会抛出一个错误 Mutating arrays is not supported
,来自 sort!
。下面是一个简单的例子(当然内容没有意义):
using Flux, StatsBase
xdata = randn(2, 100)
ydata = randn(100)
model = Chain(Dense(2,10), Dense(10, 1))
function trimmedLoss(x,y; trimFrac=0.f05)
yhat = model(x)
absRes = abs.(yhat .- y) |> vec
trimVal = quantile(absRes, 1.f0-trimFrac)
s = sum(ifelse.(absRes .> trimVal, 0.f0 , absRes ))/(length(absRes)*(1.f0-trimFrac))
#s = sum(absRes)/length(absRes) # using this and commenting out the two above works (no surprise)
end
println(trimmedLoss(xdata, ydata)) #works ok
Flux.train!(trimmedLoss, params(model), zip([xdata], [ydata]), ADAM())
println(trimmedLoss(xdata, ydata)) #changed loss?
这一切都在 Flux 0.10 和 Julia 1.2
提前感谢您提供任何提示或解决方法!
理想情况下,我们会定义一个 custom adjoint for quantile
so that this works out of the box. (Feel free to open an issue 来提醒我们这样做。)
与此同时,有一个快速解决方法。实际上,这里的排序会造成麻烦,因此如果您执行 quantile(xs, p, sorted=true)
它就会起作用。显然这需要对 xs
进行排序以获得正确的结果,因此您可能需要使用 quantile(sort(xs), ...)
.
根据您的 Zygote 版本,您可能还需要 sort
的伴随物。这个很简单:
julia> using Zygote: @adjoint
julia> @adjoint function sort(x)
p = sortperm(x)
x[p], x̄ -> (x̄[invperm(p)],)
end
julia> gradient(x -> quantile(sort(x), 0.5, sorted=true), [1, 2, 3, 3])
([0.0, 0.5, 0.5, 0.0],)
我们将在下一个 Zygote 版本中内置它,但现在如果您将它添加到您的脚本中,它将使您的代码正常工作。