Tensorflow 数据 API - 预取

Tensorflow Data API - prefetch

我正在尝试使用 TF 的新功能,即数据 API,但我不确定 prefetch 是如何工作的。在下面的代码中

def dataset_input_fn(...)
    dataset = tf.data.TFRecordDataset(filenames, compression_type="ZLIB")
    dataset = dataset.map(lambda x:parser(...))
    dataset = dataset.map(lambda x,y: image_augmentation(...)
                      , num_parallel_calls=num_threads
                     )

    dataset = dataset.shuffle(buffer_size)
    dataset = dataset.batch(batch_size)    
    dataset = dataset.repeat(num_epochs)
    iterator = dataset.make_one_shot_iterator()

我把 dataset=dataset.prefetch(batch_size) 放在上面的每一行之间有关系吗?或者,如果数据集来自 tf.contrib.data?

,也许应该在每个使用 output_buffer_size 的操作之后

github 的讨论中,我发现了 mrry 的评论:

Note that in TF 1.4 there will be a Dataset.prefetch() method that makes it easier to add prefetching at any point in the pipeline, not just after a map(). (You can try it by downloading the current nightly build.)

For example, Dataset.prefetch() will start a background thread to populate a ordered buffer that acts like a tf.FIFOQueue, so that downstream pipeline stages need not block. However, the prefetch() implementation is much simpler, because it doesn't need to support as many different concurrent operations as a tf.FIFOQueue.

所以这意味着可以通过任何命令放置预取,并且它适用于上一个命令。到目前为止,我已经注意到将它放在最后的最大性能提升。

关于 还有一个讨论,其中 mrry 解释了更多关于预取和缓冲区的内容。

更新 2018/10/01

从版本 1.7.0 开始,数据集 API(在贡献中)有一个选项 prefetch_to_device。请注意,此转换必须是管道中的最后一个,当 TF 2.0 到达时 contrib 将消失。要在多个 GPU 上进行预取,请使用 MultiDeviceIterator(示例请参见 #13610) multi_device_iterator_ops.py

https://www.tensorflow.org/versions/master/api_docs/python/tf/contrib/data/prefetch_to_device