在 tensorflow 2.0 beta 中从 tf.data.Dataset 检索下一个元素

retrieving the next element from tf.data.Dataset in tensorflow 2.0 beta

在 tensorflow 2.0-beta 之前,要从 tf.data.Dataset 中检索第一个元素,我们可以使用如下所示的迭代器:

#!/usr/bin/python

import tensorflow as tf

train_dataset = tf.data.Dataset.from_tensor_slices([1.0, 2.0, 3.0, 4.0])
iterator = train_dataset.make_one_shot_iterator()
with tf.Session() as sess:
    # 1.0 will be printed.
    print (sess.run(iterator.get_next()))

在 tensorflow 2.0-beta 中,上面的 one-shot-iterator 似乎已被弃用。要打印出整个元素,我们可以使用以下 for 方法。

#!/usr/bin/python

import tensorflow as tf

train_dataset = tf.data.Dataset.from_tensor_slices([1.0, 2.0, 3.0, 4.0])

for data in train_dataset:
    # 1.0, 2.0, 3.0, and 4.0 will be printed.
    print (data.numpy())

但是,如果我们只想从 tf.data.Dataset 中检索一个元素,那么我们如何使用 tensorflow 2.0 beta 呢?好像不支持next(train_dataset)。如上所示,使用旧的一次性迭代器可以轻松完成,但使用基于 for 的新方法不是很明显。

欢迎任何建议。

您可以 .take(1) 来自数据集:

for elem in train_dataset.take(1):
  print (elem.numpy())

启用即时执行(TF 2.0 中的默认设置)的有效方法是:

elem = next(iter(train_dataset))

说明:数据集有一个 __iter__ 成员函数来支持 for elem in dataset: 方法。这 returns 一个迭代器。 Python 函数 iter 就是这样做的:基本上调用 __iter__ 函数。 next 然后 returns 迭代器产生的第一个元素。

我还没有找到适用于非急切执行的解决方案,因为目前会引发 RuntimeError: __iter__() is only supported inside of tf.function or when eager execution is enabled.