import typing from starlette._utils import is_async_callable from starlette.concurrency import run_in_threadpool from starlette.exceptions import HTTPException, WebSocketException from starlette.requests import Request from starlette.responses import PlainTextResponse, Response from starlette.types import ASGIApp, Message, Receive, Scope, Send from starlette.websockets import WebSocket class ExceptionMiddleware: def __init__( self, app: ASGIApp, handlers: typing.Optional[ typing.Mapping[typing.Any, typing.Callable[[Request, Exception], Response]] ] = None, debug: bool = False, ) -> None: self.app = app self.debug = debug # TODO: We ought to handle 404 cases if debug is set. self._status_handlers: typing.Dict[int, typing.Callable] = {} self._exception_handlers: typing.Dict[ typing.Type[Exception], typing.Callable ] = { HTTPException: self.http_exception, WebSocketException: self.websocket_exception, } if handlers is not None: for key, value in handlers.items(): self.add_exception_handler(key, value) def add_exception_handler( self, exc_class_or_status_code: typing.Union[int, typing.Type[Exception]], handler: typing.Callable[[Request, Exception], Response], ) -> None: if isinstance(exc_class_or_status_code, int): self._status_handlers[exc_class_or_status_code] = handler else: assert issubclass(exc_class_or_status_code, Exception) self._exception_handlers[exc_class_or_status_code] = handler def _lookup_exception_handler( self, exc: Exception ) -> typing.Optional[typing.Callable]: for cls in type(exc).__mro__: if cls in self._exception_handlers: return self._exception_handlers[cls] return None async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: if scope["type"] not in ("http", "websocket"): await self.app(scope, receive, send) return response_started = False async def sender(message: Message) -> None: nonlocal response_started if message["type"] == "http.response.start": response_started = True await send(message) try: await self.app(scope, receive, sender) except Exception as exc: handler = None if isinstance(exc, HTTPException): handler = self._status_handlers.get(exc.status_code) if handler is None: handler = self._lookup_exception_handler(exc) if handler is None: raise exc if response_started: msg = "Caught handled exception, but response already started." raise RuntimeError(msg) from exc if scope["type"] == "http": request = Request(scope, receive=receive) if is_async_callable(handler): response = await handler(request, exc) else: response = await run_in_threadpool(handler, request, exc) await response(scope, receive, sender) elif scope["type"] == "websocket": websocket = WebSocket(scope, receive=receive, send=send) if is_async_callable(handler): await handler(websocket, exc) else: await run_in_threadpool(handler, websocket, exc) def http_exception(self, request: Request, exc: HTTPException) -> Response: if exc.status_code in {204, 304}: return Response(status_code=exc.status_code, headers=exc.headers) return PlainTextResponse( exc.detail, status_code=exc.status_code, headers=exc.headers ) async def websocket_exception( self, websocket: WebSocket, exc: WebSocketException ) -> None: await websocket.close(code=exc.code, reason=exc.reason)