Python 中的简单 k-means 算法

Simple k-means algorithm in Python

下面是k-means算法的一个非常简单的实现。

import numpy as np
import matplotlib.pyplot as plt

np.random.seed(0)

DIM = 2
N = 2000
num_cluster = 4
iterations = 3

x = np.random.randn(N, DIM)
y = np.random.randint(0, num_cluster, N)

mean = np.zeros((num_cluster, DIM))
for t in range(iterations):
    for k in range(num_cluster):
        mean[k] = np.mean(x[y==k], axis=0)
    for i in range(N):
        dist = np.sum((mean - x[i])**2, axis=1)
        pred = np.argmin(dist)
        y[i] = pred

for k in range(num_cluster):
    plt.scatter(x[y==k,0], x[y==k,1])
plt.show()

以下是代码生成的两个示例输出:

第一个示例 (num_cluster = 4) 看起来符合预期。然而,第二个示例 (num_cluster = 11) 仅显示在集群上,这显然不是我想要的。代码的工作取决于我定义的 类 的数量和迭代次数。

到目前为止,我找不到代码中的错误。不知何故集群消失了,但我不知道为什么。

有人看到我的错误吗?

您得到的是一簇,因为实际上只有一簇。
你的代码中没有任何东西可以避免集群消失,事实是这也会发生在 4 个集群上,但在更多的迭代之后。
我 运行 你的代码有 4 个集群和 1000 次迭代,它们都被吞没在一个大的主导集群中。
想一想,你的大集群超过了一个临界点,并且一直在增长,因为其他点逐渐接近它而不是之前的平均值。
如果您达到平衡(或静止)点,则不会发生这种情况,此时集群之间没有任何移动。但它显然有点罕见,而且你试图估计的集群越多就越罕见。


澄清:当有 4 个 "real" 集群并且您试图估计 4 个集群时,同样的事情也会发生。但这将意味着相当讨厌的初始化,可以通过智能地聚合多个 运行domly 种子运行来避免。
还有常见的"tricks"比如取初始均值相距较远,或者在不同的预估高密度位置的中心等。不过这也开始涉及了,你应该更深入地阅读k-意味着这个目的。

看来确实有 NaN 的身影。 使用种子=1,迭代次数=2,集群的数量从最初的 4 个减少到有效的 3 个。在下一次迭代中,这在技术上会直线下降到 1。

NaN 意味着有问题的质心的坐标然后会导致奇怪的事情。为了排除那些变空的有问题的集群,一个(可能有点太懒了)选项是将相关坐标设置为 Inf,从而使其成为比仍在游戏中的那些 "more distant than any other" 点(只要 'input'坐标不能是Inf)。 下面的代码片段是对此的快速说明以及我用来查看正在发生的事情的一些调试消息:

[...]
for k in range(num_cluster):
    mean[k] = np.mean(x[y==k], axis=0)
    # print mean[k]
    if any(np.isnan(mean[k])):
        # print "oh no!"
        mean[k] = [np.Inf] * DIM
[...]

通过此修改,发布的算法似乎以更稳定的方式工作(即,到目前为止我无法破解它)。

另请参阅 Quora link also mentioned among the comments about the split opinions, and the book "The Elements of Statistical Learning" for example here - 算法在相关方面也没有明确定义。

K-means 对初始条件也非常敏感。也就是说,k-means 可以并且将会丢弃集群(但丢弃到一个是很奇怪的)。在您的代码中,您将随机簇分配给这些点。

这就是问题所在:如果我对您的数据进行几个随机子样本,它们将具有大致相同的平均点。每次迭代,非常相似的质心将彼此靠近并且更有可能下​​降。

相反,我更改了您的代码以在您的数据集中选择 num_cluster 个点作为初始质心(更高的方差)。这似乎产生了更稳定的结果(没有观察到在几十次运行中下降到一个集群的行为):

import numpy as np
import matplotlib.pyplot as plt

DIM = 2
N = 2000
num_cluster = 11
iterations = 3

x = np.random.randn(N, DIM)
y = np.zeros(N)
# initialize clusters by picking num_cluster random points
# could improve on this by deliberately choosing most different points
for t in range(iterations):
    if t == 0:
        index_ = np.random.choice(range(N),num_cluster,replace=False)
        mean = x[index_]
    else:
        for k in range(num_cluster):
            mean[k] = np.mean(x[y==k], axis=0)
    for i in range(N):
        dist = np.sum((mean - x[i])**2, axis=1)
        pred = np.argmin(dist)
        y[i] = pred

for k in range(num_cluster):
    fig = plt.scatter(x[y==k,0], x[y==k,1])
plt.show()