import typing import anyio from starlette.background import BackgroundTask from starlette.requests import Request from starlette.responses import ContentStream, Response, StreamingResponse from starlette.types import ASGIApp, Message, Receive, Scope, Send RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]] DispatchFunction = typing.Callable[ [Request, RequestResponseEndpoint], typing.Awaitable[Response] ] T = typing.TypeVar("T") class BaseHTTPMiddleware: def __init__( self, app: ASGIApp, dispatch: typing.Optional[DispatchFunction] = None ) -> None: self.app = app self.dispatch_func = self.dispatch if dispatch is None else dispatch async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: if scope["type"] != "http": await self.app(scope, receive, send) return response_sent = anyio.Event() async def call_next(request: Request) -> Response: app_exc: typing.Optional[Exception] = None send_stream, recv_stream = anyio.create_memory_object_stream() async def receive_or_disconnect() -> Message: if response_sent.is_set(): return {"type": "http.disconnect"} async with anyio.create_task_group() as task_group: async def wrap(func: typing.Callable[[], typing.Awaitable[T]]) -> T: result = await func() task_group.cancel_scope.cancel() return result task_group.start_soon(wrap, response_sent.wait) message = await wrap(request.receive) if response_sent.is_set(): return {"type": "http.disconnect"} return message async def close_recv_stream_on_response_sent() -> None: await response_sent.wait() recv_stream.close() async def send_no_error(message: Message) -> None: try: await send_stream.send(message) except anyio.BrokenResourceError: # recv_stream has been closed, i.e. response_sent has been set. return async def coro() -> None: nonlocal app_exc async with send_stream: try: await self.app(scope, receive_or_disconnect, send_no_error) except Exception as exc: app_exc = exc task_group.start_soon(close_recv_stream_on_response_sent) task_group.start_soon(coro) try: message = await recv_stream.receive() info = message.get("info", None) if message["type"] == "http.response.debug" and info is not None: message = await recv_stream.receive() except anyio.EndOfStream: if app_exc is not None: raise app_exc raise RuntimeError("No response returned.") assert message["type"] == "http.response.start" async def body_stream() -> typing.AsyncGenerator[bytes, None]: async with recv_stream: async for message in recv_stream: assert message["type"] == "http.response.body" body = message.get("body", b"") if body: yield body if app_exc is not None: raise app_exc response = _StreamingResponse( status_code=message["status"], content=body_stream(), info=info ) response.raw_headers = message["headers"] return response async with anyio.create_task_group() as task_group: request = Request(scope, receive=receive) response = await self.dispatch_func(request, call_next) await response(scope, receive, send) response_sent.set() async def dispatch( self, request: Request, call_next: RequestResponseEndpoint ) -> Response: raise NotImplementedError() # pragma: no cover class _StreamingResponse(StreamingResponse): def __init__( self, content: ContentStream, status_code: int = 200, headers: typing.Optional[typing.Mapping[str, str]] = None, media_type: typing.Optional[str] = None, background: typing.Optional[BackgroundTask] = None, info: typing.Optional[typing.Mapping[str, typing.Any]] = None, ) -> None: self._info = info super().__init__(content, status_code, headers, media_type, background) async def stream_response(self, send: Send) -> None: if self._info: await send({"type": "http.response.debug", "info": self._info}) return await super().stream_response(send)