我正在对图像进行预测,在那里我写下了所有类的名称,在测试文件夹中,我有20个图像。请给我一些提示,为什么我会出错?我们如何检查模型的索引?
代码
import numpy as np
import sys, random
import torch
from torchvision import models, transforms
from PIL import Image
from pathlib import Path
import matplotlib.pyplot as plt
import glob
# Paths for image directory and model
IMDIR = './test'
MODEL = 'checkpoint/resnet18/Monday_31_May_2021_21h_25m_05s/resnet18-1000-regular.pth'
# Load the model for testing
model = models.resnet18()
model.named_children()
torch.save(model.state_dict, MODEL)
model.eval()
# Class labels for prediction
class_names = ['BC', 'BK', 'CC', 'CL', 'CM', 'DF', 'DG', 'DS', 'HL', 'IF', 'JD', 'JS', 'LD', 'LP', 'LS', 'PO', 'RI',
'SD', 'SG', 'TO']
# Retreive 9 random images from directory
files = Path(IMDIR).resolve().glob('*.*')
print(files)
images = random.sample(list(files), 1)
print(images)
# Configure plots
fig = plt.figure(figsize=(9, 9))
rows, cols = 3, 3
# Preprocessing transformations
preprocess = transforms.Compose([
transforms.Resize((256, 256)),
# transforms.CenterCrop(size=224),
transforms.ToTensor(),
transforms.Normalize(0.5306, 0.1348)
])
# Enable gpu mode, if cuda available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Perform prediction and plot results
with torch.no_grad():
for num, img in enumerate(images):
img = Image.open(img).convert('RGB')
inputs = preprocess(img).unsqueeze(0).cpu()
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
print(preds)
label = class_names[preds]
plt.subplot(rows, cols, num + 1)
plt.title("Pred: " + label)
plt.axis('off')
plt.imshow(img)
'''
Sample run: python test.py test
'''
追溯
Traceback (most recent call last):
File "/media/khawar/HDD_Khawar/CVPR/pytorch-cifar100/test_box.py", line 57, in <module>
label = class_names[preds]
IndexError: list index out of range
您的错误源于您没有对resnet模型的线性层进行任何修改。
我建议添加这个代码:
# What you have
model = models.resnet18()
# What you need
model.fc = nn.Sequential(
nn.Linear(model.fc.in_features, len(class_names)))
这将改变最后的线性层以输出正确数量的节点
Sarthak