目前我正在尝试加载我的预训练GoogleNet。然而,似乎有一个大小不匹配的问题,我试图改变num_classes来解决它,但无济于事,仍然存在问题。
import os
from tkinter import Variable
from matplotlib import image, transforms
import torch
import torchvision
from torch import nn, optim
checkpoint = torch.load("extraFile/Kaggle_googlenet.pth")
model = torchvision.models.googlenet(pretrained=True, num_classes = 3)
model.load_state_dict(checkpoint)
model.eval()
def predict_image(image_path):
transformation = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
image_tensor = transformation(image).float()
image_tensor = image_tensor.unsqueeze_(0)
if torch.cuda.is_available():
image_tensor.cuda()
input = Variable(image_tensor)
output = model(input)
index = output.data.numpy().argmax()
return index
if __name__ == "main":
imagefile = "a.png"
imagepath = os.path.join(os.getcwd(),imagefile)
prediction = predict_image(imagepath)
print("Predicted Class: ",prediction)
错误是
File "c:UserssoongDocumentsFYPFuzzy-Integral-Covid-Detection-mainextra.py", line 9, in <module>
model = torchvision.models.googlenet(pretrained=True, num_classes = 3)
File "C:ProgramDataAnaconda3libsite-packagestorchvisionmodelsgooglenet.py", line 52, in googlenet
model.load_state_dict(state_dict)
File "C:ProgramDataAnaconda3libsite-packagestorchnnmodulesmodule.py", line 1051, in load_state_dict
raise RuntimeError('Error(s) in loading state_dict for {}:nt{}'.format(
RuntimeError: Error(s) in loading state_dict for GoogLeNet:
size mismatch for aux1.fc2.weight: copying a param with shape torch.Size([1000, 1024]) from checkpoint, the shape in current model is torch.Size([3, 1024]).
size mismatch for aux1.fc2.bias: copying a param with shape torch.Size([1000]) from checkpoint, the shape in current model is torch.Size([3]).
size mismatch for aux2.fc2.weight: copying a param with shape torch.Size([1000, 1024]) from checkpoint, the shape in current model is torch.Size([3, 1024]).
size mismatch for aux2.fc2.bias: copying a param with shape torch.Size([1000]) from checkpoint, the shape in current model is torch.Size([3]).
size mismatch for fc.weight: copying a param with shape torch.Size([1000, 1024]) from checkpoint, the shape in current model is torch.Size([3, 1024]).
size mismatch for fc.bias: copying a param with shape torch.Size([1000]) from checkpoint, the shape in current model is torch.Size([3]).
如果你使用pretrained=True
选项,它将加载一个在ImageNet上训练的模型,因此在1000个类上。
你要做的就是加载它,然后根据类的大小重新初始化所有的层。
这里是一个教程,其中完成了不同的架构,遗憾的是不是为GoogleNet。但是错误应该已经告诉你,你必须用更低的类数来初始化哪些层,所以我将从aux1.fc2
,aux2.fc2
,fc
层开始。
最好的开始方式是先打印架构,然后看一下模型。寻找最后的层,其中包括1000以某种形式,这是一个强有力的指标,它使用类大小。
googlenet = torchvision.models.googlenet(pretrained=True)
print(googlenet)
从我所看到的,只有完全连接的(fc
)层似乎实际存在,aux1
和aux2
对我来说是None
,所以这已经可以做
googlenet = torchvision.models.googlenet(pretrained=True)
googlenet.fc = torch.nn.Linear(in_features=1024, out_features=3, bias=True)
请记住,为了使用这个,您需要对某些数据集上定义的类进行微调。
如果你只是想在一些数据上运行分类,并希望在ImageNet中包含类,你不必做任何特别的事情,只需加载没有num_classes
的模型并运行它并将ImageNet类映射到你需要的类。