如何用pytorch在cifar10或stl10中加载一种类型的图像



这是一个非常简单的问题,我只是想从标准pytorch图像数据集中选择一类特定的图像(例如"汽车"(。目前数据加载程序如下所示:

def cycle(iterable):
while True:
for x in iterable:
yield x
train_loader = torch.utils.data.DataLoader(
torchvision.datasets.STL10('drive/My Drive/training/stl10', split='train+unlabeled', transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
])),
shuffle=True, batch_size=8)
train_iterator = iter(cycle(train_loader))
class_names = ['airplane', 'bird', 'car', 'cat', 'deer', 'dog', 'horse', 'monkey', 'ship', 'truck']
train_iterator = iter(cycle(train_loader))

迭代器返回一批所有类型的混洗图像,但我希望能够选择返回哪些类型的图像,例如,只返回鹿或船只的图像

完成!

def cycle(iterable):
while True:
for x in iterable:
yield x
# Return only images of certain class (eg. aeroplanes = class 0)
def get_same_index(target, label):
label_indices = []
for i in range(len(target)):
if target[i] == label:
label_indices.append(i)
return label_indices
# STL10 dataset
train_dataset = torchvision.datasets.STL10('drive/My Drive/training/stl10', split='train+unlabeled', download=True, transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor()]))
label_class = 1# birds
# Get indices of label_class
train_indices = get_same_index(train_dataset.labels, label_class)
bird_set = torch.utils.data.Subset(train_dataset, train_indices)
train_loader = torch.utils.data.DataLoader(dataset=bird_set, shuffle=True,
batch_size=batch_size, drop_last=True)
train_iterator = iter(cycle(train_loader))

最新更新