DCGAN的两个问题:data normalization和fake/real batch
Two questions on DCGAN: data normalization and fake/real batch
我正在分析在图像生成中使用 DCGAN + Reptile 的元学习 class。
关于这段代码我有两个问题。
第一个问题:为什么在 DCGAN 训练期间(第 74 行)
training_batch = torch.cat ([real_batch, fake_batch])
是由真例(real_batch)和假例(fake_batch)组成的training_batch?为什么要通过混合真实和虚假图像来进行训练?我见过很多 DCGAN,但从未以这种方式完成训练。
第二个问题:为什么在训练时使用normalize_data函数(第49行)和unnormalize_data函数(第55行)?
def normalize_data(data):
data *= 2
data -= 1
return data
def unnormalize_data(data):
data += 1
data /= 2
return data
项目使用的是Mnist数据集,如果我想使用像CIFAR10这样的颜色数据集,是否需要修改那些归一化?
如果你仔细阅读文档(查看 def initialize_gan(self):
函数),你会发现
self.meta_g == Generator
self.meta_d == Discriminator
并且在您引用的行中,fake_batch 被定义为生成器的一部分:
fake_batch = self.meta_g(torch.tensor(np.random.normal(size=(self.batch_size, self.z_shape)), dtype=torch.float, device=device))
training_batch = torch.cat([real_batch, fake_batch])
因此,因为它是一个 GAN,你给鉴别器提供了假的和真实的图像,鉴别器必须弄清楚它是哪一个。
关于你的第二个问题,我假设,但我不完全确定这两个函数用于生成假图像?我会仔细检查一下。
有帮助吗?
训练 GAN 涉及给鉴别器提供真实和虚假的例子。通常,您会看到它们是在两个不同的场合给出的。默认情况下,torch.cat
在第一个维度 (dim=0
) 上连接张量,这是批次维度。因此它只是将批量大小加倍,其中前半部分是真实图像,后半部分是假图像。
为了计算损失,他们调整了目标,使得前半部分(原始批量大小)被分类为真实的,后半部分被分类为假的。来自 initialize_gan
:
self.discriminator_targets = torch.tensor([1] * self.batch_size + [-1] * self.batch_size, dtype=torch.float, device=device).view(-1, 1)
图像用 [0, 1] 之间的浮点值表示。归一化会改变它以产生 [-1, 1] 之间的值。 GAN 通常在生成器中使用 tanh,因此假图像的值介于 [-1, 1] 之间,因此真实图像应该在同一范围内,否则判别器很难区分假图像和真实图像.
如果要显示这些图像,需要先将它们去标准化,即将它们转换为[0, 1]之间的值。
The project uses the Mnist dataset, if I wanted to use a color dataset like CIFAR10, do I have to modify those normalizations?
不,你不需要改变它们,因为彩色图像也有它们在 [0, 1] 之间的值,只是有更多的值,代表 3 个通道 (RGB)。
我正在分析在图像生成中使用 DCGAN + Reptile 的元学习 class。
关于这段代码我有两个问题。
第一个问题:为什么在 DCGAN 训练期间(第 74 行)
training_batch = torch.cat ([real_batch, fake_batch])
是由真例(real_batch)和假例(fake_batch)组成的training_batch?为什么要通过混合真实和虚假图像来进行训练?我见过很多 DCGAN,但从未以这种方式完成训练。
第二个问题:为什么在训练时使用normalize_data函数(第49行)和unnormalize_data函数(第55行)?
def normalize_data(data):
data *= 2
data -= 1
return data
def unnormalize_data(data):
data += 1
data /= 2
return data
项目使用的是Mnist数据集,如果我想使用像CIFAR10这样的颜色数据集,是否需要修改那些归一化?
如果你仔细阅读文档(查看 def initialize_gan(self):
函数),你会发现
self.meta_g == Generator
self.meta_d == Discriminator
并且在您引用的行中,fake_batch 被定义为生成器的一部分:
fake_batch = self.meta_g(torch.tensor(np.random.normal(size=(self.batch_size, self.z_shape)), dtype=torch.float, device=device))
training_batch = torch.cat([real_batch, fake_batch])
因此,因为它是一个 GAN,你给鉴别器提供了假的和真实的图像,鉴别器必须弄清楚它是哪一个。
关于你的第二个问题,我假设,但我不完全确定这两个函数用于生成假图像?我会仔细检查一下。
有帮助吗?
训练 GAN 涉及给鉴别器提供真实和虚假的例子。通常,您会看到它们是在两个不同的场合给出的。默认情况下,torch.cat
在第一个维度 (dim=0
) 上连接张量,这是批次维度。因此它只是将批量大小加倍,其中前半部分是真实图像,后半部分是假图像。
为了计算损失,他们调整了目标,使得前半部分(原始批量大小)被分类为真实的,后半部分被分类为假的。来自 initialize_gan
:
self.discriminator_targets = torch.tensor([1] * self.batch_size + [-1] * self.batch_size, dtype=torch.float, device=device).view(-1, 1)
图像用 [0, 1] 之间的浮点值表示。归一化会改变它以产生 [-1, 1] 之间的值。 GAN 通常在生成器中使用 tanh,因此假图像的值介于 [-1, 1] 之间,因此真实图像应该在同一范围内,否则判别器很难区分假图像和真实图像.
如果要显示这些图像,需要先将它们去标准化,即将它们转换为[0, 1]之间的值。
The project uses the Mnist dataset, if I wanted to use a color dataset like CIFAR10, do I have to modify those normalizations?
不,你不需要改变它们,因为彩色图像也有它们在 [0, 1] 之间的值,只是有更多的值,代表 3 个通道 (RGB)。