TF 数据集迭代器只前进一次而不是两次
TF Dataset iterator advances only once instead of twice
我正在使用Tensorflow 1.4.1,正在学习Tensorflow Dataset API. In the section that describes consuming values from an iterator,有下面的例子
dataset1 = tf.data.Dataset.from_tensor_slices(tf.random_uniform([4, 10]))
dataset2 = tf.data.Dataset.from_tensor_slices((tf.random_uniform([4]), tf.random_uniform([4, 100])))
dataset3 = tf.data.Dataset.zip((dataset1, dataset2))
iterator = dataset3.make_initializable_iterator()
sess.run(iterator.initializer)
next1, (next2, next3) = iterator.get_next()
...使用以下指导语:
Note that evaluating any of next1
, next2
, or next3
will advance the
iterator for all components. A typical consumer of an iterator will
include all components in a single expression.
我决定通过以下简单示例来测试此行为。
import tensorflow as tf
dataset1 = tf.data.Dataset.range(5)
dataset2 = tf.data.Dataset.range(5)
dataset3 = tf.data.Dataset.zip((dataset1, dataset2))
iterator = dataset3.make_initializable_iterator()
with tf.Session() as sess:
sess.run(iterator.initializer)
next1, next2 = iterator.get_next()
A = next1
B = next1 + next2
while True:
try:
a, b = sess.run([A,B])
print(a,b)
except tf.errors.OutOfRangeError:
print('done')
break
如您所见,我在两个表达式 A
和 B
中计算 next1
。根据上面的引述,如果迭代器确实对每个评估都是先进的,我期待迭代器对这两个评估都是先进的,并看到包含
的打印输出
(0, 2)
(2, 6)
然而,我得到的却是:
(0, 0)
(1, 2)
(2, 4)
(3, 6)
(4, 8)
为什么迭代器只前进一次?什么是显示我期望看到的行为的工作示例?
当您在 TensorFlow 图中有一个改变状态的操作(如 iterator.get_next()
)时,经常会出现混淆。规则相当简单:
Each stateful operation in a graph (that is not in a tf.while_loop()
or tf.cond()
) will execute exactly once per Session.run()
call.
应用该规则,图中只有一个 iterator.get_next()
操作,因此迭代器每次 Session.run()
调用只会前进一次,并且相同的元素将用于计算 A
和 B
.
要获得所需的行为,您需要创建第二个 iterator.get_next()
操作。此外,为了获得确定性行为,我们需要确保两个 iterator.get_next()
操作之间存在控制依赖关系。例如,以下程序展示了您想要的行为:
import tensorflow as tf
dataset1 = tf.data.Dataset.range(5)
dataset2 = tf.data.Dataset.range(5)
dataset3 = tf.data.Dataset.zip((dataset1, dataset2))
iterator = dataset3.make_initializable_iterator()
with tf.Session() as sess:
sess.run(iterator.initializer)
next1, next2 = iterator.get_next()
A = next1
# Get a second element from `iterator` and add a control dependency to
# ensure that it is consumed *after* `A` is computed.
with tf.control_dependencies([A]):
next3, next4 = iterator.get_next()
B = next3 + next4
while True:
try:
a, b = sess.run([A,B])
print(a,b)
except tf.errors.OutOfRangeError:
print('done')
break
我正在使用Tensorflow 1.4.1,正在学习Tensorflow Dataset API. In the section that describes consuming values from an iterator,有下面的例子
dataset1 = tf.data.Dataset.from_tensor_slices(tf.random_uniform([4, 10]))
dataset2 = tf.data.Dataset.from_tensor_slices((tf.random_uniform([4]), tf.random_uniform([4, 100])))
dataset3 = tf.data.Dataset.zip((dataset1, dataset2))
iterator = dataset3.make_initializable_iterator()
sess.run(iterator.initializer)
next1, (next2, next3) = iterator.get_next()
...使用以下指导语:
Note that evaluating any of
next1
,next2
, ornext3
will advance the iterator for all components. A typical consumer of an iterator will include all components in a single expression.
我决定通过以下简单示例来测试此行为。
import tensorflow as tf
dataset1 = tf.data.Dataset.range(5)
dataset2 = tf.data.Dataset.range(5)
dataset3 = tf.data.Dataset.zip((dataset1, dataset2))
iterator = dataset3.make_initializable_iterator()
with tf.Session() as sess:
sess.run(iterator.initializer)
next1, next2 = iterator.get_next()
A = next1
B = next1 + next2
while True:
try:
a, b = sess.run([A,B])
print(a,b)
except tf.errors.OutOfRangeError:
print('done')
break
如您所见,我在两个表达式 A
和 B
中计算 next1
。根据上面的引述,如果迭代器确实对每个评估都是先进的,我期待迭代器对这两个评估都是先进的,并看到包含
(0, 2)
(2, 6)
然而,我得到的却是:
(0, 0)
(1, 2)
(2, 4)
(3, 6)
(4, 8)
为什么迭代器只前进一次?什么是显示我期望看到的行为的工作示例?
当您在 TensorFlow 图中有一个改变状态的操作(如 iterator.get_next()
)时,经常会出现混淆。规则相当简单:
Each stateful operation in a graph (that is not in a
tf.while_loop()
ortf.cond()
) will execute exactly once perSession.run()
call.
应用该规则,图中只有一个 iterator.get_next()
操作,因此迭代器每次 Session.run()
调用只会前进一次,并且相同的元素将用于计算 A
和 B
.
要获得所需的行为,您需要创建第二个 iterator.get_next()
操作。此外,为了获得确定性行为,我们需要确保两个 iterator.get_next()
操作之间存在控制依赖关系。例如,以下程序展示了您想要的行为:
import tensorflow as tf
dataset1 = tf.data.Dataset.range(5)
dataset2 = tf.data.Dataset.range(5)
dataset3 = tf.data.Dataset.zip((dataset1, dataset2))
iterator = dataset3.make_initializable_iterator()
with tf.Session() as sess:
sess.run(iterator.initializer)
next1, next2 = iterator.get_next()
A = next1
# Get a second element from `iterator` and add a control dependency to
# ensure that it is consumed *after* `A` is computed.
with tf.control_dependencies([A]):
next3, next4 = iterator.get_next()
B = next3 + next4
while True:
try:
a, b = sess.run([A,B])
print(a,b)
except tf.errors.OutOfRangeError:
print('done')
break