采用Pytorch数据集的子集



我有一个我想在某些数据集上训练的网络(例如, CIFAR10)。我可以通过

创建数据加载程序对象
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          shuffle=True, num_workers=2)

我的问题如下:假设我想进行几个不同的训练迭代。假设我首先想在所有图像上训练网络,然后在所有图像上,然后在所有图像上均位置等等。为此,我需要能够访问这些图像。不幸的是,trainset似乎不允许使用此类访问。也就是说,尝试进行trainset[:1000]或更一般的trainset[mask]会丢失错误。

我可以做

trainset.train_data=trainset.train_data[mask]
trainset.train_labels=trainset.train_labels[mask]

,然后

trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                              shuffle=True, num_workers=2)

但是,这将迫使我在每次迭代中创建完整数据集的新副本(因为我已经更改了trainset.train_data,因此我需要重新定义trainset)。有什么方法可以避免它吗?

理想情况下,我想拥有一些"等效"到

的东西
trainloader = torch.utils.data.DataLoader(trainset[mask], batch_size=4,
                                              shuffle=True, num_workers=2)

torch.utils.data.Subset更容易,支持shuffle,并且不需要编写自己的采样器:

import torchvision
import torch
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=None)
evens = list(range(0, len(trainset), 2))
odds = list(range(1, len(trainset), 2))
trainset_1 = torch.utils.data.Subset(trainset, evens)
trainset_2 = torch.utils.data.Subset(trainset, odds)
trainloader_1 = torch.utils.data.DataLoader(trainset_1, batch_size=4,
                                            shuffle=True, num_workers=2)
trainloader_2 = torch.utils.data.DataLoader(trainset_2, batch_size=4,
                                            shuffle=True, num_workers=2)

您可以为避免重新创建数据集的数据集加载程序定义自定义采样器(只为每个不同采样创建一个新的加载程序)。

class YourSampler(Sampler):
    def __init__(self, mask):
        self.mask = mask
    def __iter__(self):
        return (self.indices[i] for i in torch.nonzero(self.mask))
    def __len__(self):
        return len(self.mask)
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
sampler1 = YourSampler(your_mask)
sampler2 = YourSampler(your_other_mask)
trainloader_sampler1 = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          sampler = sampler1, shuffle=False, num_workers=2)
trainloader_sampler2 = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          sampler = sampler2, shuffle=False, num_workers=2)

ps:您可以在此处找到更多信息:http://pytorch.org/docs/master/_modules/_modules/torch/torils/data/data/sampler.html#sampler

最新更新