Tensorflow:什么时候用列表在 sess.run 中完成变量赋值?

Tensorflow: When are variable assignments done in sess.run with a list?

我一直以为变量赋值是在给定sess.run的列表中的所有操作之后完成的,但是下面的代码returns不同的执行结果不同。好像是在列表中随机运行操作,在列表操作运行之后赋值给变量

a = tf.Variable(0)
b = tf.Variable(1)
c = tf.Variable(1)
update_a = tf.assign(a, b + c)
update_b = tf.assign(b, c + a)
update_c = tf.assign(c, a + b)

with tf.Session() as sess:
  sess.run(initialize_all_variables)
  for i in range(5):
    a_, b_, c_ = sess.run([update_a, update_b, update_c])

我想知道变量赋值的时间。 哪个是正确的:"update_x -> assign x -> ... -> udpate_z -> assign z" 或 "update_x -> udpate_y -> udpate_z -> assign a, b, c"? (其中 (x, y, z) 是 (a, b, c) 的排列) 另外,如果有后一种赋值的实现方法(链表中的所有操作都完成后才赋值),请告诉我如何实现。

update_aupdate_bupdate_c这三个操作在数据流图中没有相互依赖关系,因此TensorFlow可以选择以任意顺序执行它们。 (在当前的实现中,它们三个都可能在不同的线程上并行执行。)第二个问题是默认情况下会缓存变量的读取,因此在您的程序中, update_b 中分配的值(即 c + a)可能会使用 a 的原始值或更新后的值,具体取决于首次读取变量的时间。

如果您想确保操作以特定顺序发生,您可以在 with tf.control_dependencies([...]): 块中使用 with tf.control_dependencies([...]): blocks to enforce that operations created within the block happen after operations named in the list. You can use tf.Variable.read_value() 来明确读取变量的点。

因此,如果你想确保 update_a 发生在 update_b 之前并且 update_b 发生在 update_c 之前,你可以这样做:

update_a = tf.assign(a, b + c)

with tf.control_dependencies([update_a]):
  update_b = tf.assign(b, c + a.read_value())

with tf.control_dependencies([update_b]):
  update_c = tf.assign(c, a.read_value() + b.read_value())

基于你的这个例子,

v = tf.Variable(0)
c = tf.constant(3)
add = tf.add(v, c)
update = tf.assign(v, add)
mul = tf.mul(add, update)

with tf.Session() as sess:
    sess.run(tf.initialize_all_variables())
    res = sess.run([mul, mul])
    print(res)

输出:[9, 9]

您得到 [9, 9],这实际上是我们要求它做的。可以这样想:

在运行期间,一旦从列表中取出mul,它就会查找this的定义并找到tf.mul(add, update)。现在,它需要导致 tf.add(v, c)add 的值。因此,它插入 vc 的值,得到 add 的值作为 3。

好的,现在我们需要 update 的值定义为 tf.assign(v, add)。我们有 add(它刚才计算为 3)和 v 的值。因此,它将 v 的值更新为 3,这也是 update.

的值

现在,addupdate 的值为 3。因此,乘法在 mul 中产生 9。

根据我们得到的结果,我认为,对于列表中的下一项(操作),它只是 returns 刚刚计算的 mul 的值。我不确定它是再次执行这些步骤还是只是 returns 它刚刚为 mul 计算的相同(缓存?)值意识到我们有结果或这些操作并行发生(对于每个列表中的元素)。也许@mrry 或@YaroslavBulatov 可以对这部分发表评论?


引用@mrry 的评论:

When you call sess.run([x, y, z]) once, TensorFlow executes each op that those tensors depend on one time only (unless there's a tf.while_loop() in your graph). If a tensor appears twice in the list (like mul in your example), TensorFlow will execute it once and return two copies of the result. To run the assignment more than once, you must either call sess.run() multiple times, or use tf.while_loop() to put a loop in your graph.