import functools import inspect import typing from urllib.parse import urlencode from starlette._utils import is_async_callable from starlette.exceptions import HTTPException from starlette.requests import HTTPConnection, Request from starlette.responses import RedirectResponse, Response from starlette.websockets import WebSocket _CallableType = typing.TypeVar("_CallableType", bound=typing.Callable) def has_required_scope(conn: HTTPConnection, scopes: typing.Sequence[str]) -> bool: for scope in scopes: if scope not in conn.auth.scopes: return False return True def requires( scopes: typing.Union[str, typing.Sequence[str]], status_code: int = 403, redirect: typing.Optional[str] = None, ) -> typing.Callable[[_CallableType], _CallableType]: scopes_list = [scopes] if isinstance(scopes, str) else list(scopes) def decorator(func: typing.Callable) -> typing.Callable: sig = inspect.signature(func) for idx, parameter in enumerate(sig.parameters.values()): if parameter.name == "request" or parameter.name == "websocket": type_ = parameter.name break else: raise Exception( f'No "request" or "websocket" argument on function "{func}"' ) if type_ == "websocket": # Handle websocket functions. (Always async) @functools.wraps(func) async def websocket_wrapper( *args: typing.Any, **kwargs: typing.Any ) -> None: websocket = kwargs.get( "websocket", args[idx] if idx < len(args) else None ) assert isinstance(websocket, WebSocket) if not has_required_scope(websocket, scopes_list): await websocket.close() else: await func(*args, **kwargs) return websocket_wrapper elif is_async_callable(func): # Handle async request/response functions. @functools.wraps(func) async def async_wrapper( *args: typing.Any, **kwargs: typing.Any ) -> Response: request = kwargs.get("request", args[idx] if idx < len(args) else None) assert isinstance(request, Request) if not has_required_scope(request, scopes_list): if redirect is not None: orig_request_qparam = urlencode({"next": str(request.url)}) next_url = "{redirect_path}?{orig_request}".format( redirect_path=request.url_for(redirect), orig_request=orig_request_qparam, ) return RedirectResponse(url=next_url, status_code=303) raise HTTPException(status_code=status_code) return await func(*args, **kwargs) return async_wrapper else: # Handle sync request/response functions. @functools.wraps(func) def sync_wrapper(*args: typing.Any, **kwargs: typing.Any) -> Response: request = kwargs.get("request", args[idx] if idx < len(args) else None) assert isinstance(request, Request) if not has_required_scope(request, scopes_list): if redirect is not None: orig_request_qparam = urlencode({"next": str(request.url)}) next_url = "{redirect_path}?{orig_request}".format( redirect_path=request.url_for(redirect), orig_request=orig_request_qparam, ) return RedirectResponse(url=next_url, status_code=303) raise HTTPException(status_code=status_code) return func(*args, **kwargs) return sync_wrapper return decorator # type: ignore[return-value] class AuthenticationError(Exception): pass class AuthenticationBackend: async def authenticate( self, conn: HTTPConnection ) -> typing.Optional[typing.Tuple["AuthCredentials", "BaseUser"]]: raise NotImplementedError() # pragma: no cover class AuthCredentials: def __init__(self, scopes: typing.Optional[typing.Sequence[str]] = None): self.scopes = [] if scopes is None else list(scopes) class BaseUser: @property def is_authenticated(self) -> bool: raise NotImplementedError() # pragma: no cover @property def display_name(self) -> str: raise NotImplementedError() # pragma: no cover @property def identity(self) -> str: raise NotImplementedError() # pragma: no cover class SimpleUser(BaseUser): def __init__(self, username: str) -> None: self.username = username @property def is_authenticated(self) -> bool: return True @property def display_name(self) -> str: return self.username class UnauthenticatedUser(BaseUser): @property def is_authenticated(self) -> bool: return False @property def display_name(self) -> str: return ""