自定义数据集加载器pytorch



我正在进行新冠肺炎分类。我从kaggle获取数据集。它有一个名为dataset的文件夹,其中包含3个文件夹普通pnuemonia和新冠肺炎,每个文件夹都包含这些类的图像我在pytorch自定义数据加载器中编写getitem时陷入困境数据集有189张新冠肺炎图像,但通过此获取项目,我获得了920张新冠肺炎图像,请帮助

class_names = ['normal', 'viral', 'covid']
root_dir = 'COVID-19 Radiography Database'
source_dirs = ['NORMAL', 'Viral Pneumonia', 'COVID-19']
if os.path.isdir(os.path.join(root_dir, source_dirs[1])):
os.mkdir(os.path.join(root_dir, 'test'))
for i, d in enumerate(source_dirs):
os.rename(os.path.join(root_dir, d), os.path.join(root_dir, class_names[i]))
for c in class_names:
os.mkdir(os.path.join(root_dir, 'test', c))
for c in class_names:
images = [x for x in os.listdir(os.path.join(root_dir, c)) if x.lower().endswith('png')]
selected_images = random.sample(images, 30)
for image in selected_images:
source_path = os.path.join(root_dir, c, image)
target_path = os.path.join(root_dir, 'test', c, image)
shutil.move(source_path, target_path)

以上代码用于创建测试数据集,每个类有30个图像

class ChestXRayDataset(torch.utils.data.Dataset):
def __init__(self, image_dirs, transform):
def get_images(class_name):
images = [x for x in os.listdir(image_dirs[class_name]) if 
x[-3:].lower().endswith('png')]
print(f'Found {len(images)} {class_name} examples')
return images

self.images = {}
self.class_names = ['normal', 'viral', 'covid']

for class_name in self.class_names:
self.images[class_name] = get_images(class_name)

self.image_dirs = image_dirs
self.transform = transform

def __len__(self):
return sum([len(self.images[class_name]) for class_name in self.class_names])

def __getitem__(self, index):
class_name = random.choice(self.class_names)
index = index % len(self.images[class_name])
image_name = self.images[class_name][index]
image_path = os.path.join(self.image_dirs[class_name], image_name)
image = Image.open(image_path).convert('RGB')
return self.transform(image), self.class_names.index(class_name)

**卡在该**的获取项目中

文件夹中的图像排列如下数据集如下

**混淆矩阵的代码为**

nb_classes = 3
confusion_matrix = torch.zeros(nb_classes, nb_classes)
with torch.no_grad():
for data in tqdm_notebook(dl_train,total=len(dl_train),unit='batch'):
img,lab = data
print(lab)
img,lab = img.to(device),lab.to(device)
_,output = torch.max(model(img),1)
print(output)

for t, p in zip(lab.view(-1), output.view(-1)):
confusion_matrix[t.long(), p.long()] += 1

混淆矩阵的输出只有一个类正在接受训练混淆阵列图像

将图像放在字典中会使操作变得复杂,而不是使用列表。此外,数据集不应该有任何随机性,数据的混洗应该从DataLoader而不是从数据集发生。

使用以下内容:

class ChestXRayDataset(torch.utils.data.Dataset):
def __init__(self, image_dirs, transform):
def get_images(class_name):
images = [x for x in os.listdir(image_dirs[class_name]) if 
x[-3:].lower().endswith('png')]
print(f'Found {len(images)} {class_name} examples')
return images

self.images = []
self.labels = []
self.class_names = ['normal', 'viral', 'covid']

for class_name in self.class_names:
images = get_images(class_name)
# This is a list containing all the images
self.images.extend(images)
# This is a list containing all the corresponding image labels
self.labels.extend([class_name]*len(images))

self.image_dirs = image_dirs
self.transform = transform

def __len__(self):
return len(self.images)
# Will return the image and its label at the position `index`
def __getitem__(self, index):
# image at index position of all the images
image_name = self.images[index]
# Its label 
class_name = self.labels[index]
image_path = os.path.join(self.image_dirs[class_name], image_name)
image = Image.open(image_path).convert('RGB')
return self.transform(image), self.class_names.index(class_name)

如果你列举它,说使用

ds = ChestXRayDataset(image_dirs, transform)
for x, y in ds:
print (x.shape, y)

您应该按顺序查看所有图像和标签。

然而,在实际情况下,您宁愿使用Torch DataLoader,并将shuffle参数设置为Trueds对象传递给它。因此,DataLoader将通过用搅乱的index值调用__getitem__来处理数据集的搅乱。