tf.keras.layers.BatchNormalization with trainable=False 似乎没有更新其内部移动均值和方差

tf.keras.layers.BatchNormalization with trainable=False appears to not update its internal moving mean and variance

我正在尝试找出 BatchNormalization 层在 TensorFlow 中的具体行为。我想出了以下代码,据我所知应该是一个完全有效的 keras 模型,但是 BatchNormalization 的均值和方差似乎没有更新。

来自文档 https://www.tensorflow.org/api_docs/python/tf/keras/layers/BatchNormalization

in the case of the BatchNormalization layer, setting trainable = False on the layer means that the layer will be subsequently run in inference mode (meaning that it will use the moving mean and the moving variance to normalize the current batch, rather than using the mean and variance of the current batch).

我希望模型在每个后续预测调用中 return 一个不同的值。 然而,我看到的是完全相同的值 returned 10 次。 谁能向我解释为什么 BatchNormalization 层不更新其内部值?

import tensorflow as tf
import numpy as np

if __name__ == '__main__':

    np.random.seed(1)
    x = np.random.randn(3, 5) * 5 + 0.3

    bn = tf.keras.layers.BatchNormalization(trainable=False, epsilon=1e-9)
    z = input = tf.keras.layers.Input([5])
    z = bn(z)

    model = tf.keras.Model(inputs=input, outputs=z)

    for i in range(10):
        print(x)
        print(model.predict(x))
        print()

我使用 TensorFlow 2.1.0

好的,我发现我的假设有误。移动平均线 在训练期间 更新,而不是像我想的那样在推理期间更新。这是完全有道理的,因为在推理过程中更新移动平均值可能会导致生产模型不稳定(例如,一长串高度病态的输入样本 [例如,它们的生成分布与训练网络的分布大不相同]可能会使网络产生偏差并导致有效输入样本的性能下降)。

当您是 fine-tuning 预训练模型并且即使在训练期间也想冻结网络的某些层时,可训练参数很有用。因为当您调用 model.predict(x)(甚至 model(x)model(x, training=False))时,该层会自动使用移动平均值而不是批量平均值。

下面的代码清楚地演示了这一点

import tensorflow as tf
import numpy as np

if __name__ == '__main__':

    np.random.seed(1)
    x = np.random.randn(10, 5) * 5 + 0.3

    z = input = tf.keras.layers.Input([5])
    z = tf.keras.layers.BatchNormalization(trainable=True, epsilon=1e-9, momentum=0.99)(z)

    model = tf.keras.Model(inputs=input, outputs=z)
    
    # a dummy loss function
    model.compile(loss=lambda x, y: (x - y) ** 2)

    # a dummy fit just to update the batchnorm moving averages
    model.fit(x, x, batch_size=3, epochs=10)
    
    # first predict uses the moving averages from training
    pred = model(x).numpy()
    print(pred.mean(axis=0))
    print(pred.var(axis=0))
    print()
    
    # outputs the same thing as previous predict
    pred = model(x).numpy()
    print(pred.mean(axis=0))
    print(pred.var(axis=0))
    print()
    
    # here calling the model with training=True results in update of moving averages
    # furthermore, it uses the batch mean and variance as in training, 
    # so the result is very different
    pred = model(x, training=True).numpy()
    print(pred.mean(axis=0))
    print(pred.var(axis=0))
    print()
    
    # here we see again that the moving averages are used but they differ slightly after
    # the previous call, as expected
    pred = model(x).numpy()
    print(pred.mean(axis=0))
    print(pred.var(axis=0))
    print()

最后,我发现文档(https://www.tensorflow.org/api_docs/python/tf/keras/layers/BatchNormalization)提到了这个:

  1. When performing inference using a model containing batch normalization, it is generally (though not always) desirable to use accumulated statistics rather than mini-batch statistics. This is accomplished by passing training=False when calling the model, or using model.predict.

希望这对以后有类似误解的人有所帮助。