From 8845b5b99c017dfc2eae1c96a292e5ce29eb9136 Mon Sep 17 00:00:00 2001 From: Mike Cousins Date: Wed, 4 Aug 2021 23:36:04 -0400 Subject: [PATCH] fix(call_stack): match spy IDs in get_by_rehearsals (#55) Closes #54 --- decoy/call_stack.py | 6 +++--- decoy/verifier.py | 3 ++- tests/test_call_stack.py | 9 ++++++--- 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/decoy/call_stack.py b/decoy/call_stack.py index 8b53145..3c66d17 100644 --- a/decoy/call_stack.py +++ b/decoy/call_stack.py @@ -23,7 +23,7 @@ def consume_when_rehearsal(self) -> WhenRehearsal: """ try: call = self._stack[-1] - except KeyError: + except IndexError: raise MissingRehearsalError() if not isinstance(call, SpyCall): raise MissingRehearsalError() @@ -47,12 +47,12 @@ def consume_verify_rehearsals(self, count: int) -> List[VerifyRehearsal]: return rehearsals def get_by_rehearsals(self, rehearsals: Sequence[VerifyRehearsal]) -> List[SpyCall]: - """Get a list of all non-rehearsal calls to the given Spy IDs.""" + """Get all non-rehearsal calls to the spies in the given rehearsals.""" return [ call for call in self._stack if isinstance(call, SpyCall) - and any(rehearsal == call for rehearsal in rehearsals) + and any(rehearsal.spy_id == call.spy_id for rehearsal in rehearsals) ] def get_all(self) -> List[BaseSpyCall]: diff --git a/decoy/verifier.py b/decoy/verifier.py index b426c4f..39644b1 100644 --- a/decoy/verifier.py +++ b/decoy/verifier.py @@ -18,8 +18,9 @@ def verify( if times is not None: if len(calls) == times: return None + else: - for i, call in enumerate(calls): + for i in range(len(calls)): calls_subset = calls[i : i + len(rehearsals)] if calls_subset == rehearsals: diff --git a/tests/test_call_stack.py b/tests/test_call_stack.py index 9133942..277a72e 100644 --- a/tests/test_call_stack.py +++ b/tests/test_call_stack.py @@ -21,8 +21,11 @@ def test_push_and_consume_when_rehearsal() -> None: def test_consume_when_rehearsal_raises_empty_error() -> None: """It should raise an error if the stack is empty on pop.""" subject = CallStack() - call = SpyCall(spy_id=42, spy_name="my_spy", args=(), kwargs={}) + with pytest.raises(MissingRehearsalError): + subject.consume_when_rehearsal() + + call = SpyCall(spy_id=42, spy_name="my_spy", args=(), kwargs={}) subject.push(call) subject.consume_when_rehearsal() @@ -64,7 +67,7 @@ def test_consume_verify_rehearsals_raises_error() -> None: def test_get_by_rehearsal() -> None: - """It can get a list of calls made matching a given rehearsal.""" + """It can get a list of calls made matching spy IDs of given rehearsals.""" subject = CallStack() call_1 = SpyCall(spy_id=101, spy_name="spy_1", args=(1,), kwargs={}) call_2 = SpyCall(spy_id=101, spy_name="spy_1", args=(2,), kwargs={}) @@ -88,7 +91,7 @@ def test_get_by_rehearsal() -> None: VerifyRehearsal(spy_id=202, spy_name="spy_2", args=(1,), kwargs={}), ] ) - assert result == [call_3] + assert result == [call_1, call_3, call_4] result = subject.get_by_rehearsals( [VerifyRehearsal(spy_id=303, spy_name="spy_3", args=(1,), kwargs={})]