我有一个随机(非确定性,不可重复)数据的来源,我想在Dataset和Dataloader中包装PyTorch训练。我该怎么做呢?
__len__
没有定义,因为源是无限的(可能有重复)。__getitem__
没有定义,因为源是不确定的。
当定义自定义数据集类时,您通常会子类化torch.utils.data.Dataset
并定义__len__()
和__getitem__()
。
但是,对于需要顺序访问而不是随机访问的情况,可以使用可迭代风格的数据集。要做到这一点,您可以创建torch.utils.data.IterableDataset
的子类并定义__iter__()
。无论__iter__()
返回什么,都应该是一个合适的迭代器;它应该保持状态(如果有必要)并定义__next__()
来获取序列中的下一项。当没有东西可读时,__next__()
应该raise StopIteration
。与无限的数据集在你的情况中,它不需要这样做。
下面是一个例子:
import torch
class MyInfiniteIterator:
def __next__(self):
return torch.randn(10)
class MyInfiniteDataset(torch.utils.data.IterableDataset):
def __iter__(self):
return MyInfiniteIterator()
dataset = MyInfiniteDataset()
dataloader = torch.utils.data.DataLoader(dataset, batch_size = 32)
for batch in dataloader:
# ... Do some stuff here ...
# ...
# if some_condition:
# break