在 Tensorflow 中恢复检查点时如何获取 global_step?
How to get the global_step when restoring checkpoints in Tensorflow?
我正在这样保存我的会话状态:
self._saver = tf.saver()
self._saver.save(self._session, '/network', global_step=self._time)
当我稍后恢复时,我想为我从中恢复的检查点获取 global_step 的值。这是为了从中设置一些超参数。
执行此操作的 hacky 方法是 运行 通过并解析检查点目录中的文件名。但是一定有更好的内置方法来做到这一点吗?
一般模式是使用 global_step
变量来跟踪步骤
global_step = tf.Variable(0, name='global_step', trainable=False)
train_op = optimizer.minimize(loss, global_step=global_step)
然后你可以用
保存
saver.save(sess, save_path, global_step=global_step)
恢复时,global_step
的值也会被恢复
现在的0.10rc0版本好像不一样了,没有了tf.saver()。现在是 tf.train.Saver()。此外,保存命令将信息添加到 global_step 的 save_path 文件名中,因此我们不能只在同一个 save_path 上调用恢复,因为那不是实际的保存文件。
我现在看到的最简单的方法是将 SessionManager 与这样的保护程序一起使用:
my_checkpoint_dir = "/tmp/checkpoint_dir"
# make a saver to use with SessionManager for restoring
saver = tf.train.Saver()
# Build an initialization operation to run below.
init = tf.initialize_all_variables()
# use a SessionManager to help with automatic variable restoration
sm = tf.train.SessionManager()
# try to find the latest checkpoint in my_checkpoint_dir, then create a session with that restored
# if no such checkpoint, then call the init_op after creating a new session
sess = sm.prepare_session("", init_op=init, saver=saver, checkpoint_dir=my_checkpoint_dir))
就是这样。现在你有一个从 my_checkpoint_dir 恢复的会话(在调用它之前确保该目录存在),或者如果那里没有检查点那么它会创建一个新会话并调用 init_op 来初始化你的变量.
当你想保存的时候,你只需要在那个目录中保存到你想要的任何名称,然后将 global_step 传入。这是一个示例,我将循环中的步骤变量保存为 global_step, 所以如果你杀死程序并重新启动它,它会回到那个点,以便它恢复检查点:
checkpoint_path = os.path.join(my_checkpoint_dir, 'model.ckpt')
saver.save(sess, checkpoint_path, global_step=step)
这会在 my_checkpoint_dir 中创建文件,例如 "model.ckpt-1000",其中 1000 是传入的 global_step。如果它保留 运行,那么您会更像 "model.ckpt-2000".当程序重新启动时,上面的 SessionManager 会获取其中最新的一个。 checkpoint_path 可以是您想要的任何文件名,只要它在 checkpoint_dir 中即可。 save() 将创建附加了 global_step 的文件(如上所示)。它还会创建一个 "checkpoint" 索引文件,这是 SessionManager 找到最新保存检查点的方式。
这有点hack,但其他答案对我根本不起作用
ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
#Extract from checkpoint filename
step = int(os.path.basename(ckpt.model_checkpoint_path).split('-')[1])
2017 年 9 月更新
我不确定这是否由于更新而开始工作,但以下方法似乎可以有效地让 global_step 正确更新和加载:
创建两个操作。一个用于保存 global_step,另一个用于增加它:
global_step = tf.Variable(0, trainable=False, name='global_step')
increment_global_step = tf.assign_add(global_step,1,
name = 'increment_global_step')
现在在你的训练循环中 运行 每次你 运行 你的训练操作时的增量操作。
sess.run([train_op,increment_global_step],feed_dict=feed_dict)
如果您想在任何时候将全局步长值检索为整数,只需在加载模型后使用以下命令:
sess.run(global_step)
这对于创建文件名或计算当前纪元是有用的,而无需第二个 tensorflow 变量来保存该值。例如,计算加载时的当前纪元类似于:
loaded_epoch = sess.run(global_step)//(batch_size*num_train_records)
我和 Lawrence Du 有同样的问题,我找不到通过恢复模型来获得 global_step 的方法。所以我应用了 to the inception v3 training code in the Tensorflow/models github repo 我正在使用。下面的代码还包含与 pretrained_model_checkpoint_path
.
相关的修复
如果您有更好的解决方案,或者知道我遗漏了什么,请发表评论!
无论如何,这段代码对我有用:
...
# When not restoring start at 0
last_step = 0
if FLAGS.pretrained_model_checkpoint_path:
# A model consists of three files, use the base name of the model in
# the checkpoint path. E.g. my-model-path/model.ckpt-291500
#
# Because we need to give the base name you can't assert (will always fail)
# assert tf.gfile.Exists(FLAGS.pretrained_model_checkpoint_path)
variables_to_restore = tf.get_collection(
slim.variables.VARIABLES_TO_RESTORE)
restorer = tf.train.Saver(variables_to_restore)
restorer.restore(sess, FLAGS.pretrained_model_checkpoint_path)
print('%s: Pre-trained model restored from %s' %
(datetime.now(), FLAGS.pretrained_model_checkpoint_path))
# HACK : global step is not restored for some unknown reason
last_step = int(os.path.basename(FLAGS.pretrained_model_checkpoint_path).split('-')[1])
# assign to global step
sess.run(global_step.assign(last_step))
...
for step in range(last_step + 1, FLAGS.max_steps):
...
请注意我关于全局步骤保存和恢复的解决方案。
保存:
global_step = tf.Variable(0, trainable=False, name='global_step')
saver.save(sess, model_path + model_name, global_step=_global_step)
恢复:
if os.path.exists(model_path):
saver.restore(sess, tf.train.latest_checkpoint(model_path))
print("Model restore finished, current globle step: %d" % global_step.eval())
您可以使用 global_step
变量来跟踪步骤,但如果在您的代码中,您正在初始化或将此值分配给另一个 step
变量,它可能不一致。
例如,您定义 global_step
使用:
global_step = tf.Variable(0, name='global_step', trainable=False)
分配给你的训练操作:
train_op = optimizer.minimize(loss, global_step=global_step)
保存在你的检查点:
saver.save(sess, checkpoint_path, global_step=global_step)
并从您的检查点恢复:
saver.restore(sess, checkpoint_path)
global_step
的值也被恢复,但是如果您要将它分配给另一个变量,比如 step
,那么您必须执行如下操作:
step = global_step.eval(session=sess)
变量step
,包含检查点中最后保存的global_step
。
从图中定义 global_step 比零变量(如前定义)更好:
global_step = tf.train.get_or_create_global_step()
这将获取您的最后一个 global_step
(如果存在)或创建一个(如果不存在)。
TL;DR
作为 tensorflow 变量(将在会话中评估)
global_step = tf.train.get_or_create_global_step()
# use global_step variable to calculate your hyperparameter
# this variable will be evaluated later in the session
saver = tf.train.Saver()
with tf.Session() as sess:
# restore all variables from checkpoint
saver.restore(sess, checkpoint_path)
# than init table and local variables and start training/evaluation ...
或: 作为 numpy 整数(没有任何会话):
reader = tf.train.NewCheckpointReader(absolute_checkpoint_path)
global_step = reader.get_tensor('global_step')
长答案
至少有两种方法可以从检查点检索全局变量。作为 tensorflow 变量或 numpy 整数。如果未在 Saver
的 save
方法中将 global_step
作为参数提供,则解析文件名将不起作用。对于预训练模型,请参阅答案末尾的备注。
作为 Tensorflow 变量
如果您需要 global_step
变量来计算一些超参数,您可以使用 tf.train.get_or_create_global_step()
。这将 return 一个张量流变量。因为变量将在稍后的会话中被评估,所以你只能使用 tensorflow 操作来计算你的超参数。因此,例如:max(global_step, 100)
将不起作用。您必须使用 tensorflow 等价物 tf.maximum(global_step, 100)
,可以在稍后的会话中对其进行评估。
在会话中,您可以使用 saver.restore(sess, checkpoint_path)
使用检查点初始化全局步骤变量
global_step = tf.train.get_or_create_global_step()
# use global_step variable to calculate your hyperparameter
# this variable will be evaluated later in the session
hyper_parameter = tf.maximum(global_step, 100)
saver = tf.train.Saver()
with tf.Session() as sess:
# restore all variables from checkpoint
saver.restore(sess, checkpoint_path)
# than init table and local variables and start training/evaluation ...
# for verification you can print the global step and your hyper parameter
print(sess.run([global_step, hyper_parameter]))
或者:作为 numpy 整数(无会话)
如果您需要全局步骤变量作为标量而不启动会话,您也可以直接从检查点文件中读取此变量。你只需要一个NewCheckpointReader
。由于旧版 tensorflow 中的 bug,您应该将检查点文件的路径转换为绝对路径。使用 reader 您可以获得模型的所有张量作为 numpy 变量。
全局步骤变量的名称是常量字符串 tf.GraphKeys.GLOBAL_STEP
定义为 'global_step'
.
absolute_checkpoint_path = os.path.abspath(checkpoint_path)
reader = tf.train.NewCheckpointReader(absolute_checkpoint_path)
global_step = reader.get_tensor(tf.GraphKeys.GLOBAL_STEP)
对预训练模型的备注:在大多数在线可用的预训练模型中,全局步长重置为零。因此,这些模型可用于初始化模型参数以进行微调,而不会覆盖全局步骤。
变量没有按预期恢复的原因很可能是因为它是在您的 tf.Saver()
对象创建之后创建的。
当您没有明确指定 var_list
或为 var_list
指定 None
时,创建 tf.Saver()
对象的位置很重要。许多程序员的预期行为是在调用 save()
方法时保存图中的所有变量,但事实并非如此,也许应该这样记录。创建对象时会保存图中所有变量的快照。
除非您遇到任何性能问题,否则最安全的做法是在您决定保存进度时立即创建保存程序对象。否则,请确保在创建所有变量后创建保护程序对象。
此外,传递给saver.save(sess, save_path, global_step=global_step)
的global_step
只是一个用于创建文件名的计数器,与是否将其恢复为global_step
变量无关.在我看来,这是一个用词不当的参数,因为如果您要在每个纪元结束时保存您的进度,那么最好为该参数传递您的纪元编号。
我正在这样保存我的会话状态:
self._saver = tf.saver()
self._saver.save(self._session, '/network', global_step=self._time)
当我稍后恢复时,我想为我从中恢复的检查点获取 global_step 的值。这是为了从中设置一些超参数。
执行此操作的 hacky 方法是 运行 通过并解析检查点目录中的文件名。但是一定有更好的内置方法来做到这一点吗?
一般模式是使用 global_step
变量来跟踪步骤
global_step = tf.Variable(0, name='global_step', trainable=False)
train_op = optimizer.minimize(loss, global_step=global_step)
然后你可以用
保存saver.save(sess, save_path, global_step=global_step)
恢复时,global_step
的值也会被恢复
现在的0.10rc0版本好像不一样了,没有了tf.saver()。现在是 tf.train.Saver()。此外,保存命令将信息添加到 global_step 的 save_path 文件名中,因此我们不能只在同一个 save_path 上调用恢复,因为那不是实际的保存文件。
我现在看到的最简单的方法是将 SessionManager 与这样的保护程序一起使用:
my_checkpoint_dir = "/tmp/checkpoint_dir"
# make a saver to use with SessionManager for restoring
saver = tf.train.Saver()
# Build an initialization operation to run below.
init = tf.initialize_all_variables()
# use a SessionManager to help with automatic variable restoration
sm = tf.train.SessionManager()
# try to find the latest checkpoint in my_checkpoint_dir, then create a session with that restored
# if no such checkpoint, then call the init_op after creating a new session
sess = sm.prepare_session("", init_op=init, saver=saver, checkpoint_dir=my_checkpoint_dir))
就是这样。现在你有一个从 my_checkpoint_dir 恢复的会话(在调用它之前确保该目录存在),或者如果那里没有检查点那么它会创建一个新会话并调用 init_op 来初始化你的变量.
当你想保存的时候,你只需要在那个目录中保存到你想要的任何名称,然后将 global_step 传入。这是一个示例,我将循环中的步骤变量保存为 global_step, 所以如果你杀死程序并重新启动它,它会回到那个点,以便它恢复检查点:
checkpoint_path = os.path.join(my_checkpoint_dir, 'model.ckpt')
saver.save(sess, checkpoint_path, global_step=step)
这会在 my_checkpoint_dir 中创建文件,例如 "model.ckpt-1000",其中 1000 是传入的 global_step。如果它保留 运行,那么您会更像 "model.ckpt-2000".当程序重新启动时,上面的 SessionManager 会获取其中最新的一个。 checkpoint_path 可以是您想要的任何文件名,只要它在 checkpoint_dir 中即可。 save() 将创建附加了 global_step 的文件(如上所示)。它还会创建一个 "checkpoint" 索引文件,这是 SessionManager 找到最新保存检查点的方式。
这有点hack,但其他答案对我根本不起作用
ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
#Extract from checkpoint filename
step = int(os.path.basename(ckpt.model_checkpoint_path).split('-')[1])
2017 年 9 月更新
我不确定这是否由于更新而开始工作,但以下方法似乎可以有效地让 global_step 正确更新和加载:
创建两个操作。一个用于保存 global_step,另一个用于增加它:
global_step = tf.Variable(0, trainable=False, name='global_step')
increment_global_step = tf.assign_add(global_step,1,
name = 'increment_global_step')
现在在你的训练循环中 运行 每次你 运行 你的训练操作时的增量操作。
sess.run([train_op,increment_global_step],feed_dict=feed_dict)
如果您想在任何时候将全局步长值检索为整数,只需在加载模型后使用以下命令:
sess.run(global_step)
这对于创建文件名或计算当前纪元是有用的,而无需第二个 tensorflow 变量来保存该值。例如,计算加载时的当前纪元类似于:
loaded_epoch = sess.run(global_step)//(batch_size*num_train_records)
我和 Lawrence Du 有同样的问题,我找不到通过恢复模型来获得 global_step 的方法。所以我应用了 pretrained_model_checkpoint_path
.
如果您有更好的解决方案,或者知道我遗漏了什么,请发表评论!
无论如何,这段代码对我有用:
...
# When not restoring start at 0
last_step = 0
if FLAGS.pretrained_model_checkpoint_path:
# A model consists of three files, use the base name of the model in
# the checkpoint path. E.g. my-model-path/model.ckpt-291500
#
# Because we need to give the base name you can't assert (will always fail)
# assert tf.gfile.Exists(FLAGS.pretrained_model_checkpoint_path)
variables_to_restore = tf.get_collection(
slim.variables.VARIABLES_TO_RESTORE)
restorer = tf.train.Saver(variables_to_restore)
restorer.restore(sess, FLAGS.pretrained_model_checkpoint_path)
print('%s: Pre-trained model restored from %s' %
(datetime.now(), FLAGS.pretrained_model_checkpoint_path))
# HACK : global step is not restored for some unknown reason
last_step = int(os.path.basename(FLAGS.pretrained_model_checkpoint_path).split('-')[1])
# assign to global step
sess.run(global_step.assign(last_step))
...
for step in range(last_step + 1, FLAGS.max_steps):
...
请注意我关于全局步骤保存和恢复的解决方案。
保存:
global_step = tf.Variable(0, trainable=False, name='global_step')
saver.save(sess, model_path + model_name, global_step=_global_step)
恢复:
if os.path.exists(model_path):
saver.restore(sess, tf.train.latest_checkpoint(model_path))
print("Model restore finished, current globle step: %d" % global_step.eval())
您可以使用 global_step
变量来跟踪步骤,但如果在您的代码中,您正在初始化或将此值分配给另一个 step
变量,它可能不一致。
例如,您定义 global_step
使用:
global_step = tf.Variable(0, name='global_step', trainable=False)
分配给你的训练操作:
train_op = optimizer.minimize(loss, global_step=global_step)
保存在你的检查点:
saver.save(sess, checkpoint_path, global_step=global_step)
并从您的检查点恢复:
saver.restore(sess, checkpoint_path)
global_step
的值也被恢复,但是如果您要将它分配给另一个变量,比如 step
,那么您必须执行如下操作:
step = global_step.eval(session=sess)
变量step
,包含检查点中最后保存的global_step
。
从图中定义 global_step 比零变量(如前定义)更好:
global_step = tf.train.get_or_create_global_step()
这将获取您的最后一个 global_step
(如果存在)或创建一个(如果不存在)。
TL;DR
作为 tensorflow 变量(将在会话中评估)
global_step = tf.train.get_or_create_global_step()
# use global_step variable to calculate your hyperparameter
# this variable will be evaluated later in the session
saver = tf.train.Saver()
with tf.Session() as sess:
# restore all variables from checkpoint
saver.restore(sess, checkpoint_path)
# than init table and local variables and start training/evaluation ...
或: 作为 numpy 整数(没有任何会话):
reader = tf.train.NewCheckpointReader(absolute_checkpoint_path)
global_step = reader.get_tensor('global_step')
长答案
至少有两种方法可以从检查点检索全局变量。作为 tensorflow 变量或 numpy 整数。如果未在 Saver
的 save
方法中将 global_step
作为参数提供,则解析文件名将不起作用。对于预训练模型,请参阅答案末尾的备注。
作为 Tensorflow 变量
如果您需要 global_step
变量来计算一些超参数,您可以使用 tf.train.get_or_create_global_step()
。这将 return 一个张量流变量。因为变量将在稍后的会话中被评估,所以你只能使用 tensorflow 操作来计算你的超参数。因此,例如:max(global_step, 100)
将不起作用。您必须使用 tensorflow 等价物 tf.maximum(global_step, 100)
,可以在稍后的会话中对其进行评估。
在会话中,您可以使用 saver.restore(sess, checkpoint_path)
global_step = tf.train.get_or_create_global_step()
# use global_step variable to calculate your hyperparameter
# this variable will be evaluated later in the session
hyper_parameter = tf.maximum(global_step, 100)
saver = tf.train.Saver()
with tf.Session() as sess:
# restore all variables from checkpoint
saver.restore(sess, checkpoint_path)
# than init table and local variables and start training/evaluation ...
# for verification you can print the global step and your hyper parameter
print(sess.run([global_step, hyper_parameter]))
或者:作为 numpy 整数(无会话)
如果您需要全局步骤变量作为标量而不启动会话,您也可以直接从检查点文件中读取此变量。你只需要一个NewCheckpointReader
。由于旧版 tensorflow 中的 bug,您应该将检查点文件的路径转换为绝对路径。使用 reader 您可以获得模型的所有张量作为 numpy 变量。
全局步骤变量的名称是常量字符串 tf.GraphKeys.GLOBAL_STEP
定义为 'global_step'
.
absolute_checkpoint_path = os.path.abspath(checkpoint_path)
reader = tf.train.NewCheckpointReader(absolute_checkpoint_path)
global_step = reader.get_tensor(tf.GraphKeys.GLOBAL_STEP)
对预训练模型的备注:在大多数在线可用的预训练模型中,全局步长重置为零。因此,这些模型可用于初始化模型参数以进行微调,而不会覆盖全局步骤。
变量没有按预期恢复的原因很可能是因为它是在您的 tf.Saver()
对象创建之后创建的。
当您没有明确指定 var_list
或为 var_list
指定 None
时,创建 tf.Saver()
对象的位置很重要。许多程序员的预期行为是在调用 save()
方法时保存图中的所有变量,但事实并非如此,也许应该这样记录。创建对象时会保存图中所有变量的快照。
除非您遇到任何性能问题,否则最安全的做法是在您决定保存进度时立即创建保存程序对象。否则,请确保在创建所有变量后创建保护程序对象。
此外,传递给saver.save(sess, save_path, global_step=global_step)
的global_step
只是一个用于创建文件名的计数器,与是否将其恢复为global_step
变量无关.在我看来,这是一个用词不当的参数,因为如果您要在每个纪元结束时保存您的进度,那么最好为该参数传递您的纪元编号。