Tensorflow:如何替换计算图中的节点?

Tensorflow: How to replace a node in a calculation graph?

如果您有两个不相交的图,并且想要 link 它们,请转此:

x = tf.placeholder('float')
y = f(x)

y = tf.placeholder('float')
z = f(y)

进入这个:

x = tf.placeholder('float')
y = f(x)
z = g(y)

有办法吗?在某些情况下,它似乎可以使施工更容易。

例如,如果你有一个输入图像为 tf.placeholder 的图形,并且想要优化输入图像,deep-dream 风格,有没有办法只用 [= 替换占位符=13=]节点?还是在构建图形之前必须考虑到这一点?

TL;DR:如果您可以将这两个计算定义为 Python 函数,您应该这样做。如果您不能,TensorFlow 中有更高级的功能来序列化和导入图形,这允许您从不同的来源组合图形。

在 TensorFlow 中执行此操作的一种方法是将不相交的计算构建为单独的 tf.Graph 对象,然后使用 Graph.as_graph_def():

将它们转换为序列化协议缓冲区
with tf.Graph().as_default() as g_1:
  input = tf.placeholder(tf.float32, name="input")
  y = f(input)
  # NOTE: using identity to get a known name for the output tensor.
  output = tf.identity(y, name="output")

gdef_1 = g_1.as_graph_def()

with tf.Graph().as_default() as g_2:  # NOTE: g_2 not g_1       
  input = tf.placeholder(tf.float32, name="input")
  z = g(input)
  output = tf.identity(y, name="output")

gdef_2 = g_2.as_graph_def()

然后您可以使用 tf.import_graph_def():

gdef_1gdef_2 组合成第三个图形
with tf.Graph().as_default() as g_combined:
  x = tf.placeholder(tf.float32, name="")

  # Import gdef_1, which performs f(x).
  # "input:0" and "output:0" are the names of tensors in gdef_1.
  y, = tf.import_graph_def(gdef_1, input_map={"input:0": x},
                           return_elements=["output:0"])

  # Import gdef_2, which performs g(y)
  z, = tf.import_graph_def(gdef_2, input_map={"input:0": y},
                           return_elements=["output:0"]

如果您想合并经过训练的模型(例如,在新模型中重用预训练模型的一部分),您可以使用 Saver 保存第一个模型的检查点,然后恢复该模型(全部或部分)进入另一个模型。

例如,假设您想在模型 2 中重用模型 1 的权重 w,并将 x 从占位符转换为变量:

with tf.Graph().as_default() as g1:
    x = tf.placeholder('float')
    w = tf.Variable(1., name="w")
    y = x * w
    saver = tf.train.Saver()

with tf.Session(graph=g1) as sess:
    w.initializer.run()
    # train...
    saver.save(sess, "my_model1.ckpt")

with tf.Graph().as_default() as g2:
    x = tf.Variable(2., name="v")
    w = tf.Variable(0., name="w")
    z = x + w
    restorer = tf.train.Saver([w]) # only restore w

with tf.Session(graph=g2) as sess:
    x.initializer.run()  # x now needs to be initialized
    restorer.restore(sess, "my_model1.ckpt") # restores w=1
    print(z.eval())  # prints 3.

事实证明,tf.train.import_meta_graph 将所有附加参数传递给具有 input_map 参数的底层 import_scoped_meta_graph,并在它自己(内部)调用 import_graph_def

它没有记录在案,我花了太多时间才找到它,但它确实有效!

实际例子:

import tensorflow as tf
g1 = tf.Graph()
with g1.as_default():
    # set variables/placeholders
    tf.placeholder(tf.int32, [], name='g1_a')
    tf.placeholder(tf.int32, [], name='g1_b')

    # example on exacting tensor by name
    a = g1.get_tensor_by_name('g1_a:0')
    b = g1.get_tensor_by_name('g1_b:0')

    # operation ==>>     c = 2 * 3 = 6
    mul_op = tf.multiply(a, b, name='g1_mul')
    sess = tf.Session()
    g1_mul_results = sess.run(mul_op, feed_dict={'g1_a:0': 2, 'g1_b:0': 3})
    print('graph1 mul = ', g1_mul_results)  # output = 6

    print('\ngraph01 operations/variables:')
    for op in g1.get_operations():
        print(op.name)

g2 = tf.Graph()
with g2.as_default():
    # set variables/placeholders
    tf.import_graph_def(g1.as_graph_def())
    g2_c = tf.placeholder(tf.int32, [], name='g2_c')

    # example on exacting tensor by name
    g1_b = g2.get_tensor_by_name('import/g1_b:0')
    g1_mul = g2.get_tensor_by_name('import/g1_mul:0')

    # operation ==>>
    b = tf.multiply(g1_b, g2_c, name='g2_var_times_g1_a')
    f = tf.multiply(g1_mul, g1_b, name='g1_mul_times_g1_b')

    print('\ngraph01 operations/variables:')
    for op in g2.get_operations():
        print(op.name)
    sess = tf.Session()

    # graph1 variable 'a' times graph2 variable 'c'(graph2)
    ans = sess.run('g2_var_times_g1_a:0', feed_dict={'g2_c:0': 4, 'import/g1_b:0': 5})
    print('\ngraph2 g2_var_times_g1_a = ', ans)  # output = 20

    # graph1 mul_op (a*b) times graph1 variable 'b'
    ans = sess.run('g1_a_times_g1_b:0',
                   feed_dict={'import/g1_a:0': 6, 'import/g1_b:0': 7})
    print('\ngraph2 g1_mul_times_g1_b:0 = ', ans)  # output = (6*7)*7 = 294

''' output
graph1 mul =  6

graph01 operations/variables:
g1_a
g1_b
g1_mul

graph01 operations/variables:
import/g1_a
import/g1_b
import/g1_mul
g2_c
g2_var_times_g1_a
g1_a_times_g1_b

graph2 g2_var_times_g1_a =  20

graph2 g1_a_times_g1_b:0 =  294
'''

参考LINK