在 Tensorflow 训练期间改变正则化因子

Changing regularization factor during training in Tensorflow

请问有什么简单的方法吗?

例如,可以使用 tf.keras.optimizers.schedules:

轻松更改学习率
lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(0.001)
optimizer = tf.keras.optimizers.SGD(learning_rate=lr_schedule)

有没有一种简单的方法可以对正则化因子做同样的事情?像这样:

r_schedule = tf.keras.optimizers.schedules.ExponentialDecay(0.1)
regularizer = tf.keras.regularizers.L2(l2=r_schedule)

如果不是,我怎样才能以最小的努力逐渐改变正则化因子?

IIUC,我认为你应该能够使用自定义回调并实现与 tf.keras.optimizers.schedules.ExponentialDecay 相同/相似的 logic(但它可能会超出最小的努力):

import tensorflow as tf

class Decay(tf.keras.callbacks.Callback):

  def __init__(self, l2, decay_steps, decay_rate, staircase):
    super().__init__()
    self.l2 = l2
    self.decay_steps = decay_steps
    self.decay_rate = decay_rate
    self.staircase = staircase

  def on_epoch_end(self, epoch, logs=None):
    global_step_recomp = self.params.get('steps')
    p = global_step_recomp / self.decay_steps
    if self.staircase:
      p = tf.floor(p)
    self.l2.assign(tf.multiply(
        self.l2, tf.pow(self.decay_rate, p)))
     
l2 = tf.Variable(initial_value=0.01, trainable=False)

def l2_regularizer(weights):
    tf.print(l2)
    loss = l2 * tf.reduce_sum(tf.square(weights))
    return loss

model = tf.keras.Sequential()
model.add(tf.keras.layers.Dense(1, kernel_regularizer=l2_regularizer))
model.compile(optimizer='adam', loss='mse')
model.fit(tf.random.normal((50,1 )), tf.random.normal((50,1 )), batch_size=4, callbacks=[Decay(l2,
    decay_steps=100000,
    decay_rate=0.56,
    staircase=False)], epochs=3)
Epoch 1/3
0.01
 1/13 [=>............................] - ETA: 8s - loss: 0.63850.01
0.01
0.01
0.01
0.01
0.01
0.01
0.01
 9/13 [===================>..........] - ETA: 0s - loss: 2.13940.01
0.01
0.01
0.01
13/13 [==============================] - 1s 6ms/step - loss: 2.4884
Epoch 2/3
0.00999924541
 1/13 [=>............................] - ETA: 0s - loss: 1.97210.00999924541
0.00999924541
0.00999924541
0.00999924541
0.00999924541
0.00999924541
0.00999924541
0.00999924541
 9/13 [===================>..........] - ETA: 0s - loss: 2.37490.00999924541
0.00999924541
0.00999924541
0.00999924541
13/13 [==============================] - 0s 7ms/step - loss: 2.4541
Epoch 3/3
0.00999849103
 1/13 [=>............................] - ETA: 0s - loss: 0.81400.00999849103
0.00999849103
0.00999849103
0.00999849103
0.00999849103
0.00999849103
 7/13 [===============>..............] - ETA: 0s - loss: 2.71970.00999849103
0.00999849103
0.00999849103
0.00999849103
0.00999849103
0.00999849103
13/13 [==============================] - 0s 10ms/step - loss: 2.4195
<keras.callbacks.History at 0x7f7a5a4ff5d0>