import binascii from base64 import b64decode from typing import Optional from fastapi.exceptions import HTTPException from fastapi.openapi.models import HTTPBase as HTTPBaseModel from fastapi.openapi.models import HTTPBearer as HTTPBearerModel from fastapi.security.base import SecurityBase from fastapi.security.utils import get_authorization_scheme_param from pydantic import BaseModel from starlette.requests import Request from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN class HTTPBasicCredentials(BaseModel): username: str password: str class HTTPAuthorizationCredentials(BaseModel): scheme: str credentials: str class HTTPBase(SecurityBase): def __init__( self, *, scheme: str, scheme_name: Optional[str] = None, description: Optional[str] = None, auto_error: bool = True, ): self.model = HTTPBaseModel(scheme=scheme, description=description) self.scheme_name = scheme_name or self.__class__.__name__ self.auto_error = auto_error async def __call__( self, request: Request ) -> Optional[HTTPAuthorizationCredentials]: authorization = request.headers.get("Authorization") scheme, credentials = get_authorization_scheme_param(authorization) if not (authorization and scheme and credentials): if self.auto_error: raise HTTPException( status_code=HTTP_403_FORBIDDEN, detail="Not authenticated" ) else: return None return HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials) class HTTPBasic(HTTPBase): def __init__( self, *, scheme_name: Optional[str] = None, realm: Optional[str] = None, description: Optional[str] = None, auto_error: bool = True, ): self.model = HTTPBaseModel(scheme="basic", description=description) self.scheme_name = scheme_name or self.__class__.__name__ self.realm = realm self.auto_error = auto_error async def __call__( # type: ignore self, request: Request ) -> Optional[HTTPBasicCredentials]: authorization = request.headers.get("Authorization") scheme, param = get_authorization_scheme_param(authorization) if self.realm: unauthorized_headers = {"WWW-Authenticate": f'Basic realm="{self.realm}"'} else: unauthorized_headers = {"WWW-Authenticate": "Basic"} invalid_user_credentials_exc = HTTPException( status_code=HTTP_401_UNAUTHORIZED, detail="Invalid authentication credentials", headers=unauthorized_headers, ) if not authorization or scheme.lower() != "basic": if self.auto_error: raise HTTPException( status_code=HTTP_401_UNAUTHORIZED, detail="Not authenticated", headers=unauthorized_headers, ) else: return None try: data = b64decode(param).decode("ascii") except (ValueError, UnicodeDecodeError, binascii.Error): raise invalid_user_credentials_exc username, separator, password = data.partition(":") if not separator: raise invalid_user_credentials_exc return HTTPBasicCredentials(username=username, password=password) class HTTPBearer(HTTPBase): def __init__( self, *, bearerFormat: Optional[str] = None, scheme_name: Optional[str] = None, description: Optional[str] = None, auto_error: bool = True, ): self.model = HTTPBearerModel(bearerFormat=bearerFormat, description=description) self.scheme_name = scheme_name or self.__class__.__name__ self.auto_error = auto_error async def __call__( self, request: Request ) -> Optional[HTTPAuthorizationCredentials]: authorization = request.headers.get("Authorization") scheme, credentials = get_authorization_scheme_param(authorization) if not (authorization and scheme and credentials): if self.auto_error: raise HTTPException( status_code=HTTP_403_FORBIDDEN, detail="Not authenticated" ) else: return None if scheme.lower() != "bearer": if self.auto_error: raise HTTPException( status_code=HTTP_403_FORBIDDEN, detail="Invalid authentication credentials", ) else: return None return HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials) class HTTPDigest(HTTPBase): def __init__( self, *, scheme_name: Optional[str] = None, description: Optional[str] = None, auto_error: bool = True, ): self.model = HTTPBaseModel(scheme="digest", description=description) self.scheme_name = scheme_name or self.__class__.__name__ self.auto_error = auto_error async def __call__( self, request: Request ) -> Optional[HTTPAuthorizationCredentials]: authorization = request.headers.get("Authorization") scheme, credentials = get_authorization_scheme_param(authorization) if not (authorization and scheme and credentials): if self.auto_error: raise HTTPException( status_code=HTTP_403_FORBIDDEN, detail="Not authenticated" ) else: return None if scheme.lower() != "digest": raise HTTPException( status_code=HTTP_403_FORBIDDEN, detail="Invalid authentication credentials", ) return HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials)