火炬视觉转换中的增强功能未按预期工作

Augmentation in torch vision transform is not working as expected

我正在使用 pytorch 开发 CNN。我的模型在没有增强的情况下在训练和测试集上都提供了很好的准确性,但是我想学习增强,所以我使用了 torchvision 变换来增强,并且在应用增强模型之后开始表现最差并且损失根本没有减少。所以我尝试调试并观察到增强图像看起来 distorted/unexpected 有人可以帮我解决这个问题吗?

自定义数据集

class traindataset(Dataset):
    def __init__(self,data,train_end_idx,augmentation = None):
        '''
        data: data is a pandas dataframe generated from csv file where it has columns-> [name,labels,col 1,col2,...,col784]. shape of data->(10000, 786)
        
        '''
        self.data=data
        self.augmentation=augmentation
        self.train_end=train_end_idx
        self.target=self.data.iloc[:self.train_end,1].values
        self.image=self.data.iloc[:self.train_end,2:].values#contains full data
        
    def __len__(self):
        return len(self.target);
    def __getitem__(self,idx):
        
        self.target=self.target
        self.ima=self.image[idx].reshape(1,784) #only takes the selected index
        if self.augmentation is not None:
            self.ima = self.augmentation(self.ima)
        
        return torch.tensor(self.target[idx]),self.ima
                                        

使用了增强

torchvision_transform = transforms.Compose([
    np.uint8,
    transforms.ToPILImage(),
    transforms.Resize((28,28)),
    transforms.RandomRotation([45,135]),
    transforms.ToTensor()
    ])  

增强图像(图片PFA)

transformed=torchvision_transform(x)
plt.imshow(transformed.squeeze().numpy(), interpolation='nearest')
plt.show()
            

正常图像

x=data.iloc[:1,2:].values
plt.imshow(x.reshape(28,28), interpolation='nearest')
plt.show()

第一张图片有增强,第二张图片没有增强。 如果你愿意,你可以在不下载任何东西的情况下使用代码 here

似乎 transforms.Resize() 函数没有正确重塑张量。首先重塑似乎可以解决问题并生成正确的图像(您在相册部分执行了此步骤)。

transformed = torchvision_transform(x.reshape(28,28))