在 tensorflow 中恢复图形失败,因为没有要保存的变量

Restoring graph in tensorflow fails because there is no variable to save

我知道堆栈和 github 等方面有无数关于如何在 Tensorflow 中恢复训练好的模型的问题。我已经阅读了其中的大部分 (,2,3)。

我遇到的问题几乎与问题 3 完全相同,但我希望尽可能以不同的方式解决它,因为我的训练和测试需要在从 shell 调用的单独脚本中进行,我确实这样做了不想添加我用来在测试脚本中定义图表的完全相同的行,所以我不能使用 tensorflow FLAGS 和其他基于手动重新 运行 图表的答案。

我也不想 sess.run 每个变量并手动映射它们,因为我的图表很大(使用 import_graph_def 和参数 input_map) .

所以我 运行 一些图表并在特定脚本中对其进行训练。例如(但没有训练部分)

#Script 1
import tensorflow as tf
import cPickle as pickle

x=tf.Variable(42)
saver=tf.train.Saver()
sess=tf.Session()
#Saving the graph
graph_def=sess.graph_def
with open('graph.pkl','wb') as output:
  pickle.dump(graph_def,output,HIGHEST_PROTOCOL)


#Training the model
sess.run(tf.initialize_all_variables())
#Saving the variables
saver.save(sess,"pretrained_model.ckpt")

我现在已经保存了图形和变量,所以即使我的图形中有额外的训练节点,我也应该能够从另一个脚本运行我的测试模型。

#Script 2
import tensorflow as tf
import cPickle as pickle

sess=tf.Session()
with open('graph.pkl','rb') as input:
  graph_def=pickle.load(input)


tf.import_graph_def(graph_def,name='persisted')

然后显然我想使用保护程序恢复变量,但我遇到了与 3 相同的问题,因为没有找到要保存的变量,甚至无法创建保护程序。所以我不能写:

saver=tf.train.Saver()
saver.restore(sess,"pretrained_model.ckpt")

有没有办法绕过这些限制?我认为通过导入图形它会恢复每个节点中未初始化的变量,但似乎不会。我真的需要像大多数给出的答案一样重新运行 第二次吗?

变量列表保存在 Collection 中,未保存在 GraphDef 中。 Saver 默认使用 ops.GraphKeys.VARIABLES 集合中的列表(可通过 tf.all_variables() 访问),如果您从 GraphDef 恢复而不是使用 Python API要构建您的模型,该集合是空的。您可以在 tf.train.Saver(var_list=['MyVariable1:0', 'MyVariable2:0',...]).

中手动指定变量列表

或者您可以使用 MetaGraphDef 来代替 GraphDef 来保存集合,最近添加了 MetaGraphDef HowTo

据我所知和我的测试,您不能简单地将名称传递给 tf.train.Saver 对象。它必须是变量列表或字典。

我还想从 graph_def 读取模型,然后使用 saver 加载变量 - 然而尝试它只会导致错误消息:"Variable to save is not a variable"