python 中的嵌套 For 循环错误

Nested For Loops error in python

尝试使用欧几里德距离手动实现 KNN 分类器,编写了两种类型的代码,一种给出了正确的结果,另一种给出了完全不同的图片,谁能帮我解决这个问题,我在哪里犯了错误?

第一个代码(有效):

def squared_dist(x,y):
    return np.sum(np.square(x-y))

def my_classifier(x):
    distances = [squared_dist(x,train_data[i,]) for i in range(len(train_labels))]
    index = np.argmin(distances)
    return train_labels[index]

[my_classifier(test_data[i,]) for i in range(len(test_data))]

上面一行给了我准确的预测

第二个代码(无效):

想在一行中实现两个 for 循环,所以写成

def my_classifier2(x):
    distances = [squared_dist(x[j,],train_data[i,]) for j in range(len(test_labels)) for i in range(len(train_labels))]
    index = np.argmin(distances)
    return train_labels[index]

my_classifier2(test_data)

上面的代码报错了 索引 2492731 超出轴 0 的范围,大小为 7500

任何人都可以向我解释第二个代码中出了什么问题以及如何解决它吗?

在您的第一个代码中,distances 列表将仅包含 len(train_labels) 个元素。但是,在您的第二个代码中,distances 将具有 len(train_labels)*len(test_labels) 元素,因为您将 train_datatest_data 中所有行之间的所有成对距离放入其中。因此,来自 np.argmin 的索引可能超过 train_labels.

的长度

以下是您可能的修复方法

def my_classifier2(x):
    distances = [[squared_dist(x[j,],train_data[i,]) for i in range(len(train_labels))] for j in range(len(test_labels))]

    indices = np.argmin(distances, axis=1)
    return train_labels[indices]

现在 distances 将包含 len(test_labels) 行。每行是相应测试数据点与所有列车数据点之间的 len(train_labels) 个距离值的列表。

np.argminaxis=1 将找到每一行的最小索引。最后,train_labels[indices] 为您提供了最终结果的 len(test_labels) 个元素的数组。

更新

假设您有 5 个火车示例:

[[a, a, a],
 [b, b, b],
 [c, c, c],
 [d, d, d]]

和2个测试示例:

 [[x, x, x]
  [y, y, y]]

在第一个代码中,您的函数 my_classifier 一次只处理一个测试示例。例如,当您调用 my_classifier([x, x, x]) 时,您有:

distances = [xa, xb, xc, xd]

当我说xa时,它表示[x, x, x]和[a, a, a]之间的距离。这里 distanceslen(train_labels) 个元素。

对于第二个代码,您的代码生成

distances = [xa, xb, xc, xd, ya ,yb ,yc, yd]

而应该是

distances = [[xa, xb, xc, xd],
             [ya, yb, yc, yd]]

现在 distances 有 2 行 (len(test_labels)),每行有 4 个元素 (len(train_labels))。当你在这个 distances 上应用 np.argminaxis=1 时,你会得到

indices = [i1, i2]

其中 i1 是第一行的最小索引,i2 是第二行的最小索引。当你把indicestrain_labels,它会挑出[train_labels[i1], train_labels[i2]],也就是答案