转向数值稳定的 log-sum-exp 会导致极大的损失值
Moving to numerically stable log-sum-exp leads to extremely large loss values
我正在研究一个使用 LSTM 和 MDN 来预测某些分布的网络。我为这些 MDN 使用的损失函数涉及尝试使我的目标数据适合预测分布。我正在尝试计算这些目标数据的 log_probs 的 log-sum-exp 以计算损失。当我使用标准的 log-sum-exp 时,我得到了合理的初始损失值(大约 50-70),即使后来它遇到了一些 NaN 和中断。根据我在网上阅读的内容,需要一个数值稳定的 log-sum-exp 版本来避免这个问题。然而,一旦我使用稳定版本,我的损失值就会飙升至 15-20k 的数量级。他们确实接受了培训,但最终他们也会导致 NaN。
注意:我没有在 PyTorch 中使用 logsumexp 函数,因为我需要根据我的混合成分进行加权求和。
def log_sum_exp(self,value, weights, dim=None):
eps = 1e-20
m, idx = torch.max(value, dim=dim, keepdim=True)
return m.squeeze(dim) + torch.log(torch.sum(torch.exp(value-m)*(weights.unsqueeze(2)),
dim=dim) + eps)
def mdn_loss(self, pi, sigma, mu, target):
eps = 1e-20
target = target.unsqueeze(1)
m = torch.distributions.Normal(loc=mu, scale=sigma)
probs = m.log_prob(target)
# Size of probs is batch_size x num_mixtures x num_out_features
# Size of pi is batch_size x num_mixtures
loss = -self.log_sum_exp(probs, pi, dim=1)
return loss.mean()
添加 anomaly_detection 后,NaN 似乎出现在:
概率 = m.log_prob(目标)
仅仅通过转向数值稳定的版本就看到了这些巨大的初始损失值,这让我相信我当前的实现中存在一些错误。请帮忙。
问题已解决。当为这些值计算 log_probs 时,我的目标有一些大值导致溢出计算。去掉一些奇怪的数据点,把数据归一化,loss立马降下来。
我正在研究一个使用 LSTM 和 MDN 来预测某些分布的网络。我为这些 MDN 使用的损失函数涉及尝试使我的目标数据适合预测分布。我正在尝试计算这些目标数据的 log_probs 的 log-sum-exp 以计算损失。当我使用标准的 log-sum-exp 时,我得到了合理的初始损失值(大约 50-70),即使后来它遇到了一些 NaN 和中断。根据我在网上阅读的内容,需要一个数值稳定的 log-sum-exp 版本来避免这个问题。然而,一旦我使用稳定版本,我的损失值就会飙升至 15-20k 的数量级。他们确实接受了培训,但最终他们也会导致 NaN。
注意:我没有在 PyTorch 中使用 logsumexp 函数,因为我需要根据我的混合成分进行加权求和。
def log_sum_exp(self,value, weights, dim=None):
eps = 1e-20
m, idx = torch.max(value, dim=dim, keepdim=True)
return m.squeeze(dim) + torch.log(torch.sum(torch.exp(value-m)*(weights.unsqueeze(2)),
dim=dim) + eps)
def mdn_loss(self, pi, sigma, mu, target):
eps = 1e-20
target = target.unsqueeze(1)
m = torch.distributions.Normal(loc=mu, scale=sigma)
probs = m.log_prob(target)
# Size of probs is batch_size x num_mixtures x num_out_features
# Size of pi is batch_size x num_mixtures
loss = -self.log_sum_exp(probs, pi, dim=1)
return loss.mean()
添加 anomaly_detection 后,NaN 似乎出现在: 概率 = m.log_prob(目标)
仅仅通过转向数值稳定的版本就看到了这些巨大的初始损失值,这让我相信我当前的实现中存在一些错误。请帮忙。
问题已解决。当为这些值计算 log_probs 时,我的目标有一些大值导致溢出计算。去掉一些奇怪的数据点,把数据归一化,loss立马降下来。