使用 Pytorch 分布与手动计算 KL 散度的不同结果
Different results in computing KL Divergence using Pytorch Distributions vs manually
我注意到使用
时 KL-Divergence 项 KL(Q(x)||P(x)) 的计算方式不同
mean(Q(x)*(log Q(x) - log P(x)))
对
torch.distributions.kl_divergence(Q, P)
哪里
Q = torch.distributions.Normal(some mean, some sigma)
P = torch.distributions.Normal(0, 1)
当我绘制 KL 散度损失时,我得到了这两个相似但不同的图:
here
谁能指出造成这种差异的原因?
完整代码如下:
import numpy as np
import torch
import torch.distributions as dist
import matplotlib.pyplot as plt
def kl_1(log_qx, log_px):
"""
inputs: [B, z_dim] torch
"""
return (log_qx.exp() * (log_qx-log_px)).mean()
# ground-truth (target) P(x)
P = dist.Normal(0, 1)
mus = np.arange(-5, 5, 0.1)
sigma = 1
N = 100
kls = {"1": [], "2": []}
for mu in mus:
# prediction (current) Q(x)
Q = dist.Normal(mu, sigma)
# sample from Q
qx = Q.sample((N,))
# log prob
log_qx = Q.log_prob(qx)
log_px = P.log_prob(qx)
# kl 1
kl1 = kl_1(log_qx, log_px)
kls['1'].append(kl1.numpy())
# kl 2
kl2 = dist.kl_divergence(Q, P)
kls['2'].append(kl2.numpy())
plt.figure()
plt.scatter(mus, kls['1'], label="Q*(logQ-logP)")
plt.scatter(mus, kls['2'], label="kl_divergence")
plt.xlabel("mean of Q(x)")
plt.ylabel("computed KL Divergence")
plt.legend()
plt.show()
如果要根据 dx
上的积分计算期望值,则样本按概率密度加权。如果您使用给定分布的样本,那么您可以直接将期望值近似为平均值,这对应于 d cq(x)
上的积分,因此 d cq(x) = q(x) dx
,其中 cq(x)
是累积概率函数, q(x)
id 变量的概率密度函数 Q
.
import numpy as np
import torch
import torch.distributions as dist
import matplotlib.pyplot as plt
def kl_1(log_qx, log_px):
"""
inputs: [B, z_dim] torch
"""
return (log_qx-log_px).mean()
# ground-truth (target) P(x)
P = dist.Normal(0, 1)
mus = np.arange(-5, 5, 0.1)
sigma = 1
N = 100
kls = {"1": [], "2": []}
for mu in mus:
# prediction (current) Q(x)
Q = dist.Normal(mu, sigma)
# sample from Q
qx = Q.sample((N,))
# log prob
log_qx = Q.log_prob(qx)
log_px = P.log_prob(qx)
# kl 1
kl1 = kl_1(log_qx, log_px)
kls['1'].append(kl1.numpy())
# kl 2
kl2 = dist.kl_divergence(Q, P)
kls['2'].append(kl2.numpy())
plt.figure()
plt.scatter(mus, kls['1'], label="Q*(logQ-logP)")
plt.scatter(mus, kls['2'], label="kl_divergence")
plt.xlabel("mean of Q(x)")
plt.ylabel("computed KL Divergence")
plt.legend()
我注意到使用
时 KL-Divergence 项 KL(Q(x)||P(x)) 的计算方式不同mean(Q(x)*(log Q(x) - log P(x)))
对
torch.distributions.kl_divergence(Q, P)
哪里
Q = torch.distributions.Normal(some mean, some sigma)
P = torch.distributions.Normal(0, 1)
当我绘制 KL 散度损失时,我得到了这两个相似但不同的图: here
谁能指出造成这种差异的原因?
完整代码如下:
import numpy as np
import torch
import torch.distributions as dist
import matplotlib.pyplot as plt
def kl_1(log_qx, log_px):
"""
inputs: [B, z_dim] torch
"""
return (log_qx.exp() * (log_qx-log_px)).mean()
# ground-truth (target) P(x)
P = dist.Normal(0, 1)
mus = np.arange(-5, 5, 0.1)
sigma = 1
N = 100
kls = {"1": [], "2": []}
for mu in mus:
# prediction (current) Q(x)
Q = dist.Normal(mu, sigma)
# sample from Q
qx = Q.sample((N,))
# log prob
log_qx = Q.log_prob(qx)
log_px = P.log_prob(qx)
# kl 1
kl1 = kl_1(log_qx, log_px)
kls['1'].append(kl1.numpy())
# kl 2
kl2 = dist.kl_divergence(Q, P)
kls['2'].append(kl2.numpy())
plt.figure()
plt.scatter(mus, kls['1'], label="Q*(logQ-logP)")
plt.scatter(mus, kls['2'], label="kl_divergence")
plt.xlabel("mean of Q(x)")
plt.ylabel("computed KL Divergence")
plt.legend()
plt.show()
如果要根据 dx
上的积分计算期望值,则样本按概率密度加权。如果您使用给定分布的样本,那么您可以直接将期望值近似为平均值,这对应于 d cq(x)
上的积分,因此 d cq(x) = q(x) dx
,其中 cq(x)
是累积概率函数, q(x)
id 变量的概率密度函数 Q
.
import numpy as np
import torch
import torch.distributions as dist
import matplotlib.pyplot as plt
def kl_1(log_qx, log_px):
"""
inputs: [B, z_dim] torch
"""
return (log_qx-log_px).mean()
# ground-truth (target) P(x)
P = dist.Normal(0, 1)
mus = np.arange(-5, 5, 0.1)
sigma = 1
N = 100
kls = {"1": [], "2": []}
for mu in mus:
# prediction (current) Q(x)
Q = dist.Normal(mu, sigma)
# sample from Q
qx = Q.sample((N,))
# log prob
log_qx = Q.log_prob(qx)
log_px = P.log_prob(qx)
# kl 1
kl1 = kl_1(log_qx, log_px)
kls['1'].append(kl1.numpy())
# kl 2
kl2 = dist.kl_divergence(Q, P)
kls['2'].append(kl2.numpy())
plt.figure()
plt.scatter(mus, kls['1'], label="Q*(logQ-logP)")
plt.scatter(mus, kls['2'], label="kl_divergence")
plt.xlabel("mean of Q(x)")
plt.ylabel("computed KL Divergence")
plt.legend()