KeyError: 'Failed to format this callback filepath: Reason: \'lr\''

KeyError: 'Failed to format this callback filepath: Reason: \'lr\''

我最近从 Tensorflow 2.2.0 切换到 2.4.1,现在 ModelCheckpoint 回调路径出现问题。如果我使用带有 tf 2.2 的环境,此代码可以正常工作,但当我使用 tf 2.4.1 时出现错误。

checkpoint_filepath = 'path_to/temp_checkpoints/model/epoch-{epoch}_loss-{lr:.2e}_loss-{val_loss:.3e}'
checkpoint = ModelCheckpoint(checkpoint_filepath, monitor='val_loss')

history = model.fit(training_data, training_data,
                    epochs=10,
                    batch_size=32,
                    shuffle=True,
                    validation_data=(validation_data, validation_data),
                    verbose=verbose, callbacks=[checkpoint])

错误:

KeyError: 'Failed to format this callback filepath: "path_to/temp_checkpoints/model/epoch-{epoch}_loss-{lr:.2e}_loss-{val_loss:.3e}". Reason: 'lr''

ModelCheckpoint中,filepath参数的格式化名称,只能包含:epoch + epoch结束后logs中的键.

您可以在日志中看到可用的密钥,如下所示:

class CustomCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        keys = list(logs.keys())
        print("Log keys: {}".format(keys))

model.fit(..., callbacks=[CustomCallback()])

如果你运行上面的代码,你会看到这样的东西:

Log keys: ['loss', 'mean_absolute_error', 'val_loss', 'val_mean_absolute_error']

其中显示您可以使用的可用密钥(加上 epoch)和 lr 对您不可用(您已经使用了 3 个密钥:epochlrval_lossfilepath 名称中)。


解法:

您可以自己将学习率添加到日志中:

import tensorflow.keras.backend as K
class CustomCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        logs.update({'lr': K.eval(self.model.optimizer.lr)})
        keys = list(logs.keys())
        print("Log keys: {}".format(keys)) #you will see now `lr` available

checkpoint_filepath = 'path_to/temp_checkpoints/model/epoch-{epoch}_loss-{lr:.2e}_loss-{val_loss:.3e}'
checkpoint = ModelCheckpoint(checkpoint_filepath, monitor='val_loss')

history = model.fit(training_data, training_data,
                    epochs=10,
                    batch_size=32,
                    shuffle=True,
                    validation_data=(validation_data, validation_data),
                    verbose=verbose, callbacks=[checkpoint, CustomCallback()])