import http.client import inspect import warnings from enum import Enum from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Type, Union, cast from fastapi import routing from fastapi.datastructures import DefaultPlaceholder from fastapi.dependencies.models import Dependant from fastapi.dependencies.utils import get_flat_dependant, get_flat_params from fastapi.encoders import jsonable_encoder from fastapi.openapi.constants import METHODS_WITH_BODY, REF_PREFIX from fastapi.openapi.models import OpenAPI from fastapi.params import Body, Param from fastapi.responses import Response from fastapi.utils import ( deep_dict_update, generate_operation_id_for_path, get_model_definitions, is_body_allowed_for_status_code, ) from pydantic import BaseModel from pydantic.fields import ModelField, Undefined from pydantic.schema import ( field_schema, get_flat_models_from_fields, get_model_name_map, ) from pydantic.utils import lenient_issubclass from starlette.responses import JSONResponse from starlette.routing import BaseRoute from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY validation_error_definition = { "title": "ValidationError", "type": "object", "properties": { "loc": { "title": "Location", "type": "array", "items": {"anyOf": [{"type": "string"}, {"type": "integer"}]}, }, "msg": {"title": "Message", "type": "string"}, "type": {"title": "Error Type", "type": "string"}, }, "required": ["loc", "msg", "type"], } validation_error_response_definition = { "title": "HTTPValidationError", "type": "object", "properties": { "detail": { "title": "Detail", "type": "array", "items": {"$ref": REF_PREFIX + "ValidationError"}, } }, } status_code_ranges: Dict[str, str] = { "1XX": "Information", "2XX": "Success", "3XX": "Redirection", "4XX": "Client Error", "5XX": "Server Error", "DEFAULT": "Default Response", } def get_openapi_security_definitions( flat_dependant: Dependant, ) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]: security_definitions = {} operation_security = [] for security_requirement in flat_dependant.security_requirements: security_definition = jsonable_encoder( security_requirement.security_scheme.model, by_alias=True, exclude_none=True, ) security_name = security_requirement.security_scheme.scheme_name security_definitions[security_name] = security_definition operation_security.append({security_name: security_requirement.scopes}) return security_definitions, operation_security def get_openapi_operation_parameters( *, all_route_params: Sequence[ModelField], model_name_map: Dict[Union[Type[BaseModel], Type[Enum]], str], ) -> List[Dict[str, Any]]: parameters = [] for param in all_route_params: field_info = param.field_info field_info = cast(Param, field_info) if not field_info.include_in_schema: continue parameter = { "name": param.alias, "in": field_info.in_.value, "required": param.required, "schema": field_schema( param, model_name_map=model_name_map, ref_prefix=REF_PREFIX )[0], } if field_info.description: parameter["description"] = field_info.description if field_info.examples: parameter["examples"] = jsonable_encoder(field_info.examples) elif field_info.example != Undefined: parameter["example"] = jsonable_encoder(field_info.example) if field_info.deprecated: parameter["deprecated"] = field_info.deprecated parameters.append(parameter) return parameters def get_openapi_operation_request_body( *, body_field: Optional[ModelField], model_name_map: Dict[Union[Type[BaseModel], Type[Enum]], str], ) -> Optional[Dict[str, Any]]: if not body_field: return None assert isinstance(body_field, ModelField) body_schema, _, _ = field_schema( body_field, model_name_map=model_name_map, ref_prefix=REF_PREFIX ) field_info = cast(Body, body_field.field_info) request_media_type = field_info.media_type required = body_field.required request_body_oai: Dict[str, Any] = {} if required: request_body_oai["required"] = required request_media_content: Dict[str, Any] = {"schema": body_schema} if field_info.examples: request_media_content["examples"] = jsonable_encoder(field_info.examples) elif field_info.example != Undefined: request_media_content["example"] = jsonable_encoder(field_info.example) request_body_oai["content"] = {request_media_type: request_media_content} return request_body_oai def generate_operation_id( *, route: routing.APIRoute, method: str ) -> str: # pragma: nocover warnings.warn( "fastapi.openapi.utils.generate_operation_id() was deprecated, " "it is not used internally, and will be removed soon", DeprecationWarning, stacklevel=2, ) if route.operation_id: return route.operation_id path: str = route.path_format return generate_operation_id_for_path(name=route.name, path=path, method=method) def generate_operation_summary(*, route: routing.APIRoute, method: str) -> str: if route.summary: return route.summary return route.name.replace("_", " ").title() def get_openapi_operation_metadata( *, route: routing.APIRoute, method: str, operation_ids: Set[str] ) -> Dict[str, Any]: operation: Dict[str, Any] = {} if route.tags: operation["tags"] = route.tags operation["summary"] = generate_operation_summary(route=route, method=method) if route.description: operation["description"] = route.description operation_id = route.operation_id or route.unique_id if operation_id in operation_ids: message = ( f"Duplicate Operation ID {operation_id} for function " + f"{route.endpoint.__name__}" ) file_name = getattr(route.endpoint, "__globals__", {}).get("__file__") if file_name: message += f" at {file_name}" warnings.warn(message) operation_ids.add(operation_id) operation["operationId"] = operation_id if route.deprecated: operation["deprecated"] = route.deprecated return operation def get_openapi_path( *, route: routing.APIRoute, model_name_map: Dict[type, str], operation_ids: Set[str] ) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]: path = {} security_schemes: Dict[str, Any] = {} definitions: Dict[str, Any] = {} assert route.methods is not None, "Methods must be a list" if isinstance(route.response_class, DefaultPlaceholder): current_response_class: Type[Response] = route.response_class.value else: current_response_class = route.response_class assert current_response_class, "A response class is needed to generate OpenAPI" route_response_media_type: Optional[str] = current_response_class.media_type if route.include_in_schema: for method in route.methods: operation = get_openapi_operation_metadata( route=route, method=method, operation_ids=operation_ids ) parameters: List[Dict[str, Any]] = [] flat_dependant = get_flat_dependant(route.dependant, skip_repeats=True) security_definitions, operation_security = get_openapi_security_definitions( flat_dependant=flat_dependant ) if operation_security: operation.setdefault("security", []).extend(operation_security) if security_definitions: security_schemes.update(security_definitions) all_route_params = get_flat_params(route.dependant) operation_parameters = get_openapi_operation_parameters( all_route_params=all_route_params, model_name_map=model_name_map ) parameters.extend(operation_parameters) if parameters: all_parameters = { (param["in"], param["name"]): param for param in parameters } required_parameters = { (param["in"], param["name"]): param for param in parameters if param.get("required") } # Make sure required definitions of the same parameter take precedence # over non-required definitions all_parameters.update(required_parameters) operation["parameters"] = list(all_parameters.values()) if method in METHODS_WITH_BODY: request_body_oai = get_openapi_operation_request_body( body_field=route.body_field, model_name_map=model_name_map ) if request_body_oai: operation["requestBody"] = request_body_oai if route.callbacks: callbacks = {} for callback in route.callbacks: if isinstance(callback, routing.APIRoute): ( cb_path, cb_security_schemes, cb_definitions, ) = get_openapi_path( route=callback, model_name_map=model_name_map, operation_ids=operation_ids, ) callbacks[callback.name] = {callback.path: cb_path} operation["callbacks"] = callbacks if route.status_code is not None: status_code = str(route.status_code) else: # It would probably make more sense for all response classes to have an # explicit default status_code, and to extract it from them, instead of # doing this inspection tricks, that would probably be in the future # TODO: probably make status_code a default class attribute for all # responses in Starlette response_signature = inspect.signature(current_response_class.__init__) status_code_param = response_signature.parameters.get("status_code") if status_code_param is not None: if isinstance(status_code_param.default, int): status_code = str(status_code_param.default) operation.setdefault("responses", {}).setdefault(status_code, {})[ "description" ] = route.response_description if route_response_media_type and is_body_allowed_for_status_code( route.status_code ): response_schema = {"type": "string"} if lenient_issubclass(current_response_class, JSONResponse): if route.response_field: response_schema, _, _ = field_schema( route.response_field, model_name_map=model_name_map, ref_prefix=REF_PREFIX, ) else: response_schema = {} operation.setdefault("responses", {}).setdefault( status_code, {} ).setdefault("content", {}).setdefault(route_response_media_type, {})[ "schema" ] = response_schema if route.responses: operation_responses = operation.setdefault("responses", {}) for ( additional_status_code, additional_response, ) in route.responses.items(): process_response = additional_response.copy() process_response.pop("model", None) status_code_key = str(additional_status_code).upper() if status_code_key == "DEFAULT": status_code_key = "default" openapi_response = operation_responses.setdefault( status_code_key, {} ) assert isinstance( process_response, dict ), "An additional response must be a dict" field = route.response_fields.get(additional_status_code) additional_field_schema: Optional[Dict[str, Any]] = None if field: additional_field_schema, _, _ = field_schema( field, model_name_map=model_name_map, ref_prefix=REF_PREFIX ) media_type = route_response_media_type or "application/json" additional_schema = ( process_response.setdefault("content", {}) .setdefault(media_type, {}) .setdefault("schema", {}) ) deep_dict_update(additional_schema, additional_field_schema) status_text: Optional[str] = status_code_ranges.get( str(additional_status_code).upper() ) or http.client.responses.get(int(additional_status_code)) description = ( process_response.get("description") or openapi_response.get("description") or status_text or "Additional Response" ) deep_dict_update(openapi_response, process_response) openapi_response["description"] = description http422 = str(HTTP_422_UNPROCESSABLE_ENTITY) if (all_route_params or route.body_field) and not any( [ status in operation["responses"] for status in [http422, "4XX", "default"] ] ): operation["responses"][http422] = { "description": "Validation Error", "content": { "application/json": { "schema": {"$ref": REF_PREFIX + "HTTPValidationError"} } }, } if "ValidationError" not in definitions: definitions.update( { "ValidationError": validation_error_definition, "HTTPValidationError": validation_error_response_definition, } ) if route.openapi_extra: deep_dict_update(operation, route.openapi_extra) path[method.lower()] = operation return path, security_schemes, definitions def get_flat_models_from_routes( routes: Sequence[BaseRoute], ) -> Set[Union[Type[BaseModel], Type[Enum]]]: body_fields_from_routes: List[ModelField] = [] responses_from_routes: List[ModelField] = [] request_fields_from_routes: List[ModelField] = [] callback_flat_models: Set[Union[Type[BaseModel], Type[Enum]]] = set() for route in routes: if getattr(route, "include_in_schema", None) and isinstance( route, routing.APIRoute ): if route.body_field: assert isinstance( route.body_field, ModelField ), "A request body must be a Pydantic Field" body_fields_from_routes.append(route.body_field) if route.response_field: responses_from_routes.append(route.response_field) if route.response_fields: responses_from_routes.extend(route.response_fields.values()) if route.callbacks: callback_flat_models |= get_flat_models_from_routes(route.callbacks) params = get_flat_params(route.dependant) request_fields_from_routes.extend(params) flat_models = callback_flat_models | get_flat_models_from_fields( body_fields_from_routes + responses_from_routes + request_fields_from_routes, known_models=set(), ) return flat_models def get_openapi( *, title: str, version: str, openapi_version: str = "3.0.2", description: Optional[str] = None, routes: Sequence[BaseRoute], tags: Optional[List[Dict[str, Any]]] = None, servers: Optional[List[Dict[str, Union[str, Any]]]] = None, terms_of_service: Optional[str] = None, contact: Optional[Dict[str, Union[str, Any]]] = None, license_info: Optional[Dict[str, Union[str, Any]]] = None, ) -> Dict[str, Any]: info: Dict[str, Any] = {"title": title, "version": version} if description: info["description"] = description if terms_of_service: info["termsOfService"] = terms_of_service if contact: info["contact"] = contact if license_info: info["license"] = license_info output: Dict[str, Any] = {"openapi": openapi_version, "info": info} if servers: output["servers"] = servers components: Dict[str, Dict[str, Any]] = {} paths: Dict[str, Dict[str, Any]] = {} operation_ids: Set[str] = set() flat_models = get_flat_models_from_routes(routes) model_name_map = get_model_name_map(flat_models) definitions = get_model_definitions( flat_models=flat_models, model_name_map=model_name_map ) for route in routes: if isinstance(route, routing.APIRoute): result = get_openapi_path( route=route, model_name_map=model_name_map, operation_ids=operation_ids ) if result: path, security_schemes, path_definitions = result if path: paths.setdefault(route.path_format, {}).update(path) if security_schemes: components.setdefault("securitySchemes", {}).update( security_schemes ) if path_definitions: definitions.update(path_definitions) if definitions: components["schemas"] = {k: definitions[k] for k in sorted(definitions)} if components: output["components"] = components output["paths"] = paths if tags: output["tags"] = tags return jsonable_encoder(OpenAPI(**output), by_alias=True, exclude_none=True) # type: ignore