From 7d021c2cdc58f74e9b803bcf7a9f5d56a6524192 Mon Sep 17 00:00:00 2001 From: Mike Cousins Date: Sun, 25 Sep 2022 16:01:19 -0400 Subject: [PATCH] fix(spy): resolve source to origin of GenericAlias (#143) Fixes #142 --- decoy/spy_core.py | 13 +++++++++++-- tests/fixtures.py | 33 +++++++++++++++------------------ tests/test_spy_core.py | 28 ++++++++++++++++++++++++++++ 3 files changed, 54 insertions(+), 20 deletions(-) diff --git a/decoy/spy_core.py b/decoy/spy_core.py index 3390538..3465b3a 100644 --- a/decoy/spy_core.py +++ b/decoy/spy_core.py @@ -27,10 +27,10 @@ class BoundArgs(NamedTuple): class SpyCore: - """Core spy logic for mimicing a given `source` object. + """Core spy logic for mimicking a given `source` object. Arguments: - source: The source object the Spy is mimicing. + source: The source object the Spy is mimicking. name: The spec's name. If `None`, will be derived from `source`. Will fallback to a default value. module_name: The spec's module name. If left unspecified, @@ -47,6 +47,8 @@ def __init__( module_name: Union[str, _FROM_SOURCE, None] = FROM_SOURCE, is_async: bool = False, ) -> None: + source = _resolve_source(source) + self._source = source self._name = _get_name(source) if name is None else name self._module_name = ( @@ -139,6 +141,13 @@ def create_child_core(self, name: str, is_async: bool) -> "SpyCore": ) +def _resolve_source(source: Any) -> Any: + """Resolve the source object, unwrapping any generic aliases.""" + origin = inspect.getattr_static(source, "__origin__", None) + + return origin if origin is not None else source + + def _get_name(source: Any) -> str: """Get the name of a source object.""" source_name = getattr(source, "__name__", None) if source is not None else None diff --git a/tests/fixtures.py b/tests/fixtures.py index 08ddbe7..951d469 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -1,6 +1,6 @@ """Common test fixtures.""" from functools import lru_cache -from typing import Any +from typing import Any, Generic, TypeVar class SomeClass: @@ -8,30 +8,24 @@ class SomeClass: def foo(self, val: str) -> str: """Get the foo string.""" - ... def bar(self, a: int, b: float, c: str) -> bool: """Get the bar bool based on a few inputs.""" - ... @staticmethod def fizzbuzz(hello: str) -> int: """Fizz some buzzes.""" - ... def do_the_thing(self, *, flag: bool) -> None: """Perform a side-effect without a return value.""" - ... @property def primitive_property(self) -> str: """Get a primitive computed property.""" - ... @lru_cache(maxsize=None) def some_wrapped_method(self, val: str) -> str: """Get a thing through a wrapped method.""" - ... class SomeNestedClass: @@ -41,12 +35,10 @@ class SomeNestedClass: def foo(self, val: str) -> str: """Get the foo string.""" - ... @property def child(self) -> SomeClass: """Get the child instance.""" - ... class SomeAsyncClass: @@ -54,15 +46,12 @@ class SomeAsyncClass: async def foo(self, val: str) -> str: """Get the foo string.""" - ... async def bar(self, a: int, b: float, c: str) -> bool: """Get the bar bool based on a few inputs.""" - ... async def do_the_thing(self, *, flag: bool) -> None: """Perform a side-effect without a return value.""" - ... class SomeAsyncCallableClass: @@ -70,7 +59,6 @@ class SomeAsyncCallableClass: async def __call__(self, val: int) -> int: """Get an integer.""" - ... class SomeCallableClass: @@ -78,25 +66,34 @@ class SomeCallableClass: async def __call__(self, val: int) -> int: """Get an integer.""" - ... def noop(*args: Any, **kwargs: Any) -> Any: """No-op.""" - ... def some_func(val: str) -> str: """Test function.""" - ... async def some_async_func(val: str) -> str: """Async test function.""" - ... @lru_cache(maxsize=None) def some_wrapped_func(val: str) -> str: """Wrapped test function.""" - ... + + +GenericT = TypeVar("GenericT") + + +class GenericClass(Generic[GenericT]): + """A generic class definition.""" + + def hello(self, val: GenericT) -> None: + """Say hello.""" + + +ConcreteAlias = GenericClass[str] +"""An alias with a generic type specified""" diff --git a/tests/test_spy_core.py b/tests/test_spy_core.py index e6e6005..8a1df48 100644 --- a/tests/test_spy_core.py +++ b/tests/test_spy_core.py @@ -11,6 +11,9 @@ SomeAsyncCallableClass, SomeCallableClass, SomeNestedClass, + GenericClass, + GenericT, + ConcreteAlias, some_func, some_async_func, some_wrapped_func, @@ -78,6 +81,16 @@ class GetNameSpec(NamedTuple): expected_name="SomeNestedClass.child.foo", expected_full_name="tests.fixtures.SomeNestedClass.child.foo", ), + GetNameSpec( + subject=SpyCore(source=GenericClass[int], name=None), + expected_name="GenericClass", + expected_full_name="tests.fixtures.GenericClass", + ), + GetNameSpec( + subject=SpyCore(source=ConcreteAlias, name=None), + expected_name="GenericClass", + expected_full_name="tests.fixtures.GenericClass", + ), ], ) def test_get_name( @@ -226,6 +239,21 @@ class GetSignatureSpec(NamedTuple): return_annotation=str, ), ), + GetSignatureSpec( + subject=SpyCore(source=ConcreteAlias, name=None).create_child_core( + "hello", is_async=False + ), + expected_signature=inspect.Signature( + parameters=[ + inspect.Parameter( + name="val", + kind=inspect.Parameter.POSITIONAL_OR_KEYWORD, + annotation=GenericT, + ) + ], + return_annotation=None, + ), + ), ], ) def test_get_signature(