有没有办法检索随机 torchvision 变换中使用的特定参数?

Is there a way to retrieve the specific parameters used in a random torchvision transform?

我可以在训练期间通过应用随机变换 (rotation/translation/rescaling) 来扩充我的数据,但我不知道选择的值。

我需要知道应用了哪些值。我可以手动设置这些值,但是我失去了火炬视觉转换提供的很多好处。

是否有一种简单的方法来获取这些值并以合理的方式实现它们以在训练期间应用?

这是一个例子。我希望能够打印出旋转角度,translation/rescaling 应用于每个图像:

import numpy as np
import matplotlib.pyplot as plt
from torchvision import transforms


RandAffine = transforms.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.8, 1.2))

rotate = transforms.RandomRotation(degrees=45)
shift = RandAffine
composed = transforms.Compose([rotate,
                               shift])

# Apply each of the above transforms on sample.
fig = plt.figure()
sample = np.zeros((28,28))
sample[5:15,7:20] = 255
sample = transforms.ToPILImage()(sample.astype(np.uint8))
title = ['None', 'Rot','Aff','Comp']
for i, tsfrm in enumerate([None,rotate, shift, composed]):
    if tsfrm:
        t_sample = tsfrm(sample)
    else:
        t_sample = sample
    ax = plt.subplot(1, 5, i + 2)
    plt.tight_layout()
    ax.set_title(title[i])
    ax.imshow(np.reshape(np.array(list(t_sample.getdata())), (-1,28)), cmap='gray')    

plt.show()

恐怕没有简单的解决方法:Torchvision 的随机变换实用程序的构建方式是在调用时对变换参数进行采样。它们是 独特的 随机变换,从某种意义上说,(1) 所使用的参数是用户无法访问的,(2) 相同的随机变换不可重复。

从 Torchvision 0.8.0 开始,随机变换通常由两个主要函数构建:

  • get_params:将根据变换的超参数(您在初始化变换运算符时提供的内容,即参数的取值范围)进行采样

  • forward:应用转换时执行的函数。重要的部分是它从 get_params 获取参数,然后使用关联的确定性函数将其应用于输入。对于 RandomRotation, F.rotate will get called. Similarly, RandomAffine will use F.affine.

您的问题的一个解决方案是自己对 get_params 的参数进行采样,然后调用函数 - deterministic - API 代替。因此,您不会为此使用 RandomRotationRandomAffine 或任何其他 Random* 转换。


例如,让我们看一下T.RandomRotation(为简洁起见,我删除了注释)。

class RandomRotation(torch.nn.Module):
    def __init__(
        self, degrees, interpolation=InterpolationMode.NEAREST, expand=False, 
        center=None, fill=None, resample=None):
        # ...

    @staticmethod
    def get_params(degrees: List[float]) -> float:
        angle = float(torch.empty(1).uniform_(float(degrees[0]), \
            float(degrees[1])).item())
        return angle

    def forward(self, img):
        fill = self.fill
        if isinstance(img, Tensor):
            if isinstance(fill, (int, float)):
                fill = [float(fill)] * F._get_image_num_channels(img)
            else:
                fill = [float(f) for f in fill]
        angle = self.get_params(self.degrees)

        return F.rotate(img, angle, self.resample, self.expand, self.center, fill)

    def __repr__(self):
        # ...

考虑到这一点,这里有一个可能的覆盖来修改 T.RandomRotation:

class RandomRotation(T.RandomRotation):
    def __init__(*args, **kwargs):
        super(RandomRotation, self).__init__(*args, **kwargs) # let super do all the work

        self.angle = self.get_params(self.degrees) # initialize your random parameters

    def forward(self): # override T.RandomRotation's forward
        fill = self.fill
        if isinstance(img, Tensor):
            if isinstance(fill, (int, float)):
                fill = [float(fill)] * F._get_image_num_channels(img)
            else:
                fill = [float(f) for f in fill]

        return F.rotate(img, self.angle, self.resample, self.expand, self.center, fill)

我基本上复制了T.RandomRotationforward函数,唯一的区别是参数在__init__中采样( 一次)而不是在 forward 内部( 每次调用)。 Torchvision 的实现涵盖了所有情况,您通常不需要复制完整的 forward。在某些情况下,您几乎可以直接调用功能版本。例如,如果不需要设置 fill 参数,则可以丢弃该部分,只使用:

class RandomRotation(T.RandomRotation):
    def __init__(*args, **kwargs):
        super(RandomRotation, self).__init__(*args, **kwargs) # let super do all the work

        self.angle = self.get_params(self.degrees) # initialize your random parameters

    def forward(self): # override T.RandomRotation's forward
        return F.rotate(img, self.angle, self.resample, self.expand, self.center)

如果您想覆盖其他随机变换,您可以查看 the source code。 API 是不言自明的,您在为每个转换实现覆盖时应该不会有太多问题。