我试图用pytorch建立一个模型,我想使用自定义数据集。我有一个dataset.py
,它定义了一个类MyData
,它是torch.utils.data.Dataset
的一个子类。文件在这里。
# dataset.py
import torch
from tqdm import tqdm
import numpy as np
import re
from torch.utils.data import Dataset
from pathlib import Path
class MyDataset(Dataset):
def __init__(self, path, size=10000):
if not Path(path).exists():
raise FileNotFoundError
self.data = []
self.load_data(path, size)
def __len__(self):
return len(self.data)
def __getitem__(self, index):
return self.data[index]
def load_data(self, path, size):
# Loading data from csv files and some preparation
# Each sample is in the format of (int_tag1, int_tag2, feature_dictionary),
# then the sample is appended to self.data
pass
然后我尝试在测试文件dataset_test.py
DataLoader
来测试这个数据集from torch.utils.data import DataLoader
from dataset import MyDataset
path = 'dataset/sample_train.csv'
size = 1000
dataset = MyDataset(path, size)
dataloader = DataLoader(dataset, batch_size=1000)
for v in dataloader:
print(v)
我得到了以下输出
730600it [11:08, 1093.11it/s]
1000it [00:00, 20325.47it/s]
Traceback (most recent call last):
File "dataset_test.py", line 12, in <module>
for v in dataloader:
File "/home/usr/.local/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 521, in __next__
data = self._next_data()
File "/home/usr/.local/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 561, in _next_data
data = self._dataset_fetcher.fetch(index) # may raise StopIteration
File "/home/usr/.local/lib/python3.6/site-packages/torch/utils/data/_utils/fetch.py", line 52, in fetch
return self.collate_fn(data)
File "/home/usr/.local/lib/python3.6/site-packages/torch/utils/data/_utils/collate.py", line 84, in default_collate
return [default_collate(samples) for samples in transposed]
File "/home/usr/.local/lib/python3.6/site-packages/torch/utils/data/_utils/collate.py", line 84, in <listcomp>
return [default_collate(samples) for samples in transposed]
File "/home/usr/.local/lib/python3.6/site-packages/torch/utils/data/_utils/collate.py", line 74, in default_collate
return {key: default_collate([d[key] for d in batch]) for key in elem}
File "/home/usr/.local/lib/python3.6/site-packages/torch/utils/data/_utils/collate.py", line 74, in <dictcomp>
return {key: default_collate([d[key] for d in batch]) for key in elem}
File "/home/usr/.local/lib/python3.6/site-packages/torch/utils/data/_utils/collate.py", line 74, in <listcomp>
return {key: default_collate([d[key] for d in batch]) for key in elem}
KeyError: '210'
前两行可能是加载数据时的输出。(我不确定,因为我没有写任何输出。但是我使用tqdm来加载数据,所以我假设它是tqdm的输出?)
然后,我得到了这个键错误。我想知道哪一部分需要修改?我认为数据集类写得很好,因为从文件读取数据时没有错误。是否因为样本的格式不正确,导致数据加载器无法从数据集中正确加载数据?对格式有什么要求吗?我读过其他人的代码,但我没有发现任何信息提到在Dataset类中有样本格式的任何要求。
编辑:一个单独的样本看起来像这样
('0', '0', {'210': '9093445', '216': '9154780', '301': '9351665', '205': '4186222', '206': '8316799', '207': '8416205', '508': '9355039', '121': '3438658', '122': '3438762', '101': '31390', '124': '3438769', '125': '3438774', '127': '3438782', '128': '3864885', '129': '3864887', '150_14': '3941161', '127_14': '3812616', '109_14': '449068', '110_14': '569621'})
前两个'0'
s是标签,下面的字典包含特征。
正如@Shai提到的,如果它们在feature_dictionary
中的键在批处理中不相同,那么您从DataLoader
的默认collate_fn
中得到此错误。作为一种解决方案,您可以像下面这样编写自定义collate_fn
,它可以工作
class MyDataset(Dataset):
# ... your code ...
def collate_fn(self, batch):
tag1_batch = []
tag2_batch = []
feat_dict_batch = []
for tag1, tag2, feat_dict in batch:
tag1_batch.append(tag1)
tag2_batch.append(tag2)
feat_dict_batch.append(feat_dict)
return tag1_batch, tag2_batch, feat_dict_batch
path = 'dataset/sample_train.csv'
size = 1000
dataset = MyDataset(path, size)
dataloader = DataLoader(dataset, batch_size=3, collate_fn=dataset.collate_fn)
我遇到了类似的问题。对于我的例子,我注意到了pd。系列触发键错误。我将我的数据(包括目标和特征)转换为np。阵列/火炬。张量及其工作