为什么这段代码不用import sklearn就可以使用sklearn函数呢?

Why is this code able to use the sklearn function without import sklearn?

所以我看了一个教程,作者在anaconda环境(安装了sklearn)中使用pickled模型的predict功能时不需要import sklearn

我试图在 Google Colab 中重现它的最小版本。如果你有 pickled-sklearn-model,下面的代码可以在 Colab 中运行(安装了 sklearn):

import pickle
model = pickle.load(open("model.pkl", "rb"), encoding="bytes")
out = model.predict([[20, 0, 1, 1, 0]])
print(out)

我意识到我仍然需要安装 sklearn 包。如果我卸载 sklearn,predict 功能现在无法使用:

!pip uninstall scikit-learn
import pickle
model = pickle.load(open("model.pkl", "rb"), encoding="bytes")
out = model.predict([[20, 0, 1, 1, 0]])
print(out)

错误:

WARNING: Skipping scikit-learn as it is not installed.

---------------------------------------------------------------------------

ModuleNotFoundError                       Traceback (most recent call last)

<ipython-input-1-dec96951ae29> in <module>()
      1 get_ipython().system('pip uninstall scikit-learn')
      2 import pickle
----> 3 model = pickle.load(open("model.pkl", "rb"), encoding="bytes")
      4 out = model.predict([[20, 0, 1, 1, 0]])
      5 print(out)

ModuleNotFoundError: No module named 'sklearn'

那么,它是如何工作的?据我所知,pickle 不依赖于 scikit-learn。序列化模型做 import sklearn 吗? 为什么在第一个代码中不导入 scikit learn 就可以使用 predict 函数?

第一次 pickle 模型时,你安装了 sklearn。 pickle 文件的结构依赖于 sklearn,因为它所代表的对象的 class 是一个 sklearn class,而 pickle 需要知道 class 的详细信息' s 结构以便解开对象。

当您尝试在没有安装 sklearn 的情况下解压文件时,pickle 从文件中确定 class 对象是 sklearn.x.y.z 的实例,或者您有什么,并且然后 unpickling 失败,因为当 pickle 尝试解析该名称时找不到模块 sklearn。请注意,异常发生在 unpickling 行,而不是调用 predict 的行。

你不需要在你的代码中导入 sklearn 当它工作时,因为一旦对象被 unpickled,它知道它的 class 是什么以及它的所有方法名称是什么,所以你可以调用他们来自对象。

这里有几个问题,让我们一一分析:

So, how does it work? as far as I understand pickle doesn't depend on scikit-learn.

这里的 scikit-learn 没有什么特别之处。 Pickle 将对任何模块表现出这种行为。这是一个 Numpy 的例子:

