PyMC3 多项式模型不适用于非整数观察数据

PyMC3 Multinomial Model doesn't work with non-integer observe data

我正在尝试使用 PyMC3 求解一个相当简单的多项式分布。如果我将 'noise' 值设置为 0.0,它会完美运行。但是,当我将它更改为其他任何值时,例如 0.01,我在 find_MAP() 函数中遇到错误,如果我不使用 find_MAP().

它会挂起

多项式必须是稀疏的有什么原因吗?

import numpy as np
from pymc3 import *
import pymc3 as mc
import pandas as pd
print 'pymc3 version: ' + mc.__version__


sample_size = 10
number_of_experiments = 1


true_probs = [0.2, 0.1, 0.3, 0.4]


k = len(true_probs)


noise = 0.0
y = np.random.multinomial(n=number_of_experiments, pvals=true_probs, size=sample_size)+noise
y_denominator = np.sum(y,axis=1)
y = y/y_denominator[:,None]


with Model() as multinom_test:
    probs = Dirichlet('probs', a = np.ones(k), shape = k)
    for i in range(sample_size):
        data = Multinomial('data_%d' % (i),
                           n = y[i].sum(),
                           p = probs,
                           observed = y[i])


with multinom_test:
    start = find_MAP()
    trace = sample(5000, Slice())
trace[probs].mean(0)

错误:

ValueError: Optimization error: max, logp or dlogp at max have non-
finite values. Some values may be outside of distribution support. 
max: {'probs_stickbreaking_': array([  0.00000000e+00,  -4.47034834e- 
08,   0.00000000e+00])} logp: array(-inf) dlogp: array([  
0.00000000e+00,   2.98023221e-08,   0.00000000e+00])Check that 1) you 
don't have hierarchical parameters, these will lead to points with 
infinite density. 2) your distribution logp's are properly specified. 
Specific issues:

这对我有用

sample_size = 10
number_of_experiments = 100

true_probs = [0.2, 0.1, 0.3, 0.4]
k = len(true_probs)
noise = 0.01
y = np.random.multinomial(n=number_of_experiments, pvals=true_probs, size=sample_size)+noise

with pm.Model() as multinom_test:
    a = pm.Dirichlet('a', a=np.ones(k))
    for i in range(sample_size):
        data_pred = pm.Multinomial('data_pred_%s'% i, n=number_of_experiments, p=a, observed=y[i])
    trace = pm.sample(50000, pm.Metropolis())
    #trace = pm.sample(1000) # also works with NUTS

pm.traceplot(trace[500:]);