为什么归一化会导致我的网络在训练中出现梯度爆炸?

Why is Normalization causing my network to have exploding gradients in training?

我建立了一个网络(在 Pytorch 中),它在图像恢复方面表现良好。我正在使用带有 Resnet50 编码器 backbone 的自动编码器,但是,我只使用了 1 的批量大小。我正在试验一些频域的东西,这些东西一次只允许我处理一张图像。

我发现我的网络表现相当不错,但是,只有当我从网络中删除所有批归一化时它才会表现良好。当然,批处理规范对于 1 的批处理大小是无用的,所以我切换到为此目的设计的组规范化。但是,即使使用组规范,我的梯度也会爆炸。训练可以进行 20 - 100 个 epoch,然后游戏就结束了。有时它会恢复并再次爆炸。

我还应该说,在训练中,每张输入的新图像都被赋予了截然不同的噪声量来训练随机噪声量。以前已经这样做过,但可能加上批量大小为 1 可能会出现问题。

我正在为这个问题挠头,我想知道是否有人有建议。我已经调整了我的学习率并削减了最大梯度,但这并没有真正解决实际问题。我可以 post 一些代码,但我不确定从哪里开始,希望有人能给我一个理论。有任何想法吗?谢谢!

为了回答我自己的问题,我的网络在训练中不稳定,因为批量大小为 1 使得批次之间的数据差异太大。或者正如论文所说的那样,内部协变量偏移过高。

不仅我的图像是从一个非常大的不同数据集中绘制的,而且它们还被随机旋转和翻转。除此之外,还为每张图像选择了 0 到 30 之间的随机高斯噪声噪声,因此在某些情况下,一幅图像可能几乎没有噪声,而下一幅图像可能几乎无法区分。或者正如论文所说的那样,内部协变量偏移过高。

在上面的问题中我提到了组规范——我的网络很复杂,一些代码是从其他工作改编的。在我的代码中仍然隐藏着我错过的批量规范函数。我删除了它们。我仍然不确定为什么 BN 让事情变得更糟。

在此之后,我用大小为 32 的组重新实现了组规范,现在训练得更好了。

简而言之,删除额外的 BN 并添加 Group 范数很有帮助。