无法理解 tensorflow 文档中使用的 GAN 模型的损失函数
Can't understand the loss functions for the GAN model used in the tensorflow documentation
我无法理解 tensorflow 文档中 GAN 模型中的损失函数。为什么将 tf.ones_like()
用于 real_loss 而将 tf.zeros_like()
用于 假输出 ??
def discriminator_loss(real_output,fake_output):
real_loss = cross_entropy(tf.ones_like(real_output),real_output)
fake_loss = cross_entropy(tf.zeros_like(fake_output),fake_output)
total_loss = real_loss + fake_loss
return total_loss
我们有以下损失函数,我们需要以 mini-max 方式最小化(或者 min-max,如果你想这样称呼的话)。
- generator_loss = -log(generated_labels)
- discriminator_loss = -log(real_labels) - log(1 - generated_labels)
其中 real_output
= real_labels 且 fake_output
= generated_labels.
现在,考虑到这一点,让我们看看 TensorFlow 文档中的代码片段代表什么:
real_loss = cross_entropy(tf.ones_like(real_output), real_output)
计算为
- real_loss = -1 * log(real_output) - (1 - 1) * log(1 - real_output) = -log(real_output)
fake_loss = cross_entropy(tf.zeros_like(fake_output),fake_output)
计算为
- fake_loss = -0 * log(fake_output) - (1 - 0) * log(1 - fake_output) = -log(1 - fake_output )
total_loss = real_loss + fake_loss
计算为
- total_loss = -log(real_output) - log(1 - fake_output)
显然,我们得到了我们想要最小化的 mini-max 游戏中判别器的损失函数。
我无法理解 tensorflow 文档中 GAN 模型中的损失函数。为什么将 tf.ones_like()
用于 real_loss 而将 tf.zeros_like()
用于 假输出 ??
def discriminator_loss(real_output,fake_output):
real_loss = cross_entropy(tf.ones_like(real_output),real_output)
fake_loss = cross_entropy(tf.zeros_like(fake_output),fake_output)
total_loss = real_loss + fake_loss
return total_loss
我们有以下损失函数,我们需要以 mini-max 方式最小化(或者 min-max,如果你想这样称呼的话)。
- generator_loss = -log(generated_labels)
- discriminator_loss = -log(real_labels) - log(1 - generated_labels)
其中 real_output
= real_labels 且 fake_output
= generated_labels.
现在,考虑到这一点,让我们看看 TensorFlow 文档中的代码片段代表什么:
real_loss = cross_entropy(tf.ones_like(real_output), real_output)
计算为- real_loss = -1 * log(real_output) - (1 - 1) * log(1 - real_output) = -log(real_output)
fake_loss = cross_entropy(tf.zeros_like(fake_output),fake_output)
计算为- fake_loss = -0 * log(fake_output) - (1 - 0) * log(1 - fake_output) = -log(1 - fake_output )
total_loss = real_loss + fake_loss
计算为- total_loss = -log(real_output) - log(1 - fake_output)
显然,我们得到了我们想要最小化的 mini-max 游戏中判别器的损失函数。