这段代码能识别 MNIST 集吗? (K-近邻法)

Will this code work to recognise the MNIST set? (K-NN method)

我不确定下面的代码是否会执行,因为它在“计算预测”上停留了很长时间。如果它不起作用,我应该更改什么?

import struct
import matplotlib.pyplot as plt
import numpy as np
import os
from scipy.special import expit
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score



clf = KNeighborsClassifier()

def load_data():
    with open('train-labels-idx1-ubyte', 'rb') as labels:
        magic, n = struct.unpack('>II', labels.read(8))
        train_labels = np.fromfile(labels, dtype=np.uint8)
    with open('train-images-idx3-ubyte', 'rb') as imgs:
        magic, num, nrows, ncols = struct.unpack('>IIII', imgs.read(16))
        train_images = np.fromfile(imgs, dtype=np.uint8).reshape(num, 784)
    with open('t10k-labels-idx1-ubyte', 'rb') as labels:
        magic, n = struct.unpack('>II', labels.read(8))
        test_labels = np.fromfile(labels, dtype=np.uint8)
    with open('t10k-images-idx3-ubyte', 'rb') as imgs:
        magic, num, nrows, ncols = struct.unpack('>IIII', imgs.read(16))
        test_images = np.fromfile(imgs, dtype=np.uint8).reshape(num, 784)
    return train_images, train_labels, test_images, test_labels


def knn(train_x, train_y, test_x, test_y):
    clf.fit(train_x, train_y)
    print("Compute predictions")
    predicted = clf.predict(test_x)
    print("Accuracy: ", accuracy_score(test_y, predicted))

train_x, train_y, test_x, test_y = load_data()
knn(train_x, train_y, test_x, test_y)

it has been stuck on "Computing Prediction" for a long time

我建议您使用一组非常有限的数据来测试是否一切正常,然后再 运行 使用整个数据集。这样你就可以确保代码有意义。

测试代码后,您可以安全地继续使用整个数据集进行训练。

这样,您就可以很容易地辨别出代码耗时是因为某些代码问题还是仅仅因为数据量大(也许代码没问题,但您可能会意识到,例如 10 个样本,它花费的时间比您愿意 to/can 等待的时间长,因此您可以相应地进行调整 - 否则您正在处理的内容太多了。

话虽如此,如果代码没问题但花费的时间太长,我也建议像 Soumya 一样在 Colab 上尝试 运行。你那里有一些很好的硬件,会话长达 12 小时,并且可以让你的电脑同时测试其他代码!