从高斯混合模型中采样数据点 python

Sampling data points from a Gaussian Mixture Model python

我对 python 和 GMM 真的很陌生。我最近学习了 GMM 并尝试实现 here

中的代码

我在使用 运行 gmm.sample() 方法时遇到了一些问题:

gmm16 = GaussianMixture(n_components=16, covariance_type='full', random_state=0)    
Xnew = gmm16.sample(400,random_state=42)
plt.scatter(Xnew[:, 0], Xnew[:, 1])

错误显示:

TypeError: sample() got an unexpected keyword argument 'random_state'

我查看了最新的文档,发现方法sample应该只包含n,表示要生成的样本数。但是当我删除'random_state=42'时,出现了新的错误:

代码:

Xnew = gmm16.sample(400)
plt.scatter(Xnew[:, 0], Xnew[:, 1])

错误:

TypeError: tuple indices must be integers or slices, not tuple

有人在实现 Jake VanderPlas 的代码时遇到过这个问题吗?我该如何解决?

我的 Jupyter:

The version of the notebook server is: 5.7.4

Python 3.7.1 (default, Dec 14 2018, 13:28:58)

Type 'copyright', 'credits' or 'license' for more information

IPython 7.2.0 -- An enhanced Interactive Python. Type '?' for help.

您的问题在于将数据输入散点图的方式。 特别是您在元组中有一个 numpy 数组,并且您的索引方式不正确。 试试这个。

plt.scatter(Xnew[0][:,0], Xnew[0][:,1])

基本上我们所拥有的是第一个索引 Xnew[0] 将指向您想要的元组中的元素(numpy 数组),第二个将根据需要对其进行切片。 [:,1] 这里我们取所有的行和第二列。

你得到 TypeError 因为 sample 方法 returns 一个 tuple,参见 here

这应该可以完成工作:

Xnew, Ynew = gmm16.sample(400)  # if Ynew is valuable
plt.scatter(Xnew[:, 0], Xnew[:, 1])

Xnew, _ = gmm16.sample(400)  # if Ynew isn't valuable
plt.scatter(Xnew[:, 0], Xnew[:, 1])