Pickle 后更改了 Sklearn gridsearchCV 对象 dump/load
Sklearn gridsearchCV object changed after pickle dump/load
我有一个用
创建的 gridsearchCV 对象
grid_search = GridSearchCV(pred_home_pipeline, param_grid)
我想保存整个网格搜索对象,以便稍后探索模型调整结果。我不想只保存 the best_estimator_
。但是在转储和重新加载之后,重新加载的和原始的 grid_search 对象在某些方面有所不同,我无法追踪。
# save to disk
with open(filepath, 'wb') as handle:
pickle.dump(grid_search, handle, protocol=pickle.HIGHEST_PROTOCOL)
# reload
with open(filepath, 'rb') as handle:
grid_reloaded = pickle.load(handle)
# test object is unchanged after dump/reload
print(grid_search == grid_reloaded)
False
奇怪。查看 print(grid_search)
和 print(grid_reloaded)
的输出,它们看起来确实是一样的。
他们为我完全从网格搜索过程中提取的数据创建了完全相同的一组 525 个预测值:
grid_search_preds = grid_search.predict(X_test)
grid_reloaded_preds= grid_reloaded.predict(X_test)
(grid_search_preds == grid_reloaded_preds).all()
True
...尽管 best_estimator_
属性在技术上并不相同:
grid_search.best_estimator_ == grid_reloaded.best_estimator_
False
...尽管 best_estimate_ 属性在比较 print(grid_search.best_estimatmator_)
和 print(grid_reloaded.best_estimator_)
时看起来也确实相同
这是怎么回事?保存 gridsearchcv 对象以供以后检查是否安全?
那是因为比较返回的是对象是否是同一个对象。
要了解原因,请遵循对象层次结构,您会看到没有 __eq__
函数被覆盖(或 __cmp__
):
因此,“==”比较回退到对象内存位置比较,当然您重新加载的实例和当前实例不能相等。这是比较以查看它们是否是同一对象。
查看更多 here。
这里是来自 sklearn 的 github 的 sklearn contributor GaelVaroquaux's answer 关于为什么这里没有实现 __eq__
方法,以及测试两个 sklearn 对象是否相等的解决方案:
No, I would rather not add an eq. These things are very difficult to
get right, and one should not expect a library to implement eq on
complex objects.
One thing that you can do, is use joblib.hash to compute an MD5 hash
of the object, and use this for comparison.
我有一个用
创建的 gridsearchCV 对象grid_search = GridSearchCV(pred_home_pipeline, param_grid)
我想保存整个网格搜索对象,以便稍后探索模型调整结果。我不想只保存 the best_estimator_
。但是在转储和重新加载之后,重新加载的和原始的 grid_search 对象在某些方面有所不同,我无法追踪。
# save to disk
with open(filepath, 'wb') as handle:
pickle.dump(grid_search, handle, protocol=pickle.HIGHEST_PROTOCOL)
# reload
with open(filepath, 'rb') as handle:
grid_reloaded = pickle.load(handle)
# test object is unchanged after dump/reload
print(grid_search == grid_reloaded)
False
奇怪。查看 print(grid_search)
和 print(grid_reloaded)
的输出,它们看起来确实是一样的。
他们为我完全从网格搜索过程中提取的数据创建了完全相同的一组 525 个预测值:
grid_search_preds = grid_search.predict(X_test)
grid_reloaded_preds= grid_reloaded.predict(X_test)
(grid_search_preds == grid_reloaded_preds).all()
True
...尽管 best_estimator_
属性在技术上并不相同:
grid_search.best_estimator_ == grid_reloaded.best_estimator_
False
...尽管 best_estimate_ 属性在比较 print(grid_search.best_estimatmator_)
和 print(grid_reloaded.best_estimator_)
这是怎么回事?保存 gridsearchcv 对象以供以后检查是否安全?
那是因为比较返回的是对象是否是同一个对象。
要了解原因,请遵循对象层次结构,您会看到没有 __eq__
函数被覆盖(或 __cmp__
):
因此,“==”比较回退到对象内存位置比较,当然您重新加载的实例和当前实例不能相等。这是比较以查看它们是否是同一对象。
查看更多 here。
这里是来自 sklearn 的 github 的 sklearn contributor GaelVaroquaux's answer 关于为什么这里没有实现 __eq__
方法,以及测试两个 sklearn 对象是否相等的解决方案:
No, I would rather not add an eq. These things are very difficult to get right, and one should not expect a library to implement eq on complex objects.
One thing that you can do, is use joblib.hash to compute an MD5 hash of the object, and use this for comparison.