具有多个工作人员的可迭代 pytorch 数据集

Iterable pytorch dataset with multiple workers

所以我有一个比我的 ram 内存大的文本文件,我想在 PyTorch 中创建一个逐行读取的数据集,这样我就不必一次将它全部加载到内存中。我发现 pytorch IterableDataset 作为我的问题的潜在解决方案。它仅在使用 1 个工作人员时按预期工作,如果使用多个工作人员,它将创建重复的记录。让我举个例子:

testfile.txt 包含:

0 - Dummy line
1 - Dummy line
2 - Dummy line
3 - Dummy line
4 - Dummy line
5 - Dummy line
6 - Dummy line
7 - Dummy line
8 - Dummy line
9 - Dummy line

定义 IterableDataset:

class CustomIterableDatasetv1(IterableDataset):

    def __init__(self, filename):

        #Store the filename in object's memory
        self.filename = filename

    def preprocess(self, text):

        ### Do something with text here
        text_pp = text.lower().strip()
        ###

        return text_pp

    def line_mapper(self, line):
        
        #Splits the line into text and label and applies preprocessing to the text
        text, label = line.split('-')
        text = self.preprocess(text)

        return text, label


    def __iter__(self):

        #Create an iterator
        file_itr = open(self.filename)

        #Map each element using the line_mapper
        mapped_itr = map(self.line_mapper, file_itr)
        
        return mapped_itr

我们现在可以测试它了:

base_dataset = CustomIterableDatasetv1("testfile.txt")
#Wrap it around a dataloader
dataloader = DataLoader(base_dataset, batch_size = 1, num_workers = 1)
for X, y in dataloader:
    print(X,y)

它输出:



('0',) (' Dummy line\n',)
('1',) (' Dummy line\n',)
('2',) (' Dummy line\n',)
('3',) (' Dummy line\n',)
('4',) (' Dummy line\n',)
('5',) (' Dummy line\n',)
('6',) (' Dummy line\n',)
('7',) (' Dummy line\n',)
('8',) (' Dummy line\n',)
('9',) (' Dummy line',)

没错。但是如果我将工人数量更改为 2,则输出变为

('0',) (' Dummy line\n',)
('0',) (' Dummy line\n',)
('1',) (' Dummy line\n',)
('1',) (' Dummy line\n',)
('2',) (' Dummy line\n',)
('2',) (' Dummy line\n',)
('3',) (' Dummy line\n',)
('3',) (' Dummy line\n',)
('4',) (' Dummy line\n',)
('4',) (' Dummy line\n',)
('5',) (' Dummy line\n',)
('5',) (' Dummy line\n',)
('6',) (' Dummy line\n',)
('6',) (' Dummy line\n',)
('7',) (' Dummy line\n',)
('7',) (' Dummy line\n',)
('8',) (' Dummy line\n',)
('8',) (' Dummy line\n',)
('9',) (' Dummy line',)
('9',) (' Dummy line',)

这是不正确的,因为在数据加载器中为每个工作人员创建了每个样本的副本。

有没有办法用pytorch解决这个问题?因此,可以创建一个数据加载器,以不加载内存中的所有文件,并支持多个工作程序。

您可以使用 torch.utils.data.get_worker_info util. This means you can step through the iterator and add an offset depending on the worker id. You can wrap an iterator with itertools.islice 访问 Dataset__iter__ 函数中的工作标识符,它允许您执行 start 索引以及step.

这是一个最小的例子:

class DS(IterableDataset):
    def __init__(self, batch_size):
        super().__init__()
        self.batch_size = batch_size

    def __iter__(self):
        uid = torch.utils.data.get_worker_info().id
        itr = islice(range(10), uid, None, self.batch_size)

即使我们使用 num_workers > 1:

,遍历数据加载器也会产生唯一的实例
>>> for x in DataLoader(DS(batch_size=2), batch_size=2, num_workers=2):
...     print(x)
tensor([0, 2])
tensor([1, 3])
tensor([4, 6])
tensor([5, 7])
tensor([8])
tensor([9])

在你的情况下你可以这样做:

    def __iter__(self):
        # create an iterator
        file_itr = open(self.filename)

        # map each element using the line_mapper
        mapped_itr = map(self.line_mapper, file_itr)
    
        # wrap the iterator
        step_itr = islice(mapped_itr, uid, None, self.batch_size)

        return step_itr

所以我在 torch 讨论论坛 https://discuss.pytorch.org/t/iterable-pytorch-dataset-with-multiple-workers/135475/3 中找到了答案,他们指出我应该使用工作人员信息来连续切片到批量大小。

新数据集如下所示:

class CustomIterableDatasetv1(IterableDataset):

    def __init__(self, filename):

        #Store the filename in object's memory
        self.filename = filename

    def preprocess(self, text):

        ### Do something with text here
        text_pp = text.lower().strip()
        ###

        return text_pp

    def line_mapper(self, line):
        
        #Splits the line into text and label and applies preprocessing to the text
        text, label = line.split('-')
        text = self.preprocess(text)

        return text, label


    def __iter__(self):
        worker_total_num = torch.utils.data.get_worker_info().num_workers
        worker_id = torch.utils.data.get_worker_info().id
        #Create an iterator
        file_itr = open(self.filename)

        #Map each element using the line_mapper
        mapped_itr = map(self.line_mapper, file_itr)
        
        #Add multiworker functionality
        mapped_itr = itertools.islice(mapped_itr, worker_id, None, worker_total_num)

        return mapped_itr

特别感谢@Ivan 也指出了切片的解决方法

有两个工人returns只有一个工人的数据相同