加载保存的模型以恢复训练
Loading a saved model to resume training
我正在训练 ResNet 模型来对汽车品牌进行分类。
我在训练期间为每个 epoch 保存了权重。
为了测试,我在 epoch 3 停止了训练。
# checkpoint = ModelCheckpoint("best_model.hdf5", monitor='loss', verbose=1)
checkpoint_path = "weights/cp-{epoch:04d}.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)
cp_callback = tf.keras.callbacks.ModelCheckpoint(
checkpoint_path, verbose=1,
# Save weights, every epoch.
save_freq='epoch')
model.save_weights(checkpoint_path.format(epoch=0))
history = model.fit_generator(
training_set,
validation_data = test_set,
epochs = 50,
steps_per_epoch = len(training_set),
validation_steps = len(test_set),
callbacks = [cp_callback]
)
但是,当加载它们时,我不确定它是否从上次保存的纪元恢复,因为它再次显示纪元 1/50。下面是我用来加载上次保存的模型的代码。
from keras.models import Sequential, load_model
# load the model
new_model = load_model('./weights/cp-0003.ckpt')
# fit the model
history = new_model.fit_generator(
training_set,
validation_data = test_set,
epochs = 50,
steps_per_epoch = len(training_set),
validation_steps = len(test_set),
callbacks = [cp_callback]
)
这是它的样子:
Image showing that running the saved weight starts from epoch 1/50 again
有人可以帮忙吗?
您可以使用 fit_generator 的 initial_epoch
参数。默认情况下,它设置为 0,但您可以将其设置为任何正数:
from keras.models import Sequential, load_model
import tensorflow as tf
checkpoint_path = "weights/cp-{epoch:04d}.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)
cp_callback = tf.keras.callbacks.ModelCheckpoint(
checkpoint_path, verbose=1,
# Save weights, every epoch.
save_freq='epoch')
model.save_weights(checkpoint_path.format(epoch=0))
history = model.fit_generator(
training_set,
validation_data=test_set,
epochs=3,
steps_per_epoch=len(training_set),
validation_steps=len(test_set),
callbacks = [cp_callback]
)
new_model = load_model('./weights/cp-0003.ckpt')
# fit the model
history = new_model.fit_generator(
training_set,
validation_data=test_set,
epochs=50,
steps_per_epoch=len(training_set),
validation_steps=len(test_set),
callbacks=[cp_callback],
initial_epoch=3
)
这将为您的模型训练 50 - 3 = 47 个额外的时期。
如果您使用 Tensorflow 2.X 关于您的代码的一些评论:
fit_generator
已弃用,因为 fit
现在支持生成器
- 您应该将导入
from keras....
替换为 from tensorflow.keras...
我正在训练 ResNet 模型来对汽车品牌进行分类。
我在训练期间为每个 epoch 保存了权重。
为了测试,我在 epoch 3 停止了训练。
# checkpoint = ModelCheckpoint("best_model.hdf5", monitor='loss', verbose=1)
checkpoint_path = "weights/cp-{epoch:04d}.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)
cp_callback = tf.keras.callbacks.ModelCheckpoint(
checkpoint_path, verbose=1,
# Save weights, every epoch.
save_freq='epoch')
model.save_weights(checkpoint_path.format(epoch=0))
history = model.fit_generator(
training_set,
validation_data = test_set,
epochs = 50,
steps_per_epoch = len(training_set),
validation_steps = len(test_set),
callbacks = [cp_callback]
)
但是,当加载它们时,我不确定它是否从上次保存的纪元恢复,因为它再次显示纪元 1/50。下面是我用来加载上次保存的模型的代码。
from keras.models import Sequential, load_model
# load the model
new_model = load_model('./weights/cp-0003.ckpt')
# fit the model
history = new_model.fit_generator(
training_set,
validation_data = test_set,
epochs = 50,
steps_per_epoch = len(training_set),
validation_steps = len(test_set),
callbacks = [cp_callback]
)
这是它的样子: Image showing that running the saved weight starts from epoch 1/50 again
有人可以帮忙吗?
您可以使用 fit_generator 的 initial_epoch
参数。默认情况下,它设置为 0,但您可以将其设置为任何正数:
from keras.models import Sequential, load_model
import tensorflow as tf
checkpoint_path = "weights/cp-{epoch:04d}.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)
cp_callback = tf.keras.callbacks.ModelCheckpoint(
checkpoint_path, verbose=1,
# Save weights, every epoch.
save_freq='epoch')
model.save_weights(checkpoint_path.format(epoch=0))
history = model.fit_generator(
training_set,
validation_data=test_set,
epochs=3,
steps_per_epoch=len(training_set),
validation_steps=len(test_set),
callbacks = [cp_callback]
)
new_model = load_model('./weights/cp-0003.ckpt')
# fit the model
history = new_model.fit_generator(
training_set,
validation_data=test_set,
epochs=50,
steps_per_epoch=len(training_set),
validation_steps=len(test_set),
callbacks=[cp_callback],
initial_epoch=3
)
这将为您的模型训练 50 - 3 = 47 个额外的时期。
如果您使用 Tensorflow 2.X 关于您的代码的一些评论:
fit_generator
已弃用,因为fit
现在支持生成器- 您应该将导入
from keras....
替换为from tensorflow.keras...