我有一个数据集,我必须以这样一种方式处理它,它可以与PyTorch的卷积神经网络一起工作(我对PyTorch完全陌生)。数据存储在一个数据框中,该数据框中有一个用于图片的列(28 x 28的ndarray,有int32项)和一个用于类标签的列。图像的像素仅采用+1和-1的值(因为它是经典二维Ising模型的模拟数据)。数据框看起来像这样。
我导入了以下内容(其中很多内容现在不相关,但为了完整性,我包含了所有内容)。"data_loader"是一个自定义的py文件):
import numpy as np
import matplotlib.pyplot as plt
import data_loader
import pandas as pd
import torch
import torchvision.transforms as T
from torchvision.utils import make_grid
from torch.nn import Module
from torch.nn import Conv2d
from torch.nn import Linear
from torch.nn import MaxPool2d
from torch.nn import ReLU
from torch.nn import LogSoftmax
from torch import flatten
from sklearn.metrics import classification_report
import time as time
from torch.utils.data import DataLoader, Dataset
然后,我想让它在正确的形状,以使它对PyTorch有用。为此,我定义了以下类
class MetropolisDataset(Dataset):
def __init__(self, data_frame, transform=None):
self.data_frame = data_frame
self.transform = transform
def __len__(self):
return len(self.data_frame)
def __getitem__(self,idx):
if torch.is_tensor(idx):
idx = idx.tolist()
label = self.data_frame['label'].iloc[idx]
image = self.data_frame['image'].iloc[idx]
image = np.array(image)
if self.transform:
image = self.transform(image)
return (image, label)
我把这个类的实例称为:
train_set = MetropolisDataset(data_frame = df_train,
transform = T.Compose([
T.ToPILImage(),
T.ToTensor()]))
validation_set = MetropolisDataset(data_frame = df_validation,
transform = T.Compose([
T.ToPILImage(),
T.ToTensor()]))
test_set = MetropolisDataset(data_frame = df_test,
transform = T.Compose([
T.ToPILImage(),
T.ToTensor()]))
这里还没有出现问题,因为我能够从上面定义的类的这些实例中读出和显示图像。
然后,据我所知,有必要让它通过PyTorch的DataLoader,我这样做:
batch_size = 64
train_dl = DataLoader(train_set, batch_size, shuffle=True, num_workers=3, pin_memory=True)
validation_dl = DataLoader(validation_set, batch_size, shuffle=True, num_workers=3, pin_memory=True)
test_dl = DataLoader(test_set, batch_size, shuffle=True, num_workers=3, pin_memory=True)
但是,如果我想使用DataLoader的这些实例,就不会发生任何事情。我既没有得到错误,也没有得到任何计算。我试着运行CNN,但它似乎没有计算任何东西。我还尝试用本文提供的代码展示一些示例图像,但还是出现了同样的问题。示例代码为:
def show_images(images, nmax=10):
fig, ax = plt.subplots(figsize=(8, 8))
ax.set_xticks([]); ax.set_yticks([])
ax.imshow(make_grid((images.detach()[:nmax]), nrow=8).permute(1, 2, 0))
def show_batch(dl, nmax=64):
for images in dl:
show_images(images, nmax)
break
show_batch(test_dl)
似乎在我的MetropolisDataset
类或DataLoader本身的实现中存在一些错误。如何解决这个问题?
正如评论中提到的,由于我是在Jupyter笔记本中工作,因此通过将num_workers设置为零可以部分解决这个问题。然而,这留下了一个进一步的问题,当我想应用DataLoader来运行CNN时,我得到了错误。问题是,我的数据确实由int32数字而不是float32组成。我没有包括进一步的代码,因为这与我的数据直接相关——然而,这个问题(经常)只是一个错误的数据类型。