具有判别器编译问题的 VAE
VAE with a discriminator compiling problem
与原生生成模型相反,此 vae 的输入是 RGB 图像。在这里,如果我使用 add_loss
方法编译 self.combined
,损失大约在 15000 到 -22000 之间。使用 mse
编译工作正常。
def __init__(self,type = 'landmark'):
self.latent_dim = 128
self.input_shape = (128,128,3)
self.batch_size = 1
self.original_dim = self.latent_dim*self.latent_dim
patch = int(self.input_shape[0] / 2**4)
self.disc_patch = (patch, patch, 1)
optimizer = tf.keras.optimizers.Adam(0.0002, 0.5)
pd = patch_discriminator(type)
self.discriminator = pd.discriminator()
self.discriminator.compile(loss = 'binary_crossentropy',optimizer = optimizer)
self.discriminator.trainable = False
vae = VAE(self.latent_dim,type = type)
encoder = vae.inference_net()
decoder = vae.generative_net()
if type == 'image':
self.orig_out = tf.random.normal(shape = (self.batch_size,128,128,3))
else:
self.orig_out = tf.random.normal(shape = (self.batch_size,128,128,1))
vae_input = tf.keras.layers.Input(shape = self.input_shape)
self.encoder_out = encoder(vae_input)
self.decoder_out = decoder(self.encoder_out[2])
self.generator = tf.keras.Model(vae_input,self.decoder_out)
vae_loss = self.compute_loss()
self.generator.add_loss(vae_loss)
self.generator.compile(optimizer = optimizer)
valid = self.discriminator([self.decoder_out,self.decoder_out])
self.combined = tf.keras.Model(vae_input,valid)
self.combined.add_loss(vae_loss)
self.combined.compile(optimizer = optimizer)
# self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)
self.dl = DataLoader()
compute loss 计算 VAE 的 kl 损失。最初 self.orig_out
被设置为正常张量,并在下面的训练循环中更新。
def compute_loss(self):
bce = tf.keras.losses.BinaryCrossentropy()
reconstruction_loss = bce(self.decoder_out,self.orig_out)
reconstruction_loss = self.original_dim*reconstruction_loss
z_mean = self.encoder_out[0]
z_log_var = self.encoder_out[1]
kl_loss = 1 + z_log_var - K.square(z_mean) - K.exp(z_log_var)
kl_loss = K.sum(kl_loss, axis=-1)
kl_loss *= -0.5
vae_loss = K.mean(reconstruction_loss + kl_loss)
return vae_loss
训练循环:
def train(self,batch_size = 1,epochs = 10):
start_time = datetime.datetime.now()
valid = np.ones((batch_size,) + self.disc_patch)
fake = np.zeros((batch_size,) + self.disc_patch)
threshold = epochs//10
for epoch in range(epochs):
for batch_i,(imA,imB,n_batches) in enumerate(self.dl.load_batch(target='landmark',batch_size=batch_size)):
self.orig_out = tf.convert_to_tensor(imB, dtype=tf.float32)
fakeA = self.generator.predict(imA)
d_real_loss = self.discriminator.train_on_batch([imB,imB],valid)
d_fake_loss = self.discriminator.train_on_batch([imB,fakeA],fake)
d_loss = 0.5*np.add(d_real_loss,d_fake_loss)
combined_loss = self.combined.train_on_batch(imA)
#combined_loss = self.combined.train_on_batch(imA,valid)
elapsed_time = datetime.datetime.now() - start_time
print (f"[Epoch {epoch}/{epochs}] [Batch {batch_i}/{n_batches}] [D loss: {d_loss}] [G loss: {combined_loss}] time: {elapsed_time}")
如果我使用 add_loss()
方法编译 self.combined
且 kl loss,我无法在 train_on_batch
期间传递输出,如上所示。因此生成器不会学习并产生随机输出。如何使用 kl loss 编译带鉴别器的 vae?
我不知道这是否是正确的答案,但可以更轻松地使用 Tensorflow 对 VAE 进行建模,因为它处理自定义训练循环。
您可以关注此 ,其中可能包含与您的问题相关的一些信息。
与原生生成模型相反,此 vae 的输入是 RGB 图像。在这里,如果我使用 add_loss
方法编译 self.combined
,损失大约在 15000 到 -22000 之间。使用 mse
编译工作正常。
def __init__(self,type = 'landmark'):
self.latent_dim = 128
self.input_shape = (128,128,3)
self.batch_size = 1
self.original_dim = self.latent_dim*self.latent_dim
patch = int(self.input_shape[0] / 2**4)
self.disc_patch = (patch, patch, 1)
optimizer = tf.keras.optimizers.Adam(0.0002, 0.5)
pd = patch_discriminator(type)
self.discriminator = pd.discriminator()
self.discriminator.compile(loss = 'binary_crossentropy',optimizer = optimizer)
self.discriminator.trainable = False
vae = VAE(self.latent_dim,type = type)
encoder = vae.inference_net()
decoder = vae.generative_net()
if type == 'image':
self.orig_out = tf.random.normal(shape = (self.batch_size,128,128,3))
else:
self.orig_out = tf.random.normal(shape = (self.batch_size,128,128,1))
vae_input = tf.keras.layers.Input(shape = self.input_shape)
self.encoder_out = encoder(vae_input)
self.decoder_out = decoder(self.encoder_out[2])
self.generator = tf.keras.Model(vae_input,self.decoder_out)
vae_loss = self.compute_loss()
self.generator.add_loss(vae_loss)
self.generator.compile(optimizer = optimizer)
valid = self.discriminator([self.decoder_out,self.decoder_out])
self.combined = tf.keras.Model(vae_input,valid)
self.combined.add_loss(vae_loss)
self.combined.compile(optimizer = optimizer)
# self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)
self.dl = DataLoader()
compute loss 计算 VAE 的 kl 损失。最初 self.orig_out
被设置为正常张量,并在下面的训练循环中更新。
def compute_loss(self):
bce = tf.keras.losses.BinaryCrossentropy()
reconstruction_loss = bce(self.decoder_out,self.orig_out)
reconstruction_loss = self.original_dim*reconstruction_loss
z_mean = self.encoder_out[0]
z_log_var = self.encoder_out[1]
kl_loss = 1 + z_log_var - K.square(z_mean) - K.exp(z_log_var)
kl_loss = K.sum(kl_loss, axis=-1)
kl_loss *= -0.5
vae_loss = K.mean(reconstruction_loss + kl_loss)
return vae_loss
训练循环:
def train(self,batch_size = 1,epochs = 10):
start_time = datetime.datetime.now()
valid = np.ones((batch_size,) + self.disc_patch)
fake = np.zeros((batch_size,) + self.disc_patch)
threshold = epochs//10
for epoch in range(epochs):
for batch_i,(imA,imB,n_batches) in enumerate(self.dl.load_batch(target='landmark',batch_size=batch_size)):
self.orig_out = tf.convert_to_tensor(imB, dtype=tf.float32)
fakeA = self.generator.predict(imA)
d_real_loss = self.discriminator.train_on_batch([imB,imB],valid)
d_fake_loss = self.discriminator.train_on_batch([imB,fakeA],fake)
d_loss = 0.5*np.add(d_real_loss,d_fake_loss)
combined_loss = self.combined.train_on_batch(imA)
#combined_loss = self.combined.train_on_batch(imA,valid)
elapsed_time = datetime.datetime.now() - start_time
print (f"[Epoch {epoch}/{epochs}] [Batch {batch_i}/{n_batches}] [D loss: {d_loss}] [G loss: {combined_loss}] time: {elapsed_time}")
如果我使用 add_loss()
方法编译 self.combined
且 kl loss,我无法在 train_on_batch
期间传递输出,如上所示。因此生成器不会学习并产生随机输出。如何使用 kl loss 编译带鉴别器的 vae?
我不知道这是否是正确的答案,但可以更轻松地使用 Tensorflow 对 VAE 进行建模,因为它处理自定义训练循环。
您可以关注此