我想基于笔记本创建Python脚本,以使用相同的.pkl
文件获得运行时。
这一行:
learn = load_learner('model.pkl', cpu=True)
我得到这个错误:
(project) me@ubuntu-pcs:~/PycharmProjects/project$ python main.py
Traceback (most recent call last):
File "main.py", line 6, in <module>
from src.train.train_model import train
File "/home/me/PycharmProjects/project/src/train/train_model.py", line 17, in <module>
learn = load_learner('yasmine-sftp/export_2.pkl', cpu=True) # to run on GPU
File "/home/me/miniconda3/envs/project/lib/python3.6/site-packages/fastai/learner.py", line 384, in load_learner
res = torch.load(fname, map_location='cpu' if cpu else None, pickle_module=pickle_module)
File "/home/me/miniconda3/envs/project/lib/python3.6/site-packages/torch/serialization.py", line 607, in load
return _load(opened_zipfile, map_location, pickle_module, **pickle_load_args)
File "/home/me/miniconda3/envs/project/lib/python3.6/site-packages/torch/serialization.py", line 882, in _load
result = unpickler.load()
File "/home/me/miniconda3/envs/project/lib/python3.6/site-packages/torch/serialization.py", line 875, in find_class
return super().find_class(mod_name, name)
AttributeError: Can't get attribute 'Tf' on <module '__main__' from 'main.py'>
这是因为为了打开.pkl
文件,我需要用于训练它的原始函数。
谢天谢地,回头看笔记本,Tf(o)
在那里:
def Tf(o):
return '/mnt/scratch2/DLinTHDP/PathLAKE/Version_4_fastai/Dataset/CD8/Train/masks/'+f'{o.stem}_P{o.suffix}'
但是,在Python脚本中放置Tf(o)
的任何地方,我仍然得到相同的错误。
我应该把Tf(o)
放在哪里?
在错误信息中:<module '__main__' from 'main.py'>
似乎建议将其放在main()
或if __name__ ...
下。
我到处都试过了。导入Tf(o)
也不工作。
main.py
:
import glob
from pathlib import Path
from train_model import train
ROOT = Path("folder/path") # Detection Folder
def main(root: Path):
train(root)
if __name__ == '__main__':
main(ROOT)
train_model.py
:
from pathlib import Path
from fastai.vision.all import *
folder_path = Path('.')
learn = load_learner('model.pkl', cpu=True) # AttributeError
learn.load('model_3C_34_CELW_V_1.1') # weights
def train(root: Path):
# ...
我无法检查文件:
(project) me@ubuntu-pcs:~/PycharmProjects/project$ python -m pickletools -a model.pkl
Traceback (most recent call last):
File "/home/me/miniconda3/envs/project/lib/python3.6/runpy.py", line 193, in _run_module_as_main
"__main__", mod_spec)
File "/home/me/miniconda3/envs/project/lib/python3.6/runpy.py", line 85, in _run_code
exec(code, run_globals)
File "/home/me/miniconda3/envs/project/lib/python3.6/pickletools.py", line 2830, in <module>
args.indentlevel, annotate)
File "/home/me/miniconda3/envs/project/lib/python3.6/pickletools.py", line 2394, in dis
for opcode, arg, pos in genops(pickle):
File "/home/me/miniconda3/envs/project/lib/python3.6/pickletools.py", line 2242, in _genops
arg = opcode.arg.reader(data)
File "/home/me/miniconda3/envs/project/lib/python3.6/pickletools.py", line 373, in read_stringnl_noescape
return read_stringnl(f, stripquotes=False)
File "/home/me/miniconda3/envs/project/lib/python3.6/pickletools.py", line 359, in read_stringnl
data = codecs.escape_decode(data)[0].decode("ascii")
UnicodeDecodeError: 'ascii' codec can't decode byte 0x80 in position 63: ordinal not in range(128)
问题
为什么我得到这个错误是bcTf()
函数被用来训练model.pkl
文件,在相同的命名空间(因为它是在一个笔记本文件中完成的)。
这篇文章说:
pickle
是惰性的,不序列化类定义或函数定义。相反,它保存了如何查找类的参考(它所在的模块及其名称)
解决方案pickle有一个名为dill的扩展,它可以序列化Python对象和函数等(不包括引用)PyPI
根据ecatkins Edward Atkins的说法,我通过修改fastai.basic_train.py (fastai==1.0.61)中的load_learner来解决类似的问题https://forums.fast.ai/t/error-loading-saved-model-with-custom-loss-function/37627/7
当用于定义模型的类在原始模块中找不到时,load_learner失败。所以,覆盖泡菜类装入器(find_class)搜索指定的模块。
import imp, sys
pickle2 = imp.load_module('pickle2', *imp.find_module('pickle'))
# The module where the class is now found.
MODULE = "MY.MODULE"
class CustomUnpickler(pickle2.Unpickler):
def find_class(self, module, name):
try:
return super().find_class(module, name)
except AttributeError:
if module == "__main__":
print(f"load_learner can't find {name} in original module {module}; getting it from {MODULE}")
module = MODULE
return super().find_class(module, name)
# Modified load_learner from fastai.basic_train.py (fastai==1.0.61), according to ecatkins Edward Atkins
# https://forums.fast.ai/t/error-loading-saved-model-with-custom-loss-function/37627/7
def load_learner2(path:PathOrStr, file:PathLikeOrBinaryStream='export.pkl', test:ItemList=None, tfm_y=None, **db_kwargs):
"Load a `Learner` object saved with `export_state` in `path/file` with empty data, optionally add `test` and load on `cpu`. `file` can be file-like (file or buffer)"
source = Path(path)/file if is_pathlike(file) else file
# state = torch.load(source, map_location='cpu') if defaults.device == torch.device('cpu') else torch.load(source)
# Use custom class loader here
pickle2.Unpickler = CustomUnpickler
state = torch.load(source, map_location='cpu', pickle_module=pickle2) if defaults.device == torch.device('cpu') else torch.load(source, pickle_module=pickle2)
model = state.pop('model')
src = LabelLists.load_state(path, state.pop('data'))
if test is not None: src.add_test(test, tfm_y=tfm_y)
data = src.databunch(**db_kwargs)
cb_state = state.pop('cb_state')
clas_func = state.pop('cls')
res = clas_func(data, model, **state)
res.callback_fns = state['callback_fns'] #to avoid duplicates
res.callbacks = [load_callback(c,s, res) for c,s in cb_state.items()]
return res
另一种方法是手动复制训练环境的名称空间。
In my case
global Tf
Tf = None