需要帮助理解 CGAN 中的标签输入
Need help understanding the label input in a CGAN
我正在尝试实现 CGAN。我知道在卷积生成器和鉴别器模型中,您可以通过添加代表标签的深度来增加输入量。因此,如果您的数据中有 10 类,则您的生成器和鉴别器都将基础深度 + 10 作为其输入量。
但是,我正在网上阅读各种实现,但我似乎无法找到他们实际获取此标签的位置。当然,CGAN 不能是无监督的,因为您需要获取要输入的标签。例如在 cifar10 中,如果您在青蛙的真实图像上训练鉴别器,则需要 'frog' 注释。
这是我正在研究的一段代码:
class CGAN(object):
def __init__(self, args):
# parameters
self.epoch = args.epoch
self.batch_size = args.batch_size
self.save_dir = args.save_dir
self.result_dir = args.result_dir
self.dataset = args.dataset
self.log_dir = args.log_dir
self.gpu_mode = args.gpu_mode
self.model_name = args.gan_type
self.input_size = args.input_size
self.z_dim = 62
self.class_num = 10
self.sample_num = self.class_num ** 2
# load dataset
self.data_loader = dataloader(self.dataset, self.input_size, self.batch_size)
data = self.data_loader.__iter__().__next__()[0]
# networks init
self.G = generator(input_dim=self.z_dim, output_dim=data.shape[1], input_size=self.input_size, class_num=self.class_num)
self.D = discriminator(input_dim=data.shape[1], output_dim=1, input_size=self.input_size, class_num=self.class_num)
self.G_optimizer = optim.Adam(self.G.parameters(), lr=args.lrG, betas=(args.beta1, args.beta2))
self.D_optimizer = optim.Adam(self.D.parameters(), lr=args.lrD, betas=(args.beta1, args.beta2))
if self.gpu_mode:
self.G.cuda()
self.D.cuda()
self.BCE_loss = nn.BCELoss().cuda()
else:
self.BCE_loss = nn.BCELoss()
print('---------- Networks architecture -------------')
utils.print_network(self.G)
utils.print_network(self.D)
print('-----------------------------------------------')
# fixed noise & condition
self.sample_z_ = torch.zeros((self.sample_num, self.z_dim))
for i in range(self.class_num):
self.sample_z_[i*self.class_num] = torch.rand(1, self.z_dim)
for j in range(1, self.class_num):
self.sample_z_[i*self.class_num + j] = self.sample_z_[i*self.class_num]
temp = torch.zeros((self.class_num, 1))
for i in range(self.class_num):
temp[i, 0] = i
temp_y = torch.zeros((self.sample_num, 1))
for i in range(self.class_num):
temp_y[i*self.class_num: (i+1)*self.class_num] = temp
self.sample_y_ = torch.zeros((self.sample_num, self.class_num)).scatter_(1, temp_y.type(torch.LongTensor), 1)
if self.gpu_mode:
self.sample_z_, self.sample_y_ = self.sample_z_.cuda(), self.sample_y_.cuda()
def train(self):
self.train_hist = {}
self.train_hist['D_loss'] = []
self.train_hist['G_loss'] = []
self.train_hist['per_epoch_time'] = []
self.train_hist['total_time'] = []
self.y_real_, self.y_fake_ = torch.ones(self.batch_size, 1), torch.zeros(self.batch_size, 1)
if self.gpu_mode:
self.y_real_, self.y_fake_ = self.y_real_.cuda(), self.y_fake_.cuda()
self.D.train()
print('training start!!')
start_time = time.time()
for epoch in range(self.epoch):
self.G.train()
epoch_start_time = time.time()
for iter, (x_, y_) in enumerate(self.data_loader):
if iter == self.data_loader.dataset.__len__() // self.batch_size:
break
z_ = torch.rand((self.batch_size, self.z_dim))
y_vec_ = torch.zeros((self.batch_size, self.class_num)).scatter_(1, y_.type(torch.LongTensor).unsqueeze(1), 1)
y_fill_ = y_vec_.unsqueeze(2).unsqueeze(3).expand(self.batch_size, self.class_num, self.input_size, self.input_size)
if self.gpu_mode:
x_, z_, y_vec_, y_fill_ = x_.cuda(), z_.cuda(), y_vec_.cuda(), y_fill_.cuda()
# update D network
self.D_optimizer.zero_grad()
D_real = self.D(x_, y_fill_)
D_real_loss = self.BCE_loss(D_real, self.y_real_)
G_ = self.G(z_, y_vec_)
D_fake = self.D(G_, y_fill_)
D_fake_loss = self.BCE_loss(D_fake, self.y_fake_)
D_loss = D_real_loss + D_fake_loss
self.train_hist['D_loss'].append(D_loss.item())
D_loss.backward()
self.D_optimizer.step()
# update G network
self.G_optimizer.zero_grad()
G_ = self.G(z_, y_vec_)
D_fake = self.D(G_, y_fill_)
G_loss = self.BCE_loss(D_fake, self.y_real_)
self.train_hist['G_loss'].append(G_loss.item())
G_loss.backward()
self.G_optimizer.step()
似乎 y_vec_ 和 y_fill_ 是图像的标签,但在 y_fill_ 用于为鉴别器标记真实图像的实例中,它等于y_fill_ = y_vec_.unsqueeze(2).unsqueeze(3).expand(self.batch_size, self.class_num, self.input_size, self.input_size)
它似乎没有从数据集中获取任何关于标签的信息?
它是如何给鉴别器正确的标签的?
谢谢!
y_fill_
基于 y_vec_
,而 y_vec_
基于 y_
,因此他们从正确的小批量读取标签信息。您可能会对 scatter
操作感到困惑,基本上代码所做的是将标签转换为单热编码
我正在尝试实现 CGAN。我知道在卷积生成器和鉴别器模型中,您可以通过添加代表标签的深度来增加输入量。因此,如果您的数据中有 10 类,则您的生成器和鉴别器都将基础深度 + 10 作为其输入量。
但是,我正在网上阅读各种实现,但我似乎无法找到他们实际获取此标签的位置。当然,CGAN 不能是无监督的,因为您需要获取要输入的标签。例如在 cifar10 中,如果您在青蛙的真实图像上训练鉴别器,则需要 'frog' 注释。
这是我正在研究的一段代码:
class CGAN(object):
def __init__(self, args):
# parameters
self.epoch = args.epoch
self.batch_size = args.batch_size
self.save_dir = args.save_dir
self.result_dir = args.result_dir
self.dataset = args.dataset
self.log_dir = args.log_dir
self.gpu_mode = args.gpu_mode
self.model_name = args.gan_type
self.input_size = args.input_size
self.z_dim = 62
self.class_num = 10
self.sample_num = self.class_num ** 2
# load dataset
self.data_loader = dataloader(self.dataset, self.input_size, self.batch_size)
data = self.data_loader.__iter__().__next__()[0]
# networks init
self.G = generator(input_dim=self.z_dim, output_dim=data.shape[1], input_size=self.input_size, class_num=self.class_num)
self.D = discriminator(input_dim=data.shape[1], output_dim=1, input_size=self.input_size, class_num=self.class_num)
self.G_optimizer = optim.Adam(self.G.parameters(), lr=args.lrG, betas=(args.beta1, args.beta2))
self.D_optimizer = optim.Adam(self.D.parameters(), lr=args.lrD, betas=(args.beta1, args.beta2))
if self.gpu_mode:
self.G.cuda()
self.D.cuda()
self.BCE_loss = nn.BCELoss().cuda()
else:
self.BCE_loss = nn.BCELoss()
print('---------- Networks architecture -------------')
utils.print_network(self.G)
utils.print_network(self.D)
print('-----------------------------------------------')
# fixed noise & condition
self.sample_z_ = torch.zeros((self.sample_num, self.z_dim))
for i in range(self.class_num):
self.sample_z_[i*self.class_num] = torch.rand(1, self.z_dim)
for j in range(1, self.class_num):
self.sample_z_[i*self.class_num + j] = self.sample_z_[i*self.class_num]
temp = torch.zeros((self.class_num, 1))
for i in range(self.class_num):
temp[i, 0] = i
temp_y = torch.zeros((self.sample_num, 1))
for i in range(self.class_num):
temp_y[i*self.class_num: (i+1)*self.class_num] = temp
self.sample_y_ = torch.zeros((self.sample_num, self.class_num)).scatter_(1, temp_y.type(torch.LongTensor), 1)
if self.gpu_mode:
self.sample_z_, self.sample_y_ = self.sample_z_.cuda(), self.sample_y_.cuda()
def train(self):
self.train_hist = {}
self.train_hist['D_loss'] = []
self.train_hist['G_loss'] = []
self.train_hist['per_epoch_time'] = []
self.train_hist['total_time'] = []
self.y_real_, self.y_fake_ = torch.ones(self.batch_size, 1), torch.zeros(self.batch_size, 1)
if self.gpu_mode:
self.y_real_, self.y_fake_ = self.y_real_.cuda(), self.y_fake_.cuda()
self.D.train()
print('training start!!')
start_time = time.time()
for epoch in range(self.epoch):
self.G.train()
epoch_start_time = time.time()
for iter, (x_, y_) in enumerate(self.data_loader):
if iter == self.data_loader.dataset.__len__() // self.batch_size:
break
z_ = torch.rand((self.batch_size, self.z_dim))
y_vec_ = torch.zeros((self.batch_size, self.class_num)).scatter_(1, y_.type(torch.LongTensor).unsqueeze(1), 1)
y_fill_ = y_vec_.unsqueeze(2).unsqueeze(3).expand(self.batch_size, self.class_num, self.input_size, self.input_size)
if self.gpu_mode:
x_, z_, y_vec_, y_fill_ = x_.cuda(), z_.cuda(), y_vec_.cuda(), y_fill_.cuda()
# update D network
self.D_optimizer.zero_grad()
D_real = self.D(x_, y_fill_)
D_real_loss = self.BCE_loss(D_real, self.y_real_)
G_ = self.G(z_, y_vec_)
D_fake = self.D(G_, y_fill_)
D_fake_loss = self.BCE_loss(D_fake, self.y_fake_)
D_loss = D_real_loss + D_fake_loss
self.train_hist['D_loss'].append(D_loss.item())
D_loss.backward()
self.D_optimizer.step()
# update G network
self.G_optimizer.zero_grad()
G_ = self.G(z_, y_vec_)
D_fake = self.D(G_, y_fill_)
G_loss = self.BCE_loss(D_fake, self.y_real_)
self.train_hist['G_loss'].append(G_loss.item())
G_loss.backward()
self.G_optimizer.step()
似乎 y_vec_ 和 y_fill_ 是图像的标签,但在 y_fill_ 用于为鉴别器标记真实图像的实例中,它等于y_fill_ = y_vec_.unsqueeze(2).unsqueeze(3).expand(self.batch_size, self.class_num, self.input_size, self.input_size)
它似乎没有从数据集中获取任何关于标签的信息? 它是如何给鉴别器正确的标签的?
谢谢!
y_fill_
基于 y_vec_
,而 y_vec_
基于 y_
,因此他们从正确的小批量读取标签信息。您可能会对 scatter
操作感到困惑,基本上代码所做的是将标签转换为单热编码