TensorFlow TFRecordDataset 随机播放 buffer_size 行为

TensorFlow TFRecordDataset shuffle buffer_size behavior

我不清楚 tf.TFRecordDataset 中的 buffer_size 参数的作用。假设我们有以下代码:

dataset = dataset.shuffle(buffer_size=10000).repeat().batch(batch_size)

这是否意味着只有前 10k 个样本将被使用并永远重复,或者我将遍历整个数据集?如果不是,它到底是做什么用的?那么这段代码呢?

dataset = dataset.repeat().shuffle(buffer_size=10000).batch(batch_size)

我注意到 ,但它没有说明 buffer_size

这个 可能有助于更好地理解 shuffle 方法的 buffer_size 参数。

简而言之,数据集的缓冲区中总是有超过 buffer_size 个元素,并且每次添加元素时都会随机播放此缓冲区。

因此,缓冲区大小为 1 就像不洗牌一样,缓冲区长度为数据集的长度就像传统的洗牌一样。


要了解洗牌和重复数据集之间的正确顺序,请查看官方 performance guide

最佳做法通常是先洗牌再重复,因为这将确保您在每个时期都能看到整个数据集。