对 `tf.cond` 的行为感到困惑

Confused by the behavior of `tf.cond`

我的图表中需要一个条件控制流。如果 predTrue,该图应该调用更新变量的操作,然后 returns 它,否则它 returns 变量不变。简化版是:

pred = tf.constant(True)
x = tf.Variable([1])
assign_x_2 = tf.assign(x, [2])
def update_x_2():
  with tf.control_dependencies([assign_x_2]):
    return tf.identity(x)
y = tf.cond(pred, update_x_2, lambda: tf.identity(x))
with tf.Session() as session:
  session.run(tf.initialize_all_variables())
  print(y.eval())

但是,我发现 pred=Truepred=False 都会导致相同的结果 y=[2],这意味着当未选择 update_x_2 时也会调用赋值操作通过 tf.cond。怎么解释呢?以及如何解决这个问题?

TL;DR: 如果您希望 tf.cond() 在其中一个分支中执行副作用(如赋值),您必须创建执行的操作传递给 tf.cond().

的函数 内部 的副作用

tf.cond() 的行为有点不直观。由于 TensorFlow 图中的执行在图中向前流动,因此您在 either 分支中引用的所有操作都必须在评估条件之前执行。这意味着 true 和 false 分支都接收到对 tf.assign() op 的控制依赖性,因此 y 总是设置为 2,即使 pred 是 False

解决方案是在定义真正分支的函数内创建 tf.assign() 操作。例如,您可以按如下方式构建代码:

pred = tf.placeholder(tf.bool, shape=[])
x = tf.Variable([1])
def update_x_2():
  with tf.control_dependencies([tf.assign(x, [2])]):
    return tf.identity(x)
y = tf.cond(pred, update_x_2, lambda: tf.identity(x))
with tf.Session() as session:
  session.run(tf.initialize_all_variables())
  print(y.eval(feed_dict={pred: False}))  # ==> [1]
  print(y.eval(feed_dict={pred: True}))   # ==> [2]
pred = tf.constant(False)
x = tf.Variable([1])

def update_x_2():
    assign_x_2 = tf.assign(x, [2])
    with tf.control_dependencies([assign_x_2]):
        return tf.identity(x)
y = tf.cond(pred, update_x_2, lambda: tf.identity(x))
with tf.Session() as session:
  session.run(tf.initialize_all_variables())
  print(y.eval())

这将得到 [1] 的结果。

这个答案和上面的答案完全一样。但我想分享的是你可以把你想用的每一个操作都放在它的分支函数中。因为,根据您的示例代码,张量 x 可以直接由 update_x_2 函数使用。