Pytorch DataLoader更改dict返回值



给定一个Pytorch数据集,该数据集读取JSON文件如下:

import csv
from torch.utils.data import IterableDataset
from torch.utils.data import DataLoader2, DataLoader
class MyDataset(IterableDataset):
def __init__(self, jsonfilename):
self.filename = jsonfilename
def __iter__(self):
with open(self.filename) as fin:
reader = csv.reader(fin)
headers = next(reader)
for line in reader:
yield dict(zip(headers, line))

content = """imagefile,label
train/0/16585.png,0
train/0/56789.png,0"""
with open('myfile.json', 'w') as fout:
fout.write(content)
ds = MyDataset("myfile.json")

当我循环浏览数据集时,返回值是json的每一行的dict,例如

ds = MyDataset("myfile.json")
for i in ds:
print(i)

[out]:

{'imagefile': 'train/0/16585.png', 'label': '0'}
{'imagefile': 'train/0/56789.png', 'label': '0'}

但是,当我将数据集读取到DataLoader中时,它会以列表的形式返回dict的值,而不是值本身,例如

ds = MyDataset("myfile.json")
x = DataLoader(dataset=ds)
for i in x:
print(i)

[out]:

{'imagefile': ['train/0/16585.png'], 'label': ['0']}
{'imagefile': ['train/0/56789.png'], 'label': ['0']}

问(第1部分(:为什么DataLoader将dict的值更改为列表

以及

Q(第2部分(:当使用DataLoader运行__iter__时,如何使DataLoader只返回dict的值,而不是值列表?是否有一些参数/选项可以在DataLoader中使用

原因是torch.utils.data.DataLoader中的默认整理行为,它决定了如何合并批次中的数据样本。默认情况下,使用torch.utils.data.default_collatecollate函数,该函数将映射转换为:

映射[K,V_i]->映射[K,default_collate([V_1,V_2,…](]

和字符串为:

str->str(不变(

注意,如果在示例中将batch_size设置为2,则会得到:

{'imagefile': ['train/0/16585.png', 'train/0/56789.png'], 'label': ['0', '0']}

作为这些转换的结果。

假设您不需要批处理,您可以通过设置batch_size=None禁用它来获得所需的输出。此处的详细信息:加载批处理和非批处理数据。

有关详细信息,请参阅@GoodDeeds的回答!https://stackoverflow.com/a/73824234/610569


以下答案适用于TL;DR阅读器:

Q: 为什么DataLoader会将dict的值更改为列表

A: 因为有一个隐含的假设,即DataLoader对象的__iter__应该返回一批数据,而不是单个数据。

问(第2部分(:当使用DataLoader运行iter时,如何使DataLoader只返回dict的值,而不是值列表?是否有一些参数/选项可以在DataLoader中使用

A: 由于隐含的批返回行为,最好修改{key: [value1, value2, ...]中的返回数据批,而不是试图强制DataLoader返回{key: value1}

要更好地理解批处理假设,请尝试batch_size参数:

x = DataLoader(dataset=ds, batch_size=2)
for i in x:
print(i)

[out]:

{'imagefile': ['train/0/16585.png', 'train/0/56789.png'], 'label': ['0', '0']}

最新更新