Skip to content

Commit

Permalink
fix(spy): resolve source to origin of GenericAlias (#143)
Browse files Browse the repository at this point in the history
Fixes #142
  • Loading branch information
mcous authored Sep 25, 2022
1 parent 2096434 commit 7d021c2
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 20 deletions.
13 changes: 11 additions & 2 deletions decoy/spy_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@ class BoundArgs(NamedTuple):


class SpyCore:
"""Core spy logic for mimicing a given `source` object.
"""Core spy logic for mimicking a given `source` object.
Arguments:
source: The source object the Spy is mimicing.
source: The source object the Spy is mimicking.
name: The spec's name. If `None`, will be derived from `source`.
Will fallback to a default value.
module_name: The spec's module name. If left unspecified,
Expand All @@ -47,6 +47,8 @@ def __init__(
module_name: Union[str, _FROM_SOURCE, None] = FROM_SOURCE,
is_async: bool = False,
) -> None:
source = _resolve_source(source)

self._source = source
self._name = _get_name(source) if name is None else name
self._module_name = (
Expand Down Expand Up @@ -139,6 +141,13 @@ def create_child_core(self, name: str, is_async: bool) -> "SpyCore":
)


def _resolve_source(source: Any) -> Any:
"""Resolve the source object, unwrapping any generic aliases."""
origin = inspect.getattr_static(source, "__origin__", None)

return origin if origin is not None else source


def _get_name(source: Any) -> str:
"""Get the name of a source object."""
source_name = getattr(source, "__name__", None) if source is not None else None
Expand Down
33 changes: 15 additions & 18 deletions tests/fixtures.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,31 @@
"""Common test fixtures."""
from functools import lru_cache
from typing import Any
from typing import Any, Generic, TypeVar


class SomeClass:
"""Testing class."""

def foo(self, val: str) -> str:
"""Get the foo string."""
...

def bar(self, a: int, b: float, c: str) -> bool:
"""Get the bar bool based on a few inputs."""
...

@staticmethod
def fizzbuzz(hello: str) -> int:
"""Fizz some buzzes."""
...

def do_the_thing(self, *, flag: bool) -> None:
"""Perform a side-effect without a return value."""
...

@property
def primitive_property(self) -> str:
"""Get a primitive computed property."""
...

@lru_cache(maxsize=None)
def some_wrapped_method(self, val: str) -> str:
"""Get a thing through a wrapped method."""
...


class SomeNestedClass:
Expand All @@ -41,62 +35,65 @@ class SomeNestedClass:

def foo(self, val: str) -> str:
"""Get the foo string."""
...

@property
def child(self) -> SomeClass:
"""Get the child instance."""
...


class SomeAsyncClass:
"""Async testing class."""

async def foo(self, val: str) -> str:
"""Get the foo string."""
...

async def bar(self, a: int, b: float, c: str) -> bool:
"""Get the bar bool based on a few inputs."""
...

async def do_the_thing(self, *, flag: bool) -> None:
"""Perform a side-effect without a return value."""
...


class SomeAsyncCallableClass:
"""Async callable class."""

async def __call__(self, val: int) -> int:
"""Get an integer."""
...


class SomeCallableClass:
"""Async callable class."""

async def __call__(self, val: int) -> int:
"""Get an integer."""
...


def noop(*args: Any, **kwargs: Any) -> Any:
"""No-op."""
...


def some_func(val: str) -> str:
"""Test function."""
...


async def some_async_func(val: str) -> str:
"""Async test function."""
...


@lru_cache(maxsize=None)
def some_wrapped_func(val: str) -> str:
"""Wrapped test function."""
...


GenericT = TypeVar("GenericT")


class GenericClass(Generic[GenericT]):
"""A generic class definition."""

def hello(self, val: GenericT) -> None:
"""Say hello."""


ConcreteAlias = GenericClass[str]
"""An alias with a generic type specified"""
28 changes: 28 additions & 0 deletions tests/test_spy_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
SomeAsyncCallableClass,
SomeCallableClass,
SomeNestedClass,
GenericClass,
GenericT,
ConcreteAlias,
some_func,
some_async_func,
some_wrapped_func,
Expand Down Expand Up @@ -78,6 +81,16 @@ class GetNameSpec(NamedTuple):
expected_name="SomeNestedClass.child.foo",
expected_full_name="tests.fixtures.SomeNestedClass.child.foo",
),
GetNameSpec(
subject=SpyCore(source=GenericClass[int], name=None),
expected_name="GenericClass",
expected_full_name="tests.fixtures.GenericClass",
),
GetNameSpec(
subject=SpyCore(source=ConcreteAlias, name=None),
expected_name="GenericClass",
expected_full_name="tests.fixtures.GenericClass",
),
],
)
def test_get_name(
Expand Down Expand Up @@ -226,6 +239,21 @@ class GetSignatureSpec(NamedTuple):
return_annotation=str,
),
),
GetSignatureSpec(
subject=SpyCore(source=ConcreteAlias, name=None).create_child_core(
"hello", is_async=False
),
expected_signature=inspect.Signature(
parameters=[
inspect.Parameter(
name="val",
kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
annotation=GenericT,
)
],
return_annotation=None,
),
),
],
)
def test_get_signature(
Expand Down

0 comments on commit 7d021c2

Please sign in to comment.