diff --git a/src/stubs.ts b/src/stubs.ts index 759d7ee..ea0186b 100644 --- a/src/stubs.ts +++ b/src/stubs.ts @@ -1,11 +1,10 @@ -import type { Mock } from 'vitest' import { createBehaviorStack, type BehaviorStack, BehaviorType, } from './behaviors.ts' import { NotAMockFunctionError } from './errors.ts' -import type { AnyFunction } from './types.ts' +import type { AnyFunction, MockInstance } from './types.ts' const BEHAVIORS_KEY = Symbol('behaviors') @@ -63,7 +62,7 @@ export const configureStub = ( return behaviors } -export const validateSpy = (maybeSpy: unknown): Mock => { +export const validateSpy = (maybeSpy: unknown): MockInstance => { if ( typeof maybeSpy === 'function' && 'mockImplementation' in maybeSpy && @@ -73,14 +72,14 @@ export const validateSpy = (maybeSpy: unknown): Mock => { 'getMockName' in maybeSpy && typeof maybeSpy.getMockName === 'function' ) { - return maybeSpy as Mock + return maybeSpy as unknown as MockInstance } throw new NotAMockFunctionError(maybeSpy) } export const getBehaviorStack = ( - spy: Mock, + spy: MockInstance, ): BehaviorStack | undefined => { const existingImplementation = spy.getMockImplementation() as | WhenStubImplementation diff --git a/src/types.ts b/src/types.ts index 08b5f82..2791792 100644 --- a/src/types.ts +++ b/src/types.ts @@ -2,3 +2,20 @@ /** Any function, for use in `extends` */ export type AnyFunction = (...args: never[]) => unknown + +/** + * Minimally typed version of Vitest's `MockInstance`. + * + * Used to ensure backwards compatibility + * with older versions of Vitest. + */ +export interface MockInstance { + getMockName(): string + getMockImplementation(): TFunc | undefined + mockImplementation: (impl: TFunc) => this + mock: MockContext +} + +export interface MockContext { + calls: Parameters[] +} diff --git a/src/vitest-when.ts b/src/vitest-when.ts index f582e7f..0577b08 100644 --- a/src/vitest-when.ts +++ b/src/vitest-when.ts @@ -1,6 +1,6 @@ import { configureStub } from './stubs.ts' import type { WhenOptions } from './behaviors.ts' -import type { AnyFunction } from './types.ts' +import type { AnyFunction, MockInstance } from './types.ts' import { getDebug, type DebugResult } from './debug.ts' export { type WhenOptions, type Behavior, BehaviorType } from './behaviors.ts' @@ -22,7 +22,7 @@ export interface Stub { } export const when = ( - spy: TFunc, + spy: TFunc | MockInstance, options: WhenOptions = {}, ): StubWrapper => { const behaviorStack = configureStub(spy) diff --git a/test/typing.test-d.ts b/test/typing.test-d.ts index 98e8295..05d6efb 100644 --- a/test/typing.test-d.ts +++ b/test/typing.test-d.ts @@ -23,6 +23,16 @@ describe('vitest-when type signatures', () => { assertType>(stub) }) + it('should handle an spied function', () => { + const target = { simple } + const spy = vi.spyOn(target, 'simple') + const stub = subject.when(spy).calledWith(1) + + stub.thenReturn('hello') + + assertType>(stub) + }) + it('should handle a simple function', () => { const stub = subject.when(simple).calledWith(1)