diff --git a/decoy/spy_events.py b/decoy/spy_events.py index 69b8d1c..3a0d1a2 100644 --- a/decoy/spy_events.py +++ b/decoy/spy_events.py @@ -83,10 +83,10 @@ def match_event(event: AnySpyEvent, rehearsal: SpyRehearsal) -> bool: try: args_match = all( - call.args[i] == value for i, value in enumerate(rehearsed_call.args) + value == call.args[i] for i, value in enumerate(rehearsed_call.args) ) kwargs_match = all( - call.kwargs[key] == value + value == call.kwargs[key] for key, value in rehearsed_call.kwargs.items() ) @@ -95,4 +95,4 @@ def match_event(event: AnySpyEvent, rehearsal: SpyRehearsal) -> bool: except (IndexError, KeyError): return False - return event.payload == rehearsal.payload + return rehearsal.payload == event.payload diff --git a/tests/test_spy_events.py b/tests/test_spy_events.py index 8c0b94f..aebf55d 100644 --- a/tests/test_spy_events.py +++ b/tests/test_spy_events.py @@ -176,3 +176,50 @@ def test_match_event( """It should match a call to a rehearsal.""" result = match_event(event, rehearsal) assert result is expected_result + + +def test_match_eq_override() -> None: + """It should prefer __eq__ from the rehearsal.""" + + class _Matcher: + def __eq__(self, other: object) -> bool: + return True + + class _Value: + def __eq__(self, other: object) -> bool: + return False + + event_args = SpyEvent( + spy=SpyInfo(id=42, name="my_spy", is_async=False), + payload=SpyCall(args=(_Value(),), kwargs={}), + ) + + event_kwargs = SpyEvent( + spy=SpyInfo(id=42, name="my_spy", is_async=False), + payload=SpyCall(args=(), kwargs={"value": _Value()}), + ) + + rehearsal_ars = WhenRehearsal( + spy=SpyInfo(id=42, name="my_spy", is_async=False), + payload=SpyCall(args=(_Matcher(),), kwargs={}), + ) + + rehearsal_kwargs = WhenRehearsal( + spy=SpyInfo(id=42, name="my_spy", is_async=False), + payload=SpyCall(args=(), kwargs={"value": _Matcher()}), + ) + + rehearsal_args_ignore_extra = WhenRehearsal( + spy=SpyInfo(id=42, name="my_spy", is_async=False), + payload=SpyCall(args=(_Matcher(),), kwargs={}, ignore_extra_args=True), + ) + + rehearsal_kwargs_ignore_extra = WhenRehearsal( + spy=SpyInfo(id=42, name="my_spy", is_async=False), + payload=SpyCall(args=(), kwargs={"value": _Matcher()}, ignore_extra_args=True), + ) + + assert match_event(event_args, rehearsal_ars) is True + assert match_event(event_kwargs, rehearsal_kwargs) is True + assert match_event(event_args, rehearsal_args_ignore_extra) is True + assert match_event(event_kwargs, rehearsal_kwargs_ignore_extra) is True