Microsoft 身份验证 - Python Flask msal 示例应用移植到 FastAPI

Microsoft Authentication - Python Flask msal Example App Ported to FastAPI

我没有做太多网络工作,但我最近开始使用 FastAPI 并正在构建一个带有 jinja2 模板的 MVC 应用程序,该应用程序使用 PowerBI 嵌入式容量为应用程序拥有的数据排列提供多个嵌入式分析。所有这些都很好用。但是,我想添加更多模块,并且我想使用 msal 包通过将用户路由到 Microsoft 登录页面来进行用户身份验证,让他们登录我设置的 multi-tenant 应用服务在 Azure 中,然后通过重定向 URI 重定向回我的页面,获取令牌并进行授权。 Microsoft 在 Flask 中保存了一个很好的例子 here。但是,我无法将示例移植到 FastAPI。

我可以让用户进入登录屏幕并登录,但我无法在我的回调 URI 中捕获令牌 - 它是适当的路由,但我无法从响应中捕获令牌。

有没有人(或任何人都可以)采用那个超级简单的 Flask 示例并将其移植到 FastAPI?我在网上找到的 FAPI 的所有内容都是 back-end token-bearer headers 的 API - 不适用于 MVC 应用程序。

这是我当前的代码。乱七八糟,因为我内置了“测试”。

import msal
import requests
from fastapi import APIRouter, Request, Response
from fastapi.responses import RedirectResponse
from starlette.templating import Jinja2Templates

from config import get_settings

settings = get_settings()
router = APIRouter()
templates = Jinja2Templates('templates')


# Works
@router.get('/login', include_in_schema=False)
async def login(request: Request):
    request.session['flow'] = _build_auth_code_flow(scopes=settings.AUTH_SCOPE)
    login_url = request.session['flow']['auth_uri']
    return templates.TemplateResponse('error.html', {'request': request, 'message': login_url})


# DOES NOT WORK - Pretty sure error is in here --------------------
@router.get('/getAToken', response_class=Response, include_in_schema=False)
async def authorize(request: Request):
    try:
        cache = _load_cache(request)
        result = _build_msal_app(cache=cache).acquire_token_by_auth_code_flow(
            request.session.get('flow'), request.session
        )
        if 'error' in result:
            return templates.TemplateResponse('error.html', {'request': request, 'message': result})
        request.session['user'] = result.get('id_token_claims')
        _save_cache(cache)
    except Exception as error:
        return templates.TemplateResponse('error.html', {'request': request, 'message': f'{error}: {str(request.query_params)}'})
    return templates.TemplateResponse('error.html', {'request': request, 'message': result})
# -----------------------------------------------------

    
def _load_cache(request: Request):
    cache = msal.SerializableTokenCache()
    if request.session.get("token_cache"):
        cache.deserialize(request.session["token_cache"])
    return cache


def _save_cache(request: Request, cache):
    if cache.has_state_changed:
        request.session["token_cache"] = cache.serialize()


def _build_msal_app(cache=None, authority=None):
    return msal.ConfidentialClientApplication(
        settings.CLIENT_ID,
        authority=authority or settings.AUTH_AUTHORITY,
        client_credential=settings.CLIENT_SECRET,
        token_cache=cache
    )


def _build_auth_code_flow(authority=None, scopes=None):
    return _build_msal_app(authority=authority).initiate_auth_code_flow(
        scopes or [],
        redirect_uri=settings.AUTH_REDIRECT)


def _get_token_from_cache(scope=None):
    cache = _load_cache()  # This web app maintains one cache per session
    cca = _build_msal_app(cache=cache)
    accounts = cca.get_accounts()
    if accounts:  # So all account(s) belong to the current signed-in user
        result = cca.acquire_token_silent(scope, account=accounts[0])
        _save_cache(cache)
        return result

非常感谢任何帮助。很高兴回答任何问题。谢谢。

