# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 import sys import unittest from types import ModuleType from textwrap import dedent from pathlib import Path from tempfile import TemporaryDirectory from stubgen import Writer, walk_module def create_module(name: str, code: str) -> ModuleType: mod = ModuleType(name) exec(dedent(code), mod.__dict__) if not hasattr(mod, "__all__"): # Manually populate `__all__` with all the members that doesn't start with `__` mod.__all__ = [k for k in mod.__dict__.keys() if not k.startswith("__")] # type: ignore sys.modules[name] = mod return mod class TestStubgen(unittest.TestCase): def test_function_without_docstring(self): self.single_mod( """ def foo(): pass """, """ import typing foo: typing.Any """, ) def test_regular_function(self): self.single_mod( """ def foo(bar): ''' :param bar str: :rtype bool: ''' pass """, """ def foo(bar: str) -> bool: ... """, ) def test_function_with_default_value(self): self.single_mod( """ def foo(bar, qux=None): ''' :param bar int: :param qux typing.Optional[str]: :rtype None: ''' pass """, """ import typing def foo(bar: int, qux: typing.Optional[str] = ...) -> None: ... """, ) def test_empty_class(self): self.single_mod( """ class Foo: pass """, """ class Foo: ... """, ) def test_class(self): self.single_mod( """ class Foo: @property def bar(self): ''' :type typing.List[bool]: ''' pass def qux(self, a, b, c): ''' :param a typing.Dict[typing.List[int]]: :param b str: :param c float: :rtype typing.Union[int, str, bool]: ''' pass """, """ import typing class Foo: bar: typing.List[bool] def qux(self, a: typing.Dict[typing.List[int]], b: str, c: float) -> typing.Union[int, str, bool]: ... """, ) def test_class_with_constructor_signature(self): self.single_mod( """ class Foo: ''' :param bar str: :rtype None: ''' """, """ class Foo: def __init__(self, bar: str) -> None: ... """, ) def test_class_with_static_method(self): self.single_mod( """ class Foo: @staticmethod def bar(name): ''' :param name str: :rtype typing.List[bool]: ''' pass """, """ import typing class Foo: @staticmethod def bar(name: str) -> typing.List[bool]: ... """, ) def test_class_with_an_undocumented_descriptor(self): self.single_mod( """ class Foo: @property def bar(self): pass """, """ import typing class Foo: bar: typing.Any """, ) def test_enum(self): self.single_mod( """ class Foo: def __init__(self, name): pass Foo.Bar = Foo("Bar") Foo.Baz = Foo("Baz") Foo.Qux = Foo("Qux") """, """ class Foo: Bar: Foo Baz: Foo Qux: Foo """, ) def test_generic(self): self.single_mod( """ class Foo: ''' :generic T: :generic U: :extends typing.Generic[T]: :extends typing.Generic[U]: ''' @property def bar(self): ''' :type typing.Tuple[T, U]: ''' pass def baz(self, a): ''' :param a U: :rtype T: ''' pass """, """ import typing T = typing.TypeVar('T') U = typing.TypeVar('U') class Foo(typing.Generic[T], typing.Generic[U]): bar: typing.Tuple[T, U] def baz(self, a: U) -> T: ... """, ) def test_items_with_docstrings(self): self.single_mod( """ class Foo: ''' This is the docstring of Foo. And it has multiple lines. :generic T: :extends typing.Generic[T]: :param member T: ''' @property def bar(self): ''' This is the docstring of property `bar`. :type typing.Optional[T]: ''' pass def baz(self, t): ''' This is the docstring of method `baz`. :param t T: :rtype T: ''' pass """, ''' import typing T = typing.TypeVar('T') class Foo(typing.Generic[T]): """ This is the docstring of Foo. And it has multiple lines. """ bar: typing.Optional[T] """ This is the docstring of property `bar`. """ def baz(self, t: T) -> T: """ This is the docstring of method `baz`. """ ... def __init__(self, member: T) -> None: ... ''', ) def test_adds_default_to_optional_types(self): # Since PyO3 provides `impl FromPyObject for Option` and maps Python `None` to Rust `None`, # you don't have to pass `None` explicitly. Type-stubs also shoudln't require `None`s # to be passed explicitly (meaning they should have a default value). self.single_mod( """ def foo(bar, qux): ''' :param bar typing.Optional[int]: :param qux typing.List[typing.Optional[int]]: :rtype int: ''' pass """, """ import typing def foo(bar: typing.Optional[int] = ..., qux: typing.List[typing.Optional[int]]) -> int: ... """, ) def test_multiple_mods(self): create_module( "foo.bar", """ class Bar: ''' :param qux str: :rtype None: ''' pass """, ) foo = create_module( "foo", """ import sys bar = sys.modules["foo.bar"] class Foo: ''' :param a __root_module_name__.bar.Bar: :param b typing.Optional[__root_module_name__.bar.Bar]: :rtype None: ''' @property def a(self): ''' :type __root_module_name__.bar.Bar: ''' pass @property def b(self): ''' :type typing.Optional[__root_module_name__.bar.Bar]: ''' pass __all__ = ["bar", "Foo"] """, ) with TemporaryDirectory() as temp_dir: foo_path = Path(temp_dir) / "foo.pyi" bar_path = Path(temp_dir) / "bar" / "__init__.pyi" writer = Writer(foo_path, "foo") walk_module(writer, foo) writer.dump() self.assert_stub( foo_path, """ import foo.bar import typing class Foo: a: foo.bar.Bar b: typing.Optional[foo.bar.Bar] def __init__(self, a: foo.bar.Bar, b: typing.Optional[foo.bar.Bar] = ...) -> None: ... """, ) self.assert_stub( bar_path, """ class Bar: def __init__(self, qux: str) -> None: ... """, ) def single_mod(self, mod_code: str, expected_stub: str) -> None: with TemporaryDirectory() as temp_dir: mod = create_module("test", mod_code) path = Path(temp_dir) / "test.pyi" writer = Writer(path, "test") walk_module(writer, mod) writer.dump() self.assert_stub(path, expected_stub) def assert_stub(self, path: Path, expected: str) -> None: self.assertEqual(path.read_text().strip(), dedent(expected).strip()) if __name__ == "__main__": unittest.main()