有没有一些简单的方法可以将图像预处理应用于 tf.data.Dataset?

Is there some simple way to apply image preprocess to tf.data.Dataset?

我想知道将图像预处理(例如旋转、水平翻转、填充、裁剪等)应用于 tf.data API

制作的数据的有效方法

我的数据如下:

import tensorflow as tf

// train_data -> numpy array, (50000, 32, 32, 3)
// test_data -> numpy array, (10000, 32, 32, 3)

train_generator = tf.data.Dataset.from_tensor_slices((train_data, train_labels)).shuffle(50000).batch(128)
test_generator = tf.data.Dataset.from_tensor_slices((test_data, test_labels)).batch(128)

那么有什么好的方法可以对我的数据集应用一些预处理吗?

我知道 Keras API 中的 ImageDataGenerator 很简单,但我想知道如何在 tf.data API 中处理图像。

以下是如何应用此类转换的示例:

def parse_data_train(image, label):
    # Function that we will use to parse the training data
    image = tf.image.random_crop(image, [WIDTH, HEIGHT, NUM_CHANNELS])
    image = tf.image.random_flip_left_right(image)

    return image, label

def parse_data_val_test(image, label):
    # Function that we will use to parse the validation/test data
    image = tf.image.resize_with_crop_or_pad(image, WIDTH, HEIGHT)

    return image, label


WIDTH, HEIGHT, NUM_CHANNELS = 10, 10, 3
train_data = np.random.rand(100, 32, 32, 3)
train_labels = np.random.rand(100, 10)
test_data = np.random.rand(10, 32, 32, 3)
test_labels = np.random.rand(10, 10)

# Creating the training dataset
train_dataset = tf.data.Dataset.from_tensor_slices((train_data, train_labels))
# Shuffle of the dataset upon creation and shuffle it after each epoch
train_dataset = train_dataset.shuffle(buffer_size=train_data.shape[0], reshuffle_each_iteration=True)
# Apply the transformations on the dataset
train_dataset = train_dataset.map(parse_data_train)
# Create the batches
train_dataset = train_dataset.batch(10)

# Create the test dataset
test_dataset = tf.data.Dataset.from_tensor_slices((test_data, test_labels))
# No need to shuffle since we just validate/test on this dataset
# Apply the transformations for the validation/test dataset
test_dataset = test_dataset.map(parse_data_val_test)
# Create the batches
test_dataset = test_dataset.batch(10)

总而言之,我建议依赖 tf.data.Dataset.map(),因为您可以轻松创建自己的方法,您可以在其中堆叠要应用于数据集每个样本的转换。