PyTorch:如何对多个图像应用相同的随机变换?

PyTorch : How to apply the same random transformation to multiple image?

我正在为包含多对图像的数据集编写一个简单的转换。作为数据扩充,我想对每一对应用一些随机变换,但该对中的图像应该以相同的方式进行变换。 例如,给定一对两张图像 AB,如果 A 水平翻转,则 B 必须水平翻转为 A。那么下一对 CD 应该与 AB 进行不同的转换,但是 CD 的转换方式相同。我正在尝试下面的方法

import random
import numpy as np
import torchvision.transforms as transforms
from PIL import Image

img_a = Image.open("sample_ajpg") # note that two images have the same size
img_b = Image.open("sample_b.png")
img_c, img_d = Image.open("sample_c.jpg"), Image.open("sample_d.png")

transform = transforms.RandomChoice(
    [transforms.RandomHorizontalFlip(), 
     transforms.RandomVerticalFlip()]
)
random.seed(0)
display(transform(img_a))
display(transform(img_b))

random.seed(1)
display(transform(img_c))
display(transform(img_d))

然而、上面的代码并没有选择相同的转换,正如我测试的那样,它依赖于transform被调用的次数。

有什么方法可以强制 transforms.RandomChoice 在指定时使用相同的转换?

我不知道有修复随机输出的函数。 也许尝试不同的逻辑,比如自己创建随机化以便能够重用相同的转换。 逻辑:

  • 生成一个随机数
  • 根据数字对两张图片应用变换
  • 生成另一个随机数
  • 对另外两张图片做同样的事情 试试这个:
import random
import numpy as np
import torchvision.transforms as transforms
from PIL import Image

img_a = Image.open("sample_ajpg") # note that two images have the same size
img_b = Image.open("sample_b.png")
img_c, img_d = Image.open("sample_c.jpg"), Image.open("sample_d.png")

if random.random() > 0.5:
        image_a_flipped = transforms.functional_pil.vflip(img_a)
        image_b_flipped = transforms.functional_pil.vflip(img_b)
else:
    image_a_flipped = transforms.functional_pil.hflip(img_a)
    image_b_flipped = transforms.functional_pil.hflip(img_b)

if random.random() > 0.5:
        image_c_flipped = transforms.functional_pil.vflip(img_c)
        image_d_flipped = transforms.functional_pil.vflip(img_d)
else:
    image_c_flipped = transforms.functional_pil.hflip(img_c)
    image_d_flipped = transforms.functional_pil.hflip(img_d)
    
display(image_a_flipped)
display(image_b_flipped)

display(image_c_flipped)
display(image_d_flipped)

通常的解决方法是在第一张图像上应用变换,检索该变换的参数,然后在其余图像上应用具有这些参数的确定性变换。但是,这里 RandomChoice 没有提供 API 来获取应用变换的参数,因为它涉及可变数量的变换。 在那些情况下,我通常会实现对原始函数的覆盖。

torchvision implementation,就这么简单:

class RandomChoice(RandomTransforms):
    def __call__(self, img):
        t = random.choice(self.transforms)
        return t(img)

这里有两种可能的解决方案。

  1. 您可以在 __init__ 而不是 __call__ 上从转换列表中采样:

    import random
    import torchvision.transforms as T
    
    class RandomChoice(torch.nn.Module):
        def __init__(self):
            super().__init__()
            self.t = random.choice(self.transforms)
    
        def __call__(self, img):
            return self.t(img)
    

    所以你可以这样做:

    transform = T.RandomChoice([
         T.RandomHorizontalFlip(), 
         T.RandomVerticalFlip()
    ])
    display(transform(img_a)) # both img_a and img_b will
    display(transform(img_b)) # have the same transform
    
    transform = T.RandomChoice([
        T.RandomHorizontalFlip(), 
        T.RandomVerticalFlip()
    ])
    display(transform(img_c)) # both img_c and img_d will
    display(transform(img_d)) # have the same transform
    

  1. 或者更好的是,批量转换图像:

    import random
    import torchvision.transforms as T
    
    class RandomChoice(torch.nn.Module):
        def __init__(self, transforms):
           super().__init__()
           self.transforms = transforms
    
        def __call__(self, imgs):
            t = random.choice(self.transforms)
            return [t(img) for img in imgs]
    

    允许做的事情:

    transform = T.RandomChoice([
         T.RandomHorizontalFlip(), 
         T.RandomVerticalFlip()
    ])
    
    img_at, img_bt = transform([img_a, img_b])
    display(img_at) # both img_a and img_b will
    display(img_bt) # have the same transform
    
    img_ct, img_dt = transform([img_c, img_d])
    display(img_ct) # both img_c and img_d will
    display(img_dt) # have the same transform
    

简单地说,将随机化部分从 PyTorch 中取出到 if 语句中。 下面的代码使用 vflip。对于水平或其他变换也是如此。

import random
import torchvision.transforms.functional as TF

if random.random() > 0.5:
    image = TF.vflip(image)
    mask  = TF.vflip(mask)

这个问题已经在 PyTorch forum. Several solutions' pros and cons were discussed on the official GitHub repository page 中讨论过了。 PyTorch 维护者建议了这种简单的方法。

不要使用 torchvision.transforms.RandomVerticalFlip(p=1)。使用 torchvision.transforms.functional.vflip

函数式转换可让您对转换管道进行细粒度控制。与上述转换相反,函数式转换不包含用于其参数的随机数生成器。这意味着您必须 specify/generate 所有参数,但您可以重复使用函数转换。

我意识到 OP 要求使用 torchvision 的解决方案,我认为 @Ivan 的 很好地解决了这个问题。

然而,对于那些没有绑定到特定增强库的人,我想指出 Albumentations 似乎可以在 native fashion 中很好地处理这些情况,它允许用户传递多个源图像、框等进入相同的转换。 return 的结构是字典

import albumentations as A

transform = A.Compose(
    transforms=[
        A.VerticalFlip(p=0.5),
        A.HorizontalFlip(p=0.5)],
    additional_targets={'image0': 'image', 'image1': 'image'}
)
transformed = transform(image=image, image0=image0, image1=image1)

现在您可以访问 transformed['image0']transformed['image1'] 等,所有这些都将应用随机参数

引用 Random transforms for both input and target? 我认为这可能是最简洁的方法。在应用任何转换之前保存随机状态,并为每个后续调用恢复它

t = transforms.RandomRotation(degrees=360)
state = torch.get_rng_state()
x = t(x)
torch.set_rng_state(state)
y = t(y)