如何在Flux.jl中定义自定义损失函数?

How to define a custom loss function in Flux.jl?

查看 Flux.jl 文档,我看到有大量内置损失函数:https://fluxml.ai/Flux.jl/stable/models/losses/。我的问题是,如果我想要更深奥的东西,我该如何在 Flux 中定义和使用我自己的损失函数?

您可以使用任何returns单个浮点值的可微函数作为您的损失,如上面评论所述,准备好的函数只是为了您的方便。 您可以传递任何内容,例如

using Flux
yourcustomloss(ŷ, y) = sum(.- sum(y .* logsoftmax(ŷ), dims = 1))

并计算它的梯度或将其传递给train!函数。