diff --git a/decoy/warning_checker.py b/decoy/warning_checker.py index 704fa75..36c1f31 100644 --- a/decoy/warning_checker.py +++ b/decoy/warning_checker.py @@ -1,5 +1,7 @@ """Warning checker.""" -from typing import Dict, List, Sequence +from collections import defaultdict +from itertools import groupby +from typing import Dict, List, NamedTuple, Sequence from warnings import warn from .spy_events import ( @@ -8,6 +10,7 @@ SpyEvent, VerifyRehearsal, WhenRehearsal, + SpyRehearsal, match_event, ) from .warnings import DecoyWarning, MiscalledStubWarning, RedundantVerifyWarning @@ -23,44 +26,62 @@ def check(all_calls: Sequence[AnySpyEvent]) -> None: _check_no_redundant_verify(all_calls) +class _Call(NamedTuple): + event: SpyEvent + all_rehearsals: List[SpyRehearsal] + matching_rehearsals: List[SpyRehearsal] + + def _check_no_miscalled_stubs(all_events: Sequence[AnySpyEvent]) -> None: """Ensure every call matches a rehearsal, if the spy has rehearsals.""" - all_calls_by_id: Dict[int, List[AnySpyEvent]] = {} + all_events_by_id: Dict[int, List[AnySpyEvent]] = defaultdict(list) + all_calls_by_id: Dict[int, List[_Call]] = defaultdict(list) for event in all_events: - if isinstance(event.payload, SpyCall): - spy_id = event.spy.id - spy_calls = all_calls_by_id.get(spy_id, []) - all_calls_by_id[spy_id] = [*spy_calls, event] + all_events_by_id[event.spy.id].append(event) + + for events in all_events_by_id.values(): + for index, event in enumerate(events): + if isinstance(event, SpyEvent) and isinstance(event.payload, SpyCall): + when_rehearsals = [ + rehearsal + for rehearsal in events[0:index] + if isinstance(rehearsal, WhenRehearsal) + and isinstance(rehearsal.payload, SpyCall) + ] + verify_rehearsals = [ + rehearsal + for rehearsal in events[index + 1 :] + if isinstance(rehearsal, VerifyRehearsal) + and isinstance(rehearsal.payload, SpyCall) + ] + + all_rehearsals: List[SpyRehearsal] = [ + *when_rehearsals, + *verify_rehearsals, + ] + matching_rehearsals = [ + rehearsal + for rehearsal in all_rehearsals + if match_event(event, rehearsal) + ] + + all_calls_by_id[event.spy.id].append( + _Call(event, all_rehearsals, matching_rehearsals) + ) for spy_calls in all_calls_by_id.values(): - unmatched: List[SpyEvent] = [] - - for index, call in enumerate(spy_calls): - past_stubs = [ - wr for wr in spy_calls[0:index] if isinstance(wr, WhenRehearsal) - ] - - matched_past_stubs = [wr for wr in past_stubs if match_event(call, wr)] - - matched_future_verifies = [ - vr - for vr in spy_calls[index + 1 :] - if isinstance(vr, VerifyRehearsal) and match_event(call, vr) - ] - - if ( - isinstance(call, SpyEvent) - and len(past_stubs) > 0 - and len(matched_past_stubs) == 0 - and len(matched_future_verifies) == 0 - ): - unmatched = [*unmatched, call] - if index == len(spy_calls) - 1: - _warn(MiscalledStubWarning(calls=unmatched, rehearsals=past_stubs)) - elif isinstance(call, WhenRehearsal) and len(unmatched) > 0: - _warn(MiscalledStubWarning(calls=unmatched, rehearsals=past_stubs)) - unmatched = [] + for rehearsals, grouped_calls in groupby(spy_calls, lambda c: c.all_rehearsals): + calls = list(grouped_calls) + is_stubbed = any(isinstance(r, WhenRehearsal) for r in rehearsals) + + if is_stubbed and all(len(c.matching_rehearsals) == 0 for c in calls): + _warn( + MiscalledStubWarning( + calls=[c.event for c in calls], + rehearsals=rehearsals, + ) + ) def _check_no_redundant_verify(all_calls: Sequence[AnySpyEvent]) -> None: diff --git a/decoy/warnings.py b/decoy/warnings.py index 140301e..508bd03 100644 --- a/decoy/warnings.py +++ b/decoy/warnings.py @@ -7,7 +7,7 @@ import os from typing import Sequence -from .spy_events import SpyEvent, WhenRehearsal, VerifyRehearsal +from .spy_events import SpyEvent, SpyRehearsal, VerifyRehearsal from .stringify import stringify_call, stringify_error_message, count @@ -34,12 +34,12 @@ class MiscalledStubWarning(DecoyWarning): calls: Actual calls to the mock. """ - rehearsals: Sequence[WhenRehearsal] + rehearsals: Sequence[SpyRehearsal] calls: Sequence[SpyEvent] def __init__( self, - rehearsals: Sequence[WhenRehearsal], + rehearsals: Sequence[SpyRehearsal], calls: Sequence[SpyEvent], ) -> None: heading = os.linesep.join( diff --git a/tests/test_call_handler.py b/tests/test_call_handler.py index 6bc49c0..3492b37 100644 --- a/tests/test_call_handler.py +++ b/tests/test_call_handler.py @@ -21,11 +21,7 @@ def stub_store(decoy: Decoy) -> StubStore: @pytest.fixture() -def subject( - decoy: Decoy, - spy_log: SpyLog, - stub_store: StubStore, -) -> CallHandler: +def subject(spy_log: SpyLog, stub_store: StubStore) -> CallHandler: """Get a CallHandler instance with its dependencies mocked out.""" return CallHandler( spy_log=spy_log, @@ -98,7 +94,6 @@ def test_handle_call_with_raise( def test_handle_call_with_action( decoy: Decoy, - spy_log: SpyLog, stub_store: StubStore, subject: CallHandler, ) -> None: @@ -120,7 +115,6 @@ def test_handle_call_with_action( def test_handle_prop_get_with_action( decoy: Decoy, - spy_log: SpyLog, stub_store: StubStore, subject: CallHandler, ) -> None: