TF 数据集迭代器只前进一次而不是两次

TF Dataset iterator advances only once instead of twice

我正在使用Tensorflow 1.4.1,正在学习Tensorflow Dataset API. In the section that describes consuming values from an iterator,有下面的例子

dataset1 = tf.data.Dataset.from_tensor_slices(tf.random_uniform([4, 10]))
dataset2 = tf.data.Dataset.from_tensor_slices((tf.random_uniform([4]), tf.random_uniform([4, 100])))
dataset3 = tf.data.Dataset.zip((dataset1, dataset2))

iterator = dataset3.make_initializable_iterator()

sess.run(iterator.initializer)
next1, (next2, next3) = iterator.get_next()

...使用以下指导语:

Note that evaluating any of next1, next2, or next3 will advance the iterator for all components. A typical consumer of an iterator will include all components in a single expression.

我决定通过以下简单示例来测试此行为。

import tensorflow as tf

dataset1 = tf.data.Dataset.range(5)
dataset2 = tf.data.Dataset.range(5)
dataset3 = tf.data.Dataset.zip((dataset1, dataset2))

iterator = dataset3.make_initializable_iterator()

with tf.Session() as sess:
    sess.run(iterator.initializer)
    next1, next2 = iterator.get_next()

    A = next1
    B = next1 + next2

    while True:
        try:
            a, b = sess.run([A,B])
            print(a,b)
        except tf.errors.OutOfRangeError:
            print('done')
            break

如您所见,我在两个表达式 AB 中计算 next1。根据上面的引述,如果迭代器确实对每个评估都是先进的,我期待迭代器对这两个评估都是先进的,并看到包含

的打印输出
(0, 2)
(2, 6)

然而,我得到的却是:

(0, 0)
(1, 2)
(2, 4)
(3, 6)
(4, 8)

为什么迭代器只前进一次?什么是显示我期望看到的行为的工作示例?

当您在 TensorFlow 图中有一个改变状态的操作(如 iterator.get_next())时,经常会出现混淆。规则相当简单:

Each stateful operation in a graph (that is not in a tf.while_loop() or tf.cond()) will execute exactly once per Session.run() call.

应用该规则,图中只有一个 iterator.get_next() 操作,因此迭代器每次 Session.run() 调用只会前进一次,并且相同的元素将用于计算 AB.

要获得所需的行为,您需要创建第二个 iterator.get_next() 操作。此外,为了获得确定性行为,我们需要确保两个 iterator.get_next() 操作之间存在控制依赖关系。例如,以下程序展示了您想要的行为:

import tensorflow as tf

dataset1 = tf.data.Dataset.range(5)
dataset2 = tf.data.Dataset.range(5)
dataset3 = tf.data.Dataset.zip((dataset1, dataset2))

iterator = dataset3.make_initializable_iterator()

with tf.Session() as sess:
    sess.run(iterator.initializer)
    next1, next2 = iterator.get_next()
    A = next1

    # Get a second element from `iterator` and add a control dependency to
    # ensure that it is consumed *after* `A` is computed.
    with tf.control_dependencies([A]):
       next3, next4 = iterator.get_next()
    B = next3 + next4

    while True:
        try:
            a, b = sess.run([A,B])
            print(a,b)
        except tf.errors.OutOfRangeError:
            print('done')
            break