批处理结束后如何更改 Tensorflow 中的学习率?

How to change Learning rate in Tensorflow after a batch end?

我需要创建一个 class 来搜索模型的最佳学习率,在每批中将学习率的值增加 %5。我看到了 on_train_batch_end() 回调,但我无法设置它。

查看 tf.keras.callbacks.ReduceLROnPlateau 的来源。

https://github.com/keras-team/keras/blob/v2.6.0/keras/callbacks.py#L2581-L2701

秘密命令是

tf.keras.backend.set_value(self.model.optimizer.lr)

不能赋值,也不能使用tf.Variable.assign(),因为在Keras中,如果构建还没有发生,它可能是成员,或者在构建之后它可能是一个变量。

Reduce Learning rate on plateau 只调整一个epoch结束时的学习率。它不会在批次结束时到期。为此,您需要创建自定义回调。如果我理解正确的话,你想在每批结束时将学习率降低 5%。下面的代码将为您做到这一点。在回调模型中是编译模型的名称。 freq 是一个整数,决定学习率调整的频率。如果 freq=1 它将在每批结束时进行调整。如果 freq=2 它将每隔一批进行调整等。reduction_pct 是一个浮点数。这是学习率将降低的百分比。详细是一个布尔值。如果 verbose=True,每次调整学习率时都会打印输出,显示用于刚刚完成的批次的 LR 和将用于下一批次的 LR。如果 verbose=False 则不会生成打印输出。

class ADJLR_ON_BATCH(keras.callbacks.Callback):
    def __init__ (self, model, freq, reduction_pct, verbose):
        super(ADJLR_ON_BATCH, self).__init__()
        self.model=model
        self.freq=freq
        self.reduction_pct =reduction_pct
        self.verbose=verbose
        self.adj_batch=freq
        self.factor= 1.0-reduction_pct * .01
    def on_train_batch_end(self, batch, logs=None):
        lr=float(tf.keras.backend.get_value(self.model.optimizer.lr)) # get the current learning rate
        if batch + 1 == self.adj_batch:
            new_lr=lr * self.factor
            tf.keras.backend.set_value(self.model.optimizer.lr, new_lr) # set the learning rate in the optimizer
            self.adj_batch +=self.freq
            if verbose:
                print('\nat the end of batch ',batch + 1, ' lr was adjusted from ', lr, ' to ', new_lr)

下面是使用回调的示例,我相信您希望使用这些值

model=your_model_name # variable name of your model
reduction_pct=5.0 # reduce lr by 5%
verbose=True  # print out each time the LR is adjusted
frequency=1   # adjust LR at the end of every batch
callbacks=[ADJLR_ON_BATCH(model, frequency, reduction_pct, verbose)]

记得在 model.fit 中包含 callbacks=callbacks。下面是结果打印输出的示例,以 .001

的 LR 开头
at the end of batch  1  lr was adjusted from  0.0010000000474974513  to  0.0009500000451225787
  1/374 [..............................] - ETA: 1:14:55 - loss: 9.3936 - accuracy: 0.3333
at the end of batch  2  lr was adjusted from  0.0009500000160187483  to  0.0009025000152178108

at the end of batch  3  lr was adjusted from  0.0009025000035762787  to  0.0008573750033974647
  3/374 [..............................] - ETA: 25:04 - loss: 9.1338 - accuracy: 0.4611  
at the end of batch  4  lr was adjusted from  0.0008573749801144004  to  0.0008145062311086804