为 PyTorch 使用大数据集的最有效方法?

Most efficient way to use a large data set for PyTorch?

也许以前有人问过这个问题,但我找不到适合我情况的相关信息。

我正在使用 PyTorch 创建一个用于图像数据回归的 CNN。我没有正式的学术编程背景,所以我的许多方法都是临时的,而且效率非常低。很多时候我可以回顾我的代码并稍后清理,因为效率低下并没有严重到性能受到显着影响。然而,在这种情况下,我使用图像数据的方法需要很长时间,使用大量内存,并且每次我想测试模型的变化时都要这样做。

我所做的基本上是将图像数据加载到 numpy 数组中,将这些数组保存在 .npy 文件中,然后当我想将所述数据用于模型时,我将所有数据导入该文件中.我不认为数据集真的那么大,因为它由 5000 个 3 色通道图像组成,大小为 64x64。然而,我的内存使用率在加载时高达 70%-80%(超出 16gb),每次加载都需要 20-30 秒。

我的猜测是我对加载它的方式很愚蠢,但坦率地说,我不确定标准是什么。我应该以某种方式在需要之前将图像数据放在某个地方,还是应该直接从图像文件加载数据?在任何一种情况下,独立于文件结构的最好、最有效的方法是什么?

如果能提供任何帮助,我将不胜感激。

这里有一个具体的例子来说明我的意思。这假设您已经使用 h5py.

将图像转储到 hdf5 文件 (train_images.hdf5) 中
import h5py
hf = h5py.File('train_images.hdf5', 'r')

group_key = list(hf.keys())[0]
ds = hf[group_key]

# load only one example
x = ds[0]

# load a subset, slice (n examples) 
arr = ds[:n]

# should load the whole dataset into memory.
# this should be avoided
arr = ds[:]

简而言之,ds 现在可以用作动态提供图像的迭代器(即它不会在内存中加载任何内容)。这应该会让整个 运行 时间过得飞快。

for idx, img in enumerate(ds):
   # do something with `img`

为了速度,我建议使用 HDF5LMDB:

Reasons to use LMDB:

LMDB uses memory-mapped files, giving much better I/O performance. Works well with really large datasets. The HDF5 files are always read entirely into memory, so you can’t have any HDF5 file exceed your memory capacity. You can easily split your data into several HDF5 files though (just put several paths to h5 files in your text file). Then again, compared to LMDB’s page caching the I/O performance won’t be nearly as good. [http://deepdish.io/2015/04/28/creating-lmdb-in-python/]

如果您决定使用 LMDB:

ml-pyxis 是一个使用 LMDB 创建和读取深度学习数据集的工具。*(我是该工具的共同作者)

它允许创建二进制 blob (LMDB),并且可以非常快速地读取它们。上面的 link 附带了一些关于如何创建和读取数据的简单示例。包括 python 个生成器/迭代器。

这个 notebook 有一个关于如何创建数据集并在使用 pytorch 时并行读取它的示例。

如果您决定使用 HDF5:

PyTables 是一个用于管理分层数据集的包,旨在高效、轻松地处理极其大量的数据。

https://www.pytables.org/

除了上述答案外,由于 Pytorch 世界的一些最新进展(2020 年),以下内容可能会有用。

您的问题:我应该以某种方式在需要之前将图像数据放在某个地方,还是应该直接从图像文件加载数据?在任何一种情况下,独立于文件结构的最好、最有效的方法是什么?

您可以将图像文件以其原始格式(.jpg、.png 等)保留在本地磁盘或云存储中,但需要添加一个步骤 - 将目录压缩为 tar 文件。请阅读此内容了解更多详情:

Pytorch 博客(2020 年 8 月):用于大型数据集、许多文件、许多 GPU 的高效 PyTorch I/O 库 (https://pytorch.org/blog/efficient-pytorch-io-library-for-large-datasets-many-files-many-gpus/)

此包专为数据文件太大而无法放入内存进行训练的情况而设计。因此,您提供 URL 数据集位置(本地、云、..),它将分批并行地引入数据。

唯一的(当前)要求是数据集必须是 tar 文件格式。

tar文件可以在本地盘也可以在云端。这样,您就不必每次都将整个数据集加载到内存中。您可以使用 torch.utils.data.DataLoader 批量加载随机梯度下降。

无需将图像保存到 npy 并全部加载到内存中。只需加载一批图像路径,然后转换为张量。

下面的代码定义了MassiveDataset,并将其传递给DataLoader,一切顺利。

from torch.utils.data.dataset import Dataset
from typing import Optional, Callable
import os
import multiprocessing

def apply_transform(transform: Callable, data):
    try:
        if isinstance(data, (list, tuple)):
            return [transform(item) for item in data]

        return transform(data)
    except Exception as e:
        raise RuntimeError(f'applying transform {transform}: {e}')


class MassiveDataset(Dataset):
    def __init__(self, filename, transform: Optional[Callable] = None):
        self.offset = []
        self.n_data = 0

        if not os.path.exists(filename):
            raise ValueError(f'filename does not exist: {filename}')

        with open(filename, 'rb') as fp:
            self.offset = [0]
            while fp.readline():
                self.offset.append(fp.tell())
            self.offset = self.offset[:-1]

        self.n_data = len(self.offset)

        self.filename = filename
        self.fd = open(filename, 'rb', buffering=0)
        self.lock = multiprocessing.Lock()

        self.transform = transform

    def __len__(self):
        return self.n_data

    def __getitem__(self, index: int):
        if index < 0:
            index = self.n_data + index
        
        with self.lock:
            self.fd.seek(self.offset[index])
            line = self.fd.readline()

        data = line.decode('utf-8').strip('\n')

        return apply_transform(self.transform, data) if self.transform is not None else data

NB: 打开文件缓冲=0multiprocessing.Lock() 用于避免加载错误数据(通常是文件的一部分和文件的另一部分)。

此外,如果在 DataLoader 中使用多处理,可能会出现这样的异常 TypeError: cannot serialize '_io.BufferedReader' object。这是由 multiprocessing 中使用的 pickle 模块引起的,它无法序列化 _io.BufferedReader,但是 dill can. Replacing multiprocessing with multiprocess,一切正常(与 multiprocessing 相比有重大变化,增强的序列化是用 dill 完成的)

this issue

中讨论了同样的事情