当我在 Chainer 中将重建损失函数 F.bernoulli_nll 更改为 F.mean_squared_error 时,VAE 无法学习
VAE does not learn when I change reconstruction loss functions F.bernoulli_nll to F.mean_squared_error in Chainer
我想在使用 chainer5.0.0 的 VAE 中使用 mean_squared_error 而不是 F.bernoulli_nll 作为重建损失函数。
我是Chainer5.0.0用户。
我已经实现了 VAE(变分自动编码器)。我参考了下面的日文文章。
- https://qiita.com/kenmatsu4/items/b029d697e9995d93aa24
- https://qiita.com/kenchin110100/items/7ceb5b8e8b21c551d69a
- https://github.com/maguro27/VAE-CIFAR10_chainer
class VAE(chainer.Chain):
def __init__(self, n_in, n_latent, n_h, act_func=F.tanh):
super(VAE, self).__init__()
self.act_func = act_func
with self.init_scope():
# encoder
self.le1 = L.Linear(n_in, n_h)
self.le2 = L.Linear(n_h, n_h)
self.le3_mu = L.Linear(n_h, n_latent)
self.le3_ln_var = L.Linear(n_h, n_latent)
# decoder
self.ld1 = L.Linear(n_latent, n_h)
self.ld2 = L.Linear(n_h, n_h)
self.ld3 = L.Linear(n_h, n_in)
def __call__(self, x, sigmoid=True):
return self.decode(self.encode(x)[0], sigmoid)
def encode(self, x):
h1 = self.act_func(self.le1(x))
h2 = self.act_func(self.le2(h1))
mu = self.le3_mu(h2)
ln_var = self.le3_ln_var(h2)
return mu, ln_var
def decode(self, z, sigmoid=True):
h1 = self.act_func(self.ld1(z))
h2 = self.act_func(self.ld2(h1))
h3 = self.ld3(h2)
if sigmoid:
return F.sigmoid(h3)
else:
return h3
def get_loss_func(self, C=1.0, k=1):
def lf(x):
mu, ln_var = self.encode(x)
batchsize = len(mu.data)
# reconstruction error
rec_loss = 0
for l in six.moves.range(k):
z = F.gaussian(mu, ln_var)
z.name = "z"
rec_loss += F.bernoulli_nll(x, self.decode(z, sigmoid=False)) / (k * batchsize)
self.rec_loss = rec_loss
self.rec_loss.name = "reconstruction error"
self.latent_loss = C * gaussian_kl_divergence(mu, ln_var) / batchsize
self..name = "latent loss"
self.loss = self.rec_loss + self.latent_loss
self.loss.name = "loss"
return self.loss
return lf
我使用了这段代码并且我的 VAE 已经过 MNIST 和 Fashion-MNIST 数据集的训练。我检查了我的 VAE 在训练后输出与输入图像相似的图像。
rec_loss是Reconstruct Loss,意思是解码后的图像距离输入图像有多远。我认为我们可以使用 mean_squared_error 而不是 F.bernoulli_nll.
所以我更改了如下代码。
rec_loss += F.mean_squared_error(x, self.decode(z)) / k
但是在更改我的代码后,训练结果很奇怪。输出图像相同,这意味着输出图像不依赖于输入图像。
有什么问题?
我用日语问了这个问题(https://ja.whosebug.com/questions/55477/chainer%E3%81%A7vae%E3%82%92%E4%BD%9C%E3%82%8B%E3%81%A8%E3%81%8D%E3%81%ABloss%E9%96%A2%E6%95%B0%E3%82%92bernoulli-nll%E3%81%A7%E3%81%AF%E3%81%AA%E3%81%8Fmse%E3%82%92%E4%BD%BF%E3%81%86%E3%81%A8%E5%AD%A6%E7%BF%92%E3%81%8C%E9%80%B2%E3%81%BE%E3%81%AA%E3%81%84)。但是没有人回答,所以我在这里提交这个问题。
解决方案?
当我替换
rec_loss += F.mean_squared_error(x, self.decode(z)) / k
来自
rec_loss += F.mean(F.sum((x - self.decode(z)) ** 2, axis=1))
,问题已解决
但是为什么呢?
它们应该是相同的,除了使用 F.mean(F.sum....
的后一个代码仅沿小批量轴取平均值(因为它已经在输入数据维度上求和,在展平 MNIST 的情况下为 784),而前者在小批量轴和输入数据维度上的平均值。这意味着在扁平化 MNIST 的情况下,后者的损失要大 784 倍?我假设 k
是 1
。
我想在使用 chainer5.0.0 的 VAE 中使用 mean_squared_error 而不是 F.bernoulli_nll 作为重建损失函数。
我是Chainer5.0.0用户。 我已经实现了 VAE(变分自动编码器)。我参考了下面的日文文章。
- https://qiita.com/kenmatsu4/items/b029d697e9995d93aa24
- https://qiita.com/kenchin110100/items/7ceb5b8e8b21c551d69a
- https://github.com/maguro27/VAE-CIFAR10_chainer
class VAE(chainer.Chain):
def __init__(self, n_in, n_latent, n_h, act_func=F.tanh):
super(VAE, self).__init__()
self.act_func = act_func
with self.init_scope():
# encoder
self.le1 = L.Linear(n_in, n_h)
self.le2 = L.Linear(n_h, n_h)
self.le3_mu = L.Linear(n_h, n_latent)
self.le3_ln_var = L.Linear(n_h, n_latent)
# decoder
self.ld1 = L.Linear(n_latent, n_h)
self.ld2 = L.Linear(n_h, n_h)
self.ld3 = L.Linear(n_h, n_in)
def __call__(self, x, sigmoid=True):
return self.decode(self.encode(x)[0], sigmoid)
def encode(self, x):
h1 = self.act_func(self.le1(x))
h2 = self.act_func(self.le2(h1))
mu = self.le3_mu(h2)
ln_var = self.le3_ln_var(h2)
return mu, ln_var
def decode(self, z, sigmoid=True):
h1 = self.act_func(self.ld1(z))
h2 = self.act_func(self.ld2(h1))
h3 = self.ld3(h2)
if sigmoid:
return F.sigmoid(h3)
else:
return h3
def get_loss_func(self, C=1.0, k=1):
def lf(x):
mu, ln_var = self.encode(x)
batchsize = len(mu.data)
# reconstruction error
rec_loss = 0
for l in six.moves.range(k):
z = F.gaussian(mu, ln_var)
z.name = "z"
rec_loss += F.bernoulli_nll(x, self.decode(z, sigmoid=False)) / (k * batchsize)
self.rec_loss = rec_loss
self.rec_loss.name = "reconstruction error"
self.latent_loss = C * gaussian_kl_divergence(mu, ln_var) / batchsize
self..name = "latent loss"
self.loss = self.rec_loss + self.latent_loss
self.loss.name = "loss"
return self.loss
return lf
我使用了这段代码并且我的 VAE 已经过 MNIST 和 Fashion-MNIST 数据集的训练。我检查了我的 VAE 在训练后输出与输入图像相似的图像。
rec_loss是Reconstruct Loss,意思是解码后的图像距离输入图像有多远。我认为我们可以使用 mean_squared_error 而不是 F.bernoulli_nll.
所以我更改了如下代码。
rec_loss += F.mean_squared_error(x, self.decode(z)) / k
但是在更改我的代码后,训练结果很奇怪。输出图像相同,这意味着输出图像不依赖于输入图像。
有什么问题?
我用日语问了这个问题(https://ja.whosebug.com/questions/55477/chainer%E3%81%A7vae%E3%82%92%E4%BD%9C%E3%82%8B%E3%81%A8%E3%81%8D%E3%81%ABloss%E9%96%A2%E6%95%B0%E3%82%92bernoulli-nll%E3%81%A7%E3%81%AF%E3%81%AA%E3%81%8Fmse%E3%82%92%E4%BD%BF%E3%81%86%E3%81%A8%E5%AD%A6%E7%BF%92%E3%81%8C%E9%80%B2%E3%81%BE%E3%81%AA%E3%81%84)。但是没有人回答,所以我在这里提交这个问题。
解决方案?
当我替换
rec_loss += F.mean_squared_error(x, self.decode(z)) / k
来自
rec_loss += F.mean(F.sum((x - self.decode(z)) ** 2, axis=1))
,问题已解决
但是为什么呢?
它们应该是相同的,除了使用 F.mean(F.sum....
的后一个代码仅沿小批量轴取平均值(因为它已经在输入数据维度上求和,在展平 MNIST 的情况下为 784),而前者在小批量轴和输入数据维度上的平均值。这意味着在扁平化 MNIST 的情况下,后者的损失要大 784 倍?我假设 k
是 1
。