cond 可以支持有副作用的 TF 操作吗?
Can cond support TF ops with side effects?
tf.cond
的(源代码)文档不清楚在评估谓词时要执行的功能是否会产生副作用。我做了一些测试,但得到的结果相互矛盾。例如下面的代码不起作用:
import tensorflow as tf
from tensorflow.python.ops import control_flow_ops
pred = tf.placeholder(tf.bool, [])
count = tf.Variable(0)
adder = count.assign_add(1)
subtractor = count.assign_sub(2)
my_op = control_flow_ops.cond(pred, lambda: adder, lambda: subtractor)
sess = tf.InteractiveSession()
tf.initialize_all_variables().run()
my_op.eval(feed_dict={pred: True})
count.eval() # returns -1
my_op.eval(feed_dict={pred: False})
count.eval() # returns -2
即无论谓词的计算结果是什么,两个函数都得到 运行,因此最终结果是减去 1。另一方面,这段代码片段确实有效,唯一的区别是我添加了新的每次调用 my_op
时对图表进行操作:
pred = tf.placeholder(tf.bool, [])
count = tf.Variable(0)
my_op = control_flow_ops.cond(pred, lambda:count.assign_add(1), lambda:count.assign_sub(2))
sess = tf.InteractiveSession()
tf.initialize_all_variables().run()
my_op.eval(feed_dict={pred: False})
count.eval() # returns -2
my_op.eval(feed_dict={pred: True})
count.eval() # returns -1
不确定为什么每次都创建新操作有效而另一种情况却无效,但我显然宁愿不添加节点,因为图形最终会变得太大。
第二种情况有效,因为您在 cond 中添加了操作:这导致它们有条件地执行。
第一种情况相当于说:
adder = (count += 1)
subtractor = (count -= 2)
if (cond) { adder } else { subtractor }
由于加法器和减法器在条件之外,它们总是被执行。
第二种情况更像是在说
if (cond) { adder = (count += 1) } else { subtractor = (count -= 2) }
在这种情况下,它符合您的预期。
我们意识到副作用和(某种程度上)惰性求值之间的相互作用令人困惑,我们的目标是 medium-term 让事情变得更加统一。但现在要了解的重要一点是,我们没有进行真正的惰性评估:条件获取对在任一分支中使用的条件之外定义的每个量的依赖。
你的第二个版本——assign_add()
和 assign_sub()
ops 在传递给 cond()
的 lambda 中创建——是正确的方法。幸运的是,在调用 cond()
期间,两个 lambda 表达式中的每一个都只被评估一次,因此您的图形不会无限制地增长。
本质上 cond()
的作用如下:
创建一个 Switch
节点,它根据 pred
的值将其输入仅转发到两个输出之一。我们将输出称为 pred_true
和 pred_false
。 (它们与 pred
具有相同的值,但这并不重要,因为它从未被直接计算过。)
构建if_true
lambda对应的子图,其中所有节点对pred_true
.
有控制依赖
构建if_false
lambda对应的子图,其中所有节点对pred_false
.
有控制依赖
将来自两个 lambda 的 return 值列表压缩在一起,并为每个创建一个 Merge
节点。 Merge
节点接受两个输入,其中只有一个预计会产生,并将其转发到其输出。
Return 作为 Merge
节点输出的张量。
这意味着您可以 运行 您的第二个版本,并满足于图表保持固定大小,无论您 运行.
多少步
您的第一个版本不起作用的原因是,当 Tensor
被捕获时(如您的示例中的 adder
或 subtractor
),额外的 Switch
节点被添加以强制张量的值仅转发到实际执行的分支的逻辑。这是 TensorFlow 如何在其执行模型中结合 feed-forward 数据流和控制流的产物。结果是捕获的张量(在本例中为 assign_add
和 assign_sub
的结果)将始终被评估,即使它们未被使用,您也会看到它们的副作用。这是我们需要更好地记录的内容,,我们将使其在未来更有用。
tf.cond
的(源代码)文档不清楚在评估谓词时要执行的功能是否会产生副作用。我做了一些测试,但得到的结果相互矛盾。例如下面的代码不起作用:
import tensorflow as tf
from tensorflow.python.ops import control_flow_ops
pred = tf.placeholder(tf.bool, [])
count = tf.Variable(0)
adder = count.assign_add(1)
subtractor = count.assign_sub(2)
my_op = control_flow_ops.cond(pred, lambda: adder, lambda: subtractor)
sess = tf.InteractiveSession()
tf.initialize_all_variables().run()
my_op.eval(feed_dict={pred: True})
count.eval() # returns -1
my_op.eval(feed_dict={pred: False})
count.eval() # returns -2
即无论谓词的计算结果是什么,两个函数都得到 运行,因此最终结果是减去 1。另一方面,这段代码片段确实有效,唯一的区别是我添加了新的每次调用 my_op
时对图表进行操作:
pred = tf.placeholder(tf.bool, [])
count = tf.Variable(0)
my_op = control_flow_ops.cond(pred, lambda:count.assign_add(1), lambda:count.assign_sub(2))
sess = tf.InteractiveSession()
tf.initialize_all_variables().run()
my_op.eval(feed_dict={pred: False})
count.eval() # returns -2
my_op.eval(feed_dict={pred: True})
count.eval() # returns -1
不确定为什么每次都创建新操作有效而另一种情况却无效,但我显然宁愿不添加节点,因为图形最终会变得太大。
第二种情况有效,因为您在 cond 中添加了操作:这导致它们有条件地执行。
第一种情况相当于说:
adder = (count += 1)
subtractor = (count -= 2)
if (cond) { adder } else { subtractor }
由于加法器和减法器在条件之外,它们总是被执行。
第二种情况更像是在说
if (cond) { adder = (count += 1) } else { subtractor = (count -= 2) }
在这种情况下,它符合您的预期。
我们意识到副作用和(某种程度上)惰性求值之间的相互作用令人困惑,我们的目标是 medium-term 让事情变得更加统一。但现在要了解的重要一点是,我们没有进行真正的惰性评估:条件获取对在任一分支中使用的条件之外定义的每个量的依赖。
你的第二个版本——assign_add()
和 assign_sub()
ops 在传递给 cond()
的 lambda 中创建——是正确的方法。幸运的是,在调用 cond()
期间,两个 lambda 表达式中的每一个都只被评估一次,因此您的图形不会无限制地增长。
本质上 cond()
的作用如下:
创建一个
Switch
节点,它根据pred
的值将其输入仅转发到两个输出之一。我们将输出称为pred_true
和pred_false
。 (它们与pred
具有相同的值,但这并不重要,因为它从未被直接计算过。)构建
if_true
lambda对应的子图,其中所有节点对pred_true
. 有控制依赖
构建
if_false
lambda对应的子图,其中所有节点对pred_false
. 有控制依赖
将来自两个 lambda 的 return 值列表压缩在一起,并为每个创建一个
Merge
节点。Merge
节点接受两个输入,其中只有一个预计会产生,并将其转发到其输出。Return 作为
Merge
节点输出的张量。
这意味着您可以 运行 您的第二个版本,并满足于图表保持固定大小,无论您 运行.
多少步您的第一个版本不起作用的原因是,当 Tensor
被捕获时(如您的示例中的 adder
或 subtractor
),额外的 Switch
节点被添加以强制张量的值仅转发到实际执行的分支的逻辑。这是 TensorFlow 如何在其执行模型中结合 feed-forward 数据流和控制流的产物。结果是捕获的张量(在本例中为 assign_add
和 assign_sub
的结果)将始终被评估,即使它们未被使用,您也会看到它们的副作用。这是我们需要更好地记录的内容,