Skip to content

Commit

Permalink
fix(types)!: do not try to infer types of overloaded functions
Browse files Browse the repository at this point in the history
  • Loading branch information
mcous committed Oct 7, 2023
1 parent 14a24cc commit f59ed24
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 118 deletions.
14 changes: 14 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,20 @@ expect(spy('hello')).toEqual('goodbye');

[asymmetric matchers]: https://vitest.dev/api/expect.html#expect-anything

#### Types of overloaded functions

Due to fundamental limitations of how TypeScript handles the types of overloaded functions, `when` will always pick the _last_ overload as the type of `TFunc`. You can use the `TFunc` type argument of when to customize this if you're stubbing a different overload:

```ts
function overloaded(): null;
function overloaded(input: number): string;
function overloaded(input?: number): string | null {
// ...
}

when<() => null>(overloaded).calledWith().thenReturn(null);
```

### `.thenReturn(value: TReturn)`

When the stubbing is satisfied, return `value`
Expand Down
16 changes: 6 additions & 10 deletions src/behaviors.ts
Original file line number Diff line number Diff line change
@@ -1,22 +1,18 @@
import { equals } from '@vitest/expect';
import type {
AnyFunction,
AllParameters,
ReturnTypeFromArgs,
} from './types.ts';
import type { AnyFunction } from './types.ts';

export const ONCE = Symbol('ONCE');

export type StubValue<TValue> = TValue | typeof ONCE;

export interface BehaviorStack<TFunc extends AnyFunction> {
use: (
args: AllParameters<TFunc>
) => BehaviorEntry<AllParameters<TFunc>> | undefined;
args: Parameters<TFunc>
) => BehaviorEntry<Parameters<TFunc>> | undefined;

bindArgs: <TArgs extends AllParameters<TFunc>>(
bindArgs: <TArgs extends Parameters<TFunc>>(
args: TArgs
) => BoundBehaviorStack<ReturnTypeFromArgs<TFunc, TArgs>>;
) => BoundBehaviorStack<ReturnType<TFunc>>;
}

