如何从 GAN 训练生成器?

How to train generator from GAN?

在阅读了 GAN 教程和代码示例后,我仍然不明白生成器是如何训练的。假设我们有一个简单的案例: - 生成器输入是噪声,输出是灰度图像 10x10 - 鉴别器输入是图像 10x10,输出是从 0 到 1 的单个值(假或真)

训练鉴别器很简单 - 将其输出视为真实的并期望它为 1。获取假输出并期望为 0。我们在这里使用真实输出大小 - 单个值。

但训练生成器不同 - 我们采用假输出(1 个值)并将其作为一个预期输出。但这听起来更像是再次训练鉴别器。生成器的输出是图像 10x10 我们如何只用 1 个单值来训练它?在这种情况下反向传播如何工作?

要训练生成器,您必须反向传播整个组合模型,同时冻结鉴别器的权重,以便仅更新生成器。

为此,我们必须计算 d(g(z; θg); θd),其中 θg 和 θd 是生成器和鉴别器的权重。要更新生成器,我们可以计算梯度 wrt。仅到 θg ∂loss(d(g(z; θg); θd)) / ∂θg,然后使用正常梯度下降更新 θg。

在 Keras 中,这可能看起来像这样(使用函数 API):

genInput = Input(input_shape)
discriminator = ...
generator = ...

discriminator.trainable = True
discriminator.compile(...)

discriminator.trainable = False
combined = Model(genInput, discriminator(generator(genInput)))
combined.compile(...)

通过将trainable设置为False,已经编译的模型不受影响,只会冻结未来编译的模型。因此,鉴别器可以作为独立模型进行训练,但在组合模型中被冻结。

然后,训练你的 GAN:

X_real = ...
noise = ...
X_gen = generator.predict(noise)

# This will only train the discriminator
loss_real = discriminator.train_on_batch(X_real, one_out)
loss_fake = discriminator.train_on_batch(X_gen, zero_out)

d_loss = 0.5 * np.add(loss_real, loss_fake)

noise = ...
# This will only train the generator.
g_loss = self.combined.train_on_batch(noise, one_out)

我想理解生成器训练过程的最好方法是修改所有训练循环。

对于每个纪元:

  1. 更新鉴别器:

    • 转发真实图像mini-batch通过鉴别器;

    • 计算判别器损失并计算反向传播的梯度;

    • 通过生成器生成假图像mini-batch;

    • 前向生成的假mini-batch通过鉴别器;

    • 计算鉴别器损失并推导反向传播的梯度;

    • 添加(真实mini-batch渐变,假mini-batch渐变)

    • 更新判别器(使用 Adam 或 SGD)。

  2. 更新生成器:

    • 翻转目标:生成器将假图像标记为真实图像。注意:此步骤确保对生成器使用 cross-entropy 最小化。如果我们继续实施 GAN minmax 游戏,它有助于克服生成器梯度消失的问题。

    • 转发假图片mini-batch通过更新的判别器;

    • 根据更新后的判别器输出计算生成器损失,例如:

    损失函数(判别器估计的假图为真图的概率,1)。
    注意:这里的 1 表示 Generator label for fake images as real.

    • 更新生成器(使用 Adam 或 SGD)

希望对您有所帮助。从训练过程可以看出,GAN选手有点"cooperative, in the sense that the discriminator estimates the ratio of data to model distribution densities and then freely shares this information with the generator. From this point of view, the discriminator is more like a teacher instructing the generator in how to improve than an adversary"(引自I.Goodfellow tutorial)。