使用 Keras 进行深度学习的训练检查点
Training checkpoints for deep learning with Keras
我正在使用 Google Colab,并在我的驱动器上保存权重。
培训:
def train(model, network_input, network_output):
""" train the neural network """
filepath = "/content/gdrive/MyDrive/weights-improvement-{epoch:02d}-{loss:.4f}-bigger.hdf5"
checkpoint = ModelCheckpoint(
filepath,
monitor='loss',
verbose=0,
save_best_only=True,
mode='min'
)
callbacks_list = [checkpoint]
model.fit(network_input, network_output, epochs=200, batch_size=128, callbacks=callbacks_list)
经过一段时间的训练,我的权重是:
weights in my drive
然后我在不修改函数的情况下恢复训练,输出单元格如下所示:
output cell
我如何知道训练是从迄今为止的最佳权重恢复,即“weights-improvement-06-4.1851-bigger.hdf5”,还是只是从头开始?如果它是根据保存的权重进行训练,它不应该以某种方式表明这一点吗?也许向我展示纪元从它停止的地方开始,从纪元 4/200 而不是 1/200 开始。
如果您仍在使用相同的实例化模型对象(即您没有实例化新模型对象),它将从中断的地方恢复训练 - 它不会重新开始。
但是,如果您想使用相同的配置实例化一个新模型并从之前保存的一组权重(检查点)开始,您可以使用 tensorflow 的 latest_checkpoint
从您的模型中加载最新的检查点权重在将这些权重传递给模型之前的目录。
from tensorflow.train import latest_checkpoint
last_ckpt = latest_checkpoint(os.path.join('my','checkpoint','directory'))
# this is the newly instantiated model using the same config
model.load_weights(last_ckpt)
我正在使用 Google Colab,并在我的驱动器上保存权重。
培训:
def train(model, network_input, network_output):
""" train the neural network """
filepath = "/content/gdrive/MyDrive/weights-improvement-{epoch:02d}-{loss:.4f}-bigger.hdf5"
checkpoint = ModelCheckpoint(
filepath,
monitor='loss',
verbose=0,
save_best_only=True,
mode='min'
)
callbacks_list = [checkpoint]
model.fit(network_input, network_output, epochs=200, batch_size=128, callbacks=callbacks_list)
经过一段时间的训练,我的权重是: weights in my drive
然后我在不修改函数的情况下恢复训练,输出单元格如下所示: output cell
我如何知道训练是从迄今为止的最佳权重恢复,即“weights-improvement-06-4.1851-bigger.hdf5”,还是只是从头开始?如果它是根据保存的权重进行训练,它不应该以某种方式表明这一点吗?也许向我展示纪元从它停止的地方开始,从纪元 4/200 而不是 1/200 开始。
如果您仍在使用相同的实例化模型对象(即您没有实例化新模型对象),它将从中断的地方恢复训练 - 它不会重新开始。
但是,如果您想使用相同的配置实例化一个新模型并从之前保存的一组权重(检查点)开始,您可以使用 tensorflow 的 latest_checkpoint
从您的模型中加载最新的检查点权重在将这些权重传递给模型之前的目录。
from tensorflow.train import latest_checkpoint
last_ckpt = latest_checkpoint(os.path.join('my','checkpoint','directory'))
# this is the newly instantiated model using the same config
model.load_weights(last_ckpt)