cond 可以支持有副作用的 TF 操作吗?

Can cond support TF ops with side effects?

tf.cond 的(源代码)文档不清楚在评估谓词时要执行的功能是否会产生副作用。我做了一些测试,但得到的结果相互矛盾。例如下面的代码不起作用:

import tensorflow as tf
from tensorflow.python.ops import control_flow_ops

pred = tf.placeholder(tf.bool, [])
count = tf.Variable(0)
adder = count.assign_add(1)
subtractor = count.assign_sub(2)

my_op = control_flow_ops.cond(pred, lambda: adder, lambda: subtractor)

sess = tf.InteractiveSession()
tf.initialize_all_variables().run()

my_op.eval(feed_dict={pred: True})
count.eval() # returns -1

my_op.eval(feed_dict={pred: False})
count.eval() # returns -2

即无论谓词的计算结果是什么,两个函数都得到 运行,因此最终结果是减去 1。另一方面,这段代码片段确实有效,唯一的区别是我添加了新的每次调用 my_op 时对图表进行操作:

pred = tf.placeholder(tf.bool, [])
count = tf.Variable(0)

my_op = control_flow_ops.cond(pred, lambda:count.assign_add(1), lambda:count.assign_sub(2))

sess = tf.InteractiveSession()
tf.initialize_all_variables().run()

my_op.eval(feed_dict={pred: False})
count.eval() # returns -2

my_op.eval(feed_dict={pred: True})
count.eval() # returns -1

不确定为什么每次都创建新操作有效而另一种情况却无效,但我显然宁愿不添加节点,因为图形最终会变得太大。

第二种情况有效,因为您在 cond 中添加了操作:这导致它们有条件地执行。

第一种情况相当于说:

adder = (count += 1)
subtractor = (count -= 2)
if (cond) { adder } else { subtractor }

由于加法器和减法器在条件之外,它们总是被执行。

第二种情况更像是在说

if (cond) { adder = (count += 1) } else { subtractor = (count -= 2) }

在这种情况下,它符合您的预期。

我们意识到副作用和(某种程度上)惰性求值之间的相互作用令人困惑,我们的目标是 medium-term 让事情变得更加统一。但现在要了解的重要一点是,我们没有进行真正的惰性评估:条件获取对在任一分支中使用的条件之外定义的每个量的依赖。

你的第二个版本——assign_add()assign_sub() ops 在传递给 cond() 的 lambda 中创建——是正确的方法。幸运的是,在调用 cond() 期间,两个 lambda 表达式中的每一个都只被评估一次,因此您的图形不会无限制地增长。

本质上 cond() 的作用如下:

  1. 创建一个 Switch 节点,它根据 pred 的值将其输入仅转发到两个输出之一。我们将输出称为 pred_truepred_false。 (它们与 pred 具有相同的值,但这并不重要,因为它从未被直接计算过。)

  2. 构建if_true lambda对应的子图,其中所有节点对pred_true.

  3. 有控制依赖
  4. 构建if_false lambda对应的子图,其中所有节点对pred_false.

  5. 有控制依赖
  6. 将来自两个 lambda 的 return 值列表压缩在一起,并为每个创建一个 Merge 节点。 Merge 节点接受两个输入,其中只有一个预计会产生,并将其转发到其输出。

  7. Return 作为 Merge 节点输出的张量。

这意味着您可以 运行 您的第二个版本,并满足于图表保持固定大小,无论您 运行.

多少步

您的第一个版本不起作用的原因是,当 Tensor 被捕获时(如您的示例中的 addersubtractor),额外的 Switch 节点被添加以强制张量的值仅转发到实际执行的分支的逻辑。这是 TensorFlow 如何在其执行模型中结合 feed-forward 数据流和控制流的产物。结果是捕获的张量(在本例中为 assign_addassign_sub 的结果)将始终被评估,即使它们未被使用,您也会看到它们的副作用。这是我们需要更好地记录的内容,,我们将使其在未来更有用。