在 pytorch 中绘制转换后的(增强的)图像

Plot the transformed (augmented) images in pytorch

我想使用一种图像增强技术(例如旋转或水平翻转)并将其应用于 CIFAR-10 数据集的一些图像并在 PyTorch 中绘制它们。

我知道我们可以使用下面的代码来增强图像:

from torchvision import models, datasets, transforms
from torchvision.datasets import CIFAR10

data_transforms = transforms.Compose([
        # add augmentations
        transforms.RandomHorizontalFlip(p=0.5),
        # The output of torchvision datasets are PILImage images of range [0, 1].
        # We transform them to Tensors of normalized range [-1, 1]
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])

然后当我想加载 Cifar10 数据集时,我使用了上面的转换:

train_set = CIFAR10(
    root='./data/',
    train=True,
    download=True,
    transform=data_transforms['train'])

据我所知,当使用这段代码时,所有 CIFAR10 数据集都被转换。

问题

我的问题是如何对数据集中的某些图像使用数据转换或增强技术并绘制它们?例如 10 张图像及其增强图像。

when this code is used, all CIFAR10 datasets are transformed

实际上,仅当用户通过 __getitem__ 函数或数据加载器获取数据集中的图像时,才会调用转换管道。所以在这个时间点,train_set 不包含增强图像,它们是动态转换的。


您将需要构建另一个没有扩充的数据集。

>>> non_augmented = CIFAR10(
...     root='./data/',
...     train=True,
...     download=True)

>>> train_set = CIFAR10(
...     root='./data/',
...     train=True,
...     download=True,
...     transform=data_transforms)

将一些图像堆叠在一起:

>>> imgs = torch.stack((*[non_augmented[i][0] for i in range(10)],
                        *[train_set[i][0] for i in range(10)]))

>>> imgs.shape
torch.Size([20, 3, 32, 32])

然后 torchvision.utils.make_grid 可用于创建所需的布局:

>>> grid = torchvision.utils.make_grid(imgs, nrow=10)

给你!

>>> transforms.ToPILImage()(grid)