如何在 Flux.jl 中定义自定义训练循环

How to define a custom training loop in Flux.jl

我正在尝试使用 Flux.jl 为 ML 工作流设置我的训练循环。我知道我可以使用内置的 Flux.train!() 函数进行训练,但我需要比 API 给我的开箱即用更多的定制。如何在 Flux 中定义自己的自定义训练循环?

根据 Flux.jl docs on Training Loops,您可以执行以下操作:

function my_custom_train!(loss, ps, data, opt)
  # training_loss is declared local so it will be available for logging outside the gradient calculation.
  local training_loss
  ps = Params(ps)
  for d in data
    gs = gradient(ps) do
      training_loss = loss(d...)
      # Code inserted here will be differentiated, unless you need that gradient information
      # it is better to do the work outside this block.
      return training_loss
    end
    # Insert whatever code you want here that needs training_loss, e.g. logging.
    # logging_callback(training_loss)
    # Insert what ever code you want here that needs gradient.
    # E.g. logging with TensorBoardLogger.jl as histogram so you can see if it is becoming huge.
    update!(opt, ps, gs)
    # Here you might like to check validation set accuracy, and break out to do early stopping.
  end
end

也可以使用硬编码损失函数来简化上述示例。