停止 TensorFlow 数据集 `from_generator` 的正确方法?

Proper way to stop a TensorFlow Dataset `from_generator`?

我想使用 from_generator 构建的 TensorFlow 数据集来访问格式化文件。大多数一切正常,除了我不知道如何在生成器用完数据时停止数据集迭代器(当你超出范围时,生成器只是 returns 永远为空列表)。

我的实际代码很复杂,但我可以用这个小程序模拟一下情况:

import tensorflow as tf

def make_batch_generator_fn(batch_size=10, dset_size=100):
    feats, targs = range(dset_size), range(1, dset_size + 1)

    def batch_generator_fn():
        start_idx, stop_idx = 0, batch_size
        while True:
            # if stop_idx > dset_size: --- stop action?
            yield feats[start_idx: stop_idx], targs[start_idx: stop_idx]
            start_idx, stop_idx = start_idx + batch_size, stop_idx + batch_size

    return batch_generator_fn

def test(batch_size=10):
    dgen = make_batch_generator_fn(batch_size)
    features_shape, targets_shape = [None], [None]
    ds = tf.data.Dataset.from_generator(
        dgen, (tf.int32, tf.int32),
        (tf.TensorShape(features_shape), tf.TensorShape(targets_shape))
    )
    feats, targs = ds.make_one_shot_iterator().get_next()

    with tf.Session() as sess:
        counter = 0
        try:
            while True:
                f, t = sess.run([feats, targs])
                print(f, t)
                counter += 1
                if counter > 15:
                    break
        except tf.errors.OutOfRangeError:
            print('end of dataset at counter = {}'.format(counter))

if __name__ == '__main__':
    test()

如果我提前知道记录的数量,我可以调整批次的数量,但我并不总是知道。我尝试将一些代码放在上面的代码片段中,其中我有一个像 stop action? 这样的注释行。特别是,我试过引发 IndexError,但 TensorFlow 不喜欢这样,即使我在执行代码中明确 catch 它也是如此。我也试过引发 tf.errors.OutOfRangeError,但我不确定如何实例化它:构造函数需要三个参数 - 'node_def'、'op' 和 'message',我'我不太确定 'node_def' 和 'op' 通常使用什么。

如果您对此问题有任何想法或意见,我将不胜感激。谢谢!

Return 当您满足停止标准时:

def make_batch_generator_fn(batch_size=10, dset_size=100):
    feats, targs = range(dset_size), range(1, dset_size + 1)

    def batch_generator_fn():
        start_idx, stop_idx = 0, batch_size
        while True:
            if stop_idx > dset_size:
                return
            else:
                yield feats[start_idx: stop_idx], targs[start_idx: stop_idx]
                start_idx, stop_idx = start_idx + batch_size, stop_idx + batch_size

    return batch_generator_fn

这符合 Python 3 documentation:

中指定的行为

In a generator function, the return statement indicates that the generator is done and will cause StopIteration to be raised. The returned value (if any) is used as an argument to construct StopIteration and becomes the StopIteration.value attribute.

它适用于以下几行:

dataset_size = your dataset size
batch_size = your batch size
dataset = your tf.data.Dataset
steps_per_epoch = dataset_size // batch_size

for data, _ in zip(dataset, range(steps_per_epoch)):
    # your train_step

迭代将在完成后停止。