RuntimeError: GoogLeNet加载state_dict时出现错误:大小不匹配



目前我正在尝试加载我的预训练GoogleNet。然而,似乎有一个大小不匹配的问题,我试图改变num_classes来解决它,但无济于事,仍然存在问题。

import os
from tkinter import Variable
from matplotlib import image, transforms
import torch
import torchvision
from torch import nn, optim
checkpoint = torch.load("extraFile/Kaggle_googlenet.pth")
model = torchvision.models.googlenet(pretrained=True, num_classes = 3)
model.load_state_dict(checkpoint)
model.eval()
def predict_image(image_path):
transformation = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
image_tensor = transformation(image).float()
image_tensor = image_tensor.unsqueeze_(0)
if torch.cuda.is_available():
image_tensor.cuda()
input = Variable(image_tensor)
output = model(input)
index = output.data.numpy().argmax()
return index
if __name__ == "main":
imagefile = "a.png"
imagepath = os.path.join(os.getcwd(),imagefile)
prediction = predict_image(imagepath)
print("Predicted Class: ",prediction)

错误是

File "c:UserssoongDocumentsFYPFuzzy-Integral-Covid-Detection-mainextra.py", line 9, in <module>
model = torchvision.models.googlenet(pretrained=True, num_classes = 3)
File "C:ProgramDataAnaconda3libsite-packagestorchvisionmodelsgooglenet.py", line 52, in googlenet
model.load_state_dict(state_dict)
File "C:ProgramDataAnaconda3libsite-packagestorchnnmodulesmodule.py", line 1051, in load_state_dict
raise RuntimeError('Error(s) in loading state_dict for {}:nt{}'.format(
RuntimeError: Error(s) in loading state_dict for GoogLeNet:
size mismatch for aux1.fc2.weight: copying a param with shape torch.Size([1000, 1024]) from checkpoint, the shape in current model is torch.Size([3, 1024]).
size mismatch for aux1.fc2.bias: copying a param with shape torch.Size([1000]) from checkpoint, the shape in current model is torch.Size([3]).
size mismatch for aux2.fc2.weight: copying a param with shape torch.Size([1000, 1024]) from checkpoint, the shape in current model is torch.Size([3, 1024]).
size mismatch for aux2.fc2.bias: copying a param with shape torch.Size([1000]) from checkpoint, the shape in current model is torch.Size([3]).
size mismatch for fc.weight: copying a param with shape torch.Size([1000, 1024]) from checkpoint, the shape in current model is torch.Size([3, 1024]).
size mismatch for fc.bias: copying a param with shape torch.Size([1000]) from checkpoint, the shape in current model is torch.Size([3]).

如果你使用pretrained=True选项,它将加载一个在ImageNet上训练的模型,因此在1000个类上。

你要做的就是加载它,然后根据类的大小重新初始化所有的层。

这里是一个教程,其中完成了不同的架构,遗憾的是不是为GoogleNet。但是错误应该已经告诉你,你必须用更低的类数来初始化哪些层,所以我将从aux1.fc2,aux2.fc2,fc层开始。

最好的开始方式是先打印架构,然后看一下模型。寻找最后的层,其中包括1000以某种形式,这是一个强有力的指标,它使用类大小。

googlenet = torchvision.models.googlenet(pretrained=True)
print(googlenet)

从我所看到的,只有完全连接的(fc)层似乎实际存在,aux1aux2对我来说是None,所以这已经可以做

googlenet = torchvision.models.googlenet(pretrained=True)
googlenet.fc = torch.nn.Linear(in_features=1024, out_features=3, bias=True)

请记住,为了使用这个,您需要对某些数据集上定义的类进行微调。

如果你只是想在一些数据上运行分类,并希望在ImageNet中包含类,你不必做任何特别的事情,只需加载没有num_classes的模型并运行它并将ImageNet类映射到你需要的类。

最新更新