在特定迭代或检查点将模型加载/恢复到 tensorflow

Load / restore models into tensorflow at specific iteration or checkpoint

我有一个模型,我每迭代 10 次就保存一次。所以,我在保存的目录中有以下文件。

checkpoint  model-50.data-00000-of-00001  model-50.index  model-50.meta
model-60.data-00000-of-00001  model-60.index  model-60.meta

依此类推,直到 100。我只需要加载 model-50。因为我有 70 次迭代后的 NaN 值。默认情况下,当我恢复时,保护程序将寻找最终检查点。那么,我该如何专门加载 model-50.请帮助,否则,我必须 运行 从头开始​​构建模型,这很耗时。

由于您正在使用 tf.train.Saver 的函数 restore(),您可以使用 last_checkpoints 函数来获取所有可用检查点的列表。您将在此列表中同时看到 model-50model-60

选择正确的模型,像这样直接传给restore()

saver.restore(sess, ckpt_path)

我不确定过去是否有所不同,但至少到现在为止,您可以使用包含 all_model_checkpoint_paths.

tf.train.get_checkpoint_state() to get CheckpointState 原型

当您执行大多数关于 saving/restoring 模型的教程中显示的命令时 saver.restore(sess, tf.train.latest_checkpoint(_dir_models)) 您传递的第二个参数只是模型路径的字符串。这是在 saver.restore 文档中定义的。

save_path: Path where parameters were previously saved.

所以你可以在其中设置任何字符串的路径,latest_checkpoint 只是一个方便的函数,可以从 checkpoint 文件中提取此路径。在笔记本中打开此文件,您将看到所有可用的模型路径以及最新的路径。

您可以将该路径替换为您想要的任何路径。您可以从该文件中获取它(手动打开它或使用 get_checkpoin_state 以编程方式为您完成。