如何通过 Flask 或 FastAPI 导入 MLflow 跟踪服务器 WSGI 应用程序?

How to import MLflow tracking server WSGI application via Flask or FastAPI?

MLflow 提供了一个非常酷的跟踪服务器,但是,该服务器不提供我需要的身份验证或 RBAC。

我想添加自己的身份验证和 RBAC 功能。我认为实现此目的的一种方法是导入 MLflow WSGI 应用程序对象并添加一些中间件层以在将请求传递到跟踪服务器之前执行身份验证/授权,本质上是通过我的自定义中间件堆栈代理请求。

我该怎么做?我可以从 these docs 看到我可以使用 FastAPI 导入另一个 WSGI 应用程序并添加自定义中间件,但我不确定一些事情

  1. 在哪里可以找到 MLflow 跟踪服务器 WSGI 应用程序(可以从哪里导入)?
  2. 如何将相关参数传递给 MLflow 跟踪服务器? IE。跟踪服务器需要参数来配置后端存储层、主机和端口。如果我只是导入应用程序对象,我该如何将这些参数传递给它?

编辑 - 看起来 Flask 应用程序可以在这里找到 https://github.com/mlflow/mlflow/blob/master/mlflow/server/__init__.py#L28

这个其实很简单,下面是一个使用FastAPI导入和挂载MLflow WSGI应用的例子。

import os
import subprocess
from fastapi import FastAPI
from fastapi.middleware.wsgi import WSGIMiddleware

from mlflow.server import app as mlflow_app

app = FastAPI()
app.mount("/", WSGIMiddleware(mlflow_app))

BACKEND_STORE_URI_ENV_VAR = "_MLFLOW_SERVER_FILE_STORE"
ARTIFACT_ROOT_ENV_VAR = "_MLFLOW_SERVER_ARTIFACT_ROOT"
ARTIFACTS_DESTINATION_ENV_VAR = "_MLFLOW_SERVER_ARTIFACT_DESTINATION"
PROMETHEUS_EXPORTER_ENV_VAR = "prometheus_multiproc_dir"
SERVE_ARTIFACTS_ENV_VAR = "_MLFLOW_SERVER_SERVE_ARTIFACTS"
ARTIFACTS_ONLY_ENV_VAR = "_MLFLOW_SERVER_ARTIFACTS_ONLY"

def parse_args():
    a = argparse.ArgumentParser()
    a.add_argument("--host", type=str, default="0.0.0.0")
    a.add_argument("--port", type=str, default="5000")
    a.add_argument("--backend-store-uri", type=str, default="sqlite:///mlflow.db")
    a.add_argument("--serve-artifacts", action="store_true", default=False)
    a.add_argument("--artifacts-destination", type=str)
    a.add_argument("--default-artifact-root", type=str)
    a.add_argument("--gunicorn-opts", type=str, default="")
    a.add_argument("--n-workers", type=str, default=1)
    return a.parse_args()

def run_command(cmd, env, cwd=None):
    cmd_env = os.environ.copy()
    if cmd_env:
        cmd_env.update(env)
    child = subprocess.Popen(
        cmd, env=cmd_env, cwd=cwd, text=True, stdin=subprocess.PIPE
    )
    child.communicate()
    exit_code = child.wait()
    if exit_code != 0:
        raise Exception("Non-zero exitcode: %s" % (exit_code))
    return exit_code

def run_server(args):
    env_map = dict()
    if args.backend_store_uri:
        env_map[BACKEND_STORE_URI_ENV_VAR] = args.backend_store_uri
    if args.serve_artifacts:
        env_map[SERVE_ARTIFACTS_ENV_VAR] = "true"
    if args.artifacts_destination:
        env_map[ARTIFACTS_DESTINATION_ENV_VAR] = args.artifacts_destination
    if args.default_artifact_root:
        env_map[ARTIFACT_ROOT_ENV_VAR] = args.default_artifact_root

    print(f"Envmap: {env_map}")

    #opts = args.gunicorn_opts.split(" ") if args.gunicorn_opts else []
    opts = args.gunicorn_opts if args.gunicorn_opts else ""

    cmd = [
        "gunicorn", "-b", f"{args.host}:{args.port}", "-w", f"{args.n_workers}", "-k", "uvicorn.workers.UvicornWorker", "server:app"
    ]
    run_command(cmd, env_map)

def main():
    args = parse_args()
    run_server(args)

if __name__ == "__main__":
    main()

运行喜欢

python server.py --artifacts-destination s3://mlflow-mr --default-artifact-root s3://mlflow-mr --serve-artifacts

然后导航到您的浏览器并查看跟踪服务器 运行!这允许您在跟踪服务器前面插入自定义 FastAPI 中间件