`train!()` 在 Flux.jl 中做什么?

What does `train!()` do in Flux.jl?

在某些机器学习框架中,train 函数实际上可能不会自行进行训练,而只是设置模式(即确保模型等已准备好进行训练)。 Flux 中的 train 函数是这种情况还是 train!() 函数实际进行训练?

根据 Flux.jl docstrain!() 函数确实进行了实际训练。函数签名如下所示:train!(loss, params, data, opt; cb) 其中:

For each datapoint d in data, compute the gradient of loss with respect to params through backpropagation and call the optimizer opt. If d is a tuple of arguments to loss call loss(d...), else call loss(d). A callback is given with the keyword argument cb. For example, this will print "training" every 10 seconds (using Flux.throttle): train!(loss, params, data, opt, cb = throttle(() -> println("training"), 10)) The callback can call Flux.stop to interrupt the training loop. Multiple optimisers and callbacks can be passed to opt and cb as arrays.

另一个例子:@epochs 2 Flux.train!(loss, ps, dataset, opt) 我们进行 2 个训练阶段。您可以在 Flux transfer learning tutorial.

中找到更多信息