Pytorch GAN模型不训练:矩阵乘法错误

Pytorch GAN model doesn't train: matrix multiplication error

我正在尝试构建一个基本的 GAN 来熟悉 Pytorch。我对 Keras 有一些(有限的)经验,但由于我必须在 Pytorch 中做一个更大的项目,所以我想首先使用 'basic' 网络进行探索。

我正在使用 Pytorch Lightning。我想我已经添加了所有必要的组件。我尝试分别通过生成器和鉴别器传递一些噪声,我认为输出具有预期的形状。尽管如此,当我尝试训练 GAN 时出现运行时错误(下面的完整回溯):

RuntimeError: mat1 and mat2 shapes cannot be multiplied (7x9 and 25x1)

我注意到 7 是批处理的大小(通过打印出批处理尺寸),即使我将 batch_size 指定为 64。除此之外,老实说,我不知道在哪里开始:错误回溯对我没有帮助。

很可能,我犯了很多错误。但是,我希望你们中的一些人能够从代码中发现当前的错误,因为乘法错误似乎指向某处的维数问题。这是代码。

import os

import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from skimage import io
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision.utils import make_grid
from torchvision.transforms import Resize, ToTensor, ToPILImage, Normalize  

class DoppelDataset(Dataset):
    """
    Dataset class for face data
    """

    def __init__(self, face_dir: str, transform=None):

        self.face_dir = face_dir
        self.face_paths = os.listdir(face_dir)
        self.transform = transform

    def __len__(self):

        return len(self.face_paths)

    def __getitem__(self, idx):

        if torch.is_tensor(idx):
            idx = idx.tolist()

        face_path = os.path.join(self.face_dir, self.face_paths[idx])
        face = io.imread(face_path)

        sample = {'image': face}

        if self.transform:
            sample = self.transform(sample['image'])

        return sample


class DoppelDataModule(pl.LightningDataModule):

    def __init__(self, data_dir='../data/faces', batch_size: int = 64, num_workers: int = 0):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.num_workers = num_workers

        self.transforms = transforms.Compose([
            ToTensor(),
            Resize(100),
            Normalize(mean=(123.26290927634774, 95.90498110733365, 86.03763122875182),
                      std=(63.20679012922922, 54.86211954409834, 52.31266645797249))
        ])

    def setup(self, stage=None):
        # Initialize dataset
        doppel_data = DoppelDataset(face_dir=self.data_dir, transform=self.transforms)

        # Train/val/test split
        n = len(doppel_data)
        train_size = int(.8 * n)
        val_size = int(.1 * n)
        test_size = n - (train_size + val_size)

        self.train_data, self.val_data, self.test_data = random_split(dataset=doppel_data,
                                                                      lengths=[train_size, val_size, test_size])

    def train_dataloader(self) -> DataLoader:
        return DataLoader(dataset=self.test_data, batch_size=self.batch_size, num_workers=self.num_workers)

    def val_dataloader(self) -> DataLoader:
        return DataLoader(dataset=self.val_data, batch_size=self.batch_size, num_workers=self.num_workers)

    def test_dataloader(self) -> DataLoader:
        return DataLoader(dataset=self.test_data, batch_size=self.batch_size, num_workers=self.num_workers)


class DoppelGenerator(nn.Sequential):
    """
    Generator network that produces images based on latent vector
    """

    def __init__(self, latent_dim: int):
        super().__init__()

        def block(in_channels: int, out_channels: int, padding: int = 1, stride: int = 2, bias=False):
            return nn.Sequential(
                nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=4, stride=stride,
                                   padding=padding, bias=bias),
                nn.BatchNorm2d(num_features=out_channels),
                nn.ReLU(True)
            )

        self.model = nn.Sequential(
            block(latent_dim, 512, padding=0, stride=1),
            block(512, 256),
            block(256, 128),
            block(128, 64),
            block(64, 32),
            nn.ConvTranspose2d(32, 3, kernel_size=4, stride=2, padding=1, bias=False),
            nn.Tanh()
        )

    def forward(self, input):
        return self.model(input)


class DoppelDiscriminator(nn.Sequential):
    """
    Discriminator network that classifies images in two categories
    """

    def __init__(self):
        super().__init__()

        def block(in_channels: int, out_channels: int):
            return nn.Sequential(
                nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=4, stride=2, padding=1,
                          bias=False),
                nn.BatchNorm2d(num_features=out_channels),
                nn.LeakyReLU(0.2, inplace=True),
            )

        self.model = nn.Sequential(
            block(3, 64),
            block(64, 128),
            block(128, 256),
            block(256, 512),
            nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=0, bias=False),
            nn.Flatten(),
            nn.Linear(25, 1),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.model(input)


