PyTorch Dataset / Dataloader来自随机源



我有一个随机(非确定性,不可重复)数据的来源,我想在Dataset和Dataloader中包装PyTorch训练。我该怎么做呢?

__len__没有定义,因为源是无限的(可能有重复)。
__getitem__没有定义,因为源是不确定的。

当定义自定义数据集类时,您通常会子类化torch.utils.data.Dataset并定义__len__()__getitem__()

但是,对于需要顺序访问而不是随机访问的情况,可以使用可迭代风格的数据集。要做到这一点,您可以创建torch.utils.data.IterableDataset的子类并定义__iter__()。无论__iter__()返回什么,都应该是一个合适的迭代器;它应该保持状态(如果有必要)并定义__next__()来获取序列中的下一项。当没有东西可读时,__next__()应该raise StopIteration。与无限的数据集在你的情况中,它不需要这样做。

下面是一个例子:

import torch
class MyInfiniteIterator:
def __next__(self):
return torch.randn(10)
class MyInfiniteDataset(torch.utils.data.IterableDataset):
def __iter__(self):
return MyInfiniteIterator()
dataset = MyInfiniteDataset()
dataloader = torch.utils.data.DataLoader(dataset, batch_size = 32)
for batch in dataloader:
# ... Do some stuff here ...
# ...
# if some_condition:
#     break

最新更新