tf.scatter_update() 如何在 while_loop() 内部工作
How does the tf.scatter_update() work inside the while_loop()
我正在尝试使用 tf.scatter_update()
更新 tf.while_loop()
中的 tf.Variable
。但是,结果是初始值而不是更新后的值。这是我正在尝试做的示例代码:
from __future__ import print_function
import tensorflow as tf
def cond(sequence_len, step):
return tf.less(step,sequence_len)
def body(sequence_len, step):
begin = tf.get_variable("begin",[3],dtype=tf.int32,initializer=tf.constant_initializer(0))
begin = tf.scatter_update(begin,1,step,use_locking=None)
tf.get_variable_scope().reuse_variables()
return (sequence_len, step+1)
with tf.Graph().as_default():
sess = tf.Session()
step = tf.constant(0)
sequence_len = tf.constant(10)
_,step, = tf.while_loop(cond,
body,
[sequence_len, step],
parallel_iterations=10,
back_prop=True,
swap_memory=False,
name=None)
begin = tf.get_variable("begin",[3],dtype=tf.int32)
init = tf.initialize_all_variables()
sess.run(init)
print(sess.run([begin,step]))
结果是:[array([0, 0, 0], dtype=int32), 10]
。不过,我觉得结果应该是[0, 0, 10]
。我是不是做错了什么?
这里的问题是循环体中没有任何内容依赖于您的 tf.scatter_update()
操作,因此它永远不会被执行。使其工作的最简单方法是将对更新的控制依赖添加到 return 值:
def body(sequence_len, step):
begin = tf.get_variable("begin",[3],dtype=tf.int32,initializer=tf.constant_initializer(0))
begin = tf.scatter_update(begin, 1, step, use_locking=None)
tf.get_variable_scope().reuse_variables()
with tf.control_dependencies([begin]):
return (sequence_len, step+1)
请注意,此问题并非 TensorFlow 中的循环独有。如果您刚刚定义了一个名为 begin
的 tf.scatter_update()
操作,但在其上调用了 sess.run()
或依赖它的东西,则更新不会发生。当您使用 tf.while_loop()
时,无法直接 运行 循环体中定义的操作,因此产生副作用的最简单方法是添加控件依赖项。
注意最后的结果是[0, 9, 0]
:每次迭代都把当前步赋值给begin[1]
,最后一次迭代时当前步的值是9
(条件step == 10
).
时为假
我正在尝试使用 tf.scatter_update()
更新 tf.while_loop()
中的 tf.Variable
。但是,结果是初始值而不是更新后的值。这是我正在尝试做的示例代码:
from __future__ import print_function
import tensorflow as tf
def cond(sequence_len, step):
return tf.less(step,sequence_len)
def body(sequence_len, step):
begin = tf.get_variable("begin",[3],dtype=tf.int32,initializer=tf.constant_initializer(0))
begin = tf.scatter_update(begin,1,step,use_locking=None)
tf.get_variable_scope().reuse_variables()
return (sequence_len, step+1)
with tf.Graph().as_default():
sess = tf.Session()
step = tf.constant(0)
sequence_len = tf.constant(10)
_,step, = tf.while_loop(cond,
body,
[sequence_len, step],
parallel_iterations=10,
back_prop=True,
swap_memory=False,
name=None)
begin = tf.get_variable("begin",[3],dtype=tf.int32)
init = tf.initialize_all_variables()
sess.run(init)
print(sess.run([begin,step]))
结果是:[array([0, 0, 0], dtype=int32), 10]
。不过,我觉得结果应该是[0, 0, 10]
。我是不是做错了什么?
这里的问题是循环体中没有任何内容依赖于您的 tf.scatter_update()
操作,因此它永远不会被执行。使其工作的最简单方法是将对更新的控制依赖添加到 return 值:
def body(sequence_len, step):
begin = tf.get_variable("begin",[3],dtype=tf.int32,initializer=tf.constant_initializer(0))
begin = tf.scatter_update(begin, 1, step, use_locking=None)
tf.get_variable_scope().reuse_variables()
with tf.control_dependencies([begin]):
return (sequence_len, step+1)
请注意,此问题并非 TensorFlow 中的循环独有。如果您刚刚定义了一个名为 begin
的 tf.scatter_update()
操作,但在其上调用了 sess.run()
或依赖它的东西,则更新不会发生。当您使用 tf.while_loop()
时,无法直接 运行 循环体中定义的操作,因此产生副作用的最简单方法是添加控件依赖项。
注意最后的结果是[0, 9, 0]
:每次迭代都把当前步赋值给begin[1]
,最后一次迭代时当前步的值是9
(条件step == 10
).