结合条件和控制依赖
Combining conditionals and control dependencies
我正在尝试执行一段条件代码,该代码又依赖于另一个先执行的操作。这项工作的简单版本,如下所示:
x = tf.Variable(0.)
x_op = tf.assign(x, 1.)
with tf.control_dependencies([x_op]):
true_fun = lambda: tf.assign_add(x, 3.)
false_fun = lambda: tf.constant([])
pred = tf.constant(True)
cond_op = control_flow_ops.cond(pred, true_fun, false_fun)
其中评估 cond_op
将 x
设置为预期的 4.0
。然而,这个更复杂的版本不起作用:
def rest(x): tf.gather(x, tf.range(1, tf.size(x)))
x = tf.Variable([0., 1.])
x_op = tf.assign(x, [0., 1., 2.], validate_shape=False)
with tf.control_dependencies([x_op]):
true_fun = lambda: tf.assign(x, rest(x), validate_shape=False)
false_fun = lambda: tf.constant([])
pred = tf.constant(True)
cond_op = control_flow_ops.cond(pred, true_fun, false_fun)
特别是 x
被分配 [1.]
而不是 [1., 2.]
。我要遵循的逻辑是 x
首先分配给 [0., 1., 2.]
,然后 然后 被修剪为 [1., 2.]
。顺便说一下,这似乎与 x
的大小变化有关,因为如果在初始 x_op
分配中 x
被分配 [1., 2.]
而不是 [0., 1., 2.]
,然后评估 cond_op
导致 x
被分配 [2.]
,这是正确的行为。 IE。它首先更新为 [1., 2.]
,然后修剪为 [2.]
。
请注意,with tf.control_dependencies
仅适用于在块内创建的操作。当您在块内调用 rest(x)
时,您所指的 x
仍然是旧的 x
,它是 tf.Variable
函数的 return 值,它只是Tensor
保存变量的初始值。您可以通过调用 rest(x_op)
来传递新值。这里是完整的工作片段:
import tensorflow as tf
def rest(x): return tf.gather(x, tf.range(1, tf.size(x)))
x = tf.Variable([0., 1.])
x_op = tf.assign(x, [0., 1., 2.], validate_shape=False)
with tf.control_dependencies([x_op]):
true_fun = lambda: tf.assign(x, rest(x_op), validate_shape=False)
false_fun = lambda: tf.constant([])
pred = tf.constant(True)
cond_op = tf.cond(pred, true_fun, false_fun)
with tf.Session(""):
x.initializer.run()
print(cond_op.eval())
我正在尝试执行一段条件代码,该代码又依赖于另一个先执行的操作。这项工作的简单版本,如下所示:
x = tf.Variable(0.)
x_op = tf.assign(x, 1.)
with tf.control_dependencies([x_op]):
true_fun = lambda: tf.assign_add(x, 3.)
false_fun = lambda: tf.constant([])
pred = tf.constant(True)
cond_op = control_flow_ops.cond(pred, true_fun, false_fun)
其中评估 cond_op
将 x
设置为预期的 4.0
。然而,这个更复杂的版本不起作用:
def rest(x): tf.gather(x, tf.range(1, tf.size(x)))
x = tf.Variable([0., 1.])
x_op = tf.assign(x, [0., 1., 2.], validate_shape=False)
with tf.control_dependencies([x_op]):
true_fun = lambda: tf.assign(x, rest(x), validate_shape=False)
false_fun = lambda: tf.constant([])
pred = tf.constant(True)
cond_op = control_flow_ops.cond(pred, true_fun, false_fun)
特别是 x
被分配 [1.]
而不是 [1., 2.]
。我要遵循的逻辑是 x
首先分配给 [0., 1., 2.]
,然后 然后 被修剪为 [1., 2.]
。顺便说一下,这似乎与 x
的大小变化有关,因为如果在初始 x_op
分配中 x
被分配 [1., 2.]
而不是 [0., 1., 2.]
,然后评估 cond_op
导致 x
被分配 [2.]
,这是正确的行为。 IE。它首先更新为 [1., 2.]
,然后修剪为 [2.]
。
请注意,with tf.control_dependencies
仅适用于在块内创建的操作。当您在块内调用 rest(x)
时,您所指的 x
仍然是旧的 x
,它是 tf.Variable
函数的 return 值,它只是Tensor
保存变量的初始值。您可以通过调用 rest(x_op)
来传递新值。这里是完整的工作片段:
import tensorflow as tf
def rest(x): return tf.gather(x, tf.range(1, tf.size(x)))
x = tf.Variable([0., 1.])
x_op = tf.assign(x, [0., 1., 2.], validate_shape=False)
with tf.control_dependencies([x_op]):
true_fun = lambda: tf.assign(x, rest(x_op), validate_shape=False)
false_fun = lambda: tf.constant([])
pred = tf.constant(True)
cond_op = tf.cond(pred, true_fun, false_fun)
with tf.Session(""):
x.initializer.run()
print(cond_op.eval())