如何将图像转换应用于图像列表并保持正确的尺寸

  • 本文关键字:图像 转换 应用于 列表 pytorch
  • 更新时间 :
  • 英文 :


我使用的是Omniglot数据集,它是一组19280张图像,每张图像都是105 x 105(灰度(。

我用以下转换定义了一个自定义数据集类:

class OmniglotDataset(Dataset):
def __init__(self, X, transform=None):
self.X = X
self.transform = transform
def __len__(self):
return self.X.shape[0]
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
img = self.X[idx]
if self.transform:
img = self.transform(img)
return img
img_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
X_train.shape
(19280, 105, 105)
train_dataset = OmniglotDataset(X_train, transform=img_transform)

当我索引单个图像时,它会返回正确的维度:

train_dataset[0].shape
torch.Size([1, 105, 105])

但当我索引多个图像时,它会以错误的顺序返回维度(我希望3 x 105 x 105(:

train_dataset[[1,2,3]].shape
torch.Size([105, 3, 105])

您收到错误,因为尝试将单个图像的转换应用到列表:

获得任何大小的批次的一种更方便的方法是使用Dataloader:

from torch.utils.data import DataLoader
from torchvision import datasets, transforms
img_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
omniglot = datasets.Omniglot(root='./data', background=True, download=True, transform = img_transform)
data_loader = DataLoader(omniglot, shuffle=False, batch_size = 8)
for image_batch in data_loader:
# now image_batch contain first eight samples
print(image_batch.shape) # torch.Size([8, 1, 105, 105]) 
break

如果你真的需要以任意顺序获取图像:

from operator import itemgetter
indexes = [1,3,5]
selected_samples = itemgetter(*b)(omniglot) 

最新更新