我 运行 陷入梯度计算就地错误

I am running into a gradient computation inplace error

我在自定义数据集上 运行 此代码 (https://github.com/ayu-22/BPPNet-Back-Projected-Pyramid-Network/blob/master/Single_Image_Dehazing.ipynb),但我 运行 遇到此错误。 RuntimeError: one of the variables needed for gradient computation has been modified by an in place operation: [torch. cuda.FloatTensor [1, 512, 4, 4]] is at version 2; expected version 1 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

请参阅上面的代码 link 以澄清错误发生的位置。

我在自定义数据集上运行这个模型,下面粘贴了数据加载器部分。

    import torchvision.transforms as transforms
    train_transform = transforms.Compose([
    transforms.Resize((256,256)),
    #transforms.RandomResizedCrop(256),
    #transforms.RandomHorizontalFlip(),
    #transforms.ColorJitter(),
    transforms.ToTensor(),
    transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])
 ])

class Flare(Dataset):
  def __init__(self, flare_dir, wf_dir,transform = None):
    self.flare_dir = flare_dir
    self.wf_dir = wf_dir
    self.transform = transform
    self.flare_img = os.listdir(flare_dir)
    self.wf_img = os.listdir(wf_dir)
    
  def __len__(self):
     return len(self.flare_img)
  def __getitem__(self, idx):
    f_img = Image.open(os.path.join(self.flare_dir, self.flare_img[idx])).convert("RGB")
    for i in self.wf_img:
        if (self.flare_img[idx].split('.')[0][4:] == i.split('.')[0]):
            wf_img = Image.open(os.path.join(self.wf_dir, i)).convert("RGB")
            break
    f_img = self.transform(f_img)
    wf_img = self.transform(wf_img)
    
   return f_img, wf_img         





flare_dir = '../input/flaredataset/Flare/Flare_img'
wf_dir = '../input/flaredataset/Flare/Without_Flare_'
flare_img = os.listdir(flare_dir)
wf_img = os.listdir(wf_dir)
wf_img.sort()
flare_img.sort()
print(wf_img[0])
train_ds = Flare(flare_dir, wf_dir,train_transform)
train_loader = torch.utils.data.DataLoader(dataset=train_ds,
                                       batch_size=BATCH_SIZE, 
                                       shuffle=True)

为了更好地了解数据集 class ,您可以将我的数据集 class 与上面粘贴的 link

进行比较

您的代码卡在 GAN 网络的所谓“反向传播”中。

您定义的反向图应遵循以下内容:

def backward(self, unet_loss, dis_loss):
        dis_loss.backward(retain_graph = True)
        self.dis_optimizer.step()

        unet_loss.backward()
        self.unet_optimizer.step()

所以在你的反向图中,你首先传播 dis_loss 这是鉴别器和对抗性损失的组合,然后你传播 unet_loss 这是 [=14 的组合=]、SSIMContentLossunet_loss 与鉴别器的输出损失有关。因此,在为 unet_loss 存储后向图之前,您正在执行 dis_loss 的优化器步骤时,pytorch 很困惑并给您这个错误,我建议您按如下方式更改代码:

def backward(self, unet_loss, dis_loss):
        dis_loss.backward(retain_graph = True)
        unet_loss.backward()

        self.dis_optimizer.step()
        self.unet_optimizer.step()

这将开始您的训练!但您可以用 retain_graph=True.

进行试验

BPPNet Work 上的出色工作。