PyTorch 生成对抗网络 (GAN) 的训练生成器
Training Generator of Generative Adversarial Network (GAN) in PyTorch
我正致力于在 PyTorch 1.5.0 中实现生成对抗网络 (GAN)。
为了计算生成器的损失,我计算了判别器错误分类全真实小批量和全(生成器生成的)假小批量的负概率。然后,我依次反向传播这两个部分,最后应用阶跃函数。
计算和反向传播作为生成的假数据错误分类函数的损失部分似乎很简单,因为在该损失项的反向传播过程中,反向路径通过生成器谁首先制作了假数据。
然而,全真实数据小批量的分类不涉及通过生成器传递数据。因此,我想知道以下代码片段是否仍会计算生成器的梯度,或者它是否根本不会计算任何梯度(因为向后路径不通过生成器并且鉴别器在更新生成器时处于评估模式)?
# Update generator #
net.generator.train()
net.discriminator.eval()
net.generator.zero_grad()
# All-real minibatch
x_real = get_all_real_minibatch()
y_true = torch.full((batch_size,), label_fake).long() # Pretend true targets were fake
y_pred = net.discriminator(x_real) # Produces softmax probability distribution over (0=label_fake,1=label_real)
loss_real = NLLLoss(torch.log(y_pred), y_true)
loss_real.backward()
optimizer_generator.step()
如果这不能按预期工作,我该如何让它工作?提前致谢!
没有梯度传播到生成器,因为没有使用生成器的任何参数执行计算。处于 eval 模式的鉴别器不会阻止梯度传播到生成器,尽管如果您使用的层在 eval 模式下与训练模式相比表现不同,例如 dropout,它们会略有不同。
真实图像的错误分类不是训练生成器的一部分,因为它不会从这些信息中获得任何信息。从概念上讲,生成器应该从鉴别器未能正确分类真实图像的事实中学到什么?生成器的唯一任务是创建一个假图像,使鉴别器认为它是真实的,因此与生成器唯一相关的信息是鉴别器是否能够识别假图像。如果鉴别器确实能够识别假图像,则生成器需要自我调整以创建更具说服力的假图像。
当然这不是二进制情况,但生成器总是试图改进假图像,以便鉴别器更加确信它是真实图像。生成器的目标不是让鉴别器产生怀疑(0.5 的概率是真的还是假的),而是鉴别器完全相信它是真的,即使它是假的。这就是为什么他们是敌对的,而不是合作的。
我正致力于在 PyTorch 1.5.0 中实现生成对抗网络 (GAN)。
为了计算生成器的损失,我计算了判别器错误分类全真实小批量和全(生成器生成的)假小批量的负概率。然后,我依次反向传播这两个部分,最后应用阶跃函数。
计算和反向传播作为生成的假数据错误分类函数的损失部分似乎很简单,因为在该损失项的反向传播过程中,反向路径通过生成器谁首先制作了假数据。
然而,全真实数据小批量的分类不涉及通过生成器传递数据。因此,我想知道以下代码片段是否仍会计算生成器的梯度,或者它是否根本不会计算任何梯度(因为向后路径不通过生成器并且鉴别器在更新生成器时处于评估模式)?
# Update generator #
net.generator.train()
net.discriminator.eval()
net.generator.zero_grad()
# All-real minibatch
x_real = get_all_real_minibatch()
y_true = torch.full((batch_size,), label_fake).long() # Pretend true targets were fake
y_pred = net.discriminator(x_real) # Produces softmax probability distribution over (0=label_fake,1=label_real)
loss_real = NLLLoss(torch.log(y_pred), y_true)
loss_real.backward()
optimizer_generator.step()
如果这不能按预期工作,我该如何让它工作?提前致谢!
没有梯度传播到生成器,因为没有使用生成器的任何参数执行计算。处于 eval 模式的鉴别器不会阻止梯度传播到生成器,尽管如果您使用的层在 eval 模式下与训练模式相比表现不同,例如 dropout,它们会略有不同。
真实图像的错误分类不是训练生成器的一部分,因为它不会从这些信息中获得任何信息。从概念上讲,生成器应该从鉴别器未能正确分类真实图像的事实中学到什么?生成器的唯一任务是创建一个假图像,使鉴别器认为它是真实的,因此与生成器唯一相关的信息是鉴别器是否能够识别假图像。如果鉴别器确实能够识别假图像,则生成器需要自我调整以创建更具说服力的假图像。
当然这不是二进制情况,但生成器总是试图改进假图像,以便鉴别器更加确信它是真实图像。生成器的目标不是让鉴别器产生怀疑(0.5 的概率是真的还是假的),而是鉴别器完全相信它是真的,即使它是假的。这就是为什么他们是敌对的,而不是合作的。