是什么导致 VAE(变分自动编码器)即使在训练后也输出随机噪声?

What could cause a VAE(Variational AutoEncoder) to output random noise even after training?

我已经在 CIFAR10 数据集上训练了一个 VAE。然而,当我尝试从 VAE 生成图像时,我得到的只是一堆灰色噪声。此 VAE 的实现遵循书中的实现 Generative Deep Learning,但代码使用 PyTorch 而不是 TensorFlow。

可以找到包含训练和生成的笔记本 here, while the actual implementation of the VAE can be found here

我试过:

  1. 禁用辍学。
  2. 增加潜在维度space。

None 的方法显示出任何改进。

我已经确认:

  1. 输入大小与输出大小匹配
  2. 随着训练过程中损失的减少,反向传播运行成功。

感谢您为 Colab 笔记本提供代码和 link! +1! 另外,您的代码写得很好并且易于阅读。除非我遗漏了什么,否则我认为您的代码有两个问题:

  1. 数据归一化
  2. VAE损失的实现。

关于 1.,您的 CIFAR10DataModule class 使用 mean = 0.5std = 0.5 对 CIFAR10 图像的 RGB 通道进行归一化.由于像素值最初在 [0,1] 范围内,因此归一化图像的像素值在 [-1,1] 范围内。但是,您的 Decoder class 对重建图像应用了 nn.Sigmoid() 激活。因此,您重建的图像具有 [0,1] 范围内的像素值。我建议删除这种均值标准归一化,以便“真实”图像和重建图像的像素值都在 [0,1] 范围内。

关于 2.:因为您处理的是 RGB 图像,所以 MSE 损失是有意义的。 MSE 损失背后的想法是“高斯解码器”。该解码器假设“真实图像”的像素值是由独立的高斯分布生成的,其平均值是重建图像的像素值(即解码器的输出)并具有给定的方差。您执行的重建损失(即 r_loss = F.mse_loss(predictions, targets))相当于固定方差。使用考虑 BCE 损失的 this paper, we can do better and obtain an analytic expression for the "optimal value" of this variance parameter. Finally, the reconstruction loss should be summed over all pixels (reduction = 'sum'). To understand why, have a look at analytic expression of the reconstruction loss (see, for instance, this blog post 的想法。

这是重构后的 LitVAE class 的样子:

class LitVAE(pl.LightningModule):
    def __init__(self,
                 learning_rate: float = 0.0005,
                 **kwargs) -> None:
        """
        Parameters
        ----------
        - `learning_rate: float`:
            learning rate for the optimizer
        - `**kwargs`:
            arguments to pass to the variational autoencoder constructor
        """
        super(LitVAE, self).__init__()
        
        self.learning_rate = learning_rate 

        self.vae = VariationalAutoEncoder(**kwargs)

    def forward(self, x) -> _tensor_size_3_t: 
        return self.vae(x)

    def training_step(self, batch, batch_idx):
        r_loss, kl_loss, sigma_opt = self.shared_step(batch)
        loss = r_loss + kl_loss
        
        self.log("train_loss_step", loss)
        return {"loss": loss, 'log':{"r_loss": r_loss / len(batch[0]), "kl_loss": kl_loss / len(batch[0]), 'sigma_opt': sigma_opt}}

    def training_epoch_end(self, outputs) -> None:
        # add computation graph
        if(self.current_epoch == 0):
            sample_input = torch.randn((1, 3, 32, 32))
            sample_model = LitVAE(**MODEL_PARAMS)
            
            self.logger.experiment.add_graph(sample_model, sample_input)
            
        epoch_loss = self.average_metric(outputs, "loss")
        self.logger.experiment.add_scalar("train_loss_epoch", epoch_loss, self.current_epoch)

    def validation_step(self, batch, batch_idx):
        r_loss, kl_loss, _ = self.shared_step(batch)
        loss = r_loss + kl_loss

        self.log("valid_loss_step", loss)

        return {"loss": loss}

    def validation_epoch_end(self, outputs) -> None:
        epoch_loss = self.average_metric(outputs, "loss")
        self.logger.experiment.add_scalar("valid_loss_epoch", epoch_loss, self.current_epoch)

    def test_step(self, batch, batch_idx):
        r_loss, kl_loss, _ = self.shared_step(batch)
        loss = r_loss + kl_loss
        
        self.log("test_loss_step", loss)
        return {"loss": loss}

    def test_epoch_end(self, outputs) -> None:
        epoch_loss = self.average_metric(outputs, "loss")
        self.logger.experiment.add_scalar("test_loss_epoch", epoch_loss, self.current_epoch)

    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=self.learning_rate)
        
    def shared_step(self, batch) -> torch.TensorType: 
        # images are both samples and targets thus original 
        # labels from the dataset are not required
        true_images, _ = batch

        # perform a forward pass through the VAE 
        # mean and log_variance are used to calculate the KL Divergence loss 
        # decoder_output represents the generated images 
        mean, log_variance, generated_images = self(true_images)

        r_loss, kl_loss, sigma_opt = self.calculate_loss(mean, log_variance, generated_images, true_images)
        return r_loss, kl_loss, sigma_opt

    def calculate_loss(self, mean, log_variance, predictions, targets):
        mse = F.mse_loss(predictions, targets, reduction='mean')
        log_sigma_opt = 0.5 * mse.log()
        r_loss = 0.5 * torch.pow((targets - predictions) / log_sigma_opt.exp(), 2) + log_sigma_opt
        r_loss = r_loss.sum()
        kl_loss = self._compute_kl_loss(mean, log_variance)
        return r_loss, kl_loss, log_sigma_opt.exp()

    def _compute_kl_loss(self, mean, log_variance): 
        return -0.5 * torch.sum(1 + log_variance - mean.pow(2) - log_variance.exp())

    def average_metric(self, metrics, metric_name):
        avg_metric = torch.stack([x[metric_name] for x in metrics]).mean()
        return avg_metric

经过 10 个 epoch 后,重建的图像是这样的: