重新加载张量流模型
Reloading tensorflow model
我有两个独立的 tensorflow 进程,一个是训练模型并用 tensorflow.python.client.graph_util.convert_variables_to_constants
写出 graph_defs,另一个是用 tensorflow.import_graph_def
读取 graph_def .我希望第二个进程在第一个进程更新时定期重新加载 graph_def 。不幸的是,似乎每次我阅读 graph_def 后,旧的仍然在使用,即使我关闭当前会话并创建一个新会话。我也试过用 sess.graph.as_default()
包装 import_graph_def call
,但无济于事。这是我当前的 graph_def 加载代码:
if self.sess is not None:
self.sess.close()
self.sess = tf.Session()
graph_def = tf.GraphDef()
with open(self.graph_path, 'rb') as f:
graph_def.ParseFromString(f.read())
with self.sess.graph.as_default():
tf.import_graph_def(graph_def, name='')
这里的问题是,当您创建不带参数的 tf.Session
时,它使用 当前默认图 。假设您没有在代码中的其他任何地方创建 tf.Graph
,您将获得在进程启动时创建的全局默认图,并且它在所有会话之间共享。结果,with self.sess.graph.as_default():
没有效果。
很难根据您在问题中显示的片段推荐新结构——特别是,我不知道您是如何创建上一个图表的,也不知道 class 结构是什么——但一个可能是将 self.sess = tf.Session()
替换为以下内容:
self.sess = tf.Session(graph=tf.Graph()) # Creates a new graph for the session.
现在 with self.sess.graph.as_default():
将使用为会话创建的图形,您的程序应该具有预期的效果。
一个更可取的(至少对我而言)替代方案是显式构建图表:
with tf.Graph().as_default() as imported_graph:
tf.import_graph_def(graph_def, ...)
sess = tf.Session(graph=imported_graph)
我有两个独立的 tensorflow 进程,一个是训练模型并用 tensorflow.python.client.graph_util.convert_variables_to_constants
写出 graph_defs,另一个是用 tensorflow.import_graph_def
读取 graph_def .我希望第二个进程在第一个进程更新时定期重新加载 graph_def 。不幸的是,似乎每次我阅读 graph_def 后,旧的仍然在使用,即使我关闭当前会话并创建一个新会话。我也试过用 sess.graph.as_default()
包装 import_graph_def call
,但无济于事。这是我当前的 graph_def 加载代码:
if self.sess is not None:
self.sess.close()
self.sess = tf.Session()
graph_def = tf.GraphDef()
with open(self.graph_path, 'rb') as f:
graph_def.ParseFromString(f.read())
with self.sess.graph.as_default():
tf.import_graph_def(graph_def, name='')
这里的问题是,当您创建不带参数的 tf.Session
时,它使用 当前默认图 。假设您没有在代码中的其他任何地方创建 tf.Graph
,您将获得在进程启动时创建的全局默认图,并且它在所有会话之间共享。结果,with self.sess.graph.as_default():
没有效果。
很难根据您在问题中显示的片段推荐新结构——特别是,我不知道您是如何创建上一个图表的,也不知道 class 结构是什么——但一个可能是将 self.sess = tf.Session()
替换为以下内容:
self.sess = tf.Session(graph=tf.Graph()) # Creates a new graph for the session.
现在 with self.sess.graph.as_default():
将使用为会话创建的图形,您的程序应该具有预期的效果。
一个更可取的(至少对我而言)替代方案是显式构建图表:
with tf.Graph().as_default() as imported_graph:
tf.import_graph_def(graph_def, ...)
sess = tf.Session(graph=imported_graph)