如何正确地将数据扩充应用于 TFRecord 数据集?

How do I correctly apply data augmentation to a TFRecord Dataset?

我正在尝试在解析后对 TFRecord 数据集应用数据扩充。但是,当我在映射增强函数之前和之后检查数据集的大小时,大小是相同的。我知道解析函数正在运行并且数据集是正确的,因为我已经使用它们来训练模型。所以我只包含了映射函数的代码,然后计算示例。

这是我使用的代码:

num_ex = 0

def flip_example(image, label):
    flipped_image = flip(image)
    return flipped_image, label


dataset = tf.data.TFRecordDataset('train.TFRecord').map(parse_function)
for x in dataset:
    num_ex += 1

num_ex = 0
dataset = dataset.map(flip_example)

#Size of dataset
for x in dataset:
    num_ex += 1

在这两种情况下,num_ex = 324 而不是预期的非增强 324 和增强 648。我还成功地测试了翻转函数,所以问题似乎出在函数与数据集的交互方式上。我该如何正确实施这种增强?

当您使用 tf.data API 应用数据扩充时,它是 即时 完成的,这意味着每个示例都按实施方式进行转换在你的方法中。以这种方式增加数据并不意味着管道中的示例数量会发生变化。

如果您想使用每个示例 n 次,只需添加 dataset = dataset.repeat(count=n)。您可能需要更新代码以使用 tf.image.random_flip_left_right,否则每次翻转都以相同的方式完成。

在您第二次检查 num_ex 的示例中,数据集仅包含翻转图像,因此 324。 此外,如果您有一个大于 324 的大型数据集,您可能需要研究在线数据扩充。在这种情况下,在训练过程中,数据集在每个时期都会以不同的方式进行扩充,并且您只能在扩充数据上进行训练,而不是在原始数据集上进行训练。这有助于训练好的模型更好地泛化。 (https://www.tensorflow.org/tutorials/images/data_augmentation)