在谷歌合作中使用PyTorch加载模型的问题



我正试图在google_collaboratory中加载模型,以对其进行评估并生成所有统计结果。

我的尝试

import torch
import torch.nn.functional as F
import torch.optim as optim
import torch.backends.cudnn as cudnn
import torch.backends.cudnn as cudnn
import numpy as np
import torch.nn as nn
import os
def load_checkpoint(filepath):
checkpoint = torch.load(filepath)
model = fc_model.Network(checkpoint['input_size'],
checkpoint['output_size'],
checkpoint['hidden_layers'])
model.load_state_dict(checkpoint['state_dict'])

return model
PATH = "/content/gdrive/MyDrive/best.pt"
state_dict = load_checkpoint(PATH)

错误

---------------------------------------------------------------------------
ModuleNotFoundError                       Traceback (most recent call last)
<ipython-input-24-0515f2edfa1a> in <module>()
18 
19 PATH = "/content/gdrive/MyDrive/best.pt"
---> 20 state_dict = load_checkpoint(PATH)
2 frames
/usr/local/lib/python3.7/dist-packages/torch/serialization.py in _load(zip_file, map_location, pickle_module, pickle_file, **pickle_load_args)
849     unpickler = pickle_module.Unpickler(data_file, **pickle_load_args)
850     unpickler.persistent_load = persistent_load
--> 851     result = unpickler.load()
852 
853     torch._utils._validate_loaded_sparse_tensors()
ModuleNotFoundError: No module named 'models'
---------------------------------------------------------------------------
NOTE: If your import is failing due to a missing package, you can
manually install dependencies using either !pip or !apt.
To view examples of installing some common dependencies, click the
"Open Examples" button below.
---------------------------------------------------------------------------

我试着安装一些库,但它给了我同样的东西。无论如何,在谷歌合作库中加载模型都是一样的。

这个问题是,当您节省重量时,实际上使用的是torch.save(model而不是model.state_dict()

解决这一问题的一种方法是导入CCD_;就像你在训练时一样;。这一点很重要,因为当保存整个模型时,它会保存名称引用和权重。

如果models是一个文件,也许您需要上传它。如果它是一个物体,那么只要把它放在一个单元格里,它就会工作。

最新更新