从 .pb 文件恢复图形 def 时出现 Tensorflow 错误
Tensorflow error while restoring graph def from .pb file
我正在关注有关使用 tensorflow 进行文本分类的 wildml 博客。我更改了代码以保存图形定义,如下所示:
tf.train.write_graph(sess.graph_def,'./DeepLearn/model/','train.pb', as_text=False)
稍后在一个单独的文件中,我将按如下方式恢复图形:
with tf.gfile.FastGFile(os.path.join('./DeepLearn/model/','train.pb'), 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
_ = tf.import_graph_def(graph_def, name='')
with tf.Session() as sess:
t = sess.graph.get_tensor_by_name('embedding/W:0')
sess.run(t)
当我尝试 运行 张量并获取其值时,出现以下错误:
tensorflow.python.framework.errors.FailedPreconditionError: Attempting to use uninitialized value embedding/W
此错误的可能原因是什么。张量应该已经初始化,因为我正在从保存的图表中恢复它。
谢谢亚历山大!
是的,我需要加载图表(来自 .pb 文件)和权重(来自检查点文件)。使用了以下示例代码(取自博客)并且对我有用。
with tf.Session() as persisted_sess:
print("load graph")
with gfile.FastGFile("/tmp/load/test.pb",'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
persisted_sess.graph.as_default()
tf.import_graph_def(graph_def, name='')
persisted_result = persisted_sess.graph.get_tensor_by_name("saved_result:0")
tf.add_to_collection(tf.GraphKeys.VARIABLES,persisted_result)
try:
saver = tf.train.Saver(tf.all_variables())
except:pass
print("load data")
saver.restore(persisted_sess, "checkpoint.data") # now OK
print(persisted_result.eval())
print("DONE")
我正在关注有关使用 tensorflow 进行文本分类的 wildml 博客。我更改了代码以保存图形定义,如下所示:
tf.train.write_graph(sess.graph_def,'./DeepLearn/model/','train.pb', as_text=False)
稍后在一个单独的文件中,我将按如下方式恢复图形:
with tf.gfile.FastGFile(os.path.join('./DeepLearn/model/','train.pb'), 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
_ = tf.import_graph_def(graph_def, name='')
with tf.Session() as sess:
t = sess.graph.get_tensor_by_name('embedding/W:0')
sess.run(t)
当我尝试 运行 张量并获取其值时,出现以下错误:
tensorflow.python.framework.errors.FailedPreconditionError: Attempting to use uninitialized value embedding/W
此错误的可能原因是什么。张量应该已经初始化,因为我正在从保存的图表中恢复它。
谢谢亚历山大! 是的,我需要加载图表(来自 .pb 文件)和权重(来自检查点文件)。使用了以下示例代码(取自博客)并且对我有用。
with tf.Session() as persisted_sess:
print("load graph")
with gfile.FastGFile("/tmp/load/test.pb",'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
persisted_sess.graph.as_default()
tf.import_graph_def(graph_def, name='')
persisted_result = persisted_sess.graph.get_tensor_by_name("saved_result:0")
tf.add_to_collection(tf.GraphKeys.VARIABLES,persisted_result)
try:
saver = tf.train.Saver(tf.all_variables())
except:pass
print("load data")
saver.restore(persisted_sess, "checkpoint.data") # now OK
print(persisted_result.eval())
print("DONE")