如何读取和显示 MNIST 数据集?

How to read and display MNIST dataset?

下面的代码将 mnist 数据集打开为 csv

import numpy as np
import csv
import matplotlib.pyplot as plt

with open('C:/Z_Uni/Individual_Project/Python_Projects/NeuralNet/MNIST_Dataset/mnist_train.csv/mnist_train.csv', 'r') as csv_file:
    for data in csv.reader(csv_file):
        # The first column is the label
        label = data[0]

        # The rest of columns are pixels
        pixels = data[1:]

        # Make those columns into a array of 8-bits pixels
        # This array will be of 1D with length 784
        # The pixel intensity values are integers from 0 to 255
        pixels = np.array(pixels, dtype='uint8')

        print(pixels.shape)
        # Reshape the array into 28 x 28 array (2-dimensional array)
        pixels = pixels.reshape((28, 28))
        print(pixels.shape)
        # Plot
        plt.title('Label is {label}'.format(label=label))
        plt.imshow(pixels, cmap='gray')
        plt.show()

        break # This stops the loop, I just want to see one

我从某人那里得到了上面的代码,但无法让它显示 mnist 数字。

我收到错误:

回溯(最近调用最后): 文件“C:\Z_Uni\Individual_Project\Python_Projects\NeuralNet\Test_View_Mnist.py”,第 16 行,位于 像素 = np.array(像素,dtype='uint8') ValueError:以 10 为底的 int() 的无效文字:'1x1'

当我删除 dtype='unit8' 我收到错误:

回溯(最近调用最后): 文件“C:\Z_Uni\Individual_Project\Python_Projects\NeuralNet\Test_View_Mnist.py”,第 24 行,位于 plt.imshow(像素,cmap='gray') 文件“C:\Z_Uni\Individual_Project\Python_Projects\NeuralNet\source\lib\site-packages\matplotlib_api\deprecation.py”,第 456 行,在包装器中 return func(*args, **kwargs) imshow 中的文件“C:\Z_Uni\Individual_Project\Python_Projects\NeuralNet\source\lib\site-packages\matplotlib\pyplot.py”,第 2640 行 _ret = gca().imshow( 文件“C:\Z_Uni\Individual_Project\Python_Projects\NeuralNet\source\lib\site-packages\matplotlib_api\deprecation.py”,第 456 行,在包装器中 return func(*args, **kwargs) 文件“C:\Z_Uni\Individual_Project\Python_Projects\NeuralNet\source\lib\site-packages\matplotlib_init.py”,第 1412 行,位于内部 return func(ax, *map(sanitize_sequence, args), **kwargs) imshow

中的文件“C:\Z_Uni\Individual_Project\Python_Projects\NeuralNet\source\lib\site-packages\matplotlib\axes_axes.py”,第 5488 行
im.set_data(X)

文件“C:\Z_Uni\Individual_Project\Python_Projects\NeuralNet\source\lib\site-packages\matplotlib\image.py”,第 706 行,在 set_data raise TypeError("dtype {} 的图像数据无法转换为" TypeError: dtype

进程已完成,退出代码为 1

有人可以解释为什么会发生此错误以及如何解决吗? 谢谢

这里有两个问题。 (1) 您需要跳过第一行,因为它们是标签。 (1x1)、(1x2) 等。 (2) 您需要 int64 数据类型。下面的代码将解决这两个问题。 next(csvreader) 跳过第一行。

import numpy as np
import csv
import matplotlib.pyplot as plt

with open('./mnist_test.csv', 'r') as csv_file:
    csvreader = csv.reader(csv_file)
    next(csvreader)
    for data in csvreader:
        
        # The first column is the label
        label = data[0]

        # The rest of columns are pixels
        pixels = data[1:]

        # Make those columns into a array of 8-bits pixels
        # This array will be of 1D with length 784
        # The pixel intensity values are integers from 0 to 255
        pixels = np.array(pixels, dtype = 'int64')
        print(pixels.shape)
        # Reshape the array into 28 x 28 array (2-dimensional array)
        pixels = pixels.reshape((28, 28))
        print(pixels.shape)
        # Plot
        plt.title('Label is {label}'.format(label=label))
        plt.imshow(pixels, cmap='gray')
        plt.show()