用函数返回的pickle保存对象



如何保存从函数返回的关于定义的方法类的模型?我想为许多类似于(在我的情况下)Rocket类的类制作相同的包装器。

下面的代码会产生一个错误:不能pickle本地对象'sktime_wrapper..SKtimeWrapper'

import pickle
from sktime.transformations.panel.rocket import Rocket
from sktime.datatypes._panel._convert import from_2d_array_to_nested
def sktime_wrapper(method_class):
class SKtimeWrapper(method_class):
def transform(self, X):
X = from_2d_array_to_nested(X)
return super().transform(X)
def fit(self, X, Y):
X = from_2d_array_to_nested(X)
return super().fit(X, Y)
return SKtimeWrapper

model = sktime_wrapper(Rocket)
with open('model.pkl','wb') as f:
pickle.dump(model, f)

如果class被定义为顶级对象,pickle就可以正常工作。下面的代码非常有效,可以毫无问题地保存模型:

import pickle
from sktime.transformations.panel.rocket import Rocket
from sktime.datatypes._panel._convert import from_2d_array_to_nested
class SKtimeWrapper(Rocket):
def transform(self, X):
X = from_2d_array_to_nested(X)
return super().transform(X)
def fit(self, X, Y):
X = from_2d_array_to_nested(X)
return super().fit(X, Y)
model = SKtimeWrapper

with open('model.pkl','wb') as f:
pickle.dump(model, f)

在回答部分之后,我设法使它工作!我希望有人觉得这有用。技巧是使用__reduce__()函数。

Bellow是一个工作示例。注意,对象必须在保存之前初始化。

import pickle
from sktime.transformations.panel.rocket import Rocket
from sktime.datatypes._panel._convert import from_2d_array_to_nested
def sktime_wrapper(method_class):
class SKtimeWrapper(method_class):
PARAM = method_class
def transform(self, X):
X = from_2d_array_to_nested(X)
return super().transform(X)
def fit(self, X, Y):
X = from_2d_array_to_nested(X)
return super().fit(X, Y)
def __reduce__(self):
return (_InitializeParameterized(), (self.PARAM,), self.__dict__)
return SKtimeWrapper
class _InitializeParameterized(object):
"""
When called with the param value as the only argument, returns an
un-initialized instance of the parameterized class. Subsequent __setstate__
will be called by pickle.
"""
def __call__(self, method_class):
# make a simple object which has no complex __init__ (this one will do)
obj = _InitializeParameterized()
obj.__class__ = sktime_wrapper(method_class)
return obj

model = sktime_wrapper(Rocket)()
with open('model.pkl','wb') as f:
pickle.dump(model, f)

最新更新