在 tensorflow 2.0 beta 中从 tf.data.Dataset 检索下一个元素
retrieving the next element from tf.data.Dataset in tensorflow 2.0 beta
在 tensorflow 2.0-beta 之前,要从 tf.data.Dataset 中检索第一个元素,我们可以使用如下所示的迭代器:
#!/usr/bin/python
import tensorflow as tf
train_dataset = tf.data.Dataset.from_tensor_slices([1.0, 2.0, 3.0, 4.0])
iterator = train_dataset.make_one_shot_iterator()
with tf.Session() as sess:
# 1.0 will be printed.
print (sess.run(iterator.get_next()))
在 tensorflow 2.0-beta 中,上面的 one-shot-iterator 似乎已被弃用。要打印出整个元素,我们可以使用以下 for 方法。
#!/usr/bin/python
import tensorflow as tf
train_dataset = tf.data.Dataset.from_tensor_slices([1.0, 2.0, 3.0, 4.0])
for data in train_dataset:
# 1.0, 2.0, 3.0, and 4.0 will be printed.
print (data.numpy())
但是,如果我们只想从 tf.data.Dataset 中检索一个元素,那么我们如何使用 tensorflow 2.0 beta 呢?好像不支持next(train_dataset)
。如上所示,使用旧的一次性迭代器可以轻松完成,但使用基于 for 的新方法不是很明显。
欢迎任何建议。
您可以 .take(1)
来自数据集:
for elem in train_dataset.take(1):
print (elem.numpy())
启用即时执行(TF 2.0 中的默认设置)的有效方法是:
elem = next(iter(train_dataset))
说明:数据集有一个 __iter__
成员函数来支持 for elem in dataset:
方法。这 returns 一个迭代器。 Python 函数 iter
就是这样做的:基本上调用 __iter__
函数。 next
然后 returns 迭代器产生的第一个元素。
我还没有找到适用于非急切执行的解决方案,因为目前会引发 RuntimeError: __iter__() is only supported inside of tf.function or when eager execution is enabled.
在 tensorflow 2.0-beta 之前,要从 tf.data.Dataset 中检索第一个元素,我们可以使用如下所示的迭代器:
#!/usr/bin/python
import tensorflow as tf
train_dataset = tf.data.Dataset.from_tensor_slices([1.0, 2.0, 3.0, 4.0])
iterator = train_dataset.make_one_shot_iterator()
with tf.Session() as sess:
# 1.0 will be printed.
print (sess.run(iterator.get_next()))
在 tensorflow 2.0-beta 中,上面的 one-shot-iterator 似乎已被弃用。要打印出整个元素,我们可以使用以下 for 方法。
#!/usr/bin/python
import tensorflow as tf
train_dataset = tf.data.Dataset.from_tensor_slices([1.0, 2.0, 3.0, 4.0])
for data in train_dataset:
# 1.0, 2.0, 3.0, and 4.0 will be printed.
print (data.numpy())
但是,如果我们只想从 tf.data.Dataset 中检索一个元素,那么我们如何使用 tensorflow 2.0 beta 呢?好像不支持next(train_dataset)
。如上所示,使用旧的一次性迭代器可以轻松完成,但使用基于 for 的新方法不是很明显。
欢迎任何建议。
您可以 .take(1)
来自数据集:
for elem in train_dataset.take(1):
print (elem.numpy())
启用即时执行(TF 2.0 中的默认设置)的有效方法是:
elem = next(iter(train_dataset))
说明:数据集有一个 __iter__
成员函数来支持 for elem in dataset:
方法。这 returns 一个迭代器。 Python 函数 iter
就是这样做的:基本上调用 __iter__
函数。 next
然后 returns 迭代器产生的第一个元素。
我还没有找到适用于非急切执行的解决方案,因为目前会引发 RuntimeError: __iter__() is only supported inside of tf.function or when eager execution is enabled.