如何在 Databricks 的实验中检索 model.pkl
How can I retrive the model.pkl in the experiment in Databricks
我想从我训练的模型中取回泡菜,我知道它在我在 Databricks 中的实验中的 运行 文件中。
看来mlflow.pyfunc.load_model
只能做predict
方法
有直接访问 pickle 的选项吗?
我还尝试使用 运行 中的路径使用 pickle.load(path)
(路径示例:dbfs:/databricks/mlflow-tracking/20526156406/92f3ec23bf614c9d934dd0195/artifacts/model/model.pkl)。
使用 frmwk 的原生 load_model() 方法(例如 sklearn.load_model())或 download_artifacts()
我最近找到了可以通过以下两种方法解决的方法:
- 在保存模型时使用自定义预测函数(查看 databricks 文档了解更多详情)。
Databricks 给出的示例
class AddN(mlflow.pyfunc.PythonModel):
def __init__(self, n):
self.n = n
def predict(self, context, model_input):
return model_input.apply(lambda column: column + self.n)
# Construct and save the model
model_path = "add_n_model"
add5_model = AddN(n=5)
mlflow.pyfunc.save_model(path=model_path, python_model=add5_model)
# Load the model in `python_function` format
loaded_model = mlflow.pyfunc.load_model(model_path)
- 在我们下载工件时加载模型工件:
from mlflow.tracking import MlflowClient
client = MlflowClient()
tmp_path = client.download_artifacts(run_id="0c7946c81fb64952bc8ccb3c7c66bca3", path='model/model.pkl')
f = open(tmp_path,'rb')
model = pickle.load(f)
f.close()
client.list_artifacts(run_id="0c7946c81fb64952bc8ccb3c7c66bca3", path="")
client.list_artifacts(run_id="0c7946c81fb64952bc8ccb3c7c66bca3", path="model")
我想从我训练的模型中取回泡菜,我知道它在我在 Databricks 中的实验中的 运行 文件中。
看来mlflow.pyfunc.load_model
只能做predict
方法
有直接访问 pickle 的选项吗?
我还尝试使用 运行 中的路径使用 pickle.load(path)
(路径示例:dbfs:/databricks/mlflow-tracking/20526156406/92f3ec23bf614c9d934dd0195/artifacts/model/model.pkl)。
使用 frmwk 的原生 load_model() 方法(例如 sklearn.load_model())或 download_artifacts()
我最近找到了可以通过以下两种方法解决的方法:
- 在保存模型时使用自定义预测函数(查看 databricks 文档了解更多详情)。
Databricks 给出的示例
class AddN(mlflow.pyfunc.PythonModel):
def __init__(self, n):
self.n = n
def predict(self, context, model_input):
return model_input.apply(lambda column: column + self.n)
# Construct and save the model
model_path = "add_n_model"
add5_model = AddN(n=5)
mlflow.pyfunc.save_model(path=model_path, python_model=add5_model)
# Load the model in `python_function` format
loaded_model = mlflow.pyfunc.load_model(model_path)
- 在我们下载工件时加载模型工件:
from mlflow.tracking import MlflowClient
client = MlflowClient()
tmp_path = client.download_artifacts(run_id="0c7946c81fb64952bc8ccb3c7c66bca3", path='model/model.pkl')
f = open(tmp_path,'rb')
model = pickle.load(f)
f.close()
client.list_artifacts(run_id="0c7946c81fb64952bc8ccb3c7c66bca3", path="")
client.list_artifacts(run_id="0c7946c81fb64952bc8ccb3c7c66bca3", path="model")