运行时错误:4 维权重 [64, 3, 7, 7] 的预期 4 维输入,但得到了大小为 [3, 32, 32] 的三维输



我是PyTorch和神经网络的新手。我试图在CIFAR-10数据集上实现torchvision的resnet-50模型。

import torchvision
import torch
import torch.nn as nn
from torch import optim
import os
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import numpy as np
from collections import OrderedDict
import matplotlib.pyplot as plt
transformations=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])
trainset=torchvision.datasets.CIFAR10(root='./CIFAR10',download=True,transform=transformations,train=True)
testset=torchvision.datasets.CIFAR10(root='./CIFAR10',download=True,transform=transformations,train=False)
trainloader=DataLoader(dataset=trainset,batch_size=4)
testloader=DataLoader(dataset=testset,batch_size=4)
inputs,labels=next(iter(trainset))
inputs.size()
resnet=torchvision.models.resnet50(pretrained=True)
if torch.cuda.is_available():
resnet=resnet.cuda()
inputs,labels=inputs.cuda(),torch.Tensor(labels).cuda()
outputs=resnet(inputs)

输出

--------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-6-904acb410fe4> in <module>()
6   inputs,labels=inputs.cuda(),torch.Tensor(labels).cuda()
7 
----> 8 outputs=resnet(inputs)
5 frames
/usr/local/lib/python3.6/dist-packages/torch/nn/modules/conv.py in _conv_forward(self, input, weight)
344                             _pair(0), self.dilation, self.groups)
345         return F.conv2d(input, weight, self.bias, self.stride,
--> 346                         self.padding, self.dilation, self.groups)
347 
348     def forward(self, input):
RuntimeError: Expected 4-dimensional input for 4-dimensional weight [64, 3, 7, 7], but got 3-dimensional input of size [3, 32, 32] instead

由于某种原因,数据集是否存在问题,如果没有,我如何给出 4 维输入?ResNet-50的火炬视实施是否不适用于CIFAR-10?

目前,您正在迭代数据集,这就是获得(三维(单个图像的原因。您实际上需要遍历数据加载器才能获得 4 维图像批处理。因此,您只需要更改以下行:

inputs,labels=next(iter(trainset))

inputs,labels=next(iter(trainloader))

最新更新