Tensorflow 使用 `tf.train.MonitoredTrainingSession` 恢复 `tf.Session` 保存的检查点
Tensorflow restore `tf.Session` saved checkpoint using `tf.train.MonitoredTrainingSession`
我有使用 tf.train.MonitoredTrainingSession
训练 CNN 的代码。
当我创建一个新的 tf.train.MonitoredTrainingSession
时,我可以将 checkpoint
目录作为输入参数传递给会话,它会自动恢复它能找到的最新保存的 checkpoint
。我可以设置 hooks
来训练直到某个步骤。例如,如果 checkpoint
的步骤是 150,000
,我想训练到 200,000
,我会将 last_step
放到 200,000
。
只要使用 tf.train.MonitoredTrainingSession
保存最新的 checkpoint
,上述过程就可以完美运行。但是,如果我尝试恢复使用普通 tf.Session
保存的 checkpoint
,那么一切都会崩溃。它无法在图中找到某些键和所有键。
训练是这样完成的:
with tf.train.MonitoredTrainingSession(
checkpoint_dir=FLAGS.retrain_dir,
hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_training_steps),
tf.train.NanTensorHook(loss),
_LoggerHook()],
config=tf.ConfigProto(
log_device_placement=FLAGS.log_device_placement)) as mon_sess:
while not mon_sess.should_stop():
mon_sess.run(train_op)
如果 checkpoint_dir
属性有一个没有检查点的文件夹,这将从头开始。如果它有一个 checkpoint
从之前的训练中保存下来,它将恢复最新的 checkpoint
并继续训练。
现在,我正在恢复最新的checkpoint
并修改一些变量并保存它们:
saver = tf.train.Saver(variables_to_restore)
ckpt = tf.train.get_checkpoint_state(FLAGS.train_dir)
with tf.Session() as sess:
if ckpt and ckpt.model_checkpoint_path:
# Restores from checkpoint
saver.restore(sess, ckpt.model_checkpoint_path)
print(ckpt.model_checkpoint_path)
restored_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
FLAGS.max_training_steps = int(restored_step) + FLAGS.max_training_steps
else:
print('No checkpoint file found')
return
prune_convs(sess)
saver.save(sess, FLAGS.retrain_dir+"model.ckpt-"+restored_step)
如您所见,就在 saver.save...
我正在 p运行 网络中的所有卷积层之前。无需描述如何以及为什么这样做。关键是网络实际上被修改了。然后我将网络保存到 checkpoint
.
现在,如果我在保存的修改后的网络上部署测试,测试工作正常。但是,当我尝试 运行 保存的 checkpoint
上的 tf.train.MonitoredTrainingSession
时,它说:
Key conv1/weight_loss/avg not found in checkpoint
此外,我注意到用 tf.Session
保存的 checkpoint
的大小是用 tf.train.MonitoredTrainingSession
[=41 保存的 checkpoint
的一半=]
我知道我做错了,有什么关于如何使它起作用的建议吗?
我明白了。显然,tf.Saver
不会从 checkpoint
中恢复所有变量。我尝试立即恢复并保存,但输出只有一半大小。
我使用 tf.train.list_variables
从最新的 checkpoint
中获取所有变量,然后将它们转换为 tf.Variable
并从中创建一个 dict
。然后我将 dict
传递给 tf.Saver
,它恢复了我所有的变量。
接下来是initialize
所有变量然后修改权重。
现在可以使用了。
我有使用 tf.train.MonitoredTrainingSession
训练 CNN 的代码。
当我创建一个新的 tf.train.MonitoredTrainingSession
时,我可以将 checkpoint
目录作为输入参数传递给会话,它会自动恢复它能找到的最新保存的 checkpoint
。我可以设置 hooks
来训练直到某个步骤。例如,如果 checkpoint
的步骤是 150,000
,我想训练到 200,000
,我会将 last_step
放到 200,000
。
只要使用 tf.train.MonitoredTrainingSession
保存最新的 checkpoint
,上述过程就可以完美运行。但是,如果我尝试恢复使用普通 tf.Session
保存的 checkpoint
,那么一切都会崩溃。它无法在图中找到某些键和所有键。
训练是这样完成的:
with tf.train.MonitoredTrainingSession(
checkpoint_dir=FLAGS.retrain_dir,
hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_training_steps),
tf.train.NanTensorHook(loss),
_LoggerHook()],
config=tf.ConfigProto(
log_device_placement=FLAGS.log_device_placement)) as mon_sess:
while not mon_sess.should_stop():
mon_sess.run(train_op)
如果 checkpoint_dir
属性有一个没有检查点的文件夹,这将从头开始。如果它有一个 checkpoint
从之前的训练中保存下来,它将恢复最新的 checkpoint
并继续训练。
现在,我正在恢复最新的checkpoint
并修改一些变量并保存它们:
saver = tf.train.Saver(variables_to_restore)
ckpt = tf.train.get_checkpoint_state(FLAGS.train_dir)
with tf.Session() as sess:
if ckpt and ckpt.model_checkpoint_path:
# Restores from checkpoint
saver.restore(sess, ckpt.model_checkpoint_path)
print(ckpt.model_checkpoint_path)
restored_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
FLAGS.max_training_steps = int(restored_step) + FLAGS.max_training_steps
else:
print('No checkpoint file found')
return
prune_convs(sess)
saver.save(sess, FLAGS.retrain_dir+"model.ckpt-"+restored_step)
如您所见,就在 saver.save...
我正在 p运行 网络中的所有卷积层之前。无需描述如何以及为什么这样做。关键是网络实际上被修改了。然后我将网络保存到 checkpoint
.
现在,如果我在保存的修改后的网络上部署测试,测试工作正常。但是,当我尝试 运行 保存的 checkpoint
上的 tf.train.MonitoredTrainingSession
时,它说:
Key conv1/weight_loss/avg not found in checkpoint
此外,我注意到用 tf.Session
保存的 checkpoint
的大小是用 tf.train.MonitoredTrainingSession
[=41 保存的 checkpoint
的一半=]
我知道我做错了,有什么关于如何使它起作用的建议吗?
我明白了。显然,tf.Saver
不会从 checkpoint
中恢复所有变量。我尝试立即恢复并保存,但输出只有一半大小。
我使用 tf.train.list_variables
从最新的 checkpoint
中获取所有变量,然后将它们转换为 tf.Variable
并从中创建一个 dict
。然后我将 dict
传递给 tf.Saver
,它恢复了我所有的变量。
接下来是initialize
所有变量然后修改权重。
现在可以使用了。