将数据集映射到增强函数不会保留我的原始样本
Mapping dataset to augmentation function does not preserve my original samples
我应该如何实现一个扩充管道,其中我的数据集得到扩展而不是替换图像与扩充样本,也就是说,如何使用地图调用来扩充和保留原始样本?
我检查过的线程: ,
我目前使用的代码:
records_path = DATA_DIR+'/'+'TFRecords'+TRAIN+'train_0.tfrecord'
# Create a dataset
dataset = tf.data.TFRecordDataset(filenames=records_path)
dataset = dataset.map(parsing_fn).cache().map(lambda image, label: (tf.image.central_crop(image,0.5),label))
dataset = dataset.shuffle(100)
dataset = dataset.batch(2)
iterator = tf.compat.v1.data.make_one_shot_iterator(dataset)
我期待上面的代码通过批次迭代我会得到原始图像及其裁剪版本, 除此之外,我想我还没有正确理解 缓存方法的行为方式。
然后我使用下面的代码来展示图像,绘制随机裁剪的图像。
iterator = tf.compat.v1.data.make_one_shot_iterator(dataset)
for i in range(10):
image,label = iterator.get_next()
img_array = image[0].numpy()
plt.imshow(img_array)
plt.show()
print('label: ', label[0])
img_array = image[1].numpy()
plt.imshow(img_array)
plt.show()
print('label: ', label[1])
在您的情况下,cache()
允许在内存中应用 parsing_fn
后保留数据集。它只有助于提高性能。遍历整个数据集后,所有图像都会保存在内存中。因此,下一次迭代会更快,因为您不必再次对其应用 parsing_fn
。
如果您打算在遍历数据集时获取原始图像及其裁剪,您需要做的是 return 在您的 map()
函数中 return 图像及其裁剪:
dataset = dataset.map(parsing_fn).cache().map(lambda image, label: (tf.image.central_crop(image,0.5), image ,label))
然后,在您的迭代中,您可以同时检索:
for i in range(10):
crop, image, label = iterator.get_next()
我应该如何实现一个扩充管道,其中我的数据集得到扩展而不是替换图像与扩充样本,也就是说,如何使用地图调用来扩充和保留原始样本?
我检查过的线程:
我目前使用的代码:
records_path = DATA_DIR+'/'+'TFRecords'+TRAIN+'train_0.tfrecord'
# Create a dataset
dataset = tf.data.TFRecordDataset(filenames=records_path)
dataset = dataset.map(parsing_fn).cache().map(lambda image, label: (tf.image.central_crop(image,0.5),label))
dataset = dataset.shuffle(100)
dataset = dataset.batch(2)
iterator = tf.compat.v1.data.make_one_shot_iterator(dataset)
我期待上面的代码通过批次迭代我会得到原始图像及其裁剪版本, 除此之外,我想我还没有正确理解 缓存方法的行为方式。
然后我使用下面的代码来展示图像,绘制随机裁剪的图像。
iterator = tf.compat.v1.data.make_one_shot_iterator(dataset)
for i in range(10):
image,label = iterator.get_next()
img_array = image[0].numpy()
plt.imshow(img_array)
plt.show()
print('label: ', label[0])
img_array = image[1].numpy()
plt.imshow(img_array)
plt.show()
print('label: ', label[1])
在您的情况下,cache()
允许在内存中应用 parsing_fn
后保留数据集。它只有助于提高性能。遍历整个数据集后,所有图像都会保存在内存中。因此,下一次迭代会更快,因为您不必再次对其应用 parsing_fn
。
如果您打算在遍历数据集时获取原始图像及其裁剪,您需要做的是 return 在您的 map()
函数中 return 图像及其裁剪:
dataset = dataset.map(parsing_fn).cache().map(lambda image, label: (tf.image.central_crop(image,0.5), image ,label))
然后,在您的迭代中,您可以同时检索:
for i in range(10):
crop, image, label = iterator.get_next()