from functools import wraps from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional, Tuple, Type, TypeVar, Union, overload from . import validator from .config import Extra from .errors import ConfigError from .main import BaseModel, create_model from .typing import get_all_type_hints from .utils import to_camel __all__ = ('validate_arguments',) if TYPE_CHECKING: from .typing import AnyCallable AnyCallableT = TypeVar('AnyCallableT', bound=AnyCallable) ConfigType = Union[None, Type[Any], Dict[str, Any]] @overload def validate_arguments(func: None = None, *, config: 'ConfigType' = None) -> Callable[['AnyCallableT'], 'AnyCallableT']: ... @overload def validate_arguments(func: 'AnyCallableT') -> 'AnyCallableT': ... def validate_arguments(func: Optional['AnyCallableT'] = None, *, config: 'ConfigType' = None) -> Any: """ Decorator to validate the arguments passed to a function. """ def validate(_func: 'AnyCallable') -> 'AnyCallable': vd = ValidatedFunction(_func, config) @wraps(_func) def wrapper_function(*args: Any, **kwargs: Any) -> Any: return vd.call(*args, **kwargs) wrapper_function.vd = vd # type: ignore wrapper_function.validate = vd.init_model_instance # type: ignore wrapper_function.raw_function = vd.raw_function # type: ignore wrapper_function.model = vd.model # type: ignore return wrapper_function if func: return validate(func) else: return validate ALT_V_ARGS = 'v__args' ALT_V_KWARGS = 'v__kwargs' V_POSITIONAL_ONLY_NAME = 'v__positional_only' V_DUPLICATE_KWARGS = 'v__duplicate_kwargs' class ValidatedFunction: def __init__(self, function: 'AnyCallableT', config: 'ConfigType'): # noqa C901 from inspect import Parameter, signature parameters: Mapping[str, Parameter] = signature(function).parameters if parameters.keys() & {ALT_V_ARGS, ALT_V_KWARGS, V_POSITIONAL_ONLY_NAME, V_DUPLICATE_KWARGS}: raise ConfigError( f'"{ALT_V_ARGS}", "{ALT_V_KWARGS}", "{V_POSITIONAL_ONLY_NAME}" and "{V_DUPLICATE_KWARGS}" ' f'are not permitted as argument names when using the "{validate_arguments.__name__}" decorator' ) self.raw_function = function self.arg_mapping: Dict[int, str] = {} self.positional_only_args = set() self.v_args_name = 'args' self.v_kwargs_name = 'kwargs' type_hints = get_all_type_hints(function) takes_args = False takes_kwargs = False fields: Dict[str, Tuple[Any, Any]] = {} for i, (name, p) in enumerate(parameters.items()): if p.annotation is p.empty: annotation = Any else: annotation = type_hints[name] default = ... if p.default is p.empty else p.default if p.kind == Parameter.POSITIONAL_ONLY: self.arg_mapping[i] = name fields[name] = annotation, default fields[V_POSITIONAL_ONLY_NAME] = List[str], None self.positional_only_args.add(name) elif p.kind == Parameter.POSITIONAL_OR_KEYWORD: self.arg_mapping[i] = name fields[name] = annotation, default fields[V_DUPLICATE_KWARGS] = List[str], None elif p.kind == Parameter.KEYWORD_ONLY: fields[name] = annotation, default elif p.kind == Parameter.VAR_POSITIONAL: self.v_args_name = name fields[name] = Tuple[annotation, ...], None takes_args = True else: assert p.kind == Parameter.VAR_KEYWORD, p.kind self.v_kwargs_name = name fields[name] = Dict[str, annotation], None # type: ignore takes_kwargs = True # these checks avoid a clash between "args" and a field with that name if not takes_args and self.v_args_name in fields: self.v_args_name = ALT_V_ARGS # same with "kwargs" if not takes_kwargs and self.v_kwargs_name in fields: self.v_kwargs_name = ALT_V_KWARGS if not takes_args: # we add the field so validation below can raise the correct exception fields[self.v_args_name] = List[Any], None if not takes_kwargs: # same with kwargs fields[self.v_kwargs_name] = Dict[Any, Any], None self.create_model(fields, takes_args, takes_kwargs, config) def init_model_instance(self, *args: Any, **kwargs: Any) -> BaseModel: values = self.build_values(args, kwargs) return self.model(**values) def call(self, *args: Any, **kwargs: Any) -> Any: m = self.init_model_instance(*args, **kwargs) return self.execute(m) def build_values(self, args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Dict[str, Any]: values: Dict[str, Any] = {} if args: arg_iter = enumerate(args) while True: try: i, a = next(arg_iter) except StopIteration: break arg_name = self.arg_mapping.get(i) if arg_name is not None: values[arg_name] = a else: values[self.v_args_name] = [a] + [a for _, a in arg_iter] break var_kwargs: Dict[str, Any] = {} wrong_positional_args = [] duplicate_kwargs = [] fields_alias = [ field.alias for name, field in self.model.__fields__.items() if name not in (self.v_args_name, self.v_kwargs_name) ] non_var_fields = set(self.model.__fields__) - {self.v_args_name, self.v_kwargs_name} for k, v in kwargs.items(): if k in non_var_fields or k in fields_alias: if k in self.positional_only_args: wrong_positional_args.append(k) if k in values: duplicate_kwargs.append(k) values[k] = v else: var_kwargs[k] = v if var_kwargs: values[self.v_kwargs_name] = var_kwargs if wrong_positional_args: values[V_POSITIONAL_ONLY_NAME] = wrong_positional_args if duplicate_kwargs: values[V_DUPLICATE_KWARGS] = duplicate_kwargs return values def execute(self, m: BaseModel) -> Any: d = {k: v for k, v in m._iter() if k in m.__fields_set__ or m.__fields__[k].default_factory} var_kwargs = d.pop(self.v_kwargs_name, {}) if self.v_args_name in d: args_: List[Any] = [] in_kwargs = False kwargs = {} for name, value in d.items(): if in_kwargs: kwargs[name] = value elif name == self.v_args_name: args_ += value in_kwargs = True else: args_.append(value) return self.raw_function(*args_, **kwargs, **var_kwargs) elif self.positional_only_args: args_ = [] kwargs = {} for name, value in d.items(): if name in self.positional_only_args: args_.append(value) else: kwargs[name] = value return self.raw_function(*args_, **kwargs, **var_kwargs) else: return self.raw_function(**d, **var_kwargs) def create_model(self, fields: Dict[str, Any], takes_args: bool, takes_kwargs: bool, config: 'ConfigType') -> None: pos_args = len(self.arg_mapping) class CustomConfig: pass if not TYPE_CHECKING: # pragma: no branch if isinstance(config, dict): CustomConfig = type('Config', (), config) # noqa: F811 elif config is not None: CustomConfig = config # noqa: F811 if hasattr(CustomConfig, 'fields') or hasattr(CustomConfig, 'alias_generator'): raise ConfigError( 'Setting the "fields" and "alias_generator" property on custom Config for ' '@validate_arguments is not yet supported, please remove.' ) class DecoratorBaseModel(BaseModel): @validator(self.v_args_name, check_fields=False, allow_reuse=True) def check_args(cls, v: Optional[List[Any]]) -> Optional[List[Any]]: if takes_args or v is None: return v raise TypeError(f'{pos_args} positional arguments expected but {pos_args + len(v)} given') @validator(self.v_kwargs_name, check_fields=False, allow_reuse=True) def check_kwargs(cls, v: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]: if takes_kwargs or v is None: return v plural = '' if len(v) == 1 else 's' keys = ', '.join(map(repr, v.keys())) raise TypeError(f'unexpected keyword argument{plural}: {keys}') @validator(V_POSITIONAL_ONLY_NAME, check_fields=False, allow_reuse=True) def check_positional_only(cls, v: Optional[List[str]]) -> None: if v is None: return plural = '' if len(v) == 1 else 's' keys = ', '.join(map(repr, v)) raise TypeError(f'positional-only argument{plural} passed as keyword argument{plural}: {keys}') @validator(V_DUPLICATE_KWARGS, check_fields=False, allow_reuse=True) def check_duplicate_kwargs(cls, v: Optional[List[str]]) -> None: if v is None: return plural = '' if len(v) == 1 else 's' keys = ', '.join(map(repr, v)) raise TypeError(f'multiple values for argument{plural}: {keys}') class Config(CustomConfig): extra = getattr(CustomConfig, 'extra', Extra.forbid) self.model = create_model(to_camel(self.raw_function.__name__), __base__=DecoratorBaseModel, **fields)