在中间件上下文中获取 starlette 请求体

Get starlette request body in the middleware context

我有这样的中间件

class RequestContext(BaseHTTPMiddleware):
    async def dispatch(self, request: Request, call_next: RequestResponseEndpoint):
        request_id = request_ctx.set(str(uuid4()))  # generate uuid to request
        body = await request.body()
        if body:
            logger.info(...)  # log request with body
        else:
            logger.info(...)  # log request without body
 
        response = await call_next(request)
        response.headers['X-Request-ID'] = request_ctx.get()
        logger.info("%s" % (response.status_code))
        request_ctx.reset(request_id)

        return response

因此,body = await request.body() 行冻结了所有具有正文的请求,我从所有这些请求中得到了 504。在此上下文中如何安全地读取请求正文?我只想记录请求参数。

我不会创建一个继承自 BaseHTTPMiddleware 的中间件,因为它有一些 issues,FastAPI 让您有机会创建自己的路由器,根据我的经验,这种方法更好。

from fastapi import APIRouter, FastAPI, Request, Response, Body
from fastapi.routing import APIRoute

from typing import Callable, List
from uuid import uuid4


class ContextIncludedRoute(APIRoute):
    def get_route_handler(self) -> Callable:
        original_route_handler = super().get_route_handler()

        async def custom_route_handler(request: Request) -> Response:
            request_id = str(uuid4())
            response: Response = await original_route_handler(request)

            if await request.body():
                print(await request.body())

            response.headers["Request-ID"] = request_id
            return response

        return custom_route_handler


app = FastAPI()
router = APIRouter(route_class=ContextIncludedRoute)


@router.post("/context")
async def non_default_router(bod: List[str] = Body(...)):
    return bod


app.include_router(router)

按预期工作。

b'["string"]'
INFO:     127.0.0.1:49784 - "POST /context HTTP/1.1" 200 OK

如果你仍然想使用 BaseHTTP,我最近 运行 解决了这个问题并提出了一个解决方案:

中间件代码

from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
import json
from .async_iterator_wrapper import async_iterator_wrapper as aiwrap

