如何在DataLoader数据集上使用map()



我正试图在一个新的数据集上训练一个预训练的视觉转换器(ViT(。该数据集由jpg图像组成,这些图像被分类到文件夹(train、val、test(中,并有4个caless。我想在数据集上使用map((进行预处理。我添加了'getitem'和'len’,使其成为地图样式的数据集。但我仍然得到错误:

AttributeError: 'DataLoader' object has no attribute 'map'

这是代码:

from torch.utils.data import DataLoader 
class MyDataset(Dataset):
def __init__(self, path, transform):
self.files = glob.glob(path)
print(type(self.files))
self.transform = transform
self.labels = [filepath.split('/')[-2] for filepath in self.files]
def __getitem__(self, item):
file = self.files[item]
label = self.labels[item]
file = Image.open(file)
file = self.transform(file)
return file, label
def __len__(self):
return len(self.files)


transform=transforms.Compose([transforms.ToTensor()])

train_data = MyDataset(train_path, transform)
val_data = MyDataset(val_path, transform)
test_data = MyDataset(test_path, transform)

train = DataLoader(train_data , batch_size=1, shuffle=True, num_workers=3)
val = DataLoader(val_data , batch_size=1, shuffle=True)
test = DataLoader(test_data , batch_size=1, shuffle=True)


feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')
data_collator = default_data_collator


def preprocess_images(examples):
images = examples['img']
images = [np.array(image, dtype=np.uint8) for image in images]
images = [np.moveaxis(image, source=-1, destination=0) for image in images]
inputs = feature_extractor(images=images)
examples['pixel_values'] = inputs['pixel_values']

return examples


features = Features({
'label': ClassLabel(
names=['class1', 'class2', 'class3', 'class4']),
'img': Array3D(dtype="int64", shape=(3, 32, 32)),
'pixel_values': Array3D(dtype="float32", shape=(3, 224, 224)),
})

preprocessed_train_ds = train.map(preprocess_images, batched=True, features=features)
preprocessed_val_ds = val.map(preprocess_images, batched=True, features=features)
preprocessed_test_ds = test.map(preprocess_images, batched=True, features=features)

我还能做什么?

当您为train、test和val数据集实例化DataLoader时,您可以指向标志collate_fn=preprocess_images函数。您必须更新功能以符合您的要求。

例如

DataLoader(train_data, collate_fn=preprocess_images, , batch_size=1, shuffle=True, num_workers=3)

参见此处

报价:

collage_fn(可调用,可选(:合并样本列表以形成小批量张量。使用从地图样式数据集。

相关内容

  • 没有找到相关文章

最新更新