如何使用tensorflow的函数merge和switch?

How to use the function merge and switch of tensorflow?

mergeswitch 可能不开放给一般用户使用。我已经搜索了源代码:

merge中有说明:

Returns the value of an available element of inputs.

可用是什么意思?是switch返回的吗?这是一个演示:

from tensorflow.python.ops import control_flow_ops

x_0, x_1 = control_flow_ops.switch(tf.constant(2), False)
x_2, x_3 = control_flow_ops.switch(tf.constant(7), True)
y = control_flow_ops.merge([x_0, x_1, x_2, x_3])
with tf.Session() as sess:
    print(sess.run(y))

switch

让我们从检查 control_flow_ops.switch 函数开始:

x_0, x_1 = control_flow_ops.switch(tf.constant(2), False)
x_2, x_3 = control_flow_ops.switch(tf.constant(7), True)
with tf.Session() as sess:
  print(sess.run(x_0))    # prints 2
  print(sess.run(x_3))    # prints 7

control_flow_ops.switch returns 张量元组,但只有其中一个 具有值 (取决于条件参数)。在上面的示例中,第一个 switchx_0 = 2,第二个 x_3 = 7x_3 = 7。尝试计算 x_1x_2 将导致 Retval does not have value 错误:

  sess.run(x_1)  # FAILS!
  sess.run(x_2)  # FAILS!

换句话说,x_0x_3 可用,而 x_1x_2 不可用。

merge

control_flow_ops.merge 执行逆运算:给定一个张量元组,它选择可用的一个。准确地说,它 returns 一个具有值的张量的命名元组 ["output", "value_index"]。根据当前文档,输入应包含 恰好一个 可用张量,这意味着您的演示严格来说不受支持并导致未定义的行为。这是一个例子:

with tf.Session() as sess:
  print(sess.run(merge([x_0, x_1])))       # Merge(output=2, value_index=0)
  print(sess.run(merge([x_1, x_0])))       # Merge(output=2, value_index=1)
  print(sess.run(merge([x_2, x_3])))       # Merge(output=7, value_index=1)
  print(sess.run(merge([x_3, x_2])))       # Merge(output=7, value_index=0)
  print(sess.run(merge([x_0, x_1, x_2])))  # Merge(output=2, value_index=0)
  print(sess.run(merge([x_1, x_2, x_3])))  # Merge(output=7, value_index=2)

这两个函数都可以方便地控制计算流程,例如control_flow_ops.switch 渐变是通过 switch 本身实现的 (tensorflow source code).

也许你可以试试这个演示。

import tensorflow as tf
from tensorflow.python.ops import control_flow_ops

x_0, x_1 = control_flow_ops.switch(tf.constant(2), False)
x_2, x_3 = control_flow_ops.switch(tf.constant(7), True)
with tf.Session() as sess:
  print("anchor, output:{}".format(sess.run(x_0)))    # prints 2
  print("anchor, output:{}".format(sess.run(x_3)))    # prints 7

merge_0 = control_flow_ops.merge([x_0, x_2])
with tf.Session() as sess:
  print("anchor, output:{}".format(sess.run(merge_0)))    # Merge(output=2, value_index=1)