Pyro 在使用 NUTS (MCMC) 采样器时改变离散潜在变量的维度

Pyro changes dimension of Discrete latent variable when using NUTS (MCMC) sampler

感谢您花时间阅读我的问题,如下所示。

我需要帮助的问题是,当我 运行 使用 NUTS 采样器的模型时,我的二项式分布输出的维度在第二次迭代期间发生了(自动)变化。因此,我的其余代码(此处未给出)会引发尺寸不匹配错误。 如果我 运行 模型仅通过调用函数(不使用 Sampler)来发挥作用,即使我不断重复调用该函数,它也能很好地工作。但是当我使用采样器时它失败了。

我使用下面提到的更小、更简单的代码复制了该问题(此代码不代表我的实际代码,但复制了该问题)。

import pyro
import pyro.distributions as dist
import torch
import pyro.poutine as poutine
from pyro.infer import MCMC, NUTS

Pyro的版本是1.5,PyTorch是1.7

def model ():
        
    print("***** Start ****")
    prior = torch.ones(5) / 5
    print("Prior", prior)
    
    a = pyro.sample("a", dist.Binomial(1, prior))
    print("A", a)
    
    b = pyro.sample("b", dist.Binomial(1, a)) 
    print("B", b)
    
    print("***** End *****")
    
    return b

def conditioned_model(model, data):
    print("**** Condition Model **** ")
    return poutine.condition(model, data = {"b":data})()

data = model()
***** Start ****
Prior tensor([0.2000, 0.2000, 0.2000, 0.2000, 0.2000])
A tensor([0., 1., 0., 0., 0.])
B tensor([0., 1., 0., 0., 0.])
***** End *****

nuts_kernel = NUTS(conditioned_model, jit_compile=False)
mcmc = MCMC(nuts_kernel,
            num_samples=1,
            warmup_steps=1,
            num_chains=1)
mcmc.run(model, data)
Warmup:   0%|          | 0/2 [00:00, ?it/s]
**** Condition Model **** 
***** Start ****
Prior tensor([0.2000, 0.2000, 0.2000, 0.2000, 0.2000])
A tensor([1., 0., 0., 0., 1.])
B tensor([0., 1., 0., 0., 0.])
***** End *****
**** Condition Model **** 
***** Start ****
Prior tensor([0.2000, 0.2000, 0.2000, 0.2000, 0.2000])
A tensor([0., 1.])
B tensor([0., 1., 0., 0., 0.])
***** End *****

在上面的输出中,请注意变量 A 的维数。最初它的大小为 5,后来变为 2。由于我在 DINA 模型中的剩余代码给出了错误。

在上面的代码中,变量A是基于prior变量和prior[=61]的维度=] 是 5。然后据我了解,A 应该始终是 5。请帮助我理解为什么它变为 2 以及如何避免这种情况发生。

此外,我无法理解的是 B 的维数始终保持为 5。在上面的代码中,B 是以 A 作为输入,但 B 即使 A 改变其维度也不会改变维度.

非常感谢您的帮助。

我发现了另一个关于这个问题的讨论。

在我看来,我的代码中的问题是 NUTS 试图整合离散随机变量。因此,我无法应用基于离散随机变量的条件流。有关详细信息,请参阅此处:Error with NUTS when random variable is used in control flow