解析 Yann LeCun 的 MNIST IDX 文件格式

Parsing Yann LeCun's MNIST IDX file format

我想了解如何打开this version of the MNIST data set。例如训练集标签文件train-labels-idx1-ubyte定义为:

TRAINING SET LABEL FILE (train-labels-idx1-ubyte):
[offset] [type]          [value]          [description]
0000     32 bit integer  0x00000801(2049) magic number (MSB first)
0004     32 bit integer  60000            number of items
0008     unsigned byte   ??               label
0009     unsigned byte   ??               label
........
xxxx     unsigned byte   ??               label

而且我在网上找到了一些代码似乎可以工作,但不明白它是如何工作的:

with open('train-labels-idx1-ubyte', 'rb') as f:
    bytes = f.read(8)
    magic, size = struct.unpack(">II", bytes)

print(magic) # 2049
print(size)  # 60000

我的理解是 struct.unpack 将第二个参数解释为两个 4 字节整数的大端字节串(参见 here)。但是,当我实际打印 bytes 的值时,我得到:

b'\x00\x00\x08\x01\x00\x00\xea`'

第一个四字节整数有意义:

b'\x00\x00\x08\x01'

前两个字节为0,下一个表示数据为无符号字节。而0x01表示标签的一维向量。假设到目前为止我的理解是正确的,接下来的三个(四个?)字节发生了什么:

...\x00\x00\xea`

这如何转化为 60,000?

要了解它的工作原理,您需要将其转换为二进制表示形式。

如您所述,Python 正确提取了正确的信息:

>>> import struct
>>> with open('train-labels-idx1-ubyte', 'rb') as f:
...     data = f.read(8)
... 
>>> print(data)
b'\x00\x00\x08\x01\x00\x00\xea`'
>>> print(struct.unpack('>II', data))
(2049, 60000)

在字符串的头部,有两个4字节的整数。如果我们遍历 data:

,我们可以看到它们的二进制和十进制表示
>>> for char in data:
...     print('{0:08b} - {0:3d} - {1:s}'.format(char, str(bytes([char]))))
... 
00000000 -   0 - b'\x00'
00000000 -   0 - b'\x00'
00001000 -   8 - b'\x08'
00000001 -   1 - b'\x01'
00000000 -   0 - b'\x00'
00000000 -   0 - b'\x00'
11101010 - 234 - b'\xea'
01100000 -  96 - b'`'

最简单的部分是知道前 4 个字节是第一个整数(幻数),接下来的 4 个字节是第二个整数(项目数)。

然后,给定这最后 4 个字节,有两种方法可以构造它们所代表的整数值。

第一个选项(MNIST 中使用的选项)是大端或高端。这意味着,首先找到最重要的字节:

00000000 00000000 11101010 01100000

如果你检查这个二进制数的十进制值,它是 60,000,即 MNIST 数据集中的项目数。

此外,我们可以将其解释为小端。在这种情况下,首先找到 LESS 有效字节:

01100000 11101010 00000000 00000000

它的十进制表示是数字 1,625,948,160。

因此,如果您简单地将 \x00\x00\xea` 中的每个字节转换为二进制,并找到整个二进制数的十进制表示(如果是小端,则恢复字节的顺序),您将得到整数值他们代表。

我编写了以下代码以防有人需要解析整个图像数据集(如问题标题中所示),而不仅仅是前两个字节。

import numpy as np
import struct

with open('samples/t10k-images-idx3-ubyte','rb') as f:
    magic, size = struct.unpack(">II", f.read(8))
    nrows, ncols = struct.unpack(">II", f.read(8))
    data = np.fromfile(f, dtype=np.dtype(np.uint8).newbyteorder('>'))
    data = data.reshape((size, nrows, ncols))

这假定您解压缩了 .gz 文件。您还可以使用压缩文件,如 Marktodisco's answer 所示,通过添加 import gzip,使用 gzip.open(...) 而不是 open(...),并使用 np.frombuffer(f.read(), ...) 而不是 np.fromfile(f, ...).

为了检查,请显示第一个数字。在我的例子中是 7.

import matplotlib.pyplot as plt
plt.imshow(data[0,:,:], cmap='gray')
plt.show()

另外,下面的代码读取带有标签的文件

with open('samples/t10k-labels-idx1-ubyte','rb') as f:
    magic, size = struct.unpack(">II", f.read(8))
    data = np.fromfile(f, dtype=np.dtype(np.uint8).newbyteorder('>'))
    data = data.reshape((size,)) # (Optional)
print(data)
# Prints: [7 2 1 ... 4 5 6]

根据您的标准,最后一次整形可以是 (size,)(1, size)

Carlos 的回答很好,但如果文件仍然是 .gz 格式,它就会中断。当我 运行 代码时,出现以下错误:

ValueError: cannot reshape array of size 1648861 into shape (10000,28,28)

由于 raw data 默认使用 .gz 扩展名下载,我修改了 Carlos 的代码。见下文。

import gzip
import struct
import numpy as np

with gzip.open('t10k-images-idx3-ubyte.gz','rb') as f:
    magic, size = struct.unpack(">II", f.read(8))
    nrows, ncols = struct.unpack(">II", f.read(8))
    data = np.frombuffer(f.read(), dtype=np.dtype(np.uint8).newbyteorder('>'))
    data = data.reshape((size, nrows, ncols))

图像仍然正确加载。

import matplotlib.pyplot as plt

plt.imshow(data[0,:,:], cmap='gray')
plt.show()