如何正确加载pytorch pth模型并与另一个模型结合?



我正在寻找一个与pytorch一起工作的OCR预训练模型。我尝试https://github.com/clovaai/CRAFT-pytorch,但它不支持pytorch hub。我不能加载pth模型,因为它只有权值。如何加载模型?

我的第一个模型是在自定义数据上训练的yolov5模型,所以它应该裁剪图像并将其发送到下一个模型。下一个模式应该是OCR,主要是数字识别。但是我不能运行craft-pytorch

model = torch.hub.load('.', 'custom', path='runs/train/exp2/weights/best.pt', source='local', force_reload=True)
# It throws an error with pytorch hub
model_ocr = torch.hub.load('clovaai/CRAFT-pytorch', 'craft_mlt_25k.pth')
# Tried with torch load, but pth files have only weights not model
ocr_model = torch.load('runs/ocr/craft_mlt_25k.pth')
cap = cv2.VideoCapture('../Dataset/test/09-10.mp4')
while(cap.isOpened()):
ret, frame = cap.read()
results = model(frame)
crops = results.crop(save=False)    
for crop in crops:
if 'number' in crop['label']:
ocr_result = model_ocr(crop['im'])
ocr_crop = ocr_result.crop(save=False)

如果您有模型的源代码,您可以创建它并加载它的权重,但是根据我对这个问题的理解,您希望在没有源代码的情况下完成它。

可以直接调用torch.load:

state_dict = torch.load('model_file.pth')

state_dict应该只是一个(类型的)字典。您可以按照自己的意愿检查和拆分它:

print(state_dict.keys())
with torch.no_grad():
my_model.some_weight[:] = state_dict['some_other_name']

(注意在不可信的数据上调用torch.load是不安全的)

相关内容