将数据集映射到增强函数不会保留我的原始样本

Mapping dataset to augmentation function does not preserve my original samples

我应该如何实现一个扩充管道,其中我的数据集得到扩展而不是替换图像与扩充样本,也就是说,如何使用地图调用来扩充和保留原始样本?

我检查过的线程: ,

我目前使用的代码:

records_path = DATA_DIR+'/'+'TFRecords'+TRAIN+'train_0.tfrecord'
# Create a dataset
dataset = tf.data.TFRecordDataset(filenames=records_path)
dataset = dataset.map(parsing_fn).cache().map(lambda image, label: (tf.image.central_crop(image,0.5),label))
dataset = dataset.shuffle(100)
dataset = dataset.batch(2)
iterator = tf.compat.v1.data.make_one_shot_iterator(dataset) 

期待上面的代码通过批次迭代我会得到原始图像及其裁剪版本, 除此之外,我想我还没有正确理解 缓存方法的行为方式。

然后我使用下面的代码来展示图像,绘制随机裁剪的图像。

iterator = tf.compat.v1.data.make_one_shot_iterator(dataset) 

for i in range(10):
    image,label = iterator.get_next()
    img_array = image[0].numpy()    
    plt.imshow(img_array)
    plt.show()
    print('label: ', label[0])

    img_array = image[1].numpy()    
    plt.imshow(img_array)
    plt.show()
    print('label: ', label[1])

在您的情况下,cache() 允许在内存中应用 parsing_fn 后保留数据集。它只有助于提高性能。遍历整个数据集后,所有图像都会保存在内存中。因此,下一次迭代会更快,因为您不必再​​次对其应用 parsing_fn

如果您打算在遍历数据集时获取原始图像及其裁剪,您需要做的是 return 在您的 map() 函数中 return 图像及其裁剪:

dataset = dataset.map(parsing_fn).cache().map(lambda image, label: (tf.image.central_crop(image,0.5), image ,label))

然后,在您的迭代中,您可以同时检索:

for i in range(10):
    crop, image, label = iterator.get_next()