如何从 Tensorflow tf.data.Dataset 中无休止地阅读?

How can I read endlessly from a Tensorflow tf.data.Dataset?

我正在将我的旧数据层(使用队列)切换到 "new" 并推荐数据集 API。我是第一次使用它,所以我提供了代码示例以防我遇到根本性错误。

我从一个生成器创建我的数据集(它将读取一个文件,并提供 n 个样本)。这是一个小数据集 n_iterations >> n_samples,所以我只是想一遍又一遍地阅读这个数据集,最好是打乱顺序。

sample_set = tf.data.Dataset.from_generator( data_generator(filename),  
    (tf.uint8, tf.uint8), (tf.TensorShape([256,256,4]), tf.TensorShape([256,256,1]))
)

使用数据生成器:

class data_generator:
    def __init__(self, filename):
        self.filename= filename

    def __call__(self):
        with filename.open() as f:
           for idx in f: yield img[idx], label[idx]

要实际使用数据,我知道我需要定义一个 Iterator

sample = sample_set.make_one_shot_iterator().get_next()

然后我们设置读取数据

while True:
    try: my_sample = sess.run(sample)
    except tf.errors.OutOfRangeError: break   # this happens after dset is read once

但所有可用的迭代器似乎都是 "finite",因为它们只读取数据集一次。

是否有一种简单的方法可以无限地读取数据集?

reinitializable Iterator 将在同一数据集上重新初始化,因此此代码将一遍又一遍地读取同一数据集:

sample = tf.data.Iterator.from_structure(sample_set.output_types,
                                         sample_set.output_shapes).get_next()

sample_it.make_initializer(sample_set)     # create initialize op

with tf.Session(config=config) as sess:
    sess.run(sample_set_init_op)           # initialize in the beginning

    while True:
        try: 
             my_sample = sess.run(sample)
        except tf.errors.OutOfRangeError:
             sess.run(sample_set_init_op)  # re-initialize on same dataset

数据集有 repeat and shuffle 方法。

BUF_SIZE = 100 # choose it depending on your data
sample_set = tf.data.Dataset.from_generator( data_generator(filename),  
    (tf.uint8, tf.uint8), (tf.TensorShape([256,256,4]), 
    tf.TensorShape([256,256,1]))
).repeat().shuffle(BUF_SIZE)

如果您不向其传递显式 countDataset.repeat() 转换将无休止地重复数据集:

sample_set = tf.data.Dataset.from_generator(
    data_generator(filename), (tf.uint8, tf.uint8),
    (tf.TensorShape([256,256,4]), tf.TensorShape([256,256,1])))

# Repeats `sample_set` endlessly.
sample_set = sample_set.repeat()

sample = sample_set.make_one_shot_iterator().get_next()