tf.all_variables() 相当于各州

tf.all_variables() equivalent for states

tf.all_variables() 给出所有 Graph 变量。是否有所有状态张量的等价物?

背景:我想调试 Graph 的行为。所以一次又一次,我想在给定一些输入馈送的情况下检查状态张量。通常你会在构建图表时定义这些然后

session.run(looking_for_this_tensor, input_feed)

检查它们。

然而,我更愿意有这样的东西:

for v in tf.all_state_tensors_type_of_method():
    print (v.name, ': ', session.run(v, input_feed))

有这样的吗?找遍了也没找到

一种方法是 运行 对图表中的每个操作进行循环,并为每个操作打印其输出。

代码看起来像这样:

graph = tf.get_default_graph()
with tf.Session() as sess:
    for op in graph.get_operations():
        for tensor in op.outputs:
            print tensor.name, ':', sess.run(tensor, feed_dict=input_feed)

警告

用大图和大形状的张量做这件事会导致一团糟,因为它会打印计算中使用的每个个子张量。
您还需要参考 Tensorboard 以获取张量的确切名称,and/or 采用良好的命名约定(例如 tf.name_scope())。

您还可以将要检查的张量添加到集合中。

t1 = ... variable / constant
t2 = ... variable / constant

tf.add_to_collection("inspect", t1)
tf.add_to_collection("inspect", t2)
sess = tf.InteractiveSession()
sess.run(..., feed_dict={..})
for v in tf.get_collection("inspect):
    v.eval()