训练和验证模式 tensorflow 的 SAME 数据损失不一致
Inconsistency in loss on SAME data for train and validation modes tensorflow
我正在使用图像实现语义分割模型。作为一个好的做法,我只用一张图像测试了我的训练管道,并尝试过拟合该图像。令我惊讶的是,当使用完全相同的图像进行训练时,损失会像预期的那样接近 0,但在评估相同图像时,损失要高得多,并且随着训练的继续而不断上升。所以当 training=False
时分割输出是垃圾,但是当 运行 和 training=True
时效果很好。
为了让任何人都能重现这一点,我采用了官方 segmentation tutorial 并对其进行了一些修改,以便从头开始训练一个卷积神经网络,并且只有 1 张图像。该模型非常简单,只是一系列具有批量归一化和 Relu 的 Conv2D。结果如下
如您所见,损失和 eval_loss 确实不同,对图像进行推理在训练模式下给出了完美的结果,在评估模式下是垃圾。
我知道 Batchnormalization 在推理时表现不同,因为它使用训练时计算的平均统计数据。尽管如此,由于我们仅使用 1 张相同的图像进行训练并在相同的图像中进行评估,所以这不应该发生,对吧?此外,我在 Pytorch 中使用相同的优化器实现了相同的架构,但这并没有发生。使用 pytorch 进行训练,eval_loss 收敛到训练损失
在这里你可以找到上面提到的https://colab.research.google.com/drive/18LipgAmKVDA86n3ljFW8X0JThVEeFf0a#scrollTo=TWDATghoRczu
最后还有 Pytorch 实现
它必须对 tensorflow 使用的默认值做更多的事情。 Batchnormalization 有一个参数 momentum
控制批统计的平均。公式为:moving_mean = moving_mean * momentum + mean(batch) * (1 - momentum)
如果您在 BatchNorm 层中设置 momentum=0.0
,则平均统计数据应该与当前批次(只有 1 张图像)的统计数据完全匹配。如果这样做,您会发现验证损失几乎立即与训练损失相匹配。此外,如果您尝试使用 momentum=0.9
(这是 pytorch 中的等效默认值)并且它可以更快地工作和收敛(如在 pytorch 中)。
我正在使用图像实现语义分割模型。作为一个好的做法,我只用一张图像测试了我的训练管道,并尝试过拟合该图像。令我惊讶的是,当使用完全相同的图像进行训练时,损失会像预期的那样接近 0,但在评估相同图像时,损失要高得多,并且随着训练的继续而不断上升。所以当 training=False
时分割输出是垃圾,但是当 运行 和 training=True
时效果很好。
为了让任何人都能重现这一点,我采用了官方 segmentation tutorial 并对其进行了一些修改,以便从头开始训练一个卷积神经网络,并且只有 1 张图像。该模型非常简单,只是一系列具有批量归一化和 Relu 的 Conv2D。结果如下
如您所见,损失和 eval_loss 确实不同,对图像进行推理在训练模式下给出了完美的结果,在评估模式下是垃圾。
我知道 Batchnormalization 在推理时表现不同,因为它使用训练时计算的平均统计数据。尽管如此,由于我们仅使用 1 张相同的图像进行训练并在相同的图像中进行评估,所以这不应该发生,对吧?此外,我在 Pytorch 中使用相同的优化器实现了相同的架构,但这并没有发生。使用 pytorch 进行训练,eval_loss 收敛到训练损失
在这里你可以找到上面提到的https://colab.research.google.com/drive/18LipgAmKVDA86n3ljFW8X0JThVEeFf0a#scrollTo=TWDATghoRczu 最后还有 Pytorch 实现
它必须对 tensorflow 使用的默认值做更多的事情。 Batchnormalization 有一个参数 momentum
控制批统计的平均。公式为:moving_mean = moving_mean * momentum + mean(batch) * (1 - momentum)
如果您在 BatchNorm 层中设置 momentum=0.0
,则平均统计数据应该与当前批次(只有 1 张图像)的统计数据完全匹配。如果这样做,您会发现验证损失几乎立即与训练损失相匹配。此外,如果您尝试使用 momentum=0.9
(这是 pytorch 中的等效默认值)并且它可以更快地工作和收敛(如在 pytorch 中)。