diff --git a/README.md b/README.md index 8536c28..54edb05 100644 --- a/README.md +++ b/README.md @@ -186,7 +186,7 @@ export const calculateQuestion = async (answer: number): Promise => { ### `when(spy: TFunc, options?: WhenOptions): StubWrapper` -Configures a `vi.fn()` mock function to act as a vitest-when stub. Adds an implementation to the function that initially no-ops, and returns an API to configure behaviors for given arguments using [`.calledWith(...)`][called-with] +Configures a `vi.fn()` or `vi.spyOn()` mock function to act as a vitest-when stub. Adds an implementation to the function that initially no-ops, and returns an API to configure behaviors for given arguments using [`.calledWith(...)`][called-with] ```ts import { vi } from 'vitest' @@ -264,6 +264,21 @@ when(overloaded).calledWith().thenReturn(null) when<() => null>(overloaded).calledWith().thenReturn(null) ``` +#### Fallback + +By default, if arguments do not match, a vitest-when stub will no-op and return `undefined`. You can customize this fallback by configuring your own unconditional behavior on the mock using Vitest's built-in [mock API][]. + +```ts +const spy = vi.fn().mockReturnValue('you messed up!') + +when(spy).calledWith('hello').thenReturn('world') + +spy('hello') // "world" +spy('jello') // "you messed up!" +``` + +[mock API]: https://vitest.dev/api/mock.html + ### `.thenReturn(value: TReturn)` When the stubbing is satisfied, return `value` diff --git a/src/debug.ts b/src/debug.ts index f5f7d84..851063b 100644 --- a/src/debug.ts +++ b/src/debug.ts @@ -4,7 +4,7 @@ import { } from 'pretty-format' import { validateSpy, getBehaviorStack } from './stubs' -import type { AnyFunction } from './types' +import type { AnyFunction, MockInstance } from './types' import { type Behavior, BehaviorType } from './behaviors' export interface DebugResult { @@ -21,11 +21,11 @@ export interface Stubbing { } export const getDebug = ( - spy: TFunc, + spy: TFunc | MockInstance, ): DebugResult => { - const target = validateSpy(spy) + const target = validateSpy(spy) const name = target.getMockName() - const behaviors = getBehaviorStack(target) + const behaviors = getBehaviorStack(target) const unmatchedCalls = behaviors?.getUnmatchedCalls() ?? target.mock.calls const stubbings = behaviors?.getAll().map((entry) => ({ diff --git a/src/stubs.ts b/src/stubs.ts index ea0186b..b1fe5a1 100644 --- a/src/stubs.ts +++ b/src/stubs.ts @@ -16,7 +16,7 @@ interface WhenStubImplementation { export const configureStub = ( maybeSpy: unknown, ): BehaviorStack => { - const spy = validateSpy(maybeSpy) + const spy = validateSpy(maybeSpy) const existingBehaviors = getBehaviorStack(spy) if (existingBehaviors) { @@ -24,11 +24,12 @@ export const configureStub = ( } const behaviors = createBehaviorStack() + const fallbackImplementation = spy.getMockImplementation() const implementation = (...args: Parameters) => { const behavior = behaviors.use(args)?.behavior ?? { - type: BehaviorType.RETURN, - value: undefined, + type: BehaviorType.DO, + callback: fallbackImplementation, } switch (behavior.type) { @@ -50,19 +51,21 @@ export const configureStub = ( } case BehaviorType.DO: { - return behavior.callback(...args) + return behavior.callback?.(...args) } } } spy.mockImplementation( - Object.assign(implementation, { [BEHAVIORS_KEY]: behaviors }), + Object.assign(implementation as TFunc, { [BEHAVIORS_KEY]: behaviors }), ) return behaviors } -export const validateSpy = (maybeSpy: unknown): MockInstance => { +export const validateSpy = ( + maybeSpy: unknown, +): MockInstance => { if ( typeof maybeSpy === 'function' && 'mockImplementation' in maybeSpy && @@ -72,14 +75,14 @@ export const validateSpy = (maybeSpy: unknown): MockInstance => { 'getMockName' in maybeSpy && typeof maybeSpy.getMockName === 'function' ) { - return maybeSpy as unknown as MockInstance + return maybeSpy as unknown as MockInstance } throw new NotAMockFunctionError(maybeSpy) } export const getBehaviorStack = ( - spy: MockInstance, + spy: MockInstance, ): BehaviorStack | undefined => { const existingImplementation = spy.getMockImplementation() as | WhenStubImplementation diff --git a/src/vitest-when.ts b/src/vitest-when.ts index 0577b08..4a1eb95 100644 --- a/src/vitest-when.ts +++ b/src/vitest-when.ts @@ -47,7 +47,7 @@ export interface DebugOptions { } export const debug = ( - spy: TFunc, + spy: TFunc | MockInstance, options: DebugOptions = {}, ): DebugResult => { const log = options.log ?? true diff --git a/test/vitest-when.test.ts b/test/vitest-when.test.ts index 4f3ae86..87283d2 100644 --- a/test/vitest-when.test.ts +++ b/test/vitest-when.test.ts @@ -55,6 +55,15 @@ describe('vitest-when', () => { expect(spy(1, 2, 3)).toEqual(undefined) }) + it('should fall back to original mock implementation', () => { + const spy = vi.fn().mockReturnValue(100) + + subject.when(spy).calledWith(1, 2, 3).thenReturn(4) + + expect(spy(1, 2, 3)).toEqual(4) + expect(spy()).toEqual(100) + }) + it('should return a number of times', () => { const spy = vi.fn()