用函数返回的 pickle 保存一个对象
Save an object with pickle which is returned by function
如何保存从关于定义方法的函数返回的模型class?我想为许多 classes 制作相同的包装器,类似于(在我的情况下)Rocket class.
下面的代码会产生错误:
无法腌制本地对象 'sktime_wrapper..SKtimeWrapper'
import pickle
from sktime.transformations.panel.rocket import Rocket
from sktime.datatypes._panel._convert import from_2d_array_to_nested
def sktime_wrapper(method_class):
class SKtimeWrapper(method_class):
def transform(self, X):
X = from_2d_array_to_nested(X)
return super().transform(X)
def fit(self, X, Y):
X = from_2d_array_to_nested(X)
return super().fit(X, Y)
return SKtimeWrapper
model = sktime_wrapper(Rocket)
with open('model.pkl','wb') as f:
pickle.dump(model, f)
如果 class 被定义为顶级对象,pickle 工作正常。下面的代码非常有效,可以毫无问题地保存模型:
import pickle
from sktime.transformations.panel.rocket import Rocket
from sktime.datatypes._panel._convert import from_2d_array_to_nested
class SKtimeWrapper(Rocket):
def transform(self, X):
X = from_2d_array_to_nested(X)
return super().transform(X)
def fit(self, X, Y):
X = from_2d_array_to_nested(X)
return super().fit(X, Y)
model = SKtimeWrapper
with open('model.pkl','wb') as f:
pickle.dump(model, f)
按照答案部分,我设法让它工作了!我希望有人觉得这很有用。诀窍是使用 __reduce__()
函数。
Bellow 是一个工作示例。注意对象必须在保存前初始化。
import pickle
from sktime.transformations.panel.rocket import Rocket
from sktime.datatypes._panel._convert import from_2d_array_to_nested
def sktime_wrapper(method_class):
class SKtimeWrapper(method_class):
PARAM = method_class
def transform(self, X):
X = from_2d_array_to_nested(X)
return super().transform(X)
def fit(self, X, Y):
X = from_2d_array_to_nested(X)
return super().fit(X, Y)
def __reduce__(self):
return (_InitializeParameterized(), (self.PARAM,), self.__dict__)
return SKtimeWrapper
class _InitializeParameterized(object):
"""
When called with the param value as the only argument, returns an
un-initialized instance of the parameterized class. Subsequent __setstate__
will be called by pickle.
"""
def __call__(self, method_class):
# make a simple object which has no complex __init__ (this one will do)
obj = _InitializeParameterized()
obj.__class__ = sktime_wrapper(method_class)
return obj
model = sktime_wrapper(Rocket)()
with open('model.pkl','wb') as f:
pickle.dump(model, f)
如何保存从关于定义方法的函数返回的模型class?我想为许多 classes 制作相同的包装器,类似于(在我的情况下)Rocket class.
下面的代码会产生错误: 无法腌制本地对象 'sktime_wrapper..SKtimeWrapper'
import pickle
from sktime.transformations.panel.rocket import Rocket
from sktime.datatypes._panel._convert import from_2d_array_to_nested
def sktime_wrapper(method_class):
class SKtimeWrapper(method_class):
def transform(self, X):
X = from_2d_array_to_nested(X)
return super().transform(X)
def fit(self, X, Y):
X = from_2d_array_to_nested(X)
return super().fit(X, Y)
return SKtimeWrapper
model = sktime_wrapper(Rocket)
with open('model.pkl','wb') as f:
pickle.dump(model, f)
如果 class 被定义为顶级对象,pickle 工作正常。下面的代码非常有效,可以毫无问题地保存模型:
import pickle
from sktime.transformations.panel.rocket import Rocket
from sktime.datatypes._panel._convert import from_2d_array_to_nested
class SKtimeWrapper(Rocket):
def transform(self, X):
X = from_2d_array_to_nested(X)
return super().transform(X)
def fit(self, X, Y):
X = from_2d_array_to_nested(X)
return super().fit(X, Y)
model = SKtimeWrapper
with open('model.pkl','wb') as f:
pickle.dump(model, f)
按照答案部分,我设法让它工作了!我希望有人觉得这很有用。诀窍是使用 __reduce__()
函数。
Bellow 是一个工作示例。注意对象必须在保存前初始化。
import pickle
from sktime.transformations.panel.rocket import Rocket
from sktime.datatypes._panel._convert import from_2d_array_to_nested
def sktime_wrapper(method_class):
class SKtimeWrapper(method_class):
PARAM = method_class
def transform(self, X):
X = from_2d_array_to_nested(X)
return super().transform(X)
def fit(self, X, Y):
X = from_2d_array_to_nested(X)
return super().fit(X, Y)
def __reduce__(self):
return (_InitializeParameterized(), (self.PARAM,), self.__dict__)
return SKtimeWrapper
class _InitializeParameterized(object):
"""
When called with the param value as the only argument, returns an
un-initialized instance of the parameterized class. Subsequent __setstate__
will be called by pickle.
"""
def __call__(self, method_class):
# make a simple object which has no complex __init__ (this one will do)
obj = _InitializeParameterized()
obj.__class__ = sktime_wrapper(method_class)
return obj
model = sktime_wrapper(Rocket)()
with open('model.pkl','wb') as f:
pickle.dump(model, f)