pyro:如何指定条件分布

pyro: how to specify conditional distribution

我正在尝试使用 pyro 来指定贝叶斯网络。我有一个连续的子节点 D,它有三个离散节点是父节点,每个节点有 10 种可能的状态:

所以,我首先将我的离散节点定义为:

import torch
import pyro
import pyro.distributions as dist

def model(data):
    A = pyro.sample("A", dist.Dirichlet(torch.ones(10)))
    B = pyro.sample("B", dist.Dirichlet(torch.ones(10)))
    C = pyro.sample("C", dist.Dirichlet(torch.ones(10)))
    

现在,据我所知,我需要定义 P(D|A, B, C)。我想将其建模为正态分布,但不确定如何进行这种调节。我的计划是将先验放在这个分布参数上,然后使用 MCMC 或 HMC 来估计后验分布并学习模型参数。

但是,不确定如何继续模型定义。

pyro 的好处是模型定义非常 pythonic。底层 PyTorch 机制可以为您跟踪依赖关系。

您只需要使用样本 ABC 并计算条件 p(D|A,B,C)

的参数
def cond_mean(a, b, c):
    return # use a,b,c to compute mean

def cond_scale(a, b, c):
    return # use a,b,c to compute scale

def model(data):
    A = pyro.sample("A", dist.Dirichlet(torch.ones(10)))
    B = pyro.sample("B", dist.Dirichlet(torch.ones(10)))
    C = pyro.sample("C", dist.Dirichlet(torch.ones(10)))

    D = pyro.sample("D", dist.Normal(loc=cond_mean(A, B, C), scale=cond_scale(A, B, C)
    ...