试图了解此 Python 函数中发生的事情

Trying to understand what is happening in this Python Function

def closest_centroid(points, centroids):
    """returns an array containing the index to the nearest centroid for each point"""
    distances = np.sqrt(((points - centroids[:, np.newaxis])**2).sum(axis=2))
    return np.argmin(distances, axis=0)

谁能解释一下这个函数的具体工作原理?我目前得到 points 看起来像:

31998888119     0.94     34
23423423422     0.45     43
....

等等。在这个 numpy 数组中,points[1] 是长 ID,而 points[2]0.94points[3] 是第一个条目的 34

Centroids 只是从这个特定数组中随机选择的:

def initialize_centroids(points, k):
    """returns k centroids from the initial points"""
    centroids = points.copy()
    np.random.shuffle(centroids)
    return centroids[:k] 

现在我想从忽略 ID 的第一列的 pointscentroids(再次忽略第一列)的值中获取欧几里得距离。我不完全理解行 distances = np.sqrt(((points - centroids[:, np.newaxis])**2).sum(axis=2)) 中的语法。为什么我们要对第三列进行求和,同时对新轴进行减速:np.newaxis?另外,我应该沿着哪个轴使 np.argmin 起作用?

考虑尺寸会有所帮助。让我们假设 k=4 并且有 10 个点,所以 points.shape = (10,3).

接下来,centroids = initialize_centroids(points, 4) returns 一个维度为 (4,3) 的对象。

让我们从内部分解这条线:

distances = np.sqrt(((points - centroids[:, np.newaxis])**2).sum(axis=2))

  1. 我们想从每个点中减去每个质心。由于 pointscentroids 是二维的,因此每个 points - centroid 都是二维的。如果只有 1 个质心,那我们就可以了。但是我们有 4 个质心!所以我们需要对每个质心执行points - centroids。因此我们需要另一个维度来存储它。因此添加了 np.newaxis.

  2. 我们对它进行平方是因为它是一个距离,所以我们想将负数转换为正数(也因为我们正在最小化欧氏距离)。

  3. 我们不会对第三列求和。事实上,我们正在对每个点、每个质心求和点和质心之间的差异。

  4. np.argmin() 找到距离最小的质心。因此,对于每个质心,对于每个点,找到最小索引(因此 argmin 而不是 min)。该索引是分配给该点的质心。

这是一个例子:

points = np.array([
[   1, 2, 4],
[   1, 1, 3],
[   1, 6, 2],
[   6, 2, 3],
[   7, 2, 3],
[   1, 9, 6],
[   6, 9, 1],
[   3, 8, 6],
[   10, 9, 6],
[   0, 2, 0],
])

centroids = initialize_centroids(points, 4)

print(centroids)
array([[10,  9,  6],
   [ 3,  8,  6],
   [ 6,  2,  3],
   [ 1,  1,  3]])

distances = (pts - centroids[:, np.newaxis])**2

print(distances)
array([[[ 81,  49,   4],
    [ 81,  64,   9],
    [ 81,   9,  16],
    [ 16,  49,   9],
    [  9,  49,   9],
    [ 81,   0,   0],
    [ 16,   0,  25],
    [ 49,   1,   0],
    [  0,   0,   0],
    [100,  49,  36]],

   [[  4,  36,   4],
    [  4,  49,   9],
    [  4,   4,  16],
    [  9,  36,   9],
    [ 16,  36,   9],
    [  4,   1,   0],
    [  9,   1,  25],
    [  0,   0,   0],
    [ 49,   1,   0],
    [  9,  36,  36]],

   [[ 25,   0,   1],
    [ 25,   1,   0],
    [ 25,  16,   1],
    [  0,   0,   0],
    [  1,   0,   0],
    [ 25,  49,   9],
    [  0,  49,   4],
    [  9,  36,   9],
    [ 16,  49,   9],
    [ 36,   0,   9]],

   [[  0,   1,   1],
    [  0,   0,   0],
    [  0,  25,   1],
    [ 25,   1,   0],
    [ 36,   1,   0],
    [  0,  64,   9],
    [ 25,  64,   4],
    [  4,  49,   9],
    [ 81,  64,   9],
    [  1,   1,   9]]])

print(distances.sum(axis=2))
array([[134, 154, 106,  74,  67,  81,  41,  50,   0, 185],
   [ 44,  62,  24,  54,  61,   5,  35,   0,  50,  81],
   [ 26,  26,  42,   0,   1,  83,  53,  54,  74,  45],
   [  2,   0,  26,  26,  37,  73,  93,  62, 154,  11]])

# The minimum of the first 4 centroids is index 3. The minimum of the second 4 centroids is index 3 again.

print(np.argmin(distances.sum(axis=2), axis=0))
array([3, 3, 1, 2, 2, 1, 1, 1, 0, 3])