Skip to content

Commit

Permalink
fix(call_stack): match spy IDs in get_by_rehearsals (#55)
Browse files Browse the repository at this point in the history
Closes #54
  • Loading branch information
mcous authored Aug 5, 2021
1 parent 5ce3bc5 commit 8845b5b
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 7 deletions.
6 changes: 3 additions & 3 deletions decoy/call_stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def consume_when_rehearsal(self) -> WhenRehearsal:
"""
try:
call = self._stack[-1]
except KeyError:
except IndexError:
raise MissingRehearsalError()
if not isinstance(call, SpyCall):
raise MissingRehearsalError()
Expand All @@ -47,12 +47,12 @@ def consume_verify_rehearsals(self, count: int) -> List[VerifyRehearsal]:
return rehearsals

def get_by_rehearsals(self, rehearsals: Sequence[VerifyRehearsal]) -> List[SpyCall]:
"""Get a list of all non-rehearsal calls to the given Spy IDs."""
"""Get all non-rehearsal calls to the spies in the given rehearsals."""
return [
call
for call in self._stack
if isinstance(call, SpyCall)
and any(rehearsal == call for rehearsal in rehearsals)
and any(rehearsal.spy_id == call.spy_id for rehearsal in rehearsals)
]

def get_all(self) -> List[BaseSpyCall]:
Expand Down
3 changes: 2 additions & 1 deletion decoy/verifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@ def verify(
if times is not None:
if len(calls) == times:
return None

else:
for i, call in enumerate(calls):
for i in range(len(calls)):
calls_subset = calls[i : i + len(rehearsals)]

if calls_subset == rehearsals:
Expand Down
9 changes: 6 additions & 3 deletions tests/test_call_stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,11 @@ def test_push_and_consume_when_rehearsal() -> None:
def test_consume_when_rehearsal_raises_empty_error() -> None:
"""It should raise an error if the stack is empty on pop."""
subject = CallStack()
call = SpyCall(spy_id=42, spy_name="my_spy", args=(), kwargs={})

with pytest.raises(MissingRehearsalError):
subject.consume_when_rehearsal()

call = SpyCall(spy_id=42, spy_name="my_spy", args=(), kwargs={})
subject.push(call)
subject.consume_when_rehearsal()

Expand Down Expand Up @@ -64,7 +67,7 @@ def test_consume_verify_rehearsals_raises_error() -> None:


def test_get_by_rehearsal() -> None:
"""It can get a list of calls made matching a given rehearsal."""
"""It can get a list of calls made matching spy IDs of given rehearsals."""
subject = CallStack()
call_1 = SpyCall(spy_id=101, spy_name="spy_1", args=(1,), kwargs={})
call_2 = SpyCall(spy_id=101, spy_name="spy_1", args=(2,), kwargs={})
Expand All @@ -88,7 +91,7 @@ def test_get_by_rehearsal() -> None:
VerifyRehearsal(spy_id=202, spy_name="spy_2", args=(1,), kwargs={}),
]
)
assert result == [call_3]
assert result == [call_1, call_3, call_4]

result = subject.get_by_rehearsals(
[VerifyRehearsal(spy_id=303, spy_name="spy_3", args=(1,), kwargs={})]
Expand Down

0 comments on commit 8845b5b

Please sign in to comment.