如何在 TF2 中动态更新批量规范动量?

How to dynamically update batch norm momentum in TF2?

我发现了一个 PyTorch 实现,它将批量规范 momentum 参数从第一个时期的 0.1 衰减到最后一个时期的 0.001。关于如何使用 TF2 中的 batch norm momentum 参数执行此操作的任何建议? (即,从 0.9 开始并在 0.999 结束)例如,这是 PyTorch 代码中所做的:

# in training script
momentum = initial_momentum * np.exp(-epoch/args.epochs * np.log(initial_momentum/final_momentum))

# model class function
def set_bn_momentum(self, momentum):
    self.expand_bn.momentum = momentum
    for bn in self.layers_bn:
        bn.momentum = momentum


下面选择的答案在使用 tf.keras.Model.fit() API 时提供了一个可行的解决方案。但是,我使用的是自定义训练循环。这是我所做的:


mi = 1 - initial_momentum  # i.e., inital_momentum = 0.9, mi = 0.1
mf = 1 - final_momentum  # i.e., final_momentum = 0.999, mf = 0.001
momentum = 1 - mi * np.exp(-epoch / epochs * np.log(mi / mf))
model = set_bn_momentum(model, momentum)

set_bn_momentum 函数(归功于 this article):

def set_bn_momentum(model, momentum):
    for layer in model.layers:
        if hasattr(layer, 'momentum'):
            print(layer.name, layer.momentum)
            setattr(layer, 'momentum', momentum)

    # When we change the layers attributes, the change only happens in the model config file
    model_json = model.to_json()

    # Save the weights before reloading the model.
    tmp_weights_path = os.path.join(tempfile.gettempdir(), 'tmp_weights.h5')

    # load the model from the config
    model = tf.keras.models.model_from_json(model_json)

    # Reload the model weights
    model.load_weights(tmp_weights_path, by_name=True)
    return model




class CustomCallback(keras.callbacks.Callback):
    def on_epoch_begin(self, epoch, logs=None):
        keys = list(logs.keys())
        print("Start epoch {} of training; got log keys: {}".format(epoch, keys))

    def on_epoch_end(self, epoch, logs=None):
        keys = list(logs.keys())
        print("End epoch {} of training; got log keys: {}".format(epoch, keys))

    def on_train_batch_begin(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Training: start of batch {}; got log keys: {}".format(batch, keys))

    def on_train_batch_end(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Training: end of batch {}; got log keys: {}".format(batch, keys))

    def on_test_batch_begin(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Evaluating: start of batch {}; got log keys: {}".format(batch, keys))

    def on_test_batch_end(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Evaluating: end of batch {}; got log keys: {}".format(batch, keys))


batch = tf.keras.layers.BatchNormalization()
batch.momentum = 0.001


model.layers[1].momentum = 0.001