class DoppelGAN(pl.LightningModule):

    def __init__(self,
                 channels: int,
                 width: int,
                 height: int,
                 lr: float = 0.0002,
                 b1: float = 0.5,
                 b2: float = 0.999,
                 batch_size: int = 64,
                 **kwargs):

        super().__init__()

        # Save all keyword arguments as hyperparameters, accessible through self.hparams.X)
        self.save_hyperparameters()

        # Initialize networks
        # data_shape = (channels, width, height)
        self.generator = DoppelGenerator(latent_dim=self.hparams.latent_dim, )
        self.discriminator = DoppelDiscriminator()

        self.validation_z = torch.randn(8, self.hparams.latent_dim,1,1)

    def forward(self, input):
        return self.generator(input)

    def adversarial_loss(self, y_hat, y):
        return F.binary_cross_entropy(y_hat, y)

    def training_step(self, batch, batch_idx, optimizer_idx):
        images = batch

        # Sample noise (batch_size, latent_dim,1,1)
        z = torch.randn(images.size(0), self.hparams.latent_dim,1,1)

        # Train generator
        if optimizer_idx == 0:

            # Generate images (call generator -- see forward -- on latent vector)
            self.generated_images = self(z)

            # Log sampled images (visualize what the generator comes up with)
            sample_images = self.generated_images[:6]
            grid = make_grid(sample_images)
            self.logger.experiment.add_image('generated_images', grid, 0)

            # Ground truth result (ie: all fake)
            valid = torch.ones(images.size(0), 1)

            # Adversarial loss is binary cross-entropy
            generator_loss = self.adversarial_loss(self.discriminator(self(z)), valid)
            tqdm_dict = {'gen_loss': generator_loss}

            output = {
                'loss': generator_loss,
                'progress_bar': tqdm_dict,
                'log': tqdm_dict
            }
            return output

        # Train discriminator: classify real from generated samples
        if optimizer_idx == 1:

            # How well can it label as real?
            valid = torch.ones(images.size(0), 1)
            real_loss = self.adversarial_loss(self.discriminator(images), valid)

            # How well can it label as fake?
            fake = torch.zeros(images.size(0), 1)
            fake_loss = self.adversarial_loss(
                self.discriminator(self(z).detach()), fake)

            # Discriminator loss is the average of these
            discriminator_loss = (real_loss + fake_loss) / 2
            tqdm_dict = {'d_loss': discriminator_loss}
            output = {
                'loss': discriminator_loss,
                'progress_bar': tqdm_dict,
                'log': tqdm_dict
            }
            return output

    def configure_optimizers(self):
        lr = self.hparams.lr
        b1 = self.hparams.b1
        b2 = self.hparams.b2

        # Optimizers
        opt_g = torch.optim.Adam(self.generator.parameters(), lr=lr, betas=(b1, b2))
        opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=(b1, b2))

        # Return optimizers/schedulers (currently no scheduler)
        return [opt_g, opt_d], []

    def on_epoch_end(self):

        # Log sampled images
        sample_images = self(self.validation_z)
        grid = make_grid(sample_images)
        self.logger.experiment.add_image('generated_images', grid, self.current_epoch)


if __name__ == '__main__':

    # Global parameter
    image_dim = 128
    latent_dim = 100
    batch_size = 64

    # Initialize dataset
    tfs = transforms.Compose([
        ToPILImage(),
        Resize(image_dim),
        ToTensor()
    ])
    doppel_dataset = DoppelDataset(face_dir='../data/faces', transform=tfs)

    # Initialize data module
    doppel_data_module = DoppelDataModule(batch_size=batch_size)

    # Build models
    generator = DoppelGenerator(latent_dim=latent_dim)
    discriminator = DoppelDiscriminator()

    # Test generator
    x = torch.rand(batch_size, latent_dim, 1, 1)
    y = generator(x)
    print(f'Generator: x {x.size()} --> y {y.size()}')

    # Test discriminator
    x = torch.rand(batch_size, 3, 128, 128)
    y = discriminator(x)
    print(f'Discriminator: x {x.size()} --> y {y.size()}')

    # Build GAN
    doppelgan = DoppelGAN(batch_size=batch_size, channels=3, width=image_dim, height=image_dim, latent_dim=latent_dim)

    # Fit GAN
    trainer = pl.Trainer(gpus=0, max_epochs=5, progress_bar_refresh_rate=1)
    trainer.fit(model=doppelgan, datamodule=doppel_data_module)

完整追溯:

