如何正确使用交叉熵损失与 Softmax 进行分类?

How to correctly use Cross Entropy Loss vs Softmax for classification?

我想使用 Pytorch 训练多 class classifier。

下面的 the official Pytorch doc 展示了如何在最后一层类型 nn.Linear(84, 10) 之后使用 nn.CrossEntropyLoss()

但是,我记得这是 Softmax 所做的。

这让我很困惑。


  1. 如何以最佳方式训练“标准”class化网络?
  2. 如果网络有一个最终的线性层,如何推断每个 class 的概率?
  3. 如果网络有一个最终的 softmax 层,如何训练网络(哪个损失,以及如何训练)?

我在 Pytorch 论坛上找到了 this thread,它可能回答了所有这些问题,但我无法将其编译成可工作且可读的 Pytorch 代码。


我假设的答案:

  1. 喜欢the doc says
  2. 线性层输出的幂,实际上是 logits(对数概率)。
  3. 没看懂

我认为理解 softmax 和交叉熵很重要,至少从实践的角度来看。一旦掌握了这两个概念,就应该清楚如何在 ML 上下文中“正确”使用它们。

交叉熵 H(p, q)

交叉熵是比较两个概率分布的函数。从实用的角度来看,可能不值得深入探讨交叉熵的正式动机,但如果您有兴趣,我会推荐 Cover 和 Thomas 的 Elements of Information Theory 作为介绍性文字。这个概念很早就引入了(我相信是第 2 章)。这是我在研究生院使用的介绍文字,我认为它做得很好(当然我也有一位很棒的导师)。

要注意的关键是交叉熵是一个函数,它以两个概率分布作为输入:q 和 p 以及 returns 当 q 和 p 相等时的最小值. q代表估计分布,p代表真实分布。

在 ML classification 的上下文中,我们知道训练数据的实际标签,因此 true/target 分布 p 的真实标签概率为 1,其他地方为 0,即 p 是一个单热向量。

另一方面,估计分布(模型的输出)q 通常包含一些不确定性,因此 q 中任何 class 的概率将在 0 和 1 之间。通过训练系统为了最小化交叉熵,我们告诉系统我们希望它尝试使估计分布尽可能接近真实分布。所以,你的模型认为最有可能的class就是q的最大值对应的class。

软最大

同样,有一些复杂的统计方法可以解释 softmax,我们不会在这里讨论。从实用的角度来看,关键是 softmax 是一个函数,它将无界值列表作为输入,并输出一个有效的概率质量函数 并保持相对顺序 。重要的是要强调关于相对顺序的第二点。这意味着softmax输入中的最大元素对应于softmax输出中的最大元素。

考虑经过训练以最小化交叉熵的 softmax 激活模型。在这种情况下,在 softmax 之前,模型的目标是为正确标签生成可能的最高值,为错误标签生成可能的最低值。

PyTorch 中的交叉熵损失

PyTorch中CrossEntropyLoss的定义是softmax和cross-entropy的结合。具体

CrossEntropyLoss(x, y) := H(one_hot(y), softmax(x))

请注意,one_hot 是一个接受索引 y 并将其扩展为单热向量的函数。

等效地,您可以将 CrossEntropyLoss 表示为 PyTorch 中 LogSoftmax and negative log-likelihood loss (i.e. NLLLoss 的组合)

LogSoftmax(x) := ln(softmax(x))

CrossEntropyLoss(x, y) := NLLLoss(LogSoftmax(x), y)

由于 softmax 中的求幂,有一些计算“技巧”使得直接使用 CrossEntropyLoss 比分阶段计算更稳定(更准确,不太可能得到 NaN)。

结论

根据以上讨论,您问题的答案是

1。如何以最佳方式训练“标准”class化网络?

正如文档所说。

2。如果网络有一个最终的线性层,如何推断每个class的概率?

将 softmax 应用于网络的输出以推断每个 class 的概率。如果目标只是找到相对顺序或最高概率 class 那么只需将 argsort 或 argmax 直接应用于输出(因为 softmax 保持相对顺序)。

3。如果网络有一个最终的 softmax 层,如何训练网络(哪些损失,以及如何)?

通常,出于上述稳定性原因,您不想训练输出 softmax 输出的网络。

就是说,如果您出于某种原因绝对需要这样做,您可以记录输出并将它们提供给 NLLLoss

criterion = nn.NLLLoss()
...
x = model(data)    # assuming the output of the model is softmax activated
loss = criterion(torch.log(x), y)

这在数学上等同于对 使用 softmax 激活的模型使用 CrossEntropyLoss。

criterion = nn.CrossEntropyLoss()
...
x = model(data)    # assuming the output of the model is NOT softmax activated
loss = criterion(x, y)