在 Tensorflow 中恢复检查点时如何获取 global_step?

How to get the global_step when restoring checkpoints in Tensorflow?

我正在这样保存我的会话状态:

self._saver = tf.saver()
self._saver.save(self._session, '/network', global_step=self._time)

当我稍后恢复时,我想为我从中恢复的检查点获取 global_step 的值。这是为了从中设置一些超参数。

执行此操作的 hacky 方法是 运行 通过并解析检查点目录中的文件名。但是一定有更好的内置方法来做到这一点吗?

一般模式是使用 global_step 变量来跟踪步骤

global_step = tf.Variable(0, name='global_step', trainable=False)
train_op = optimizer.minimize(loss, global_step=global_step)

然后你可以用

保存
saver.save(sess, save_path, global_step=global_step)

恢复时,global_step 的值也会被恢复

现在的0.10rc0版本好像不一样了,没有了tf.saver()。现在是 tf.train.Saver()。此外,保存命令将信息添加到 global_step 的 save_path 文件名中,因此我们不能只在同一个 save_path 上调用恢复,因为那不是实际的保存文件。

我现在看到的最简单的方法是将 SessionManager 与这样的保护程序一起使用:

my_checkpoint_dir = "/tmp/checkpoint_dir"
# make a saver to use with SessionManager for restoring
saver = tf.train.Saver()
# Build an initialization operation to run below.
init = tf.initialize_all_variables()
# use a SessionManager to help with automatic variable restoration
sm = tf.train.SessionManager()
# try to find the latest checkpoint in my_checkpoint_dir, then create a session with that restored
# if no such checkpoint, then call the init_op after creating a new session
sess = sm.prepare_session("", init_op=init, saver=saver, checkpoint_dir=my_checkpoint_dir))

就是这样。现在你有一个从 my_checkpoint_dir 恢复的会话(在调用它之前确保该目录存在),或者如果那里没有检查点那么它会创建一个新会话并调用 init_op 来初始化你的变量.

当你想保存的时候,你只需要在那个目录中保存到你想要的任何名称,然后将 global_step 传入。这是一个示例,我将循环中的步骤变量保存为 global_step, 所以如果你杀死程序并重新启动它,它会回到那个点,以便它恢复检查点:

checkpoint_path = os.path.join(my_checkpoint_dir, 'model.ckpt')
saver.save(sess, checkpoint_path, global_step=step)

这会在 my_checkpoint_dir 中创建文件,例如 "model.ckpt-1000",其中 1000 是传入的 global_step。如果它保留 运行,那么您会更像 "model.ckpt-2000".当程序重新启动时,上面的 SessionManager 会获取其中最新的一个。 checkpoint_path 可以是您想要的任何文件名,只要它在 checkpoint_dir 中即可。 save() 将创建附加了 global_step 的文件(如上所示)。它还会创建一个 "checkpoint" 索引文件,这是 SessionManager 找到最新保存检查点的方式。

这有点hack,但其他答案对我根本不起作用

ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 

#Extract from checkpoint filename
step = int(os.path.basename(ckpt.model_checkpoint_path).split('-')[1])

2017 年 9 月更新

我不确定这是否由于更新而开始工作,但以下方法似乎可以有效地让 global_step 正确更新和加载:

创建两个操作。一个用于保存 global_step,另一个用于增加它:

    global_step = tf.Variable(0, trainable=False, name='global_step')
    increment_global_step = tf.assign_add(global_step,1,
                                            name = 'increment_global_step')

现在在你的训练循环中 运行 每次你 运行 你的训练操作时的增量操作。

sess.run([train_op,increment_global_step],feed_dict=feed_dict)

如果您想在任何时候将全局步长值检索为整数,只需在加载模型后使用以下命令:

sess.run(global_step)

这对于创建文件名或计算当前纪元是有用的,而无需第二个 tensorflow 变量来保存该值。例如,计算加载时的当前纪元类似于:

loaded_epoch = sess.run(global_step)//(batch_size*num_train_records)

我和 Lawrence Du 有同样的问题,我找不到通过恢复模型来获得 global_step 的方法。所以我应用了 to the inception v3 training code in the Tensorflow/models github repo 我正在使用。下面的代码还包含与 pretrained_model_checkpoint_path.

相关的修复

如果您有更好的解决方案,或者知道我遗漏了什么,请发表评论!

无论如何,这段代码对我有用:

...

# When not restoring start at 0
last_step = 0
if FLAGS.pretrained_model_checkpoint_path:
    # A model consists of three files, use the base name of the model in
    # the checkpoint path. E.g. my-model-path/model.ckpt-291500
    #
    # Because we need to give the base name you can't assert (will always fail)
    # assert tf.gfile.Exists(FLAGS.pretrained_model_checkpoint_path)

    variables_to_restore = tf.get_collection(
        slim.variables.VARIABLES_TO_RESTORE)
    restorer = tf.train.Saver(variables_to_restore)
    restorer.restore(sess, FLAGS.pretrained_model_checkpoint_path)
    print('%s: Pre-trained model restored from %s' %
          (datetime.now(), FLAGS.pretrained_model_checkpoint_path))

    # HACK : global step is not restored for some unknown reason
    last_step = int(os.path.basename(FLAGS.pretrained_model_checkpoint_path).split('-')[1])

    # assign to global step
    sess.run(global_step.assign(last_step))

