我有一个使用预训练模型进行预测的问题,该模型包含用于手写文本识别的编码器和解码器。我所做的是:
checkpoint = torch.load("Model/SPAN/SPAN-PT-RA_rimes.pt",map_location=torch.device('cpu'))
encoder_state_dict = checkpoint['encoder_state_dict']
decoder_state_dict = checkpoint['decoder_state_dict']
img = torch.LongTensor(img).unsqueeze(1).to(torch.device('cpu'))
global_pred = decoder_state_dict(encoder_state_dict(img))
生成如下错误:
TypeError: 'collections.OrderedDict' object is not callable
我将非常感谢你的帮助!^ _ ^encoder_state_dict
和decoder_state_dict
不是火炬模型,而是张量的集合(字典),其中包括您加载的检查点的预训练参数。
将输入(例如转换后的输入图像)提供给这样的张量集合是没有意义的。实际上,您应该使用这些stat_dicts(即,预训练张量的集合)将它们加载到映射到网络的模型对象的参数中。参见torch.nn.Module
类