我需要为以下论文的Python (Pytorch)实现加载CelebA数据集:https://arxiv.org/pdf/1908.10578.pdf加载CelebA数据集的原始代码是用MATLAB使用MatConvNet与autonn编写的(源15论文)。我有源代码,但我不确定我是否可以分享它。
这是我第一次使用Pytorch(版本1.9.0+cu102)并在计算机视觉中做论文实现。
我看了以下相关问题:我如何在谷歌Colab上加载CelebA数据集,使用火炬视觉,而不会耗尽内存?
并测试了用户anurag建议的解决方案:https://stackoverflow.com/a/65528710/15087536
不幸的是,我仍然得到一个语法错误。
代码如下:
import torchvision
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from torchvision import transforms
# Root directory for the dataset
data_root = 'data/celeba'
# Spatial size of training images, images are resized to this size.
image_size = 64
# batch size
batch_size = 50000
transform=transforms.Compose([transforms.Resize(image_size),
transforms.CenterCrop(image_size),transforms.ToTensor(),transforms.Normalize(mean=
[0.5, 0.5, 0.5],std=[0.5, 0.5, 0.5])
dataset = ImageFolder(data_root,transform) **syntax error**
由于我们不知道你的语法错误,所以我无法评论。
下面我将分享一种可能的方法。
-
您可以使用此链接从Kaggle下载celebA数据集。或者,您也可以使用这些数据创建一个Kaggle内核(不需要下载数据)
-
如果您正在使用google colab,请上传可从笔记本访问的数据。
-
接下来你可以编写一个PyTorch数据集,它将根据分区(train, valid, test)加载图像。
-
我在下面粘贴一个例子。你可以根据自己的需要自定义。
from torch.utils.data import Dataset, DataLoader
import pandas as pd
from skimage import io
class CelebDataset(Dataset):
def __init__(self,data_dir,partition_file_path,split,transform):
self.partition_file = pd.read_csv(partition_file_path)
self.data_dir = data_dir
self.split = split
self.transform = transform
def __len__(self):
self.partition_file_sub = self.partition_file[self.partition_file["partition"].isin(self.split)]
return len(self.partition_file_sub)
def __getitem__(self,idx):
img_name = os.path.join(self.data_dir,
self.partition_file_sub.iloc[idx, 0])
image = io.imread(img_name)
if self.transform:
image = self.transform(image)
return image
- 接下来,您可以创建您的列车和测试加载器。将IMAGE_PATH更改为包含图像的目录
batch_size = celeba_config['batch_size']
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
IMAGE_PATH = '../input/celeba-dataset/img_align_celeba/img_align_celeba'
trainset = CelebDataset(data_dir=IMAGE_PATH,
partition_file_path='../input/celeba-dataset/list_eval_partition.csv',
split=[0,1],
transform=transform)
trainloader = DataLoader(trainset, batch_size=batch_size,
shuffle=True, num_workers=2)
testset = CelebDataset(data_dir=IMAGE_PATH,
partition_file_path='../input/celeba-dataset/list_eval_partition.csv',
split=[2],
transform=transform)
testloader = DataLoader(testset, batch_size=batch_size,
shuffle=True, num_workers=2)