Pytorch:GAN 训练中的不同行为,具有不同但概念上等效的代码

Pytorch : different behaviours in GAN training with different, but conceptually equivalent, code

我正在尝试在 Pytorch 中实现一个简单的 GAN。以下训练代码有效:

    for epoch in range(max_epochs):  # loop over the dataset multiple times
        print(f'epoch: {epoch}')
        running_loss = 0.0

        for batch_idx,(data,_) in enumerate(data_gen_fn):
   
            # data preparation
            real_data            = data
            input_shape          = real_data.shape
            inputs_generator     = torch.randn(*input_shape).detach() 

            # generator forward
            fake_data            = generator(inputs_generator).detach()
            # discriminator forward
            optimizer_generator.zero_grad()
            optimizer_discriminator.zero_grad()

            #################### ALERT CODE #######################
            predictions_on_real = discriminator(real_data)
            predictions_on_fake = discriminator(fake_data)

            predictions = torch.cat((predictions_on_real,
                                     predictions_on_fake), dim=0)
           #########################################################

            # loss discriminator
            labels_real_fake           = torch.tensor([1]*batch_size + [0]*batch_size)
            loss_discriminator_batch   = criterion_discriminator(predictions, 
                                                          labels_real_fake)
            # update discriminator
            loss_discriminator_batch.backward()
            optimizer_discriminator.step()


            # generator
            # zero the parameter gradients
            optimizer_discriminator.zero_grad()
            optimizer_generator.zero_grad()

            fake_data            = generator(inputs_generator) # make again fake data but without detaching
            predictions_on_fake  = discriminator(fake_data) # D(G(encoding))
            
            # loss generator           
            labels_fake          = torch.tensor([1]*batch_size)
            loss_generator_batch = criterion_generator(predictions_on_fake, 
                                                       labels_fake)
  
            loss_generator_batch.backward()  # dL(D(G(encoding)))/dW_{G,D}
            optimizer_generator.step()

如果我为每次迭代绘制生成的图像,我发现生成的图像看起来像真实的图像,因此训练过程似乎运行良好。

但是,如果我尝试更改 ALERT CODE 部分中的代码,即,而不是:

   #################### ALERT CODE #######################
   predictions_on_real = discriminator(real_data)
   predictions_on_fake = discriminator(fake_data)

   predictions = torch.cat((predictions_on_real,
                            predictions_on_fake), dim=0)
   #########################################################

我使用以下:

   #################### ALERT CODE #######################
   predictions = discriminator(torch.cat( (real_data, fake_data), dim=0))
   #######################################################

这在概念上是相同的(简而言之,不是在[=15=上做两个不同的forward,前者在[=16=上,后者在[=17上=] 数据,最后连接结果,用新代码我首先连接 realfake 数据,最后我只对连接的数据进行一次前向传递。

然而,此代码版本不起作用,即生成的图像似乎总是随机噪声。

对此行为有任何解释吗?

为什么我们的结果不同?

在同一批次或不同批次中提供输入,可以 如果模型包含批次中不同元素之间的依赖关系,则可能会有所不同。到目前为止,当前深度学习模型中最常见的来源是 批量归一化 。正如您所提到的,鉴别器确实包括 batchnorm,因此这可能是不同行为的原因。这是一个例子。使用单个数字和 4 的批量大小:

features = [1., 2., 5., 6.]
print("mean {}, std {}".format(np.mean(features), np.std(features)))

print("normalized features", (features - np.mean(features)) / np.std(features))

>>>mean 3.5, std 2.0615528128088303
>>>normalized features [-1.21267813 -0.72760688  0.72760688  1.21267813]

现在我们将批次分成两部分。第一部分:

features = [1., 2.]
print("mean {}, std {}".format(np.mean(features), np.std(features)))

print("normalized features", (features - np.mean(features)) / np.std(features))

>>>mean 1.5, std 0.5
>>>normalized features [-1.  1.]

第二部分:

features = [5., 6.]
print("mean {}, std {}".format(np.mean(features), np.std(features)))

print("normalized features", (features - np.mean(features)) / np.std(features))

>>>mean 5.5, std 0.5
>>>normalized features [-1.  1.]

正如我们所见,在 split-batch 版本中,这两个批次被归一化为 完全相同的 数字,即使输入非常不同。另一方面,在 joint-batch 版本中,较大的数字仍然大于较小的数字,因为它们使用相同的统计数据进行了归一化。

为什么这很重要?

对于深度学习,总是很难说,尤其是 GAN 及其复杂的训练动态。 可能 的解释是,正如我们在上面的示例中看到的那样,即使原始输入完全不同,单独的批次也会在归一化后产生更多相似的特征。这可能有助于早期训练,因为生成器倾向于输出“垃圾”,其统计数据与真实数据有很大不同。

对于联合批次,这些不同的统计数据使鉴别器很容易区分真实数据和生成数据,我们最终会遇到鉴别器“压倒”生成器的情况。

然而,通过使用单独的批次,不同的归一化会导致生成的数据和真实数据看起来更相似,这使得鉴别器的任务变得不那么琐碎,并允许生成器学习。