尝试使用 GANs 模型第二次倒退图表

Trying to backward through the graph a second time with GANs model

我正在尝试设置一个简单的 GAN 训练循​​环,但出现以下错误:

RuntimeError: Trying to backward through the graph a second time (or directly access saved variables after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved variables after calling backward.

for epoch in range(N_EPOCHS):
    # gets data for the generator
    for i, batch in enumerate(dataloader, 0):

        # passing target images to the Discriminator
        global_disc.zero_grad()
        output_disc = global_disc(batch.to(device))
        error_target = loss(output_disc, torch.ones(output_disc.shape).cuda())
        error_target.backward()

        # apply mask to the images
        batch = apply_mask(batch)

        # passes fake images to the Discriminator
        global_output, local_output = gen(batch.to(device))
        output_disc = global_disc(global_output.detach())
        error_fake = loss(output_disc, torch.zeros(output_disc.shape).to(device))
        error_fake.backward()

        # combines the errors
        error_total = error_target + error_fake
        optimizer_disc.step()

        # updates the generator
        gen.zero_grad()
        error_gen = loss(output_disc, torch.ones(output_disc.shape).to(device))
        error_gen.backward()
        optimizer_gen.step()

        break
    break

据我所知,我的操作顺序正确,我正在将梯度归零,并且在生成器的输出进入鉴别器之前分离它。

这个 article 很有帮助,但我仍然 运行 陷入我不明白的东西。

我想到了两个要点:

  1. 你应该为你的发电机提供噪音,而不是真正的输入:

    global_output, local_output = gen(noise.to(device))
    

以上noise应该有合适的形状(它是你的生成器的输入)。

  1. 为了优化生成器,您需要重新计算鉴别器输出,因为它已经被反向传播。只需添加此行即可重新计算 output_disc:

    # updates the generator
    gen.zero_grad()
    output_disc = global_disc(global_output)
    # ...
    

请参阅 PyTorch 提供的 this tutorial 以获取完整演练。