仅下载某些 MNIST 数字

Downloading only certain MNIST digits

我正在尝试只下载项目手写数字的 MNIST 数据库的一部分。具体来说,我只想将数字 0、1、2 和 3 发送到神经网络。

我目前正在加载这样的数据(基于"Neural Networks and Deep Learning" by Michal Daniel Dobrzanski):

import cPickle
import gzip
import numpy as np

def load_data():
    f = gzip.open('src/mnist.pkl.gz', 'rb')
    training_data, validation_data, test_data = cPickle.load(f)
    f.close()
    return (training_data, validation_data, test_data)

def load_data_wrapper():
    tr_d, va_d, te_d = load_data()
    training_inputs = [np.reshape(x, (784, 1)) for x in tr_d[0]]
    training_results = [vectorized_result(y) for y in tr_d[1]]
    training_data = zip(training_inputs, training_results)
    validation_inputs = [np.reshape(x, (784, 1)) for x in va_d[0]]
    validation_data = zip(validation_inputs, va_d[1])
    test_inputs = [np.reshape(x, (784, 1)) for x in te_d[0]]
    test_data = zip(test_inputs, te_d[1])
    return (training_data, validation_data, test_data)

我尝试构建一个函数,在发送到 load_data_wrapper() 之前从 load_data() 创建新数据集(通过在 load_data_wrapper() 中将 tr_d, va_d, te_d = load_data() 更改为 tr_d, va_d, te_d = digitTest()),运气不好,请看下面:

def digitTest():
    tr_d, va_d, te_d = load_data()
    tr_d = list(tr_d)
    va_d = list(va_d)
    te_d = list(te_d)

    newTrD = []
    newTrD.append([])
    newTrD.append([])

    newVaD = []
    newVaD.append([])
    newVaD.append([])

    newTeD = []
    newTeD.append([])
    newTeD.append([])

    for index,label in enumerate(tr_d[1]):
        if tr_d[1][index] < 4:
            newTrD[0].append(tr_d[0][index])
            newTrD[1].append(tr_d[1][index])

    for index,label in enumerate(va_d[1]):
        if va_d[1][index] < 4:
            newVaD[0].append(va_d[0][index])
            newVaD[1].append(va_d[1][index])

    for index,label in enumerate(te_d[1]):
        if te_d[1][index] < 4:
            newTeD[0].append(te_d[0][index])
            newTeD[1].append(te_d[1][index])

    return (newTrD, newVaD, newTeD)

有没有可能达到我想要的效果?我怎样才能做到这一点?请注意,当从 load_data 函数解析时,数据存储在元组中。

我从来没有使用 cPickle 加载 mnist 数据集,我不知道它是什么 returns。 阅读您的代码似乎您做对了,但是如果您说它不起作用,我想 cPickle returns 数据的内容或方式有些问题。

我没有 python 2 所以我不能调试你的代码但是:

我倾向于自己做这些事情:

def loadSet(values_path, labels_path):
    labels = []
    # labels:
    # 0000     32 bit integer  0x00000803(2051) magic number
    # 0008     32 bit integer  28               number of labels
    # 0009     unsigned byte   ??               label
    # 0010     unsigned byte   ??               label
    # ....     unsigned byte   ??               label

    with open(labels_path, 'rb') as f:
        m_number = int.from_bytes(f.read(4,), 'big')
        num_labels = int.from_bytes(f.read(4), 'big')
        for i in range(num_labels):
            labels.append(int.from_bytes(f.read(1), 'big'))

    images = []
    # images:
    # 0000     32 bit integer  0x00000803(2051) magic number
    # 0004     32 bit integer  60000            number of images
    # 0008     32 bit integer  28               number of rows
    # 0012     32 bit integer  28               number of columns
    # 0016     unsigned byte   ??               pixel
    # 0020     unsigned byte   ??               pixel
    # ....     unsigned byte   ??               pixel

    with open(values_path, 'rb') as f:
        m_number = int.from_bytes(f.read(4), 'big')
        num_images = int.from_bytes(f.read(4), 'big')
        num_rows = int.from_bytes(f.read(4), 'big')
        num_cols = int.from_bytes(f.read(4), 'big')
        for i in range(num_images):
            image = []
            for x in range(num_rows * num_cols):
                image.append(int.from_bytes(f.read(1), 'big'))
            images.append(image)

此函数将从文件中加载一组 mnist 标签和值。您可以在 http://yann.lecun.com/exdb/mnist/ 获取数据集,您必须解压缩文件。 标签将是 'train-labels.idx1-ubyte'。只需将训练标签和图像或测试标签和图像的路径传递到函数中,它就会加载这些值。

Return 值是两个列表的元组:

([number], [pixels])

其中 pixels 是一个列表本身。

此外,如果文件不存在或(可能)文件格式不正确,除了抛出异常之外,这不会进行错误检查,因此您可能需要考虑以某种方式进行检查。

我也不习惯 numpy,我通常使用 C++ 和 java,但您确实可以很容易地将这些值转换为 numpy 数组 - 只需阅读主题即可。

现在过滤这些非常容易,您现在应该可以将您的方法用于 digitTest。

正如您可能会看到的那样,如果您使用原始的 mnist 数据集,您只会得到训练和测试图像。这里发生的事情是你拿了其中一组的一部分并将其用作 - 我不完全确定你在这里的措辞 - 测试数据来评估培训进度。训练完成后,您可以使用 't10k' 文件来验证您的网络训练效果如何。这里重要的是,如果你从这些 t10k 图像中分离你的测试数据,你不会再次使用它们,只有剩下的部分是为了验证网络尚未看到的数据训练。