我只是在 sklearn KNN 分类器中发现了一个错误,还是一切都按预期工作?

Did I just find a bug in sklearn KNN classifier or does everything work as intended?

我一直在研究 python sklearn k 最近邻分类器,我认为它不能正常工作 - k 大于 1 的结果是错误的。我试图可视化不同的 k-nn 方法如何随我的示例代码而变化。

代码有点长,但不是很复杂。继续 运行 自己获取照片。我以大约 10 个点的列形式生成样本二维数据。大多数代码都是关于以动画方式在图表上很好地绘制它。所有分类都发生在for循环中的“main”中调用构建库对象KNeighborsClassifier之后。

我尝试了不同的算法方法,怀疑是 kd 树问题,但我得到了相同的结果(交换算法="kdtree" 为 "brute" 或球树)

这是我得到的结果图:

result of classifier with k=3 and uniform weights, kdtrees

图片评论: 正如您在第 3 列中看到的那样,x=2 周围的所有区域都应该是红色的,例如 x=-4 周围的区域应该是蓝色的,因为下一个最近的红点在相邻的列中。我相信这不是分类器的行为方式,我不确定是我做错了什么还是库方法错误。我试图审查它的代码,但同时决定提出这个问题。我也不熟悉 C-Python 它写在.

来源和版本:我使用 scikit-learn documentation 和 mathplotlib 示例编写了代码。我 运行 python sklearn 3.6.1 和 0.18.1 版本。

奖励问题:k-neighbors 的答案是使用 kd-trees approximate 还是 definite?根据我的理解,它可以很容易地为 k=1 完美地工作,但你不确定 k 大于 1 时答案是否总是正确的。

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
from sklearn import neighbors
import random


random.seed(905) # 905
# interesting seed 2293
def generate_points(sizex, sizey):
    # sizex = 4
    # sizey = 10
    apart = 5
    # generating at which X coordinate my data column will be
    columns_x = [random.normalvariate(0, 5) for i in range(sizex)]
    columns_y = list()
    # randomising for each column the Y coordinate at which it starts
    for i in range(sizex):
        y_column = [random.normalvariate(-50, 100) for j in range(sizey)]
        y_column.sort()
        columns_y.append(y_column)

    # preparing lists of datapoints with classification
    datapoints = np.ndarray((sizex * sizey, 2))
    dataclass = list()

    # genenerating random split for each column
    for i in range(sizex):
        division = random.randint(0, sizey)
        for j in range(sizey):
            datapoints[i * sizey + j][0] = columns_x[i]
            datapoints[i * sizey + j][1] = -j * apart
            dataclass.append(j < division)

    return datapoints, dataclass


if __name__ == "__main__":
    datapoints, dataclass = generate_points(4, 10)

    #### VISUALISATION PART ####
    x_min, y_min = np.argmin(datapoints, axis=0)
    x_min, y_min = datapoints[x_min][0], datapoints[y_min][1]
    x_max, y_max = np.argmax(datapoints, axis=0)
    x_max, y_max = datapoints[x_max][0], datapoints[y_max][1]
    x_range = x_max - x_min
    y_range = y_max - y_min
    x_min -= 0.15*x_range
    x_max += 0.15*x_range
    y_min -= 0.15*y_range
    y_max += 0.15*y_range

    mesh_step_size = .1

    # Create color maps
    cmap_light = ListedColormap(['#FFAAAA', '#AAFFAA', '#AAAAFF']) # for meshgrid
    cmap_bold = ListedColormap(['#FF0000', '#00FF00', '#0000FF']) # for points

    plt.ion() # plot interactive mode
    for weights in ['uniform', 'distance']: # two types of algorithm
        for k in range(1, 13, 2): # few k choices
            # we create an instance of Neighbours Classifier and fit the data.
            clf = neighbors.KNeighborsClassifier(k, weights=weights, algorithm="kd_tree")
            clf.fit(datapoints, dataclass)

            # Plot the decision boundary. For that, we will assign a color to each
            # point in the mesh [x_min, x_max]x[y_min, y_max].
            xx, yy = np.meshgrid(np.arange(x_min, x_max, mesh_step_size),
                                 np.arange(y_min, y_max, mesh_step_size))
            Z = clf.predict(np.c_[xx.ravel(), yy.ravel()])

            # Put the result into a color plot
            Z = Z.reshape(xx.shape)

            plt.figure(1)
            plt.pcolormesh(xx, yy, Z, cmap=cmap_light)

            # Plot also the training points
            plt.scatter(datapoints[:, 0], datapoints[:, 1], c=dataclass, cmap=cmap_bold, marker='.')
            plt.xlim(xx.min(), xx.max())
            plt.ylim(yy.min(), yy.max())
            plt.title("K-NN classifier (k = %i, weights = '%s')"
                      % (k, weights))

            plt.draw()
            input("Press Enter to continue...")
            plt.clf()

此外,我决定在发布前设置种子,这样我们都会得到相同的结果,请随意设置随机种子。

你的输出似乎没问题。

从您的图表中可能不明显的一点是,点之间的水平距离实际上 比垂直距离短。即使两个相邻列之间的最远水平间距为 4.something,而任意两个相邻行之间的垂直间距为 5.

对于分类为红色的点,它们在训练集中的 3 个最近邻居中的大多数 确实是红色的。如果接下来的两个邻居是红色的,那么它们是否非常接近蓝点并不重要。对于分类为靠近红色点的蓝色点也是如此。