PyTorch - 如何使用 Avg 2d Pooling 作为数据集转换?

PyTorch - How to use Avg 2d Pooling as a dataset transform?

在 Pytorch 中,我有一个 2D 图像数据集(或者,1 通道图像),我想应用平均 2D 池化作为转换。我该怎么做呢?以下不起作用:

    omniglot_dataset = torchvision.datasets.Omniglot(
        root=data_dir,
        download=True,
        transform=torchvision.transforms.Compose([
            torchvision.transforms.ToTensor(),
            torchvision.transforms.CenterCrop((80, 80)),
            # torchvision.transforms.Resize((10, 10))
            torch.nn.functional.avg_pool2d(kernel_size=3, strides=1),
        ])
    )

转换必须是可调用对象。但是 torch.nn.functional.avg_pool2d 不是 return 一个可调用对象,而只是一个你可以调用来处理的函数,这就是为什么它们被打包在 torch.nn.functional 下的原因所有泛函都接收输入和参数。您需要使用其他版本:

torch.nn.AvgPool2d(kernel_size=3, stride=1)

其中 return 是一个可调用对象,可以调用它来处理给定的输入,例如:

pooler = torch.nn.AvgPool2d(kernel_size=3, stride=1)
output = pooler(input)

通过此处的此更改,您可以看到不同的版本如何使用可调用版本:

import torchvision
import torch
import matplotlib.pyplot as plt

omniglotv1 = torchvision.datasets.Omniglot(
        root='./dataset/',
        download=True,
        transform=torchvision.transforms.Compose([
            torchvision.transforms.ToTensor(),
            torchvision.transforms.CenterCrop((80, 80))
        ])
    )

x1, y = omniglotv1[0]
print(x1.size())      # torch.Size([1, 80, 80])

omniglotv2 = torchvision.datasets.Omniglot(
        root='./dataset/',
        download=True,
        transform=torchvision.transforms.Compose([
            torchvision.transforms.ToTensor(),
            torchvision.transforms.CenterCrop((80, 80)),
            torch.nn.AvgPool2d(kernel_size=3, stride=1)
        ])
    )

x2, y = omniglotv2[0]
print(x2.size())      # torch.Size([1, 78, 78])

pooler = torch.nn.AvgPool2d(kernel_size=3, stride=1)
omniglotv3 = torchvision.datasets.Omniglot(
        root='./dataset/',
        download=True,
        transform=torchvision.transforms.Compose([
            torchvision.transforms.ToTensor(),
            torchvision.transforms.CenterCrop((80, 80)),
            pooler
        ])
    )

x3, y = omniglotv3[0]
print(x3.size())      # torch.Size([1, 78, 78])

在这里,我只是添加了一个用于图像打印的简短代码,以查看转换后的效果:

x_img   = x1.squeeze().cpu().numpy()
ave_img = x2.squeeze().cpu().numpy()
combined = np.zeros((158,80))
combined[0:80,0:80] = x_img
combined[80:,0:78] = ave_img
plt.imshow(combined)
plt.show()

yutasrobot 上面的回答非常令人满意。我在 PyTorch 论坛上收到的另一个答案可以在 https://discuss.pytorch.org/t/how-to-use-avg-2d-pooling-as-a-dataset-transform/117995/2.

找到

"""

您可以使用 transforms.Lambda 调用函数 API:

transform=torchvision.transforms.Compose([
    torchvision.transforms.CenterCrop((80, 80)),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Lambda(lambda x: torch.nn.functional.avg_pool2d(x, kernel_size=3, stride=1)),
])

img = transforms.ToPILImage()(torch.randn(3, 224, 224))
out = transform(img)

"""