Python PyTorch Pyro - 多元分布

Python PyTorch Pyro - Multivariate Distributions

如何在 Pyro 中对多元分布进行采样?我只想要一个 (M, N) Beta 发行版,但以下内容不起作用:

impor torch
import pyro
with pyro.plate("theta_plate", M):
    theta = pyro.sample("theta",
                        pyro.distributions.Beta(concentration0=torch.ones(N),
                                                concentration1=torch.ones(N)))

对于 PyTorch 和 Pyro 发行版,语法是相同的:

import pyro.distributions as dist

samples = dist.Beta(2, 2).sample([200]) # Will draw 200 samples.

除非您只想对分布进行抽样,否则您不需要盘子概念。

美元to_event(n)申报相关样本。

import torch
import pyro
import pyro.distributions as dist

def model(N, M):
    with pyro.plate("theta_plate", M):
        theta = pyro.sample("theta", dist.Beta(torch.ones(N),1.).to_event(1))
    return theta


if __name__ == '__main__':
    print(model(10,12).shape) # (10,12)