Tensorflow:在 class 和 运行 之外创建图表
Tensorflow: Creating a graph in a class and running it outside
我相信我很难理解图在 tensorflow 中的工作原理以及如何访问它们。我的直觉是 'with graph:' 下的线条将图形作为单个实体形成。因此,我决定创建一个 class ,它将在实例化时构建一个图形,并拥有一个 运行 该图形的函数,如下所示;
class Graph(object):
#To build the graph when instantiated
def __init__(self, parameters ):
self.graph = tf.Graph()
with self.graph.as_default():
...
prediction = ...
cost = ...
optimizer = ...
...
# To launch the graph
def launchG(self, inputs):
with tf.Session(graph=self.graph) as sess:
...
sess.run(optimizer, feed_dict)
loss = sess.run(cost, feed_dict)
...
return variables
接下来的步骤是创建一个主文件,该文件将 assemble 参数传递给 class,构建图表,然后 运行 它;
#Main file
...
parameters_dict = { 'n_input': 28, 'learnRate': 0.001, ... }
#Building graph
G = Graph(parameters_dict)
P = G.launchG(Input)
...
这对我来说非常优雅,但它不太管用(很明显)。确实,似乎 launchG 函数无法访问图中定义的节点,这给我错误,例如;
---> 26 sess.run(optimizer, feed_dict)
NameError: name 'optimizer' is not defined
也许是我的 python(和 tensorflow)理解太有限了,但我有一种奇怪的印象,即创建的图(G),运行与这个会话graph 作为参数应该可以访问其中的节点,而不需要我提供明确的访问权限。
有什么启示吗?
节点prediction
、cost
、optimizer
是方法__init__
中创建的局部变量,不能在方法launchG
中访问.
最简单的解决方法是将它们声明为您的 class Graph
:
的属性
class Graph(object):
#To build the graph when instantiated
def __init__(self, parameters ):
self.graph = tf.Graph()
with self.graph.as_default():
...
self.prediction = ...
self.cost = ...
self.optimizer = ...
...
# To launch the graph
def launchG(self, inputs):
with tf.Session(graph=self.graph) as sess:
...
sess.run(self.optimizer, feed_dict)
loss = sess.run(self.cost, feed_dict)
...
return variables
您还可以使用 graph.get_tensor_by_name
和 graph.get_operation_by_name
的确切名称检索图的节点。
我相信我很难理解图在 tensorflow 中的工作原理以及如何访问它们。我的直觉是 'with graph:' 下的线条将图形作为单个实体形成。因此,我决定创建一个 class ,它将在实例化时构建一个图形,并拥有一个 运行 该图形的函数,如下所示;
class Graph(object):
#To build the graph when instantiated
def __init__(self, parameters ):
self.graph = tf.Graph()
with self.graph.as_default():
...
prediction = ...
cost = ...
optimizer = ...
...
# To launch the graph
def launchG(self, inputs):
with tf.Session(graph=self.graph) as sess:
...
sess.run(optimizer, feed_dict)
loss = sess.run(cost, feed_dict)
...
return variables
接下来的步骤是创建一个主文件,该文件将 assemble 参数传递给 class,构建图表,然后 运行 它;
#Main file
...
parameters_dict = { 'n_input': 28, 'learnRate': 0.001, ... }
#Building graph
G = Graph(parameters_dict)
P = G.launchG(Input)
...
这对我来说非常优雅,但它不太管用(很明显)。确实,似乎 launchG 函数无法访问图中定义的节点,这给我错误,例如;
---> 26 sess.run(optimizer, feed_dict)
NameError: name 'optimizer' is not defined
也许是我的 python(和 tensorflow)理解太有限了,但我有一种奇怪的印象,即创建的图(G),运行与这个会话graph 作为参数应该可以访问其中的节点,而不需要我提供明确的访问权限。
有什么启示吗?
节点prediction
、cost
、optimizer
是方法__init__
中创建的局部变量,不能在方法launchG
中访问.
最简单的解决方法是将它们声明为您的 class Graph
:
class Graph(object):
#To build the graph when instantiated
def __init__(self, parameters ):
self.graph = tf.Graph()
with self.graph.as_default():
...
self.prediction = ...
self.cost = ...
self.optimizer = ...
...
# To launch the graph
def launchG(self, inputs):
with tf.Session(graph=self.graph) as sess:
...
sess.run(self.optimizer, feed_dict)
loss = sess.run(self.cost, feed_dict)
...
return variables
您还可以使用 graph.get_tensor_by_name
和 graph.get_operation_by_name
的确切名称检索图的节点。