两个数据集之间的交替训练

Alternate training between two datasets

我正在尝试在每个时期使用增强和非增强数据集(例如:在一个时期增强而不是在不同时期增强)但我不知道该怎么做。我的方法是在每个时期一次又一次地加载 DataLoader,但我认为这是错误的。因为当我在 Dataset class 中打印 __getitem__ 中的索引时,有很多重复的索引。

这是我的训练代码:

for i in range(epoch):

    train_loss = 0.0
    valid_loss = 0.0
    since = time.time()
    scheduler.step(i)
    lr = scheduler.get_lr()

    #######################################################
    #Training Data
    #######################################################

    model_test.train()
    k = 1
    tx=""
    lx=""
    random_ = random.randint(0,1)
    print("QPQPQPQPQPQPQPQPPQPQ")
    print(random_)
    print("QPQPQPQPQPQPQPQPPQPQ")
    if random_== 0:
            tx = torchvision.transforms.Compose([
                #  torchvision.transforms.Resize((128,128)),
                torchvision.transforms.ToTensor(),
                torchvision.transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])])

            lx = torchvision.transforms.Compose([
                    #  torchvision.transforms.Resize((128,128)),
                    torchvision.transforms.Grayscale(),
                    torchvision.transforms.ToTensor(),
                    # torchvision.transforms.Lambda(lambda x: torch.cat([x, 1 - x], dim=0))
                ])
    else:
            tx = torchvision.transforms.Compose([
                #  torchvision.transforms.Resize((128,128)),
                
                torchvision.transforms.CenterCrop(96),
                torchvision.transforms.RandomRotation((-10, 10)),
                # torchvision.transforms.RandomHorizontalFlip(),
                torchvision.transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
                torchvision.transforms.ToTensor(),
                torchvision.transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
            ])
            lx = torchvision.transforms.Compose([
                    #  torchvision.transforms.Resize((128,128)),
                    
                    torchvision.transforms.CenterCrop(96),
                    torchvision.transforms.RandomRotation((-10, 10)),
                    torchvision.transforms.Grayscale(),
                    torchvision.transforms.ToTensor(),
                    # torchvision.transforms.Lambda(lambda x: torch.cat([x, 1 - x], dim=0))
                ])
    Training_Data = Images_Dataset_folder(t_data,
                                      l_data,tx,lx)
    train_loader = torch.utils.data.DataLoader(Training_Data, batch_size=batch_size, sampler=train_sampler,
                                           num_workers=num_workers, pin_memory=pin_memory,)

    valid_loader = torch.utils.data.DataLoader(Training_Data, batch_size=batch_size, sampler=valid_sampler,
                                           num_workers=num_workers, pin_memory=pin_memory,)

    
    for x,y in train_loader:

        x, y = x.to(device), y.to(device)
       
        #If want to get the input images with their Augmentation - To check the data flowing in net
        input_images(x, y, i, n_iter, k)

       # grid_img = torchvision.utils.make_grid(x)
        #writer1.add_image('images', grid_img, 0)

       # grid_lab = torchvision.utils.make_grid(y)

        opt.zero_grad()

        y_pred = model_test(x)
        lossT = calc_loss(y_pred, y)     # Dice_loss Used

        train_loss += lossT.item() * x.size(0)
        lossT.backward()
      #  plot_grad_flow(model_test.named_parameters(), n_iter)
        opt.step()
        x_size = lossT.item() * x.size(0)
        k = 2

这是我的数据集代码:

    def __init__(self, images_dir, labels_dir, transformI=None, 
        transformM=None):
        self.images = sorted(os.listdir(images_dir))
        self.labels = sorted(os.listdir(labels_dir))
        self.images_dir = images_dir
        self.labels_dir = labels_dir
        self.transformI = transformI
        self.transformM = transformM
        self.tx=self.transformI
        self.lx=self.transformM

        

    def __len__(self):

        return len(self.images)

    def __getitem__(self, i):
      
        with open("/content/x.txt", "a") as o:
            o.write(str(i)+"\n")
        i1 = Image.open(self.images_dir + self.images[i])
        l1 = Image.open(self.labels_dir + self.labels[i])

        seed = np.random.randint(0, 2 ** 32)  # make a seed with numpy generator

        # apply this seed to img tranfsorms
        random.seed(seed)
        torch.manual_seed(seed)
        
        img = self.tx(i1) 

        # apply this seed to target/label tranfsorms
        random.seed(seed)
        torch.manual_seed(seed)
        label = self.lx(l1)

        return img, label

我怎样才能达到我想要的? 提前致谢。

为每个时期实例化数据集和数据加载器似乎不是可行的方法。相反,您可能想要实例化两组数据集 + 数据加载器,每组都有其相应的扩充管道。

这里举例给大家一个基本框架:

首先在数据集本身内部定义转换管道:

class Images_Dataset_folder(Dataset):
    def __init__(self, images_dir, labels_dir, augment=False):
        super().__init__()
        self.tx, self.lx = self._augmentations() if augment else self._no_augmentations()

    def __len__(self):
        pass
        
    def __getitem__(self, i):
        pass

    def _augmentations(self):
        tx = T.Compose([
            T.ToTensor(),
            T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])])

        lx = T.Compose([
            T.Grayscale(),
            T.ToTensor()])

        return tx, lx
        
    def _no_augmentations(self):
        tx = T.Compose([
                T.CenterCrop(96),
                T.RandomRotation((-10, 10)),
                T.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
                T.ToTensor(),
                T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])])

        lx = T.Compose([
            T.CenterCrop(96),
            T.RandomRotation((-10, 10)),
            T.Grayscale(),
            T.ToTensor()])

        return tx, lx

然后你可以构建你的训练循环:

# augmented images dataset
aug_trainset = Images_Dataset_folder(t_data, l_data, augment=True)
aug_dataloader= DataLoader(aug_trainset, batch_size=batch_size)

# unaugmented images dataset
unaug_trainset = Images_Dataset_folder(t_data, l_data, augment=False)
unaug_dataloader = DataLoader(unaug_trainset, batch_size=batch_size)

# on each epoch you go through the
for i in range(epochs//2):
    # call train loop on augmented data loader
    train(model, aug_dataloader)

    # call train loop with un-augmented data loader
    train(model, unaug_dataloader )

话虽这么说,但实际上您将遍历数据集两次:一次是在未增强的图像上,第二次是在增强图像上。

如果您只想迭代一次,那么我能想出的最简单的解决方案是在 __getitem__ 中设置一个随机标志,用于决定当前图像是否需要增强。


旁注:您不想在验证集中使用训练数据!