pyro:如何指定条件分布
pyro: how to specify conditional distribution
我正在尝试使用 pyro 来指定贝叶斯网络。我有一个连续的子节点 D
,它有三个离散节点是父节点,每个节点有 10 种可能的状态:
所以,我首先将我的离散节点定义为:
import torch
import pyro
import pyro.distributions as dist
def model(data):
A = pyro.sample("A", dist.Dirichlet(torch.ones(10)))
B = pyro.sample("B", dist.Dirichlet(torch.ones(10)))
C = pyro.sample("C", dist.Dirichlet(torch.ones(10)))
现在,据我所知,我需要定义 P(D|A, B, C)
。我想将其建模为正态分布,但不确定如何进行这种调节。我的计划是将先验放在这个分布参数上,然后使用 MCMC 或 HMC 来估计后验分布并学习模型参数。
但是,不确定如何继续模型定义。
pyro
的好处是模型定义非常 pythonic。底层 PyTorch 机制可以为您跟踪依赖关系。
您只需要使用样本 A
、B
和 C
并计算条件 p(D|A,B,C)
的参数
def cond_mean(a, b, c):
return # use a,b,c to compute mean
def cond_scale(a, b, c):
return # use a,b,c to compute scale
def model(data):
A = pyro.sample("A", dist.Dirichlet(torch.ones(10)))
B = pyro.sample("B", dist.Dirichlet(torch.ones(10)))
C = pyro.sample("C", dist.Dirichlet(torch.ones(10)))
D = pyro.sample("D", dist.Normal(loc=cond_mean(A, B, C), scale=cond_scale(A, B, C)
...
我正在尝试使用 pyro 来指定贝叶斯网络。我有一个连续的子节点 D
,它有三个离散节点是父节点,每个节点有 10 种可能的状态:
所以,我首先将我的离散节点定义为:
import torch
import pyro
import pyro.distributions as dist
def model(data):
A = pyro.sample("A", dist.Dirichlet(torch.ones(10)))
B = pyro.sample("B", dist.Dirichlet(torch.ones(10)))
C = pyro.sample("C", dist.Dirichlet(torch.ones(10)))
现在,据我所知,我需要定义 P(D|A, B, C)
。我想将其建模为正态分布,但不确定如何进行这种调节。我的计划是将先验放在这个分布参数上,然后使用 MCMC 或 HMC 来估计后验分布并学习模型参数。
但是,不确定如何继续模型定义。
pyro
的好处是模型定义非常 pythonic。底层 PyTorch 机制可以为您跟踪依赖关系。
您只需要使用样本 A
、B
和 C
并计算条件 p(D|A,B,C)
def cond_mean(a, b, c):
return # use a,b,c to compute mean
def cond_scale(a, b, c):
return # use a,b,c to compute scale
def model(data):
A = pyro.sample("A", dist.Dirichlet(torch.ones(10)))
B = pyro.sample("B", dist.Dirichlet(torch.ones(10)))
C = pyro.sample("C", dist.Dirichlet(torch.ones(10)))
D = pyro.sample("D", dist.Normal(loc=cond_mean(A, B, C), scale=cond_scale(A, B, C)
...