如何在 TF2 中动态更新批量规范动量?

How to dynamically update batch norm momentum in TF2?

我发现了一个 PyTorch 实现,它将批量规范 momentum 参数从第一个时期的 0.1 衰减到最后一个时期的 0.001。关于如何使用 TF2 中的 batch norm momentum 参数执行此操作的任何建议? (即,从 0.9 开始并在 0.999 结束)例如,这是 PyTorch 代码中所做的:

# in training script
momentum = initial_momentum * np.exp(-epoch/args.epochs * np.log(initial_momentum/final_momentum))
model_pos_train.set_bn_momentum(momentum)

# model class function
def set_bn_momentum(self, momentum):
    self.expand_bn.momentum = momentum
    for bn in self.layers_bn:
        bn.momentum = momentum

解决方案:

下面选择的答案在使用 tf.keras.Model.fit() API 时提供了一个可行的解决方案。但是,我使用的是自定义训练循环。这是我所做的:

每个纪元之后:

mi = 1 - initial_momentum  # i.e., inital_momentum = 0.9, mi = 0.1
mf = 1 - final_momentum  # i.e., final_momentum = 0.999, mf = 0.001
momentum = 1 - mi * np.exp(-epoch / epochs * np.log(mi / mf))
model = set_bn_momentum(model, momentum)

set_bn_momentum 函数(归功于 this article):

def set_bn_momentum(model, momentum):
    for layer in model.layers:
        if hasattr(layer, 'momentum'):
            print(layer.name, layer.momentum)
            setattr(layer, 'momentum', momentum)

    # When we change the layers attributes, the change only happens in the model config file
    model_json = model.to_json()

    # Save the weights before reloading the model.
    tmp_weights_path = os.path.join(tempfile.gettempdir(), 'tmp_weights.h5')
    model.save_weights(tmp_weights_path)

    # load the model from the config
    model = tf.keras.models.model_from_json(model_json)

    # Reload the model weights
    model.load_weights(tmp_weights_path, by_name=True)
    return model

此方法不会显着增加训练例程的开销。

你可以在每个batch的begin/the末尾设置一个动作,这样你就可以在epoch期间控制任意参数。

回调选项下方:

class CustomCallback(keras.callbacks.Callback):
    def on_epoch_begin(self, epoch, logs=None):
        keys = list(logs.keys())
        print("Start epoch {} of training; got log keys: {}".format(epoch, keys))

    def on_epoch_end(self, epoch, logs=None):
        keys = list(logs.keys())
        print("End epoch {} of training; got log keys: {}".format(epoch, keys))

    def on_train_batch_begin(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Training: start of batch {}; got log keys: {}".format(batch, keys))

    def on_train_batch_end(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Training: end of batch {}; got log keys: {}".format(batch, keys))

    def on_test_batch_begin(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Evaluating: start of batch {}; got log keys: {}".format(batch, keys))

    def on_test_batch_end(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Evaluating: end of batch {}; got log keys: {}".format(batch, keys))

您可以获得动力

batch = tf.keras.layers.BatchNormalization()
batch.momentum = 0.001

在模型中您必须指定正确的图层

model.layers[1].momentum = 0.001

您可以在 writing_your_own_callbacks

找到更多信息和示例