具有多个worker的可迭代pytorch数据集



所以我有一个比我的ram内存大的文本文件,我想在PyTorch中创建一个逐行读取的数据集,所以我不必一次在内存中加载它。我发现pytorchIterableDataset作为我问题的潜在解决方案。它只在使用一个worker时按预期工作,如果使用多个worker,它将创建重复的记录。让我给你看一个例子:

有一个testfile.txt包含:

0 - Dummy line
1 - Dummy line
2 - Dummy line
3 - Dummy line
4 - Dummy line
5 - Dummy line
6 - Dummy line
7 - Dummy line
8 - Dummy line
9 - Dummy line

定义IterableDataset:

class CustomIterableDatasetv1(IterableDataset):
def __init__(self, filename):
#Store the filename in object's memory
self.filename = filename
def preprocess(self, text):
### Do something with text here
text_pp = text.lower().strip()
###
return text_pp
def line_mapper(self, line):

#Splits the line into text and label and applies preprocessing to the text
text, label = line.split('-')
text = self.preprocess(text)
return text, label

def __iter__(self):
#Create an iterator
file_itr = open(self.filename)
#Map each element using the line_mapper
mapped_itr = map(self.line_mapper, file_itr)

return mapped_itr

我们现在可以测试它:

base_dataset = CustomIterableDatasetv1("testfile.txt")
#Wrap it around a dataloader
dataloader = DataLoader(base_dataset, batch_size = 1, num_workers = 1)
for X, y in dataloader:
print(X,y)

输出:


('0',) (' Dummy linen',)
('1',) (' Dummy linen',)
('2',) (' Dummy linen',)
('3',) (' Dummy linen',)
('4',) (' Dummy linen',)
('5',) (' Dummy linen',)
('6',) (' Dummy linen',)
('7',) (' Dummy linen',)
('8',) (' Dummy linen',)
('9',) (' Dummy line',)

正确。但如果我把工人的数量改为2,输出就变成

('0',) (' Dummy linen',)
('0',) (' Dummy linen',)
('1',) (' Dummy linen',)
('1',) (' Dummy linen',)
('2',) (' Dummy linen',)
('2',) (' Dummy linen',)
('3',) (' Dummy linen',)
('3',) (' Dummy linen',)
('4',) (' Dummy linen',)
('4',) (' Dummy linen',)
('5',) (' Dummy linen',)
('5',) (' Dummy linen',)
('6',) (' Dummy linen',)
('6',) (' Dummy linen',)
('7',) (' Dummy linen',)
('7',) (' Dummy linen',)
('8',) (' Dummy linen',)
('8',) (' Dummy linen',)
('9',) (' Dummy line',)
('9',) (' Dummy line',)

这是不正确的,因为在数据加载器中为每个worker创建每个样本的副本。

是否有办法解决这个问题与pytorch?因此,可以创建一个数据加载器来不加载内存中的所有文件,并支持多个worker。

所以我在火炬讨论论坛https://discuss.pytorch.org/t/iterable-pytorch-dataset-with-multiple-workers/135475/3中找到了一个答案,他们指出我应该使用工人信息连续切片到批大小。

新的数据集看起来像这样:

class CustomIterableDatasetv1(IterableDataset):
def __init__(self, filename):
#Store the filename in object's memory
self.filename = filename
def preprocess(self, text):
### Do something with text here
text_pp = text.lower().strip()
###
return text_pp
def line_mapper(self, line):

#Splits the line into text and label and applies preprocessing to the text
text, label = line.split('-')
text = self.preprocess(text)
return text, label

def __iter__(self):
worker_total_num = torch.utils.data.get_worker_info().num_workers
worker_id = torch.utils.data.get_worker_info().id
#Create an iterator
file_itr = open(self.filename)
#Map each element using the line_mapper
mapped_itr = map(self.line_mapper, file_itr)

#Add multiworker functionality
mapped_itr = itertools.islice(mapped_itr, worker_id, None, worker_total_num)
return mapped_itr

特别感谢@Ivan,他也指出了切片解决方案。

如果有两个工作者,它返回的数据与只有一个工作者的数据相同

您可以使用torch.utils.data.get_worker_infoutil访问Dataset__iter__函数中的工作标识符。这意味着您可以遍历迭代器并根据workerid添加偏移量。你可以用itertools.islice包装一个迭代器,它允许你步进start索引和step索引。

下面是一个简单的例子:

class DS(IterableDataset):
def __init__(self, batch_size):
super().__init__()
self.batch_size = batch_size
def __iter__(self):
uid = torch.utils.data.get_worker_info().id
itr = islice(range(10), uid, None, self.batch_size)
return itr

即使我们使用num_workers > 1:

,循环遍历数据加载器也会产生唯一的实例
>>> for x in DataLoader(DS(batch_size=2), batch_size=2, num_workers=2):
...     print(x)
tensor([0, 2])
tensor([1, 3])
tensor([4, 6])
tensor([5, 7])
tensor([8])
tensor([9])

您可以这样做:

def __iter__(self):
# create an iterator
file_itr = open(self.filename)
# map each element using the line_mapper
mapped_itr = map(self.line_mapper, file_itr)

# wrap the iterator
step_itr = islice(mapped_itr, uid, None, self.batch_size)
return step_itr

相关内容

  • 没有找到相关文章

最新更新