MNIST 数据集上的 Pytorch 转换

Pytorch transformation on MNIST dataset

我目前有一个弱监督项目,我需要在数据集前面放一个 "masking"。我现在的问题是我不知道该怎么做。让我用一些代码和图像进一步解释。

我正在使用必须以 this 方式编辑的 MNIST 数据集。如您所见,中间的正方形被切掉了。下面的代码用于使用 for 循环编辑 MNIST。

for i in range(int(image_size/2-5),int(image_size/2+3)):
   for j in range(int(image_size/2-5),int(image_size/2+3)):
      image[i][j] = 0

但是,我目前不确定应该如何在数据加载器转换中使用它。数据加载器和转换的代码如下所示:

transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])

train_dataset = torchvision.datasets.MNIST(
    root="~/torch_datasets", train=True, transform=transform, download=True
)

test_dataset = torchvision.datasets.MNIST(
    root="~/torch_datasets", train=False, transform=transform, download=True
)

train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=128, shuffle=True, num_workers=4, pin_memory=True
)

test_loader = torch.utils.data.DataLoader(
    test_dataset, batch_size=32, shuffle=False, num_workers=4
)

def imshow(img):
    #img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))

dataiter = iter(train_loader)
images, labels = dataiter.next()
imshow(torchvision.utils.make_grid(images))

那么,是否有直接的方法将转换应用到 torchvision.transforms.Compose 中的完整数据集?

您可以将任何自定义转换定义为函数并在转换管道中使用 torchvision.transforms.Lambda

def erase_middle(image: torch.Tensor) -> torch.Tensor:
    for i in range(int(image_size/2-5),int(image_size/2+3)):
        for j in range(int(image_size/2-5),int(image_size/2+3)):
            image[:, i, j] = 0
    return image

transform = torchvision.transforms.Compose(
    [
        # First transform it to a tensor
        torchvision.transforms.ToTensor(),
        # Then erase the middle
        torchvision.transforms.Lambda(erase_middle),
    ]
)

erase_middle 可以变得更通用,这样它适用于大小不一且不一定是正方形的图像。

def erase_middle(image: torch.Tensor) -> torch.Tensor:
    _, height, width = image.size()
    x_start = width // 2 - 5
    x_end = width // 2 + 3
    y_start = height // 2 - 5
    y_end = height // 2 + 3
    # Using slices achieves the same as the for loops
    image[:, y_start:y_end, x_start:x_end] = 0
    return image