export interface BoundBehaviorStack<TReturn> {
Expand All @@ -43,7 +39,7 @@ export interface BehaviorOptions<TValue> {
export const createBehaviorStack = <
TFunc extends AnyFunction
>(): BehaviorStack<TFunc> => {
const behaviors: BehaviorEntry<AllParameters<TFunc>>[] = [];
const behaviors: BehaviorEntry<Parameters<TFunc>>[] = [];

return {
use: (args) => {
Expand Down
10 changes: 5 additions & 5 deletions src/stubs.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import type { Mock as Spy } from 'vitest';
import { createBehaviorStack, type BehaviorStack } from './behaviors.ts';
import { NotAMockFunctionError } from './errors.ts';
import type { AnyFunction, AllParameters } from './types.ts';
import type { AnyFunction } from './types.ts';

const BEHAVIORS_KEY = Symbol('behaviors');

interface WhenStubImplementation<TFunc extends AnyFunction> {
(...args: AllParameters<TFunc>): unknown;
(...args: Parameters<TFunc>): unknown;
[BEHAVIORS_KEY]: BehaviorStack<TFunc>;
}

Expand All @@ -25,7 +25,7 @@ export const configureStub = <TFunc extends AnyFunction>(

const behaviors = createBehaviorStack<TFunc>();

const implementation = (...args: AllParameters<TFunc>): unknown => {
const implementation = (...args: Parameters<TFunc>): unknown => {
const behavior = behaviors.use(args);

if (behavior?.throwError) {
Expand All @@ -48,15 +48,15 @@ export const configureStub = <TFunc extends AnyFunction>(

const validateSpy = <TFunc extends AnyFunction>(
maybeSpy: unknown
): Spy<AllParameters<TFunc>, unknown> => {
): Spy<Parameters<TFunc>, unknown> => {
if (
typeof maybeSpy === 'function' &&
'mockImplementation' in maybeSpy &&
typeof maybeSpy.mockImplementation === 'function' &&
'getMockImplementation' in maybeSpy &&
typeof maybeSpy.getMockImplementation === 'function'
) {
return maybeSpy as Spy<AllParameters<TFunc>, unknown>;
return maybeSpy as Spy<Parameters<TFunc>, unknown>;
}

throw new NotAMockFunctionError(maybeSpy);
Expand Down
58 changes: 1 addition & 57 deletions src/types.ts
Original file line number Diff line number Diff line change
@@ -1,60 +1,4 @@
/**
* Get function arguments and return value types.
*
* Support for overloaded functions, thanks to @Shakeskeyboarde
* https://github.com/microsoft/TypeScript/issues/14107#issuecomment-1146738780
*/

import type { SpyInstance } from 'vitest';
/** Common type definitions. */

/** Any function, for use in `extends` */
export type AnyFunction = (...args: never[]) => unknown;

/** Acceptable arguments for a function.*/
export type AllParameters<TFunc extends AnyFunction> =
TFunc extends SpyInstance<infer TArgs, unknown>
? TArgs
: Parameters<ToOverloads<TFunc>>;

/** The return type of a function, given the actual arguments used.*/
export type ReturnTypeFromArgs<
TFunc extends AnyFunction,
TArgs extends unknown[]
> = TFunc extends SpyInstance<unknown[], infer TReturn>
? TReturn
: ExtractReturn<ToOverloads<TFunc>, TArgs>;

/** Given a functions and actual arguments used, extract the return type. */
type ExtractReturn<
TFunc extends AnyFunction,
TArgs extends unknown[]
> = TFunc extends (...args: infer TFuncArgs) => infer TFuncReturn
? TArgs extends TFuncArgs
? TFuncReturn
: never
: never;

/** Transform an overloaded function into a union of functions. */
type ToOverloads<TFunc extends AnyFunction> = Exclude<
OverloadUnion<(() => never) & TFunc>,
TFunc extends () => never ? never : () => never
>;

/** Recursively extract functions from an overload into a union. */
type OverloadUnion<TFunc, TPartialOverload = unknown> = TFunc extends (
...args: infer TArgs
) => infer TReturn
? TPartialOverload extends TFunc
? never
:
| OverloadUnion<
TPartialOverload & TFunc,
TPartialOverload &
((...args: TArgs) => TReturn) &
OverloadProps<TFunc>
>
| ((...args: TArgs) => TReturn)
: never;

/** Properties attached to a function. */
type OverloadProps<TFunc> = Pick<TFunc, keyof TFunc>;
10 changes: 3 additions & 7 deletions src/vitest-when.ts
Original file line number Diff line number Diff line change
@@ -1,18 +1,14 @@
import { configureStub } from './stubs.ts';
import type { StubValue } from './behaviors.ts';
import type {
AnyFunction,
AllParameters,
ReturnTypeFromArgs,
} from './types.ts';
import type { AnyFunction } from './types.ts';

export { ONCE, type StubValue } from './behaviors.ts';
export * from './errors.ts';

export interface StubWrapper<TFunc extends AnyFunction> {
calledWith<TArgs extends AllParameters<TFunc>>(
calledWith<TArgs extends Parameters<TFunc>>(
...args: TArgs
): Stub<TArgs, ReturnTypeFromArgs<TFunc, TArgs>>;
): Stub<TArgs, ReturnType<TFunc>>;
}

export interface Stub<TArgs extends unknown[], TReturn> {
Expand Down
60 changes: 21 additions & 39 deletions test/typing.test-d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,12 @@ describe('vitest-when type signatures', () => {
assertType<subject.Stub<[1], string>>(stub);
});

it('should reject invalid usage of a simple function', () => {
// @ts-expect-error: args missing
subject.when(simple).calledWith();
it('should handle a generic function', () => {
const stub = subject.when(generic).calledWith(1);

// @ts-expect-error: args wrong type
subject.when(simple).calledWith('hello');
stub.thenReturn('hello');

// @ts-expect-error: return wrong type
subject.when(simple).calledWith(1).thenReturn(42);
assertType<subject.Stub<[number], string>>(stub);
});

it('should handle an overloaded function using its last overload', () => {
Expand All @@ -50,30 +47,14 @@ describe('vitest-when type signatures', () => {
assertType<subject.Stub<[1], string>>(stub);
});

it('should handle an overloaded function using its first overload', () => {
const stub = subject.when(overloaded).calledWith();
it('should handle an overloaded function using an explicit type', () => {
const stub = subject.when<() => null>(overloaded).calledWith();

stub.thenReturn(null);

assertType<subject.Stub<[], null>>(stub);
});

it('should handle an very overloaded function using its first overload', () => {
const stub = subject.when(veryOverloaded).calledWith();

stub.thenReturn(null);

assertType<subject.Stub<[], null>>(stub);
});

it('should handle an overloaded function using its last overload', () => {
const stub = subject.when(veryOverloaded).calledWith(1, 2, 3, 4);

stub.thenReturn(42);

assertType<subject.Stub<[1, 2, 3, 4], number>>(stub);
});

it('should reject invalid usage of a simple function', () => {
// @ts-expect-error: args missing
subject.when(simple).calledWith();
Expand All @@ -84,6 +65,17 @@ describe('vitest-when type signatures', () => {
// @ts-expect-error: return wrong type
subject.when(simple).calledWith(1).thenReturn(42);
});

it('should reject invalid usage of a generic function', () => {
// @ts-expect-error: args missing
subject.when(generic).calledWith();

// @ts-expect-error: args wrong type
subject.when(generic<string>).calledWith(42);

// @ts-expect-error: return wrong type
subject.when(generic).calledWith(1).thenReturn(42);
});
});

function untyped(...args: any[]): any {
Expand All @@ -94,22 +86,12 @@ function simple(input: number): string {
throw new Error(`simple(${input})`);
}

function generic<T>(input: T): string {
throw new Error(`generic(${input})`);
}

function overloaded(): null;
function overloaded(input: number): string;
function overloaded(input?: number): string | null {
throw new Error(`overloaded(${input})`);
}

function veryOverloaded(): null;
function veryOverloaded(i1: number): string;
function veryOverloaded(i1: number, i2: number): boolean;
function veryOverloaded(i1: number, i2: number, i3: number): null;
function veryOverloaded(i1: number, i2: number, i3: number, i4: number): number;
function veryOverloaded(
i1?: number,
i2?: number,
i3?: number,
i4?: number
): string | boolean | number | null {
throw new Error(`veryOverloaded(${i1}, ${i2}, ${i3}, ${i4})`);
}

0 comments on commit f59ed24

Please sign in to comment.