在Pytorch中,如何混洗DataLoader



我有一个包含10000个样本的数据集,其中的类以有序的方式存在。首先,我将数据加载到ImageFolder中,然后加载到DataLoader中,我想将此数据集拆分为train-val测试集。我知道DataLoader类有一个shuffle参数,但这对我来说不好,因为它只会在枚举时对数据进行shuffle。我知道RandomSampler函数,但有了它,我只能从数据集中随机获取n个数据量,而且我无法控制取出的数据,所以一个样本可能同时存在于train、test和val集中。

有没有一种方法可以打乱DataLoader中的数据?我唯一需要的就是洗牌,然后我可以对数据进行子集处理。

Subset数据集类获取索引(https://pytorch.org/docs/stable/data.html#torch.utils.data.Subset)。您可能可以利用它来获得以下功能。从本质上讲,您可以通过打乱索引,然后选择数据集的子集来逃脱惩罚。

# suppose dataset is the variable pointing to whole datasets
N = len(dataset)
# generate & shuffle indices
indices = numpy.arange(N)
indices = numpy.random.permutation(indices)
# there are many ways to do the above two operation. (Example, using np.random.choice can be used here too
# select train/test/val, for demo I am using 70,15,15
train_indices = indices [:int(0.7*N)]
val_indices = indices[int(0.7*N):int(0.85*N)]
test_indices = indices[int(0.85*N):]
train_dataset = Subset(dataset, train_indices)
val_dataset = Subset(dataset, val_indices)
test_dataset = Subset(dataset, test_indices)

最新更新