为什么使用点积会降低 PyMC3 的性能?

Why is using dot product worsening the performance for PyMC3?

我正在尝试 运行 使用 PyMC3 进行简单的线性回归。下面的代码是一个片段:

import numpy as np
from pymc3 import Model, sample, Normal, HalfCauchy
import pymc3 as pm

X = np.arange(500).reshape(500, 1)
y = np.random.normal(0, 5, [500, 1]) + X

with Model() as multiple_regression_model:

    beta = Normal('beta', mu=0, sd=1000, shape=2)
    sigma = HalfCauchy('sigma', 1000)

    y_hat = beta[0] + X * beta[1]

    exp = Normal('y', y_hat, sigma=sigma, observed=y)

with multiple_regression_model:
    trace = sample(1000, tune=1000)

trace['beta'].mean(axis=0)

上面的代码 运行s 用了大约 6 秒,并给出了 beta 的合理估计值 ([-0.19646408, 1.00053091])

但是当我尝试使用点积时,情况变得非常糟糕:

X = np.arange(500).reshape(500, 1)
y = np.random.normal(0, 5, [500, 1]) + X

X_aug_np = np.squeeze(np.dstack((np.ones((500, 1)), X)))

with Model() as multiple_regression_model:

    beta = Normal('beta', mu=0, sd=1000, shape=2)
    sigma = HalfCauchy('sigma', 1000)

    y_hat = pm.math.dot(X_aug_np, beta)

    exp = Normal('y', y_hat, sigma=sigma, observed=y)

with multiple_regression_model:
    trace = sample(1000, tune=1000)

trace['beta'].mean(axis=0)

现在代码在 56 秒内完成,估计完全不对([249.52363555,-0.0000481])。

我认为使用点积会使事情变得更快。为什么会这样?我是不是做错了什么?

这是一个微妙的形状和广播错误:如果您将 betashape 更改为 (2, 1),那么它会起作用。

为了了解原因,我将两个模型重命名并整理了一下代码:

import numpy as np
import pymc3 as pm

X = np.arange(500).reshape(500, 1)
y = np.random.normal(0, 5, [500, 1]) + X

X_aug_np = np.squeeze(np.dstack((np.ones((500, 1)), X)))

with pm.Model() as basic_model:
    beta = pm.Normal('beta', mu=0, sd=1000, shape=2)
    sigma = pm.HalfCauchy('sigma', 1000)

    y_hat =  beta[0] + X * beta[1]

    exp = pm.Normal('y', y_hat, sigma=sigma, observed=y)


with pm.Model() as matmul_model:
    beta = pm.Normal('beta', mu=0, sd=1000, shape=(2, 1))
    sigma = pm.HalfCauchy('sigma', 1000)

    y_hat = pm.math.dot(X_aug_np, beta)
    exp = pm.Normal('y', y_hat, sigma=sigma, observed=y)

你是怎么发现的?由于看起来模型是相同的,但它们的采样不相似,我 运行

print(matmul_model.check_test_point())
print(basic_model.check_test_point())

计算变量在合理默认值下的对数概率。这不匹配,所以我检查了 exp.tag.test_value.shape,发现它是 (500, 500),而我预计它是 (500, 1)。形状处理在概率编程中非常困难,这是因为 exp 一起广播 y_hatsigmay

作为一个附加问题,如果不设置 cores=1, chains=4,我无法在我的机器上进行 matmul_model 采样。