具有多个工作人员的可迭代 pytorch 数据集
Iterable pytorch dataset with multiple workers
所以我有一个比我的 ram 内存大的文本文件,我想在 PyTorch 中创建一个逐行读取的数据集,这样我就不必一次将它全部加载到内存中。我发现 pytorch IterableDataset
作为我的问题的潜在解决方案。它仅在使用 1 个工作人员时按预期工作,如果使用多个工作人员,它将创建重复的记录。让我举个例子:
testfile.txt
包含:
0 - Dummy line
1 - Dummy line
2 - Dummy line
3 - Dummy line
4 - Dummy line
5 - Dummy line
6 - Dummy line
7 - Dummy line
8 - Dummy line
9 - Dummy line
定义 IterableDataset:
class CustomIterableDatasetv1(IterableDataset):
def __init__(self, filename):
#Store the filename in object's memory
self.filename = filename
def preprocess(self, text):
### Do something with text here
text_pp = text.lower().strip()
###
return text_pp
def line_mapper(self, line):
#Splits the line into text and label and applies preprocessing to the text
text, label = line.split('-')
text = self.preprocess(text)
return text, label
def __iter__(self):
#Create an iterator
file_itr = open(self.filename)
#Map each element using the line_mapper
mapped_itr = map(self.line_mapper, file_itr)
return mapped_itr
我们现在可以测试它了:
base_dataset = CustomIterableDatasetv1("testfile.txt")
#Wrap it around a dataloader
dataloader = DataLoader(base_dataset, batch_size = 1, num_workers = 1)
for X, y in dataloader:
print(X,y)
它输出:
('0',) (' Dummy line\n',)
('1',) (' Dummy line\n',)
('2',) (' Dummy line\n',)
('3',) (' Dummy line\n',)
('4',) (' Dummy line\n',)
('5',) (' Dummy line\n',)
('6',) (' Dummy line\n',)
('7',) (' Dummy line\n',)
('8',) (' Dummy line\n',)
('9',) (' Dummy line',)
没错。但是如果我将工人数量更改为 2,则输出变为
('0',) (' Dummy line\n',)
('0',) (' Dummy line\n',)
('1',) (' Dummy line\n',)
('1',) (' Dummy line\n',)
('2',) (' Dummy line\n',)
('2',) (' Dummy line\n',)
('3',) (' Dummy line\n',)
('3',) (' Dummy line\n',)
('4',) (' Dummy line\n',)
('4',) (' Dummy line\n',)
('5',) (' Dummy line\n',)
('5',) (' Dummy line\n',)
('6',) (' Dummy line\n',)
('6',) (' Dummy line\n',)
('7',) (' Dummy line\n',)
('7',) (' Dummy line\n',)
('8',) (' Dummy line\n',)
('8',) (' Dummy line\n',)
('9',) (' Dummy line',)
('9',) (' Dummy line',)
这是不正确的,因为在数据加载器中为每个工作人员创建了每个样本的副本。
有没有办法用pytorch解决这个问题?因此,可以创建一个数据加载器,以不加载内存中的所有文件,并支持多个工作程序。
您可以使用 torch.utils.data.get_worker_info
util. This means you can step through the iterator and add an offset depending on the worker id. You can wrap an iterator with itertools.islice
访问 Dataset
的 __iter__
函数中的工作标识符,它允许您执行 start
索引以及step
.
这是一个最小的例子:
class DS(IterableDataset):
def __init__(self, batch_size):
super().__init__()
self.batch_size = batch_size
def __iter__(self):
uid = torch.utils.data.get_worker_info().id
itr = islice(range(10), uid, None, self.batch_size)
即使我们使用 num_workers > 1
:
,遍历数据加载器也会产生唯一的实例
>>> for x in DataLoader(DS(batch_size=2), batch_size=2, num_workers=2):
... print(x)
tensor([0, 2])
tensor([1, 3])
tensor([4, 6])
tensor([5, 7])
tensor([8])
tensor([9])
在你的情况下你可以这样做:
def __iter__(self):
# create an iterator
file_itr = open(self.filename)
# map each element using the line_mapper
mapped_itr = map(self.line_mapper, file_itr)
# wrap the iterator
step_itr = islice(mapped_itr, uid, None, self.batch_size)
return step_itr
所以我在 torch 讨论论坛 https://discuss.pytorch.org/t/iterable-pytorch-dataset-with-multiple-workers/135475/3 中找到了答案,他们指出我应该使用工作人员信息来连续切片到批量大小。
新数据集如下所示:
class CustomIterableDatasetv1(IterableDataset):
def __init__(self, filename):
#Store the filename in object's memory
self.filename = filename
def preprocess(self, text):
### Do something with text here
text_pp = text.lower().strip()
###
return text_pp
def line_mapper(self, line):
#Splits the line into text and label and applies preprocessing to the text
text, label = line.split('-')
text = self.preprocess(text)
return text, label
def __iter__(self):
worker_total_num = torch.utils.data.get_worker_info().num_workers
worker_id = torch.utils.data.get_worker_info().id
#Create an iterator
file_itr = open(self.filename)
#Map each element using the line_mapper
mapped_itr = map(self.line_mapper, file_itr)
#Add multiworker functionality
mapped_itr = itertools.islice(mapped_itr, worker_id, None, worker_total_num)
return mapped_itr
特别感谢@Ivan 也指出了切片的解决方法
有两个工人returns只有一个工人的数据相同
所以我有一个比我的 ram 内存大的文本文件,我想在 PyTorch 中创建一个逐行读取的数据集,这样我就不必一次将它全部加载到内存中。我发现 pytorch IterableDataset
作为我的问题的潜在解决方案。它仅在使用 1 个工作人员时按预期工作,如果使用多个工作人员,它将创建重复的记录。让我举个例子:
testfile.txt
包含:
0 - Dummy line
1 - Dummy line
2 - Dummy line
3 - Dummy line
4 - Dummy line
5 - Dummy line
6 - Dummy line
7 - Dummy line
8 - Dummy line
9 - Dummy line
定义 IterableDataset:
class CustomIterableDatasetv1(IterableDataset):
def __init__(self, filename):
#Store the filename in object's memory
self.filename = filename
def preprocess(self, text):
### Do something with text here
text_pp = text.lower().strip()
###
return text_pp
def line_mapper(self, line):
#Splits the line into text and label and applies preprocessing to the text
text, label = line.split('-')
text = self.preprocess(text)
return text, label
def __iter__(self):
#Create an iterator
file_itr = open(self.filename)
#Map each element using the line_mapper
mapped_itr = map(self.line_mapper, file_itr)
return mapped_itr
我们现在可以测试它了:
base_dataset = CustomIterableDatasetv1("testfile.txt")
#Wrap it around a dataloader
dataloader = DataLoader(base_dataset, batch_size = 1, num_workers = 1)
for X, y in dataloader:
print(X,y)
它输出:
('0',) (' Dummy line\n',)
('1',) (' Dummy line\n',)
('2',) (' Dummy line\n',)
('3',) (' Dummy line\n',)
('4',) (' Dummy line\n',)
('5',) (' Dummy line\n',)
('6',) (' Dummy line\n',)
('7',) (' Dummy line\n',)
('8',) (' Dummy line\n',)
('9',) (' Dummy line',)
没错。但是如果我将工人数量更改为 2,则输出变为
('0',) (' Dummy line\n',)
('0',) (' Dummy line\n',)
('1',) (' Dummy line\n',)
('1',) (' Dummy line\n',)
('2',) (' Dummy line\n',)
('2',) (' Dummy line\n',)
('3',) (' Dummy line\n',)
('3',) (' Dummy line\n',)
('4',) (' Dummy line\n',)
('4',) (' Dummy line\n',)
('5',) (' Dummy line\n',)
('5',) (' Dummy line\n',)
('6',) (' Dummy line\n',)
('6',) (' Dummy line\n',)
('7',) (' Dummy line\n',)
('7',) (' Dummy line\n',)
('8',) (' Dummy line\n',)
('8',) (' Dummy line\n',)
('9',) (' Dummy line',)
('9',) (' Dummy line',)
这是不正确的,因为在数据加载器中为每个工作人员创建了每个样本的副本。
有没有办法用pytorch解决这个问题?因此,可以创建一个数据加载器,以不加载内存中的所有文件,并支持多个工作程序。
您可以使用 torch.utils.data.get_worker_info
util. This means you can step through the iterator and add an offset depending on the worker id. You can wrap an iterator with itertools.islice
访问 Dataset
的 __iter__
函数中的工作标识符,它允许您执行 start
索引以及step
.
这是一个最小的例子:
class DS(IterableDataset):
def __init__(self, batch_size):
super().__init__()
self.batch_size = batch_size
def __iter__(self):
uid = torch.utils.data.get_worker_info().id
itr = islice(range(10), uid, None, self.batch_size)
即使我们使用 num_workers > 1
:
>>> for x in DataLoader(DS(batch_size=2), batch_size=2, num_workers=2):
... print(x)
tensor([0, 2])
tensor([1, 3])
tensor([4, 6])
tensor([5, 7])
tensor([8])
tensor([9])
在你的情况下你可以这样做:
def __iter__(self):
# create an iterator
file_itr = open(self.filename)
# map each element using the line_mapper
mapped_itr = map(self.line_mapper, file_itr)
# wrap the iterator
step_itr = islice(mapped_itr, uid, None, self.batch_size)
return step_itr
所以我在 torch 讨论论坛 https://discuss.pytorch.org/t/iterable-pytorch-dataset-with-multiple-workers/135475/3 中找到了答案,他们指出我应该使用工作人员信息来连续切片到批量大小。
新数据集如下所示:
class CustomIterableDatasetv1(IterableDataset):
def __init__(self, filename):
#Store the filename in object's memory
self.filename = filename
def preprocess(self, text):
### Do something with text here
text_pp = text.lower().strip()
###
return text_pp
def line_mapper(self, line):
#Splits the line into text and label and applies preprocessing to the text
text, label = line.split('-')
text = self.preprocess(text)
return text, label
def __iter__(self):
worker_total_num = torch.utils.data.get_worker_info().num_workers
worker_id = torch.utils.data.get_worker_info().id
#Create an iterator
file_itr = open(self.filename)
#Map each element using the line_mapper
mapped_itr = map(self.line_mapper, file_itr)
#Add multiworker functionality
mapped_itr = itertools.islice(mapped_itr, worker_id, None, worker_total_num)
return mapped_itr
特别感谢@Ivan 也指出了切片的解决方法
有两个工人returns只有一个工人的数据相同