RuntimeError 由于在具有跳过连接的 GAN 生成器架构中就地操作

RuntimeError due to inplace operation in GAN generator architecture with skip connections

我用于执行图像着色的 GAN 模型出现以下错误。它使用图像着色中常见的 LAB 颜色 space。生成器为给定的 L 通道生成 a 和 b 通道。鉴别器在连接后被馈送到所有三个通道。

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [64, 64, 128, 128]], which is output 0 of ReluBackward0, is at version 1; expected version 0 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

我认为错误是由于跳过连接造成的,但我不能完全确定。如有任何帮助,我们将不胜感激!

这是模型:

class NetGen(nn.Module):
    '''Generator'''
    def __init__(self):
        super(NetGen, self).__init__()

        self.conv1 = nn.Conv2d(1, 64, 3, stride=2, padding=1, bias=False)
        self.bnorm1 = nn.BatchNorm2d(64)
        self.relu1 = nn.LeakyReLU(0.1)

        self.conv2 = nn.Conv2d(64, 128, 3, stride=2, padding=1, bias=False)
        self.bnorm2 = nn.BatchNorm2d(128)
        self.relu2 = nn.LeakyReLU(0.1)

        self.conv3 = nn.Conv2d(128, 256, 3, stride=2, padding=1, bias=False)
        self.bnorm3 = nn.BatchNorm2d(256)
        self.relu3 = nn.LeakyReLU(0.1)

        self.conv4 = nn.Conv2d(256, 512, 3, stride=2, padding=1, bias=False)
        self.bnorm4 = nn.BatchNorm2d(512)
        self.relu4 = nn.LeakyReLU(0.1)

        self.conv5 = nn.Conv2d(512, 512, 3, stride=2, padding=1, bias=False)
        self.bnorm5 = nn.BatchNorm2d(512)
        self.relu5 = nn.LeakyReLU(0.1)

        self.deconv6 = nn.ConvTranspose2d(512, 512, 3, stride=2, padding=1, output_padding=1, bias=False)
        self.bnorm6 = nn.BatchNorm2d(512)
        self.relu6 = nn.ReLU()

        self.deconv7 = nn.ConvTranspose2d(512, 256, 3, stride=2, padding=1, output_padding=1, bias=False)
        self.bnorm7 = nn.BatchNorm2d(256)
        self.relu7 = nn.ReLU()

        self.deconv8 = nn.ConvTranspose2d(256, 128, 3, stride=2, padding=1, output_padding=1, bias=False)
        self.bnorm8 = nn.BatchNorm2d(128)
        self.relu8 = nn.ReLU()

        self.deconv9 = nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1, bias=False)
        self.bnorm9 = nn.BatchNorm2d(64)
        self.relu9 = nn.ReLU()

        self.deconv10 = nn.ConvTranspose2d(64, 2, 3, stride=2, padding=1, output_padding=1, bias=False)
        self.tanh = nn.Tanh()
        

    def forward(self, x):
        h = x
        h = self.conv1(h)
        h = self.bnorm1(h)
        h = self.relu1(h) 
        pool1 = h

        h = self.conv2(h)
        h = self.bnorm2(h)
        h = self.relu2(h) 
        pool2 = h

        h = self.conv3(h) 
        h = self.bnorm3(h)
        h = self.relu3(h)
        pool3 = h

        h = self.conv4(h) 
        h = self.bnorm4(h)
        h = self.relu4(h)
        pool4 = h

        h = self.conv5(h) 
        h = self.bnorm5(h)
        h = self.relu5(h)

        h = self.deconv6(h)
        h = self.bnorm6(h)
        h = self.relu6(h) 
        h += pool4

        h = self.deconv7(h)
        h = self.bnorm7(h)
        h = self.relu7(h) 
        h += pool3

        h = self.deconv8(h)
        h = self.bnorm8(h)
        h = self.relu8(h)
        h += pool2

        h = self.deconv9(h)
        h = self.bnorm9(h)
        h = self.relu9(h)
        h += pool1

        h = self.deconv10(h)
        h = self.tanh(h) 
        return h

