Pytorch的数据加载器混洗何时发生



我已经多次使用pytorch数据加载器的shuffle选项。但我想知道这种洗牌是什么时候发生的,它是否在迭代过程中动态执行。以以下代码为例:

namesDataset = NamesDataset()
namesTrainLoader = DataLoader(namesDataset, batch_size=16, shuffle=True)
for batch_data in namesTrainLoader:
print(batch_data)

当我们定义"namesTrainLoader"时,这是否意味着洗牌已经完成,接下来的迭代将基于固定的数据顺序?在定义了namesTrainLoader之后,for循环中会有任何随机性吗?

我试图用一些特殊值替换一半的"batch_data":

for batch_data in namesTrainLoader:
batch_data[:8] = special_val
pre = model(batch_data)

假设会有无限多的时代,"模型"最终会看到"namesTrainLoader"中的所有数据吗?或者"namesTrainLoader"的一半数据实际上被"模型"丢失了?

创建迭代器时会发生混洗。在for循环的情况下,这发生在for循环开始之前。

您可以使用以下命令手动创建迭代器:

# Iterator gets created, the data has been shuffled at this point.
data_iterator = iter(namesTrainLoader)

默认情况下,如果设置shuffle=True(不提供自己的采样器(,则数据加载器使用torch.utils.data.RandomSampler。它的实现非常直接,您可以通过查看RandomSampler.__iter__方法来查看迭代器创建时数据的混洗位置:

def __iter__(self):
n = len(self.data_source)
if self.replacement:
return iter(torch.randint(high=n, size=(self.num_samples,), dtype=torch.int64).tolist())
return iter(torch.randperm(n).tolist())

return语句是进行混洗的重要部分。它只是创建索引的随机排列。

这意味着,每次完全使用迭代器时,您都会看到整个数据集,只是每次的顺序不同。因此,没有数据丢失(不包括drop_last=True的情况(,您的模型将在每个历元看到所有数据。

您可以在此处查看PyTorch对torch.utils.data.DataLoader的实现。

如果指定shuffle=True,则将使用torch.utils.data.RandomSampler(否则为SequentialSampler(。

当创建DataLoader的实例时,任何东西都不会被打乱,它只是实例化对象和其他类似设置的必要私有成员。

当您在迭代过程中发出特殊的__iter__方法时,会返回一个名为_SingleProcessDataLoader(self)的特殊对象,它是数据的生成器(可能是批处理、混洗等,假设您不使用多处理(。

要找到所有私有和助手相关的方法有点像兔子洞,但它基本上是使用底层sampler来获取索引,这些索引用于从torch.utils.data.Dataset中获取样本。

采样器一直运行到耗尽,过程重复(通常是一个历元(。

在namesTrainLoader之后的for循环中会有任何随机性吗定义?

在每个周期/epoch开始时,RandomSampler会打乱索引,因此是的,它将在每个epoch之前随机化(当调用__iter__并返回新的_SingleProcessDataLoader(self)时(,这可以无限期地进行。

[…]"model"最终会看到"namesTrainLoader"中的所有数据吗?

是的,它很可能最终会看到所有数据点

相关内容

  • 没有找到相关文章

最新更新