在图像批次中打乱补丁

Shuffle patches in image batch

我正在尝试创建一个 transform 来批量打乱每个图像的补丁。 我打算以与 torchvision:

中其余转换相同的方式使用它
trans = transforms.Compose([
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
            ShufflePatches(patch_size=(16,16)) # our new transform
        ])

更具体地说,输入是一个 BxCxHxW 张量。我想将批处理中的每个图像拆分为大小为 patch_size 的非重叠块,将它们打乱,然后重新组合成单​​个图像。

给定图像(大小224x224):

使用 ShufflePatches(patch_size=(112,112)) 我想生成输出图像:

我认为解决方案与 torch.unfoldtorch.fold 有关,但未能进一步解决。

如有任何帮助,我们将不胜感激!

确实 在这种情况下似乎很合适。

import torch
import torch.nn.functional as nnf

class ShufflePatches(object):
  def __init__(self, patch_size):
    self.ps = patch_size

  def __call__(self, x):
    # divide the batch of images into non-overlapping patches
    u = nnf.unfold(x, kernel_size=self.ps, stride=self.ps, padding=0)
    # permute the patches of each image in the batch
    pu = torch.cat([b_[:, torch.randperm(b_.shape[-1])][None,...] for b_ in u], dim=0)
    # fold the permuted patches back together
    f = nnf.fold(pu, x.shape[-2:], kernel_size=self.ps, stride=self.ps, padding=0)
    return f

这是一个补丁大小为 16 的示例: