用函数返回的 pickle 保存一个对象

Save an object with pickle which is returned by function

如何保存从关于定义方法的函数返回的模型class?我想为许多 classes 制作相同的包装器,类似于(在我的情况下)Rocket class.

下面的代码会产生错误: 无法腌制本地对象 'sktime_wrapper..SKtimeWrapper'

import pickle
from sktime.transformations.panel.rocket import Rocket
from sktime.datatypes._panel._convert import from_2d_array_to_nested

def sktime_wrapper(method_class):
    class SKtimeWrapper(method_class):
        def transform(self, X):
            X = from_2d_array_to_nested(X)
            return super().transform(X)

        def fit(self, X, Y):
            X = from_2d_array_to_nested(X)
            return super().fit(X, Y)

    return SKtimeWrapper


model = sktime_wrapper(Rocket)

with open('model.pkl','wb') as f:
    pickle.dump(model, f)

如果 class 被定义为顶级对象,pickle 工作正常。下面的代码非常有效,可以毫无问题地保存模型:

import pickle
from sktime.transformations.panel.rocket import Rocket
from sktime.datatypes._panel._convert import from_2d_array_to_nested

class SKtimeWrapper(Rocket):
    def transform(self, X):
        X = from_2d_array_to_nested(X)
        return super().transform(X)

    def fit(self, X, Y):
        X = from_2d_array_to_nested(X)
        return super().fit(X, Y)

model = SKtimeWrapper


with open('model.pkl','wb') as f:
    pickle.dump(model, f)

按照答案部分,我设法让它工作了!我希望有人觉得这很有用。诀窍是使用 __reduce__() 函数。

Bellow 是一个工作示例。注意对象必须在保存前初始化。

import pickle
from sktime.transformations.panel.rocket import Rocket
from sktime.datatypes._panel._convert import from_2d_array_to_nested

def sktime_wrapper(method_class):
    class SKtimeWrapper(method_class):
        PARAM = method_class
        def transform(self, X):
            X = from_2d_array_to_nested(X)
            return super().transform(X)

        def fit(self, X, Y):
            X = from_2d_array_to_nested(X)
            return super().fit(X, Y)

        def __reduce__(self):
            return (_InitializeParameterized(), (self.PARAM,), self.__dict__)

    return SKtimeWrapper

class _InitializeParameterized(object):
    """
    When called with the param value as the only argument, returns an
    un-initialized instance of the parameterized class. Subsequent __setstate__
    will be called by pickle.
    """
    def __call__(self, method_class):
        # make a simple object which has no complex __init__ (this one will do)
        obj = _InitializeParameterized()
        obj.__class__ = sktime_wrapper(method_class)
        return obj


model = sktime_wrapper(Rocket)()

with open('model.pkl','wb') as f:
    pickle.dump(model, f)