Skip to content

Commit

Permalink
refactor(warnings): rework MiscalledStubWarning checker for clarity
Browse files Browse the repository at this point in the history
  • Loading branch information
mcous committed Aug 13, 2023
1 parent 5277ada commit ef248fc
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 43 deletions.
87 changes: 54 additions & 33 deletions decoy/warning_checker.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Warning checker."""
from typing import Dict, List, Sequence
from collections import defaultdict
from itertools import groupby
from typing import Dict, List, NamedTuple, Sequence
from warnings import warn

from .spy_events import (
Expand All @@ -8,6 +10,7 @@
SpyEvent,
VerifyRehearsal,
WhenRehearsal,
SpyRehearsal,
match_event,
)
from .warnings import DecoyWarning, MiscalledStubWarning, RedundantVerifyWarning
Expand All @@ -23,44 +26,62 @@ def check(all_calls: Sequence[AnySpyEvent]) -> None:
_check_no_redundant_verify(all_calls)


class _Call(NamedTuple):
event: SpyEvent
all_rehearsals: List[SpyRehearsal]
matching_rehearsals: List[SpyRehearsal]


def _check_no_miscalled_stubs(all_events: Sequence[AnySpyEvent]) -> None:
"""Ensure every call matches a rehearsal, if the spy has rehearsals."""
all_calls_by_id: Dict[int, List[AnySpyEvent]] = {}
all_events_by_id: Dict[int, List[AnySpyEvent]] = defaultdict(list)
all_calls_by_id: Dict[int, List[_Call]] = defaultdict(list)

for event in all_events:
if isinstance(event.payload, SpyCall):
spy_id = event.spy.id
spy_calls = all_calls_by_id.get(spy_id, [])
all_calls_by_id[spy_id] = [*spy_calls, event]
all_events_by_id[event.spy.id].append(event)

for events in all_events_by_id.values():
for index, event in enumerate(events):
if isinstance(event, SpyEvent) and isinstance(event.payload, SpyCall):
when_rehearsals = [
rehearsal
for rehearsal in events[0:index]
if isinstance(rehearsal, WhenRehearsal)
and isinstance(rehearsal.payload, SpyCall)
]
verify_rehearsals = [
rehearsal
for rehearsal in events[index + 1 :]
if isinstance(rehearsal, VerifyRehearsal)
and isinstance(rehearsal.payload, SpyCall)
]

all_rehearsals: List[SpyRehearsal] = [
*when_rehearsals,
*verify_rehearsals,
]
matching_rehearsals = [
rehearsal
for rehearsal in all_rehearsals
if match_event(event, rehearsal)
]

all_calls_by_id[event.spy.id].append(
_Call(event, all_rehearsals, matching_rehearsals)
)

for spy_calls in all_calls_by_id.values():
unmatched: List[SpyEvent] = []

for index, call in enumerate(spy_calls):
past_stubs = [
wr for wr in spy_calls[0:index] if isinstance(wr, WhenRehearsal)
]

matched_past_stubs = [wr for wr in past_stubs if match_event(call, wr)]

matched_future_verifies = [
vr
for vr in spy_calls[index + 1 :]
if isinstance(vr, VerifyRehearsal) and match_event(call, vr)
]

if (
isinstance(call, SpyEvent)
and len(past_stubs) > 0
and len(matched_past_stubs) == 0
and len(matched_future_verifies) == 0
):
unmatched = [*unmatched, call]
if index == len(spy_calls) - 1:
_warn(MiscalledStubWarning(calls=unmatched, rehearsals=past_stubs))
elif isinstance(call, WhenRehearsal) and len(unmatched) > 0:
_warn(MiscalledStubWarning(calls=unmatched, rehearsals=past_stubs))
unmatched = []
for rehearsals, grouped_calls in groupby(spy_calls, lambda c: c.all_rehearsals):
calls = list(grouped_calls)
is_stubbed = any(isinstance(r, WhenRehearsal) for r in rehearsals)

if is_stubbed and all(len(c.matching_rehearsals) == 0 for c in calls):
_warn(
MiscalledStubWarning(
calls=[c.event for c in calls],
rehearsals=rehearsals,
)
)


def _check_no_redundant_verify(all_calls: Sequence[AnySpyEvent]) -> None:
Expand Down
6 changes: 3 additions & 3 deletions decoy/warnings.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import os
from typing import Sequence

from .spy_events import SpyEvent, WhenRehearsal, VerifyRehearsal
from .spy_events import SpyEvent, SpyRehearsal, VerifyRehearsal
from .stringify import stringify_call, stringify_error_message, count


Expand All @@ -34,12 +34,12 @@ class MiscalledStubWarning(DecoyWarning):
calls: Actual calls to the mock.
"""

rehearsals: Sequence[WhenRehearsal]
rehearsals: Sequence[SpyRehearsal]
calls: Sequence[SpyEvent]

def __init__(
self,
rehearsals: Sequence[WhenRehearsal],
rehearsals: Sequence[SpyRehearsal],
calls: Sequence[SpyEvent],
) -> None:
heading = os.linesep.join(
Expand Down
8 changes: 1 addition & 7 deletions tests/test_call_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,7 @@ def stub_store(decoy: Decoy) -> StubStore:


@pytest.fixture()
def subject(
decoy: Decoy,
spy_log: SpyLog,
stub_store: StubStore,
) -> CallHandler:
def subject(spy_log: SpyLog, stub_store: StubStore) -> CallHandler:
"""Get a CallHandler instance with its dependencies mocked out."""
return CallHandler(
spy_log=spy_log,
Expand Down Expand Up @@ -98,7 +94,6 @@ def test_handle_call_with_raise(

def test_handle_call_with_action(
decoy: Decoy,
spy_log: SpyLog,
stub_store: StubStore,
subject: CallHandler,
) -> None:
Expand All @@ -120,7 +115,6 @@ def test_handle_call_with_action(

def test_handle_prop_get_with_action(
decoy: Decoy,
spy_log: SpyLog,
stub_store: StubStore,
subject: CallHandler,
) -> None:
Expand Down

0 comments on commit ef248fc

Please sign in to comment.