__getitem__的 idx 如何在 PyTorch 的 DataLoader 中工作?



我目前正在尝试使用PyTorch的DataLoader来处理数据,以将其输入到我的深度学习模型中,但遇到了一些困难。

我需要的数据的形状是(minibatch_size=32, rows=100, columns=41)。我编写的自定义Dataset类中的__getitem__代码如下所示:

def __getitem__(self, idx):
x = np.array(self.train.iloc[idx:100, :])
return x

我之所以这样写,是因为我希望DataLoader一次处理形状为(100, 41)的输入实例,而我们有32个这样的单个实例。

然而,我注意到,与我最初认为的相反,DataLoader传递给函数的idx参数不是连续的(这一点至关重要,因为我的数据是时间序列数据(。例如,打印值给了我这样的东西:

idx = 206000
idx = 113814
idx = 80597
idx = 3836
idx = 156187
idx = 54990
idx = 8694
idx = 190555
idx = 84418
idx = 161773
idx = 177725
idx = 178351
idx = 89217
idx = 11048
idx = 135994
idx = 15067

这是正常行为吗?我发布这个问题是因为返回的数据批次不是我最初想要的。

在使用DataLoader之前,我用来预处理数据的原始逻辑是:

  1. txtcsv文件中读取数据
  2. 计算数据中有多少批次,并相应地对数据进行切片。例如,由于一个输入实例的形状为(100, 41),其中32个形成一个小批量,因此我们通常会得到大约100个左右的批量,并相应地重塑数据
  3. 一个输入的形状是(32, 100, 41)

我不确定我应该如何处理DataLoader钩子方法。任何提示或建议都将不胜感激。提前谢谢。

定义idx的是samplerbatch_sampler,正如您在这里看到的(开源项目是您的朋友(。在这段代码(以及comment/docstring(中,您可以看到samplerbatch_sampler之间的区别。如果你看这里,你会看到索引是如何选择的:

def __next__(self):
index = self._next_index()
# and _next_index is implemented on the base class (_BaseDataLoaderIter)
def _next_index(self):
return next(self._sampler_iter)
# self._sampler_iter is defined in the __init__ like this:
self._sampler_iter = iter(self._index_sampler)
# and self._index_sampler is a property implemented like this (modified to one-liner for simplicity):
self._index_sampler = self.batch_sampler if self._auto_collation else self.sampler

注意这是_SingleProcessDataLoaderIter的实现;您可以在这里找到_MultiProcessingDataLoaderIter(如您所见,使用哪一个取决于num_workers值(。回到采样器,假设您的数据集不是_DatasetKind.Iterable,并且您没有提供自定义采样器,这意味着您正在使用(dataloader.py#L212-L215(:

if shuffle:
sampler = RandomSampler(dataset)
else:
sampler = SequentialSampler(dataset)
if batch_size is not None and batch_sampler is None:
# auto_collation without custom batch_sampler
batch_sampler = BatchSampler(sampler, batch_size, drop_last)

让我们来看看默认的BatchSampler是如何构建批次的:

def __iter__(self):
batch = []
for idx in self.sampler:
batch.append(idx)
if len(batch) == self.batch_size:
yield batch
batch = []
if len(batch) > 0 and not self.drop_last:
yield batch

非常简单:它从采样器中获取索引,直到达到所需的batch_size。

现在,"__getitem__的idx如何在PyTorch的DataLoader中工作?"这个问题可以通过查看每个默认采样器的工作方式来回答。

  • SequentialSampler(这是完整的实现——非常简单,不是吗?(:
class SequentialSampler(Sampler):
def __init__(self, data_source):
self.data_source = data_source
def __iter__(self):
return iter(range(len(self.data_source)))
def __len__(self):
return len(self.data_source)
  • RandomSampler(我们只看__iter__的实现(:
def __iter__(self):
n = len(self.data_source)
if self.replacement:
return iter(torch.randint(high=n, size=(self.num_samples,), dtype=torch.int64).tolist())
return iter(torch.randperm(n).tolist())

因此,由于您没有提供任何代码,我们只能假设:

  1. 您正在DataLoader中使用shuffle=True
  2. 您正在使用自定义采样器
  3. 您的数据集是_DatasetKind.Iterable

最新更新