will@will-desktop ~ $ python
Python 3.9.6 (default, Aug 24 2021, 18:12:51) 
[GCC 9.3.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import sys
>>> 'numpy' in sys.modules
False
>>> import numpy
>>> 'numpy' in sys.modules
True
>>> pickle.dumps(numpy.array([1, 2, 3]))
b'\x80\x04\x95\xa0\x00\x00\x00\x00\x00\x00\x00\x8c\x15numpy.core.multiarray\x94\x8c\x0c_reconstruct\x94\x93\x94\x8c\x05numpy\x94\x8c\x07ndarray\x94\x93\x94K\x00\x85\x94C\x01b\x94\x87\x94R\x94(K\x01K\x03\x85\x94h\x03\x8c\x05dtype\x94\x93\x94\x8c\x02i8\x94\x89\x88\x87\x94R\x94(K\x03\x8c\x01<\x94NNNJ\xff\xff\xff\xffJ\xff\xff\xff\xffK\x00t\x94b\x89C\x18\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x03\x00\x00\x00\x00\x00\x00\x00\x94t\x94b.'
>>> exit()

到目前为止,我所做的是表明在新的 Python 进程中 'numpy' 不在 sys.modules 中(导入模块的字典)。然后我们导入 Numpy,并 pickle 一个 Numpy 数组。

然后在下面显示的一个新的Python过程中,我们看到在我们unpickle之前还没有导入数组Numpy,但是在我们导入之后已经导入了Numpy。

will@will-desktop ~ $ python
Python 3.9.6 (default, Aug 24 2021, 18:12:51) 
[GCC 9.3.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import pickle
>>> import sys
>>> 'numpy' in sys.modules
False
>>> pickle.loads(b'\x80\x04\x95\xa0\x00\x00\x00\x00\x00\x00\x00\x8c\x15numpy.core.multiarray\x94\x8c\x0c_reconstruct\x94\x93\x94\x8c\x05numpy\x94\x8c\x07ndarray\x94\x93\x94K\x00\x85\x94C\x01b\x94\x87\x94R\x94(K\x01K\x03\x85\x94h\x03\x8c\x05dtype\x94\x93\x94\x8c\x02i8\x94\x89\x88\x87\x94R\x94(K\x03\x8c\x01<\x94NNNJ\xff\xff\xff\xffJ\xff\xff\xff\xffK\x00t\x94b\x89C\x18\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x03\x00\x00\x00\x00\x00\x00\x00\x94t\x94b.')
array([1, 2, 3])
>>> 'numpy' in sys.modules
True
>>> numpy
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
NameError: name 'numpy' is not defined

尽管被导入了,但是numpy仍然不是定义的变量名。 Python 中的导入是全局的,但导入只会更新实际执行导入的模块的命名空间。如果我们想访问numpy,我们仍然需要写import numpy,但是由于Numpy已经在其他地方导入了,所以这不会重新运行 Numpy的模块初始化代码。相反,它将在我们模块的全局字典中创建一个 numpy 变量,并使其成为对之前存在的 Numpy 模块对象的引用,并且可以通过 sys.modules['numpy'].

访问

那么 Pickle 在这里做什么?它嵌入了有关使用哪个模块来定义 pickle 中 pickle 的信息。然后当它 unp​​ickle 某些东西时,它使用该信息导入模块,这样它就可以使用 class 的 unpickle 方法。我们可以查看 Pickle 模块的源代码,我们可以看到正在发生的事情:

_Pickler we see save method uses the save_global method. This in turn uses the whichmodule函数中获取模块名称('scikit-learn',在你的例子中),然后保存在pickle中。

_UnPickler we see the find_class method uses __import__ to import the module using the stored module name. The find_class method is used in a few of the load_* methods, such as load_inst 中,这是用于加载 class 实例的内容,例如您的模型实例:

def load_inst(self):
    module = self.readline()[:-1].decode("ascii")
    name = self.readline()[:-1].decode("ascii")
    klass = self.find_class(module, name)
    self._instantiate(klass, self.pop_mark())

The documentation for Unpickler.find_class explains:

Import module if necessary and return the object called name from it, where the module and name arguments are str objects.

The docs also explain how you can restrict this behaviour:

[You] may want to control what gets unpickled by customizing Unpickler.find_class(). Unlike its name suggests, Unpickler.find_class() is called whenever a global (i.e., a class or a function) is requested. Thus it is possible to either completely forbid globals or restrict them to a safe subset.

尽管这通常仅在取消不可信数据时相关,但此处似乎并非如此。


Does the serialized model do import sklearn?

严格来说,序列化模型本身 没有任何作用。如上所述,它全部由 Pickle 模块处理。


Why can I use predict function without import scikit learn in the first code?

因为 sklearn 在 unpickles 数据时被 Pickle 模块导入,从而为您提供了一个完全实现的模型对象。就像其他模块导入 sklearn,创建模型对象,然后将其作为函数参数传递到您的代码中一样。


由于这一切,为了解开你的模型,你需要安装 sklearn - 最好是用于创建 pickle 的相同版本。通常,Pickle 模块存储任何所需模块的完全限定路径,因此 pickle 对象的 Python 进程和 unpickle 对象的进程必须让所有 [1] 个所需模块都具有相同的完全限定名称。


[1] 需要注意的是,Pickle 模块可以自动 adjust/fix 特定 modules/classes 的某些导入,它们在 Python 2 和 3 之间具有不同的完全限定名称。来自the docs:

If fix_imports is true, pickle will try to map the old Python 2 names to the new names used in Python 3.