Tensorflow 使用 `tf.train.MonitoredTrainingSession` 恢复 `tf.Session` 保存的检查点

Tensorflow restore `tf.Session` saved checkpoint using `tf.train.MonitoredTrainingSession`

我有使用 tf.train.MonitoredTrainingSession 训练 CNN 的代码。

当我创建一个新的 tf.train.MonitoredTrainingSession 时,我可以将 checkpoint 目录作为输入参数传递给会话,它会自动恢复它能找到的最新保存的 checkpoint。我可以设置 hooks 来训练直到某个步骤。例如,如果 checkpoint 的步骤是 150,000,我想训练到 200,000,我会将 last_step 放到 200,000

只要使用 tf.train.MonitoredTrainingSession 保存最新的 checkpoint,上述过程就可以完美运行。但是,如果我尝试恢复使用普通 tf.Session 保存的 checkpoint,那么一切都会崩溃。它无法在图中找到某些键和所有键。

训练是这样完成的:

with tf.train.MonitoredTrainingSession(
    checkpoint_dir=FLAGS.retrain_dir,
    hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_training_steps),
           tf.train.NanTensorHook(loss),
           _LoggerHook()],
    config=tf.ConfigProto(
        log_device_placement=FLAGS.log_device_placement)) as mon_sess:
  while not mon_sess.should_stop():
    mon_sess.run(train_op)

如果 checkpoint_dir 属性有一个没有检查点的文件夹,这将从头开始。如果它有一个 checkpoint 从之前的训练中保存下来,它将恢复最新的 checkpoint 并继续训练。

现在,我正在恢复最新的checkpoint并修改一些变量并保存它们:

saver = tf.train.Saver(variables_to_restore)

ckpt = tf.train.get_checkpoint_state(FLAGS.train_dir)

with tf.Session() as sess:
  if ckpt and ckpt.model_checkpoint_path:
    # Restores from checkpoint
    saver.restore(sess, ckpt.model_checkpoint_path)
    print(ckpt.model_checkpoint_path)
    restored_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
    FLAGS.max_training_steps = int(restored_step) + FLAGS.max_training_steps
  else:
    print('No checkpoint file found')
    return

  prune_convs(sess)
  saver.save(sess, FLAGS.retrain_dir+"model.ckpt-"+restored_step)

如您所见,就在 saver.save... 我正在 p运行 网络中的所有卷积层之前。无需描述如何以及为什么这样做。关键是网络实际上被修改了。然后我将网络保存到 checkpoint.

现在,如果我在保存的修改后的网络上部署测试,测试工作正常。但是,当我尝试 运行 保存的 checkpoint 上的 tf.train.MonitoredTrainingSession 时,它说:

Key conv1/weight_loss/avg not found in checkpoint

此外,我注意到用 tf.Session 保存的 checkpoint 的大小是用 tf.train.MonitoredTrainingSession[=41 保存的 checkpoint 的一半=]

我知道我做错了,有什么关于如何使它起作用的建议吗?

我明白了。显然,tf.Saver 不会从 checkpoint 中恢复所有变量。我尝试立即恢复并保存,但输出只有一半大小。

我使用 tf.train.list_variables 从最新的 checkpoint 中获取所有变量,然后将它们转换为 tf.Variable 并从中创建一个 dict。然后我将 dict 传递给 tf.Saver,它恢复了我所有的变量。

接下来是initialize所有变量然后修改权重。

现在可以使用了。