# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 from __future__ import annotations import inspect import re import textwrap from pathlib import Path from typing import Any, Dict, List, Optional, Set, Tuple ROOT_MODULE_NAME_PLACEHOLDER = "__root_module_name__" class Writer: """ Writer provides utilities for writing Python stubs. """ root_module_name: str path: Path subwriters: List[Writer] imports: Set[str] defs: List[str] generics: Set[str] def __init__(self, path: Path, root_module_name: str) -> None: self.path = path self.root_module_name = root_module_name self.subwriters = [] self.imports = set([]) self.defs = [] self.generics = set([]) def fix_path(self, path: str) -> str: """ Returns fixed version of given type path. It unescapes `\\[` and `\\]` and also populates placeholder for root module name. """ return path.replace(ROOT_MODULE_NAME_PLACEHOLDER, self.root_module_name).replace("\\[", "[").replace("\\]", "]") def submodule(self, path: Path) -> Writer: w = Writer(path, self.root_module_name) self.subwriters.append(w) return w def include(self, path: str) -> str: # `path` might be nested like: typing.Optional[typing.List[pokemon_service_server_sdk.model.GetPokemonSpecies]] # we need to process every subpath in a nested path paths = filter(lambda p: p, re.split("\\[|\\]|,| ", path)) for subpath in paths: parts = subpath.rsplit(".", maxsplit=1) # add `typing` to imports for a path like `typing.List` # but skip if the path doesn't have any namespace like `str` or `bool` if len(parts) == 2: self.imports.add(parts[0]) return path def fix_and_include(self, path: str) -> str: return self.include(self.fix_path(path)) def define(self, code: str) -> None: self.defs.append(code) def generic(self, name: str) -> None: self.generics.add(name) def dump(self) -> None: for w in self.subwriters: w.dump() generics = "" for g in sorted(self.generics): generics += f"{g} = {self.include('typing.TypeVar')}('{g}')\n" self.path.parent.mkdir(parents=True, exist_ok=True) contents = join([f"import {p}" for p in sorted(self.imports)]) contents += "\n\n" if generics: contents += generics + "\n" contents += join(self.defs) self.path.write_text(contents) class DocstringParserResult: def __init__(self) -> None: self.types: List[str] = [] self.params: List[Tuple[str, str]] = [] self.rtypes: List[str] = [] self.generics: List[str] = [] self.extends: List[str] = [] def parse_type_directive(line: str, res: DocstringParserResult): parts = line.split(" ", maxsplit=1) if len(parts) != 2: raise ValueError(f"Invalid `:type` directive: `{line}` must be in `:type T:` format") res.types.append(parts[1].rstrip(":")) def parse_rtype_directive(line: str, res: DocstringParserResult): parts = line.split(" ", maxsplit=1) if len(parts) != 2: raise ValueError(f"Invalid `:rtype` directive: `{line}` must be in `:rtype T:` format") res.rtypes.append(parts[1].rstrip(":")) def parse_param_directive(line: str, res: DocstringParserResult): parts = line.split(" ", maxsplit=2) if len(parts) != 3: raise ValueError(f"Invalid `:param` directive: `{line}` must be in `:param name T:` format") name = parts[1] ty = parts[2].rstrip(":") res.params.append((name, ty)) def parse_generic_directive(line: str, res: DocstringParserResult): parts = line.split(" ", maxsplit=1) if len(parts) != 2: raise ValueError(f"Invalid `:generic` directive: `{line}` must be in `:generic T:` format") res.generics.append(parts[1].rstrip(":")) def parse_extends_directive(line: str, res: DocstringParserResult): parts = line.split(" ", maxsplit=1) if len(parts) != 2: raise ValueError(f"Invalid `:extends` directive: `{line}` must be in `:extends Base[...]:` format") res.extends.append(parts[1].rstrip(":")) DocstringParserDirectives = { "type": parse_type_directive, "param": parse_param_directive, "rtype": parse_rtype_directive, "generic": parse_generic_directive, "extends": parse_extends_directive, } class DocstringParser: """ DocstringParser provides utilities for parsing type information from docstring. """ @staticmethod def parse(obj: Any) -> Optional[DocstringParserResult]: doc = inspect.getdoc(obj) if not doc: return None res = DocstringParserResult() for line in doc.splitlines(): line = line.strip() for d, p in DocstringParserDirectives.items(): if line.startswith(f":{d} ") and line.endswith(":"): p(line, res) return res @staticmethod def parse_type(obj: Any) -> str: result = DocstringParser.parse(obj) if not result or len(result.types) == 0: return "typing.Any" return result.types[0] @staticmethod def parse_function(obj: Any) -> Optional[Tuple[List[Tuple[str, str]], str]]: result = DocstringParser.parse(obj) if not result: return None return ( result.params, "None" if len(result.rtypes) == 0 else result.rtypes[0], ) @staticmethod def parse_class(obj: Any) -> Tuple[List[str], List[str]]: result = DocstringParser.parse(obj) if not result: return ([], []) return (result.generics, result.extends) @staticmethod def clean_doc(obj: Any) -> str: doc = inspect.getdoc(obj) if not doc: return "" def predicate(line: str) -> bool: for k in DocstringParserDirectives.keys(): if line.startswith(f":{k} ") and line.endswith(":"): return False return True return "\n".join([line for line in doc.splitlines() if predicate(line)]).strip() def indent(code: str, level: int = 4) -> str: return textwrap.indent(code, level * " ") def is_fn_like(obj: Any) -> bool: return ( inspect.isbuiltin(obj) or inspect.ismethod(obj) or inspect.isfunction(obj) or inspect.ismethoddescriptor(obj) or inspect.iscoroutine(obj) or inspect.iscoroutinefunction(obj) ) def is_scalar(obj: Any) -> bool: return isinstance(obj, (str, float, int, bool)) def join(args: List[str], delim: str = "\n") -> str: return delim.join(filter(lambda x: x, args)) def make_doc(obj: Any) -> str: doc = DocstringParser.clean_doc(obj) doc = textwrap.dedent(doc) if not doc: return "" return join(['"""', doc, '"""']) def make_field(writer: Writer, name: str, field: Any) -> str: return f"{name}: {writer.fix_and_include(DocstringParser.parse_type(field))}" def make_function( writer: Writer, name: str, obj: Any, include_docs: bool = True, parent: Optional[Any] = None, ) -> str: is_static_method = False if parent and isinstance(obj, staticmethod): # Get real method instance from `parent` if `obj` is a `staticmethod` is_static_method = True obj = getattr(parent, name) res = DocstringParser.parse_function(obj) if not res: # Make it `Any` if we can't parse the docstring return f"{name}: {writer.include('typing.Any')}" params, rtype = res # We're using signature for getting default values only, currently type hints are not supported # in signatures. We can leverage signatures more if it supports type hints in future. sig: Optional[inspect.Signature] = None try: sig = inspect.signature(obj) except Exception: pass def has_default(param: str, ty: str) -> bool: # PyO3 allows omitting `Option` params while calling a Rust function from Python, # we should always mark `typing.Optional[T]` values as they have default values to allow same # flexibiliy as runtime dynamics in type-stubs. if ty.startswith("typing.Optional["): return True if sig is None: return False sig_param = sig.parameters.get(param) return sig_param is not None and sig_param.default is not sig_param.empty receivers: List[str] = [] attrs: List[str] = [] if parent: if is_static_method: attrs.append("@staticmethod") else: receivers.append("self") def make_param(name: str, ty: str) -> str: fixed_ty = writer.fix_and_include(ty) param = f"{name}: {fixed_ty}" if has_default(name, fixed_ty): param += " = ..." return param params = join(receivers + [make_param(n, t) for n, t in params], delim=", ") attrs_str = join(attrs) rtype = writer.fix_and_include(rtype) body = "..." if include_docs: body = join([make_doc(obj), body]) return f""" {attrs_str} def {name}({params}) -> {rtype}: {indent(body)} """.lstrip() def make_class(writer: Writer, name: str, klass: Any) -> str: bases = list(filter(lambda n: n != "object", map(lambda b: b.__name__, klass.__bases__))) class_sig = DocstringParser.parse_class(klass) if class_sig: (generics, extends) = class_sig bases.extend(map(writer.fix_and_include, extends)) for g in generics: writer.generic(g) members: List[str] = [] class_vars: Dict[str, Any] = vars(klass) for member_name, member in sorted(class_vars.items(), key=lambda k: k[0]): if member_name.startswith("__"): continue if inspect.isdatadescriptor(member): members.append( join( [ make_field(writer, member_name, member), make_doc(member), ] ) ) elif is_fn_like(member): members.append( make_function(writer, member_name, member, parent=klass), ) elif isinstance(member, klass): # Enum variant members.append( join( [ f"{member_name}: {name}", make_doc(member), ] ) ) else: print(f"Unknown member type: {member}") if inspect.getdoc(klass) is not None: constructor_sig = DocstringParser.parse(klass) if constructor_sig is not None and ( # Make sure to only generate `__init__` if the class has a constructor defined len(constructor_sig.rtypes) > 0 or len(constructor_sig.params) > 0 ): members.append( make_function( writer, "__init__", klass, include_docs=False, parent=klass, ) ) bases_str = "" if len(bases) == 0 else f"({join(bases, delim=', ')})" doc = make_doc(klass) if doc: doc += "\n" body = join([doc, join(members, delim="\n\n") or "..."]) return f"""\ class {name}{bases_str}: {indent(body)} """ def walk_module(writer: Writer, mod: Any): exported = mod.__all__ for name, member in inspect.getmembers(mod): if name not in exported: continue if inspect.ismodule(member): subpath = writer.path.parent / name / "__init__.pyi" walk_module(writer.submodule(subpath), member) elif inspect.isclass(member): writer.define(make_class(writer, name, member)) elif is_fn_like(member): writer.define(make_function(writer, name, member)) elif is_scalar(member): writer.define(f"{name}: {type(member).__name__} = ...") else: print(f"Unknown type: {member}") def generate(module: str, outdir: str): path = Path(outdir) / "__init__.pyi" writer = Writer( path, module, ) walk_module( writer, importlib.import_module(module), ) writer.dump() if __name__ == "__main__": import argparse import importlib parser = argparse.ArgumentParser() parser.add_argument("module") parser.add_argument("outdir") args = parser.parse_args() generate(args.module, args.outdir)