...

for step in range(last_step + 1, FLAGS.max_steps):

  ...

请注意我关于全局步骤保存和恢复的解决方案。

保存:

global_step = tf.Variable(0, trainable=False, name='global_step')
saver.save(sess, model_path + model_name, global_step=_global_step)

恢复:

if os.path.exists(model_path):
    saver.restore(sess, tf.train.latest_checkpoint(model_path))
    print("Model restore finished, current globle step: %d" % global_step.eval())

您可以使用 global_step 变量来跟踪步骤,但如果在您的代码中,您正在初始化或将此值分配给另一个 step 变量,它可能不一致。

例如,您定义 global_step 使用:

global_step = tf.Variable(0, name='global_step', trainable=False)

分配给你的训练操作:

train_op = optimizer.minimize(loss, global_step=global_step)

保存在你的检查点:

saver.save(sess, checkpoint_path, global_step=global_step)

并从您的检查点恢复:

saver.restore(sess, checkpoint_path) 

global_step 的值也被恢复,但是如果您要将它分配给另一个变量,比如 step,那么您必须执行如下操作:

step = global_step.eval(session=sess)

变量step,包含检查点中最后保存的global_step

从图中定义 global_step 比零变量(如前定义)更好:

global_step = tf.train.get_or_create_global_step()

这将获取您的最后一个 global_step(如果存在)或创建一个(如果不存在)。

TL;DR

作为 tensorflow 变量(将在会话中评估)

global_step = tf.train.get_or_create_global_step()
# use global_step variable to calculate your hyperparameter 
# this variable will be evaluated later in the session
saver = tf.train.Saver()
with tf.Session() as sess:
    # restore all variables from checkpoint
    saver.restore(sess, checkpoint_path)
    # than init table and local variables and start training/evaluation ...

或: 作为 numpy 整数(没有任何会话):

reader = tf.train.NewCheckpointReader(absolute_checkpoint_path)
global_step = reader.get_tensor('global_step')


长答案

至少有两种方法可以从检查点检索全局变量。作为 tensorflow 变量或 numpy 整数。如果未在 Saversave 方法中将 global_step 作为参数提供,则解析文件名将不起作用。对于预训练模型,请参阅答案末尾的备注。

作为 Tensorflow 变量

如果您需要 global_step 变量来计算一些超参数,您可以使用 tf.train.get_or_create_global_step()。这将 return 一个张量流变量。因为变量将在稍后的会话中被评估,所以你只能使用 tensorflow 操作来计算你的超参数。因此,例如:max(global_step, 100) 将不起作用。您必须使用 tensorflow 等价物 tf.maximum(global_step, 100),可以在稍后的会话中对其进行评估。

在会话中,您可以使用 saver.restore(sess, checkpoint_path)

使用检查点初始化全局步骤变量
global_step = tf.train.get_or_create_global_step()
# use global_step variable to calculate your hyperparameter 
# this variable will be evaluated later in the session
hyper_parameter = tf.maximum(global_step, 100) 
saver = tf.train.Saver()
with tf.Session() as sess:
    # restore all variables from checkpoint
    saver.restore(sess, checkpoint_path)
    # than init table and local variables and start training/evaluation ...

    # for verification you can print the global step and your hyper parameter
    print(sess.run([global_step, hyper_parameter]))

或者:作为 numpy 整数(无会话)

如果您需要全局步骤变量作为标量而不启动会话,您也可以直接从检查点文件中读取此变量。你只需要一个NewCheckpointReader。由于旧版 tensorflow 中的 bug,您应该将检查点文件的路径转换为绝对路径。使用 reader 您可以获得模型的所有张量作为 numpy 变量。 全局步骤变量的名称是常量字符串 tf.GraphKeys.GLOBAL_STEP 定义为 'global_step'.

absolute_checkpoint_path = os.path.abspath(checkpoint_path)
reader = tf.train.NewCheckpointReader(absolute_checkpoint_path)
global_step = reader.get_tensor(tf.GraphKeys.GLOBAL_STEP)

对预训练模型的备注:在大多数在线可用的预训练模型中,全局步长重置为零。因此,这些模型可用于初始化模型参数以进行微调,而不会覆盖全局步骤。

变量没有按预期恢复的原因很可能是因为它是在您的 tf.Saver() 对象创建之后创建的。

当您没有明确指定 var_list 或为 var_list 指定 None 时,创建 tf.Saver() 对象的位置很重要。许多程序员的预期行为是在调用 save() 方法时保存图中的所有变量,但事实并非如此,也许应该这样记录。创建对象时会保存图中所有变量的快照。

除非您遇到任何性能问题,否则最安全的做法是在您决定保存进度时立即创建保存程序对象。否则,请确保在创建所有变量后创建保护程序对象。

此外,传递给saver.save(sess, save_path, global_step=global_step)global_step只是一个用于创建文件名的计数器,与是否将其恢复为global_step变量无关.在我看来,这是一个用词不当的参数,因为如果您要在每个纪元结束时保存您的进度,那么最好为该参数传递您的纪元编号。