class NetDis(nn.Module):
    '''Discriminator'''
    def __init__(self):
        super(NetDis, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(3, 64, 3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.1),

            nn.Conv2d(64, 128, 3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.1),

            nn.Conv2d(128, 256, 3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.1),

            nn.Conv2d(256, 512, 3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.1),

            nn.Conv2d(512, 512, 3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.1),


            nn.Conv2d(512, 512, 8, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.1),

            nn.Conv2d(512, 1, 1, stride=1, padding=0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.main(x)

这里是权重初始化函数:

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

这里是训练和验证代码:

class Trainer:
    def __init__(self, epochs, batch_size, learning_rate, num_workers):
        self.epochs = epochs
        self.batch_size = batch_size
        self.learning_rate = learning_rate
        self.num_workers = num_workers
        self.train_paths = train_paths
        self.val_paths = val_paths        
        self.real_label = 1
        self.fake_label = 0

    def train(self):             
        train_dataset = ColorizeData(paths=self.train_paths)
        train_dataloader = DataLoader(train_dataset, batch_size=self.batch_size, num_workers=self.num_workers,pin_memory=True, drop_last = True)
        # Model
        model_G = NetGen().to(device)
        model_D = NetDis().to(device)

        model_G.apply(weights_init)
        model_D.apply(weights_init)

        optimizer_G = torch.optim.Adam(model_G.parameters(),
                             lr=self.learning_rate, betas=(0.5, 0.999),
                             eps=1e-8, weight_decay=0)
        optimizer_D = torch.optim.Adam(model_D.parameters(),
                             lr=self.learning_rate, betas=(0.5, 0.999),
                             eps=1e-8, weight_decay=0)
        
        criterion = nn.BCELoss()
        L1 = nn.L1Loss()

        model_G.train()
        model_D.train()


        # train loop
        for epoch in range(self.epochs):
            print("Starting Training Epoch " + str(epoch + 1))
            for i, data in enumerate(tqdm(train_dataloader)):                                                    
                inputs, input_ab, input_l = data
                inputs = inputs.to(device)
                input_ab = input_ab.to(device)
                input_l = input_l.to(device)


                model_D.zero_grad()
                label = torch.full((self.batch_size,), self.real_label, dtype=torch.float, device=device)
                output = model_D(torch.cat([input_l, input_ab], dim=1))
                errD_real = criterion(torch.squeeze(output), label)
                errD_real.backward()

                fake = model_G(input_l)
                label.fill_(self.fake_label)

                output = model_D(torch.cat([input_l, fake.detach()], dim=1))
                errD_fake = criterion(torch.squeeze(output), label)
                errD_fake.backward()
                errD = errD_real + errD_fake
                optimizer_D.step()

                model_G.zero_grad()
                label.fill_(self.real_label)  
                output = model_D(torch.cat([input_l, fake], dim=1))
                errG = criterion(torch.squeeze(output), label)
                errG_L1 = L1(fake.view(fake.size(0),-1), input_ab.view(input_ab.size(0),-1))
                errG = errG + 100 * errG_L1
                errG.backward()
                optimizer_G.step()  


            print(f'Training: Epoch {epoch + 1} \t\t Discriminator Loss: {\
                errD / len(train_dataloader)}  \t\t Generator Loss: {\
                errG / len(train_dataloader)}')
            
            if (epoch + 1) % 1 == 0:
                errD_val, errG_val, val_len = self.validate(model_D, model_G, criterion, L1)
                print(f'Validation: Epoch {epoch + 1} \t\t Discriminator Loss: {\
                        errD_val / val_len}  \t\t Generator Loss: {\
                        errG_val / val_len}')
                
            torch.save(model_G.state_dict(), '../Results/Model_GAN/Generator/saved_model_' + str(epoch + 1) + '.pth')
            torch.save(model_D.state_dict(), '../Results/Model_GAN/Discriminator/saved_model_' + str(epoch + 1) + '.pth')


    def validate(self, model_D, model_G, criterion, L1):

        model_G.eval()
        model_D.eval()
        with torch.no_grad():
            valid_loss = 0.0
            val_dataset = ColorizeData(paths=self.val_paths)
            val_dataloader = DataLoader(val_dataset, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=True, drop_last = True)
            for i, data in enumerate(val_dataloader):
                inputs, input_ab, input_l = data
                inputs = inputs.to(device)
                input_ab = input_ab.to(device)
                input_l = input_l.to(device)

                label = torch.full((self.batch_size,), self.real_label, dtype=torch.float, device=device)
                output = model_D(torch.cat([input_l, input_ab], dim=1))
                errD_real = criterion(torch.squeeze(output), label)

                fake = model_G(input_l)
                label.fill_(self.fake_label)
                output = model_D(torch.cat([input_l, fake.detach()], dim=1))
                errD_fake = criterion(torch.squeeze(output), label)
                
                errD = errD_real + errD_fake

                label.fill_(self.real_label)  
                output = model_D(torch.cat([input_l, fake], dim=1))
                errG = criterion(torch.squeeze(output), label)
                errG_L1 = L1(fake.view(fake.size(0),-1), input_ab.view(input_ab.size(0),-1))
                errG = errG + 100 * errG_L1

        return errD, errG, len(val_dataloader)

编辑
正如@manaclan 所建议的,这里是我用来 运行 管道的代码:

trainer = Trainer(epochs = 100, batch_size = 64, learning_rate = 0.0002, num_workers = 2)
trainer.train()

这是数据加载器:

class ColorizeData(Dataset):
    def __init__(self, paths):
        self.input_transform = T.Compose([T.ToTensor(),
                                          T.Resize(size=(256,256)),
                                          T.Grayscale(),
                                          T.Normalize((0.5), (0.5))
                                          ])
        self.lab_transform = T.Compose([T.ToTensor(),
                                          T.Resize(size=(256,256)),
                                          T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                                          ])
        self.paths = paths

    def __len__(self) -> int:
        return len(self.paths)
    
    def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]:
        image = Image.open(self.paths[index]).convert("RGB")
        input_image = self.input_transform(image)
        image_lab = rgb2lab(image)
        image_lab = self.lab_transform(image_lab)
        image_l = image_lab[0, :, :]
        image_ab = image_lab[1:3, :, :]
        return (input_image.float(), image_ab.float(), image_l.float().reshape(1, 256, 256))

这里是导入:

from typing import Tuple
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
import torch
import numpy as np
import os
import torch.nn as nn
import torchvision.models as models
import torchvision
import torch.nn.functional as functional
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
from PIL import Image
import glob
import matplotlib.pyplot as plt
from tqdm import tqdm
from skimage.color import lab2rgb, rgb2lab, rgb2gray
from skimage import io

from torchvision.transforms.functional import resize

要重现错误,只需使用任何彩色图像数据集。 我有以下代码从文件夹“数据集”中获取我的训练、测试和验证图像:

path = "../Dataset/"
paths = np.array(glob.glob(path + "/*.jpg"))
rand_indices = np.random.permutation(len(paths))          # Number of images in dataset
train_indices, val_indices, test_indices = rand_indices[:3600], rand_indices[3600:4000], rand_indices[4000:]
train_paths = paths[train_indices]
val_paths = paths[val_indices]
test_paths = paths[test_indices]

注意:我正在使用 Google Colab,这可能是一个潜在的问题?另外,我使用的是 torch 版本 1.10.0+cu111。 之前确实是用了生成器没有skip connections的sequential model,然后就没有这个错误了。

也许可以尝试将层的输出直接用于跳过连接,如下所示:

def forward(self, x):
    h = x
    h = self.conv1(h)
    h = self.bnorm1(h)
    h1 = self.relu1(h) 

    h = self.conv2(h1)
    h = self.bnorm2(h)
    h2 = self.relu2(h) 

    h = self.conv3(h2) 
    h = self.bnorm3(h)
    h3 = self.relu3(h)

    h = self.conv4(h3) 
    h = self.bnorm4(h)
    h4 = self.relu4(h)

    h = self.conv5(h4) 
    h = self.bnorm5(h)
    h = self.relu5(h)

    h = self.deconv6(h5)
    h = self.bnorm6(h)
    h = self.relu6(h) 
    h += h4

    h = self.deconv7(h)
    h = self.bnorm7(h)
    h = self.relu7(h) 
    h += h3

    h = self.deconv8(h)
    h = self.bnorm8(h)
    h = self.relu8(h)
    h += h2

    h = self.deconv9(h)
    h = self.bnorm9(h)
    h = self.relu9(h)
    h += h1

    h = self.deconv10(h)
    h = self.tanh(h) 
    return h

很明显,问题是写成 h += poolX. 的就地跳过连接,因为 h = h + poolX 修复了这个更新。某些层的梯度计算需要h,所以就地修改会搞砸。