class some_middleware(BaseHTTPMiddleware):
   async def dispatch(self, request:Request, call_next:RequestResponseEndpoint):
      # --------------------------
      # DO WHATEVER YOU TO DO HERE
      #---------------------------
      
      response = await call_next(request)

      # Consuming FastAPI response and grabbing body here
      resp_body = [section async for section in response.__dict__['body_iterator']]
      # Repairing FastAPI response
      response.__setattr__('body_iterator', aiwrap(resp_body)

      # Formatting response body for logging
      try:
         resp_body = json.loads(resp_body[0].decode())
      except:
         resp_body = str(resp_body)

async_iterator_wrapper代码来自

class async_iterator_wrapper:
    def __init__(self, obj):
        self._it = iter(obj)
    def __aiter__(self):
        return self
    async def __anext__(self):
        try:
            value = next(self._it)
        except StopIteration:
            raise StopAsyncIteration
        return value

我真的希望这可以帮助别人!我发现这对日志记录很有帮助。

非常感谢@Eddified 的 aiwrap class

原来 await request.json() 每个请求周期只能调用一次。因此,如果您需要访问多个中间件中的请求主体以进行过滤或身份验证等,那么可以解决这个问题,即创建一个自定义中间件来复制 request.state 中请求主体的内容。中间件应在必要时尽早加载。然后链中的每个中间件或控制器都可以从 request.state 访问请求主体,而不是再次调用 await request.json()。这是一个例子:

class CopyRequestMiddleware(BaseHTTPMiddleware):
    async def dispatch(self, request: Request, call_next):
        request_body = await request.json()
        request.state.body = request_body

        response = await call_next(request)
        return response

class LogRequestMiddleware(BaseHTTPMiddleware):
    async def dispatch(self, request: Request, call_next):
        # Since it'll be loaded after CopyRequestMiddleware it can access request.state.body.
        request_body = request.state.body
        print(request_body)
    
        response = await call_next(request)
        return response

控制器也将从 request.state 访问请求正文

request_body = request.state.body

只是因为还没有说明这样的解决方案,但它对我有用:

from typing import Callable, Awaitable

from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import StreamingResponse
from starlette.concurrency import iterate_in_threadpool

class LogStatsMiddleware(BaseHTTPMiddleware):
    async def dispatch(  # type: ignore
        self, request: Request, call_next: Callable[[Request], Awaitable[StreamingResponse]],
    ) -> Response:
        response = await call_next(request)
        response_body = [section async for section in response.body_iterator]
        response.body_iterator = iterate_in_threadpool(iter(response_body))
        logging.info(f"response_body={response_body[0].decode()}")
        return response

def init_app(app):
    app.add_middleware(LogStatsMiddleware)

iterate_in_threadpool 实际上是从迭代器对象生成异步迭代器

如果您查看 starlette.responses.StreamingResponse 的实现,您会发现该函数正是用于此

如果您只想读取请求参数,我找到的最佳解决方案是实现“route_class”并在创建 fastapi.APIRouter 时将其添加为 arg,这是因为在middleware is considered problematic 据我了解,路由处理程序背后的意图是将异常处理逻辑附加到特定路由器,但由于它在每次路由调用之前被调用,您可以使用它来访问请求 arg

Fastapi documentation

您可以执行以下操作:

class MyRequestLoggingRoute(APIRoute):
    def get_route_handler(self) -> Callable:
        original_route_handler = super().get_route_handler()

        async def custom_route_handler(request: Request) -> Response:
            body = await request.body()
            if body:
               logger.info(...)  # log request with body
            else:
               logger.info(...)  # log request without body
            try:

                return await original_route_handler(request)
            except RequestValidationError as exc:
               detail = {"errors": exc.errors(), "body": body.decode()}
               raise HTTPException(status_code=422, detail=detail)

        return custom_route_handler

您可以使用通用 ASGI 中间件安全地执行此操作:

from typing import Iterable, List, Protocol, Generator

import pytest

from starlette.responses import Response
from starlette.testclient import TestClient
from starlette.types import ASGIApp, Scope, Send, Receive, Message


class Logger(Protocol):
    def info(self, message: str) -> None:
        ...


class BodyLoggingMiddleware:
    def __init__(
        self,
        app: ASGIApp,
        logger: Logger,
    ) -> None:
        self.app = app
        self.logger = logger

    async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
        if scope["type"]  != "http":
            await self.app(scope, receive, send)
            return
        
        done = False
        chunks: "List[bytes]" = []

        async def wrapped_receive() -> Message:
            nonlocal done
            message = await receive()
            if message["type"] == "http.disconnect":
                done = True
                return message
            body = message.get("body", b"")
            more_body = message.get("more_body", False)
            if not more_body:
                done = True
            chunks.append(body)
            return message
        try:
            await self.app(scope, wrapped_receive, send)
        finally:
            while not done:
                await wrapped_receive()
            self.logger.info(b"".join(chunks).decode())  # or somethin


async def consume_body_app(scope: Scope, receive: Receive, send: Send) -> None:
    done = False
    while not done:
        msg = await receive()
        done = "more_body" not in msg
    await Response()(scope, receive, send)


async def consume_partial_body_app(scope: Scope, receive: Receive, send: Send) -> None:
    await receive()
    await Response()(scope, receive, send)


class TestException(Exception):
    pass


async def consume_body_and_error_app(scope: Scope, receive: Receive, send: Send) -> None:
    done = False
    while not done:
        msg = await receive()
        done = "more_body" not in msg
    raise TestException


async def consume_partial_body_and_error_app(scope: Scope, receive: Receive, send: Send) -> None:
    await receive()
    raise TestException


class TestLogger:
    def __init__(self, recorder: List[str]) -> None:
        self.recorder = recorder
    
    def info(self, message: str) -> None:
        self.recorder.append(message)


@pytest.mark.parametrize(
    "chunks, expected_logs", [
        ([b"foo", b" ", b"bar", b" ", "baz"], ["foo bar baz"]),
    ]
)
@pytest.mark.parametrize(
    "app",
    [consume_body_app, consume_partial_body_app]
)
def test_body_logging_middleware_no_errors(chunks: Iterable[bytes], expected_logs: Iterable[str], app: ASGIApp) -> None:
    logs: List[str] = []
    client = TestClient(BodyLoggingMiddleware(app, TestLogger(logs)))

    def chunk_gen() -> Generator[bytes, None, None]:
        yield from iter(chunks)

    resp = client.get("/", data=chunk_gen())
    assert resp.status_code == 200
    assert logs == expected_logs


@pytest.mark.parametrize(
    "chunks, expected_logs", [
        ([b"foo", b" ", b"bar", b" ", "baz"], ["foo bar baz"]),
    ]
)
@pytest.mark.parametrize(
    "app",
    [consume_body_and_error_app, consume_partial_body_and_error_app]
)
def test_body_logging_middleware_with_errors(chunks: Iterable[bytes], expected_logs: Iterable[str], app: ASGIApp) -> None:
    logs: List[str] = []
    client = TestClient(BodyLoggingMiddleware(app, TestLogger(logs)))

    def chunk_gen() -> Generator[bytes, None, None]:
        yield from iter(chunks)

    with pytest.raises(TestException):
        client.get("/", data=chunk_gen())
    assert logs == expected_logs


if __name__ == "__main__":
    import os
    pytest.main(args=[os.path.abspath(__file__)])