TensorFlow Dataset 的函数 cache() 和 prefetch() 有什么作用?

What do the TensorFlow Dataset's functions cache() and prefetch() do?

我正在学习 TensorFlow 的 Image Segmentation 教程。其中有以下几行:

train_dataset = train.cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE).repeat()
train_dataset = train_dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
  1. cache() 函数有什么作用? official documentation 非常晦涩且自引用:

Caches the elements in this dataset.

  1. prefetch() 函数有什么作用? official documentation 又很晦涩:

Creates a Dataset that prefetches elements from this dataset.

tf.data.Dataset.cache 转换可以在内存或本地存储中缓存数据集。这将避免在每个时期执行一些操作(如文件打开和数据读取)。下一个 epochs 将重用缓存转换缓存的数据。

您可以在 tensorflow here 中找到有关 cache 的更多信息 here

Prefetch 与训练步骤的预处理和模型执行重叠。当模型执行训练步骤 s 时,输入管道正在读取步骤 s+1 的数据。这样做可以将步进时间减少到训练的最大值(而不是总和)以及提取数据所需的时间。

您可以在 tensorflow here 中找到有关 prefetch 的更多信息 here

希望这能回答您的问题。快乐学习。