Skip to content

Commit

Permalink
fix(spy): follow __wrapped__ when getting specs and signatures (#134)
Browse files Browse the repository at this point in the history
Fixes #133
  • Loading branch information
mcous authored May 25, 2022
1 parent 29a3b08 commit 8d86195
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 9 deletions.
13 changes: 8 additions & 5 deletions decoy/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def get_signature(self) -> Optional[inspect.Signature]:
source = self._get_source()

try:
return inspect.signature(source)
return inspect.signature(source, follow_wrapped=True)
except (ValueError, TypeError):
return None

Expand Down Expand Up @@ -130,10 +130,13 @@ def get_child_spec(self, name: str) -> "Spec":
elif isinstance(child_source, staticmethod):
child_source = child_source.__func__

elif inspect.isfunction(child_source):
# consume the `self` argument of the method to ensure proper
# signature reporting by wrapping it in a partial
child_source = functools.partial(child_source, None)
else:
child_source = inspect.unwrap(child_source)

if inspect.isfunction(child_source):
# consume the `self` argument of the method to ensure proper
# signature reporting by wrapping it in a partial
child_source = functools.partial(child_source, None)

return Spec(source=child_source, name=child_name, module_name=self._module_name)

Expand Down
19 changes: 15 additions & 4 deletions tests/common.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Common test interfaces."""
from functools import lru_cache
from typing import Any


Expand Down Expand Up @@ -27,6 +28,11 @@ def primitive_property(self) -> str:
"""Get a primitive computed property."""
...

@lru_cache(maxsize=None)
def some_wrapped_method(self, val: str) -> str:
"""Get a thing through a wrapped method."""
...


class SomeNestedClass:
"""Nested testing class."""
Expand Down Expand Up @@ -75,17 +81,22 @@ async def __call__(self, val: int) -> int:
...


# NOTE: these `Any`s are forward references for call signature testing purposes
def noop(*args: Any, **kwargs: Any) -> Any:
"""No-op."""
pass
...


def some_func(val: str) -> str:
"""Test function."""
return "can't touch this"
...


async def some_async_func(val: str) -> str:
"""Async test function."""
return "can't touch this"
...


@lru_cache(maxsize=None)
def some_wrapped_func(val: str) -> str:
"""Wrapped test function."""
...
36 changes: 36 additions & 0 deletions tests/test_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
SomeNestedClass,
some_func,
some_async_func,
some_wrapped_func,
)


Expand Down Expand Up @@ -184,6 +185,34 @@ class GetSignatureSpec(NamedTuple):
return_annotation=int,
),
),
GetSignatureSpec(
subject=Spec(source=some_wrapped_func, name=None),
expected_signature=inspect.Signature(
parameters=[
inspect.Parameter(
name="val",
kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
annotation=str,
)
],
return_annotation=str,
),
),
GetSignatureSpec(
subject=Spec(source=SomeClass, name=None).get_child_spec(
"some_wrapped_method"
),
expected_signature=inspect.Signature(
parameters=[
inspect.Parameter(
name="val",
kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
annotation=str,
)
],
return_annotation=str,
),
),
],
)
def test_get_signature(
Expand Down Expand Up @@ -315,6 +344,13 @@ class GetBindArgsSpec(NamedTuple):
expected_args=("hello",),
expected_kwargs={},
),
GetBindArgsSpec(
subject=Spec(source=some_wrapped_func, name=None),
input_args=(),
input_kwargs={"val": "hello"},
expected_args=("hello",),
expected_kwargs={},
),
],
)
def test_bind_args(
Expand Down

0 comments on commit 8d86195

Please sign in to comment.