Flux.jl中最简单的例子中的错误

Error in the most simplest example in Flux.jl

我正在测试这里的示例:https://fluxml.ai/Flux.jl/stable/models/overview/

using Flux
actual(x) = 4x + 2
x_train, x_test= hcat(0:5...), hcat(6:10...)
y_train, y_test = actual.(x_train), actual.(x_test)

predict = Dense(1 => 1)
predict(x_train)

loss(x,y) = Flux.Losses.mse(predict(x),y)
loss(x_train,y_train)

using Flux:train!
opt = Descent(0.1)
data = [(x_train, y_train)]

parameters = Flux.params(predict)
predict.weight in parameters, predict.bias in parameters

train!(loss, parameters, data, opt)

loss(x_train, y_train)

for epoch in 1:1000
    train!(loss, parameters, data, opt)
end

loss(x_train, y_train)

predict(x_test)
y_test

如你所见,它只是一个非常简单的模型actual(x) = 4x + 2。如果你运行这些代码你将得到一个近乎完美的预测结果。

1×5 Matrix{Float32}: 26.0001 30.0001 34.0001 38.0001 42.0001

1×5 Matrix{Int64}: 26 30 34 38 42

但是如果我在为模型提供更多数据方面做一些小改动,就像这样:

x_train, x_test= hcat(0:6...), hcat(6:10...)

所以除了第 3 行我什么都没改。我只是把 5 改成了 6。 那么预测结果就会变成无穷大

1×5 Matrix{Float32}: NaN NaN NaN NaN NaN

1×5 Matrix{Int64}: 26 30 34 38 42

但是为什么呢?

我认为这只是一个高学习率出错的案例。我可以用 Descent(0.1) 重现相同的 NaN 行为。我试着把它打印出来,损失先到 Inf,然后再到 NaN——这是由于高学习率而出现分歧的典型标志。所以我尝试了 0.01 的学习率并且它工作得很好 - 它给出了预期的答案。当 x_trainhcat(0:6...) 时,它可能会发散。较小的学习率允许网络采取较小的步骤,并且它设法找到预期的最小值。