这是因为 FastAPI 会话变量在客户端存储为 cookie,其数据限制为 4096 字节。从重定向 url 存储的数据使 cookie 大小超过此限制,导致数据未存储。 Starlette-session 是另一种 SessionMiddleware,它在服务器端存储变量,消除了 cookie 限制。下面是一个基本的(但凌乱的)实现:

from fastapi import FastAPI
from fastapi.templating import Jinja2Templates

from starlette.requests import Request
from starlette.responses import RedirectResponse

from starlette_session import SessionMiddleware
from starlette_session.backends import BackendType

from redis import Redis

import uvicorn
import functools
import msal


app_client_id = "sample_msal_client_id"
app_client_secret = "sample_msal_client_secret"
tenant_id = "sample_msal_tenant_id"

app = FastAPI()


redis_client = Redis(host="localhost", port=6379)
app.add_middleware(
    SessionMiddleware,
    secret_key="SECURE_SECRET_KEY",
    cookie_name="auth_cookie",
    backend_type=BackendType.redis,
    backend_client=redis_client,
)

templates = Jinja2Templates(directory="templates")

default_scope = ["https://graph.microsoft.com/.default"]
token_cache_key = "token_cache"

# Private Functions - Start
def _load_cache(session):
    cache = msal.SerializableTokenCache()
    if session.get(token_cache_key):
        cache.deserialize(session[token_cache_key])
    return cache

def _save_cache(cache,session):
    if cache.has_state_changed:
        session[token_cache_key] = cache.serialize()

def _build_msal_app(cache=None):
    return msal.ConfidentialClientApplication(
        app_client_id, 
        client_credential=app_client_secret,
        authority=f"https://login.microsoftonline.com/{tenant_id}",
        token_cache=cache
    )

def _build_auth_code_flow(request):
    return _build_msal_app().initiate_auth_code_flow(
        default_scope, #Scopes
        redirect_uri=request.url_for("callback") #Redirect URI
    )

def _get_token_from_cache(session):
    cache = _load_cache(session)  # This web app maintains one cache per session
    cca = _build_msal_app(cache=cache)
    accounts = cca.get_accounts()
    if accounts:  # So all account(s) belong to the current signed-in user
        result = cca.acquire_token_silent(default_scope, account=accounts[0])
        _save_cache(cache,session)
        return result
# Private Functions - End


# Custom Decorators - Start
def authenticated_endpoint(func):
    @functools.wraps(func)
    def is_authenticated(*args,**kwargs):
        try:
            request = kwargs["request"]
            token = _get_token_from_cache(request.session)
            if not token:
                return RedirectResponse(request.url_for("login"))
            return func(*args,**kwargs)
        except:
            return RedirectResponse(request.url_for("login"))

    return is_authenticated
# Custom Decorators - End


# Endpoints - Start
@app.get("/")
@authenticated_endpoint
def index(request:Request):
    return {
        "result": "good"
    }

@app.get("/login")
def login(request:Request):
    return templates.TemplateResponse("login.html",{
        "version": msal.__version__,
        'request': request,
        "config": {
            "B2C_RESET_PASSWORD_AUTHORITY": False
        }
    })

@app.get("/oauth/redirect")
def get_redirect_url(request:Request):
    request.session["flow"] = _build_auth_code_flow(request)
    return RedirectResponse(request.session["flow"]["auth_uri"])

@app.get("/callback")
async def callback(request:Request):
    cache = _load_cache(request.session)
    result = _build_msal_app(cache=cache).acquire_token_by_auth_code_flow(request.session.get("flow", {}), dict(request.query_params))
    if "error" in result:
        return templates.TemplateResponse("auth_error.html",{
            "result": result,
            'request': request
        })
    request.session["user"] = result.get("id_token_claims")
    request.session[token_cache_key] = cache.serialize()
    return RedirectResponse(request.url_for("index"))
# Endpoints - End

if __name__ == "__main__":
    uvicorn.run("main:app",host='0.0.0.0', port=4557,reload=True)`