为什么 tf.keras BatchNormalization 导致 GAN 产生无意义的损失和准确性?
Why is tf.keras BatchNormalization causing GANs to produce nonsense loss and accuracy?
背景:
在使用 tf.keras 在鉴别器中训练带有批量归一化层的 GAN 时,我遇到了不寻常的损失和准确性。 GAN 的最佳 objective 函数值为 log(4),当鉴别器完全无法辨别真假样本时会出现这种情况,因此对所有样本的预测值为 0.5。当我在鉴别器中包含 BatchNormalization 层时,生成器和鉴别器都获得了近乎完美的分数(高精度、低损失),这在对抗性设置中是不可能的。
没有 BatchNorm:
This figure 显示了不使用 BN 时每个 epoch (x) 的损失 (y)。请注意,偶尔低于理论最小值的值是由于训练是一个迭代过程。
This figure 显示不使用 BN 时的准确度,分别稳定在 50% 左右。这两个数字都显示了合理的值。
使用 BatchNorm:
This figure shows the losses (y) per epoch (x) when BN is used. See how the GAN objective, which shouldn't fall below log(4), approaches 0. This figure显示使用BN时的准确率,均接近100%。 GAN 是对抗性的;生成器和鉴别器不能都具有 100% 的准确率。
问题:
可以找到构建和训练 GAN 的代码 here。我是不是遗漏了什么,我是不是在实施中犯了错误,或者 tf.keras 中是否存在错误?我很确定这是一个技术问题,而不是 "GAN-hacks" 可以解决的理论问题。请注意,这仅涉及在鉴别器中使用 BatchNormalization 层;在生成器中使用它们不会导致此问题。
TF 2.0 和 2.1 中 Tensorflow 的 BatchNormalization 层存在问题;降级到 TF 1.15 解决了这个问题。问题原因尚未确定。
这是相关的 GitHub 问题:https://github.com/tensorflow/tensorflow/issues/37673
问题的原因很简单。判别器学习区分 BatchNormalization 层的训练和测试阶段,而不是训练来区分数据。
在训练阶段,BN 中使用实际的批量均值和方差,而在测试阶段,使用存储在 BN 中的移动均值和移动方差。
背景:
在使用 tf.keras 在鉴别器中训练带有批量归一化层的 GAN 时,我遇到了不寻常的损失和准确性。 GAN 的最佳 objective 函数值为 log(4),当鉴别器完全无法辨别真假样本时会出现这种情况,因此对所有样本的预测值为 0.5。当我在鉴别器中包含 BatchNormalization 层时,生成器和鉴别器都获得了近乎完美的分数(高精度、低损失),这在对抗性设置中是不可能的。
没有 BatchNorm:
This figure 显示了不使用 BN 时每个 epoch (x) 的损失 (y)。请注意,偶尔低于理论最小值的值是由于训练是一个迭代过程。 This figure 显示不使用 BN 时的准确度,分别稳定在 50% 左右。这两个数字都显示了合理的值。
使用 BatchNorm:
This figure shows the losses (y) per epoch (x) when BN is used. See how the GAN objective, which shouldn't fall below log(4), approaches 0. This figure显示使用BN时的准确率,均接近100%。 GAN 是对抗性的;生成器和鉴别器不能都具有 100% 的准确率。
问题:
可以找到构建和训练 GAN 的代码 here。我是不是遗漏了什么,我是不是在实施中犯了错误,或者 tf.keras 中是否存在错误?我很确定这是一个技术问题,而不是 "GAN-hacks" 可以解决的理论问题。请注意,这仅涉及在鉴别器中使用 BatchNormalization 层;在生成器中使用它们不会导致此问题。
TF 2.0 和 2.1 中 Tensorflow 的 BatchNormalization 层存在问题;降级到 TF 1.15 解决了这个问题。问题原因尚未确定。
这是相关的 GitHub 问题:https://github.com/tensorflow/tensorflow/issues/37673
问题的原因很简单。判别器学习区分 BatchNormalization 层的训练和测试阶段,而不是训练来区分数据。
在训练阶段,BN 中使用实际的批量均值和方差,而在测试阶段,使用存储在 BN 中的移动均值和移动方差。