Pickle 后更改了 Sklearn gridsearchCV 对象 dump/load

Sklearn gridsearchCV object changed after pickle dump/load

我有一个用

创建的 gridsearchCV 对象
grid_search = GridSearchCV(pred_home_pipeline, param_grid)

我想保存整个网格搜索对象,以便稍后探索模型调整结果。我不想只保存 the best_estimator_。但是在转储和重新加载之后,重新加载的和原始的 grid_search 对象在某些方面有所不同,我无法追踪。

# save to disk
with open(filepath, 'wb') as handle:
    pickle.dump(grid_search, handle, protocol=pickle.HIGHEST_PROTOCOL)

# reload
with open(filepath, 'rb') as handle:
    grid_reloaded = pickle.load(handle)

# test object is unchanged after dump/reload
print(grid_search == grid_reloaded)    

False

奇怪。查看 print(grid_search)print(grid_reloaded) 的输出,它们看起来确实是一样的。

他们为我完全从网格搜索过程中提取的数据创建了完全相同的一组 525 个预测值:

grid_search_preds  = grid_search.predict(X_test)
grid_reloaded_preds= grid_reloaded.predict(X_test)

(grid_search_preds == grid_reloaded_preds).all()

True

...尽管 best_estimator_ 属性在技术上并不相同:

grid_search.best_estimator_ == grid_reloaded.best_estimator_

False

...尽管 best_estimate_ 属性在比较 print(grid_search.best_estimatmator_)print(grid_reloaded.best_estimator_)

时看起来也确实相同

这是怎么回事?保存 gridsearchcv 对象以供以后检查是否安全?

那是因为比较返回的是对象是否是同一个对象。

要了解原因,请遵循对象层次结构,您会看到没有 __eq__ 函数被覆盖(或 __cmp__):

因此,“==”比较回退到对象内存位置比较,当然您重新加载的实例和当前实例不能相等。这是比较以查看它们是否是同一对象。

查看更多 here

这里是来自 sklearn 的 github 的 sklearn contributor GaelVaroquaux's answer 关于为什么这里没有实现 __eq__ 方法,以及测试两个 sklearn 对象是否相等的解决方案:

No, I would rather not add an eq. These things are very difficult to get right, and one should not expect a library to implement eq on complex objects.

One thing that you can do, is use joblib.hash to compute an MD5 hash of the object, and use this for comparison.