我如何在 Keras Tuner 中使用 tf.keras.callbacks.ModelCheckpoint?

How can i use tf.keras.callbacks.ModelCheckpoint in Keras Tuner?

所以我想在 Keras Tuner 中使用 tf.keras.callbacks.ModelCheckpoint,但是您选择保存检查点的路径的方式不允许您将其保存为具有特定名称的文件,名称与该检查点的试验和执行相关联,仅与一个纪元相关联。

也就是说,如果我只是把这个回调放在Keras Tuner中,在检查点保存发生的那一刻,最后,我不知道如何将保存的检查点关联到试验和试验执行,只有到时代。

您可以将 tf.keras.callbacks.ModelCheckpoint 用于 Keras tuner,就像在其他模型中使用的那样保存检查点。

根据this模型使用从搜索中获得的超参数训练模型后,您可以定义模型检查点并将其保存如下:

hypermodel = tuner.hypermodel.build(best_hps)

# Retrain the model
hypermodel.fit(img_train, label_train, epochs=best_epoch, validation_split=0.2)

import os
checkpoint_path = "training_1/cp.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)

# Create a callback that saves the model's weights
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
                                                 save_weights_only=True,
                                                 verbose=1)
history = hypermodel.fit(img_train, label_train, epochs=5, validation_split=0.2, callbacks=[cp_callback])
os.listdir(checkpoint_dir)

# Re-evaluate the model
loss, acc = hypermodel.evaluate(img_test, label_test, verbose=2)
print("Restored model, accuracy: {:5.2f}%".format(100 * acc))

# Loads the weights
hypermodel.load_weights(checkpoint_path)

# Re-evaluate the model
loss, acc = hypermodel.evaluate(img_test, label_test, verbose=2)
print("Restored model, accuracy: {:5.2f}%".format(100 * acc))

有关保存和加载模型检查点的更多信息,请参阅 this link。