Transfer Learning/Fine 调优 - 如何让 BatchNormalization 保持推理模式?

Transfer Learning/Fine Tuning - how to keep BatchNormalization in inference mode?

在下面的教程Transfer learning and fine-tuning by TensorFlow中解释了当解冻包含 BatchNormalization (BN) 层的模型时,这些应该通过在调用基础模型时传递 training=False 来保持推理模式.

[…]

Important notes about BatchNormalization layer

Many image models contain BatchNormalization layers. That layer is a special case on every imaginable count. Here are a few things to keep in mind.

  • BatchNormalization contains 2 non-trainable weights that get updated during training. These are the variables tracking the mean and variance of the inputs.
  • When you set bn_layer.trainable = False, the BatchNormalization layer will run in inference mode, and will not update its mean & variance statistics. This is not the case for other layers in general, as weight trainability & inference/training modes are two orthogonal concepts. But the two are tied in the case of the BatchNormalization layer.
  • When you unfreeze a model that contains BatchNormalization layers in order to do fine-tuning, you should keep the BatchNormalization layers in inference mode by passing training=False when calling the base model. Otherwise the updates applied to the non-trainable weights will suddenly destroy what the model has learned.

[…]

在例子中他们在调用基础模型时传递了training=False,但后来他们设置了base_model.trainable=True,这在我看来是与推理模式相反的,因为BN层将被设置为也可以训练。

根据我的理解,推理模式必须有 0 trainable_weights4 non_trainable_weights,这与设置 bn_layer.trainable=False 时相同运行推理模式下的bn_layer

我查了trainable_weights的个数和non_trainable_weights的个数都是2.

我对教程感到困惑,在对模型进行微调时,我如何才能真正确定 BN 层处于推理模式?

模型上的设置 training=False 是否会覆盖 bn_layer.trainable=True 的行为?因此,即使 trainable_weights2 一起列出,这些也不会在训练期间更新(微调)?


更新:

在这里我找到了一些进一步的信息:BatchNormalization layer - on keras.io

[...]

About setting layer.trainable = False on a BatchNormalization layer:

The meaning of setting layer.trainable = False is to freeze the layer, i.e. its internal state will not change during training: its trainable weights will not be updated during fit() or train_on_batch(), and its state updates will not be run.

Usually, this does not necessarily mean that the layer is run in inference mode (which is normally controlled by the training argument that can be passed when calling a layer). "Frozen state" and "inference mode" are two separate concepts.

However, in the case of the BatchNormalization layer, setting trainable = False on the layer means that the layer will be subsequently run in inference mode (meaning that it will use the moving mean and the moving variance to normalize the current batch, rather than using the mean and variance of the current batch).

This behavior has been introduced in TensorFlow 2.0, in order to enable layer.trainable = False to produce the most commonly expected behavior in the convnet fine-tuning use case.

Note that: - Setting trainable on an model containing other layers will recursively set the trainable value of all inner layers. - If the value of the trainable attribute is changed after calling compile() on a model, the new value doesn't take effect for this model until compile() is called again.

问题:

  1. 万一我想微调整个模型,所以我要解冻 base_model.trainable = True,我是否必须手动将 BN 层设置为 bn_layer.trainable = False 以将它们保留在推理模式?
  2. base_model 的调用传递 training=False 并另外设置 base_model.trainable=True 时会发生什么? BatchNormalizationDropout 等层是否保持推理模式?

阅读文档并查看 tf.keras.layers.Layertf.keras.layers.Densetf.keras.layers.BatchNormalization 的 TensorFlows 实现的源代码后,我得到了以下理解。

如果在调用层或 model/base 模型时传递 training = False,它将在推理模式下 运行。这与属性 trainable 无关,意思不同。如果他们调用参数 training_mode 而不是 training,可能会减少误解。我宁愿反过来定义它并称之为 inference_mode .

进行迁移学习或微调时 training = False 应在调用基础模型本身时传递。据我所知,到目前为止,这只会影响 tf.keras.layers.Dropouttf.keras.layers.BatchNormalization 等层,不会影响其他层。 运行 在推理模式下通过 training = False 将导致:

  • tf.layers.Dropout 根本不应用辍学率。由于 tf.layers.Dropout 没有可训练的权重,设置属性 trainable = False 将对该层没有任何影响。
  • tf.keras.layers.BatchNormalization 使用在训练期间学习的移动统计数据的均值和方差对其输入进行归一化

属性trainable只会激活或停用更新层的可训练权重。

是的,在调用base_model.trainable = True后,您需要将BN层设置为推理模式。

from tensorflow.keras import layers
for layer in base_model.layers:
    if isinstance(layer, layers.BatchNormalization):
        layer.trainable = False

你可以检查每一层,无论是在训练模式还是推理模式:

for lnum, layer in enumerate(model.layers):
    print('layer: {}, name:{}, trainable:{}, dtype: {}'.format(lnum, layer.name, layer.trainable, layer.dtype))