Traceback (most recent call last):
  File "/usr/local/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3437, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-2-28805d67d74b>", line 1, in <module>
    runfile('/Users/wouter/Documents/OneDrive/Hardnose/Projects/Coding/0002_DoppelGANger/doppelganger/gan.py', wdir='/Users/wouter/Documents/OneDrive/Hardnose/Projects/Coding/0002_DoppelGANger/doppelganger')
  File "/Applications/PyCharm.app/Contents/plugins/python/helpers/pydev/_pydev_bundle/pydev_umd.py", line 197, in runfile
    pydev_imports.execfile(filename, global_vars, local_vars)  # execute the script
  File "/Applications/PyCharm.app/Contents/plugins/python/helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile
    exec(compile(contents+"\n", file, 'exec'), glob, loc)
  File "/Users/wouter/Documents/OneDrive/Hardnose/Projects/Coding/0002_DoppelGANger/doppelganger/gan.py", line 298, in <module>
    trainer.fit(model=doppelgan, datamodule=doppel_data_module)
  File "/usr/local/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 510, in fit
    results = self.accelerator_backend.train()
  File "/usr/local/lib/python3.9/site-packages/pytorch_lightning/accelerators/accelerator.py", line 57, in train
    return self.train_or_test()
  File "/usr/local/lib/python3.9/site-packages/pytorch_lightning/accelerators/accelerator.py", line 74, in train_or_test
    results = self.trainer.train()
  File "/usr/local/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 561, in train
    self.train_loop.run_training_epoch()
  File "/usr/local/lib/python3.9/site-packages/pytorch_lightning/trainer/training_loop.py", line 550, in run_training_epoch
    batch_output = self.run_training_batch(batch, batch_idx, dataloader_idx)
  File "/usr/local/lib/python3.9/site-packages/pytorch_lightning/trainer/training_loop.py", line 718, in run_training_batch
    self.optimizer_step(optimizer, opt_idx, batch_idx, train_step_and_backward_closure)
  File "/usr/local/lib/python3.9/site-packages/pytorch_lightning/trainer/training_loop.py", line 485, in optimizer_step
    model_ref.optimizer_step(
  File "/usr/local/lib/python3.9/site-packages/pytorch_lightning/core/lightning.py", line 1298, in optimizer_step
    optimizer.step(closure=optimizer_closure)
  File "/usr/local/lib/python3.9/site-packages/pytorch_lightning/core/optimizer.py", line 286, in step
    self.__optimizer_step(*args, closure=closure, profiler_name=profiler_name, **kwargs)
  File "/usr/local/lib/python3.9/site-packages/pytorch_lightning/core/optimizer.py", line 144, in __optimizer_step
    optimizer.step(closure=closure, *args, **kwargs)
  File "/usr/local/lib/python3.9/site-packages/torch/autograd/grad_mode.py", line 26, in decorate_context
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.9/site-packages/torch/optim/adam.py", line 66, in step
    loss = closure()
  File "/usr/local/lib/python3.9/site-packages/pytorch_lightning/trainer/training_loop.py", line 708, in train_step_and_backward_closure
    result = self.training_step_and_backward(
  File "/usr/local/lib/python3.9/site-packages/pytorch_lightning/trainer/training_loop.py", line 806, in training_step_and_backward
    result = self.training_step(split_batch, batch_idx, opt_idx, hiddens)
  File "/usr/local/lib/python3.9/site-packages/pytorch_lightning/trainer/training_loop.py", line 319, in training_step
    training_step_output = self.trainer.accelerator_backend.training_step(args)
  File "/usr/local/lib/python3.9/site-packages/pytorch_lightning/accelerators/cpu_accelerator.py", line 62, in training_step
    return self._step(self.trainer.model.training_step, args)
  File "/usr/local/lib/python3.9/site-packages/pytorch_lightning/accelerators/cpu_accelerator.py", line 58, in _step
    output = model_step(*args)
  File "/Users/wouter/Documents/OneDrive/Hardnose/Projects/Coding/0002_DoppelGANger/doppelganger/gan.py", line 223, in training_step
    real_loss = self.adversarial_loss(self.discriminator(images), valid)
  File "/usr/local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/Users/wouter/Documents/OneDrive/Hardnose/Projects/Coding/0002_DoppelGANger/doppelganger/gan.py", line 154, in forward
    return self.model(input)
  File "/usr/local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/usr/local/lib/python3.9/site-packages/torch/nn/modules/container.py", line 117, in forward
    input = module(input)
  File "/usr/local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/usr/local/lib/python3.9/site-packages/torch/nn/modules/linear.py", line 93, in forward
    return F.linear(input, self.weight, self.bias)
  File "/usr/local/lib/python3.9/site-packages/torch/nn/functional.py", line 1690, in linear
    ret = torch.addmm(bias, input, weight.t())
RuntimeError: mat1 and mat2 shapes cannot be multiplied (7x9 and 25x1)

这道乘法题来自DoppelDiscriminator。有一个线性层

    nn.Linear(25, 1),

应该是

    nn.Linear(9, 1),

根据错误信息。