如何从 PyTorch 中的数据加载器获取整个数据集



如何从DataLoader加载整个数据集?我只得到一批数据集。

这是我的代码

dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=64)
images, labels = next(iter(dataloader))

您可以设置batch_size = len(dataset).请注意,这可能需要大量内存,具体取决于您的数据集。

我不确定您是想在网络训练以外的其他地方使用数据集(例如检查图像(,还是想在训练期间迭代批处理。

循环访问数据集

要么按照乌斯曼·阿里的回答(可能会溢出(你的记忆,要么你可以这样做

for i in range(len(dataset)): # or i, image in enumerate(dataset)
images, labels = dataset[i] # or whatever your dataset returns

之所以能够编写dataset[i],是因为您在Dataset类中实现了__len____getitem__(只要它是 PytorchDataset类的子类(。

从数据加载器获取所有批次

我理解您的问题的方式是您想检索所有批次来训练网络。你应该明白,iter给你一个数据加载器的迭代器(如果你不熟悉迭代器的概念,请参阅维基百科条目(。next告诉迭代器给你下一项。

因此,与遍历列表的迭代器相反,数据加载器始终返回下一项。列表迭代器在某个时间点停止。我假设你有一些类似纪元的东西,每个纪元有一些步骤。那么你的代码将如下所示

for i in range(epochs):
# some code
for j in range(steps_per_epoch):
images, labels = next(iter(dataloader))
prediction = net(images)
loss = net.loss(prediction, labels)
...

小心next(iter(dataloader)).如果你想遍历一个列表,这也可能有效,因为 Python 会缓存对象,但每次再次从索引 0 开始时,你最终可能会得到一个新的迭代器。为了避免这种情况,将迭代器取出到顶部,如下所示:

iterator = iter(dataloader)
for i in range(epochs):
for j in range(steps_per_epoch):
images, labels = next(iterator)

另一种选择是直接获取整个数据集,而无需使用数据加载器,如下所示:

images, labels = dataset[:]

相关内容

  • 没有找到相关文章

最新更新