我正在使用多个数据集。我有多个损失,每个损失都必须在这些数据集的一个子集上进行评估。我想从每个数据集生成一个批次,并评估所有适当批次的每个损失。一些损失是成对的(需要加载相应数据点的对(,而另一些则在单个数据点上计算。我需要以这样一种方式设计它,以便轻松添加新数据集。是否有任何内置的pytorch可以帮助解决这个问题?在 pytorch 中设计它的最佳方式是什么?提前谢谢。
从您的问题中不清楚您的设置到底是什么。
但是,您可以有多个 Dataset
实例,每个数据集一个。
在数据集之上,可以实现"标记数据集",即为所有样本添加"标记"的数据集:
class TaggedDataset(data.Dataset):
def __init__(dataset, tag):
super(TaggedDataset, self).__init__()
self.ds_ = dataset
self.tag_ = tag
def __len__(self):
return len(self.ds_)
def __getitem__(self, index):
return self.ds_[index], self.tag_
为每个数据集提供不同的tag
,将它们全部连接成一个ConcatDataset
,并围绕它进行常规DataLoader
。
现在,在您的训练代码中
for input, label, tag in my_tagged_loader:
# process each input according to the dataset tag it got.