使用 Mongo 数据库的 PyTorch DataLoader

PyTorch DataLoader using Mongo DB

我想知道使用连接到 MongoDB 的 DataLoader 是否明智,以及如何实现。

背景

我在(本地)MongoDB 中有大约 2000 万份文件。远远超出内存所能容纳的文档。我想在数据上训练一个深度神经网络。到目前为止,我都是先将数据导出到文件系统,子文件夹命名为文件的类。但我觉得这种方法很荒谬。如果数据已经很好地保存在数据库中,为什么要先导出(然后删除)。

问题一:

我说的对吗?直接连接到 MongoDB 有意义吗?还是有理由不这样做(例如,数据库通常太慢等)?如果数据库太慢(为什么?),能否以某种方式预取数据?

问题二:

如何实现 PyTorch DataLoader? 我在网上只找到了很少的代码片段 ([1] and [2]),这让我怀疑我的方法。

代码片段

我访问 MongoDB 的一般方式如下。我认为这没什么特别的。

import pymongo
from pymongo import MongoClient

myclient = pymongo.MongoClient("mongodb://localhost:27017/")
mydb = myclient["xyz"]
mycol = mydb["xyz_documents"]

query = {
    # some filters
}

results = mycol.find(query)

# results is now a cursor that can run through all docs
# Assume, for the sake of this example, that each doc contains a class name and some image that I want to train a classifier on

简介

这个有点开放性,但让我们尝试一下,如果我有什么地方不对,请纠正我。

So far, I have been exporting the data to the file system first, with subfolders named as the classes of the documents.

IMO 这不明智,因为:

  • 你实际上是在复制数据
  • 任何时候你想训练一个新的只给定的代码和数据库都必须重复这个操作
  • 您可以一次访问多个数据点并将它们缓存在 RAM 中供以后重用,而无需多次从硬盘读取(这非常繁重)

Am I right? Would it make sense to directly connect to the MongoDB?

上面给出的,可能是的(尤其是当涉及到清晰和可移植的实现时)

Or are there reasons not to do it (e.g. DBs generally being to slow etc.)?

AFAIK 数据库在这种情况下不应该变慢,因为它会缓存对它的访问,但不幸的是我不是数据库专家。许多加快访问速度的技巧都是开箱即用的数据库。

can one prefetch the data somehow?

是的,如果您只想获取数据,您可以一次加载大部分数据(比如 1024 条记录),然后从中加载 return 批数据(比如 batch_size=128)

实施

How would one implement a PyTorch DataLoader? I have found only very few code snippets online ([1] and [2]) which makes me doubt with my approach.

我不确定你为什么要这样做。如您列出的示例所示,您应该选择 torch.utils.data.Dataset

我将从类似于 here 的简单非优化方法开始,所以:

  • 打开与 __init__ 中的数据库的连接,并在使用时一直保持它(我会从 torch.utils.data.Dataset 创建一个上下文管理器,以便在 epoch 完成后关闭连接)
  • 我不会将结果转换为 list(特别是因为显而易见的原因你不能将它放入 RAM)因为它错过了生成器的要点
  • 我将在此数据集中执行批处理(有一个参数 batch_size here)。
  • 我不确定 __getitem__ 函数,但它似乎可以一次 return 多个数据点,因此我会使用它,它应该允许我们使用 num_workers>0 (鉴于 mycol.find(query) returns 数据每次都以相同的顺序)

鉴于此,我会按照这些思路去做:

class DatabaseDataset(torch.utils.data.Dataset):
    def __init__(self, query, batch_size, path: str, database: str):
        self.batch_size = batch_size

        client = pymongo.MongoClient(path)
        self.db = client[database]
        self.query = query
        # Or non-approximate method, if the approximate method
        # returns smaller number of items you should be fine
        self.length = self.db.estimated_document_count()

        self.cursor = None

    def __enter__(self):
        # Ensure that this find returns the same order of query every time
        # If not, you might get duplicated data
        # It is rather unlikely (depending on batch size), shouldn't be a problem
        # for 20 million samples anyway
        self.cursor = self.db.find(self.query)
        return self

    def shuffle(self):
        # Find a way to shuffle data so it is returned in different order
        # If that happens out of the box you might be fine without it actually
        pass

    def __exit__(self, *_, **__):
        # Or anything else how to close the connection
        self.cursor.close()

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, index):
        # Read takes long, hence if you can load a batch of documents it should speed things up
        examples = self.cursor[index * batch_size : (index + 1) * batch_size]
        # Do something with this data
        ...
        # Return the whole batch
        return data, labels

现在批处理由 DatabaseDataset 处理,因此 torch.utils.data.DataLoader 可以有 batch_size=1。您可能需要挤压额外的维度。

由于 MongoDB 使用锁(这并不奇怪,但请参阅 herenum_workers>0 应该不是问题。

可能的用法(示意图):

with DatabaseDataset(...) as e:
    dataloader = torch.utils.data.DataLoader(e, batch_size=1)
    for epoch in epochs:
        for batch in dataloader:
            # And all the stuff
            ...
        dataset.shuffle() # after each epoch

记住在这种情况下改组实施!(也可以在上下文管理器中完成改组,您可能想手动关闭连接或类似的东西)。