在 PyTorch 中显示增强图像的示例

Display examples of augmented images in PyTorch

我想显示一些增强训练图像样本。

我的转换包括标准 ImageNet transforms.Normalize,如下所示:

train_transforms = transforms.Compose([transforms.RandomRotation(30),
                                       transforms.RandomResizedCrop(224),
                                       transforms.RandomHorizontalFlip(),
                                       transforms.ToTensor(),
                                       transforms.Normalize([0.485, 0.456, 0.406],
                                                            [0.229, 0.224, 0.225])])

但是,由于 Normalise,图像显示的颜色很奇怪。

This answer 说我需要访问原始图像,这在加载时应用转换时很困难:

image_datasets['train'] = datasets.ImageFolder(train_dir, transform=train_transforms)

我将如何在使用标准化图像进行计算的同时以它们通常的颜色显示一些示例增强图像?

我已经遇到了同样的问题。我的解决方案是创建一个不同的 torch.Dataset 具有数据扩充但没有规范化。

Here I create the Dataset. Here 我有一个 class 实现了增强功能。我有两个成员:self.tf_augmentself.tf_transform。前者仅应用数据增强,后者应用数据增强和归一化。

我建议两个选项:

  1. 创建一个单独的 "transformation" 阶段来显示图像并进一步传递它而不做任何更改。免费奖励是您可以在转换列表的任何阶段插入。
    import cv2
    import numpy as np
    def TransformShow(name="img", wait=100):
        def transform_show(img):
            cv2.imshow(name, np.array(img))
            cv2.waitKey(wait)
            return img
        return transform_show

ToTensor() 之前插入这个 "transformer":

                                       transforms.RandomHorizontalFlip(),
                                       TransformShow("window_name", delay_in_ms),
                                       transforms.ToTensor(),

使用零 delay_in_ms 等待按键。

我这里使用OpenCV来显示图像。它也可以只用 Pillow/PIL 来完成,但我不喜欢它的处理方式。

  1. 撤消规范化并显示图像。
def show_image(img, name="img", wait=100):
    mean = np.array([0.485, 0.456, 0.406])
    std =  np.array([0.229, 0.224, 0.225])
    cv2.imshow(name, img.cpu().numpy().transpose((1,2,0)) * std + mean)
    cv2.waitKey(wait)

然后称其为

        show_image(data[0], "unaug", 1)
  1. 最后一种方法可以用快速的两行近似,但颜色有些扭曲:
    cv2.imshow("approx", data[0].cpu().numpy().transpose((1,2,0)) * 0.225 + 0.45)
    cv2.waitKey(10)

只需撤消归一化操作即可,即乘以标准差并相加 均值。

请看pytorch教程中的imshow方法: https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html#visualize-a-few-images

为了回答我自己的问题,我想到了以下内容:

# Undo transforms.Normalize
def denormalise(image):
    image = image.numpy().transpose(1, 2, 0)  # PIL images have channel last
    mean = [0.485, 0.456, 0.406]
    stdd = [0.229, 0.224, 0.225]
    image = (image * stdd + mean).clip(0, 1)
    return image


example_rows = 2
example_cols = 5

sampler = torch.utils.data.RandomSampler(image_datasets['train'],
                                         num_samples=example_rows * example_cols)

# Get a batch of images and labels  
images, indices = next(iter(sampler)) 

plt.rcParams['figure.dpi'] = 120  # Increase size of pyplot plots

# Show a grid of example images    
fig, axes = plt.subplots(example_rows, example_cols, figsize=(9, 5)) #  sharex=True, sharey=True)
axes = axes.flatten()
for ax, image, index in zip(axes, images, indices):
    ax.imshow(denormalise(image))
    ax.set_axis_off()
    ax.set_title(class_names[index], fontsize=7)

fig.subplots_adjust(wspace=0.02, hspace=0)
fig.suptitle('Augmented training set images', fontsize=20)
plt.show()

这是基于 PyTorch's Transfer Learning Tutorial's code 但在每张图片上方显示标题并且通常看起来更好。