Skip to content

Commit

Permalink
fix(types): prevent infinite type recursion with generic functions
Browse files Browse the repository at this point in the history
  • Loading branch information
mcous committed Sep 23, 2023
1 parent 14a24cc commit 1456a1e
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 16 deletions.
36 changes: 27 additions & 9 deletions src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,36 @@ type ExtractReturn<
: never
: never;

/** Transform an overloaded function into a union of functions. */
/**
* Transform an overloaded function into a union of functions.
*
* If the overload includes `() => never`, it must be hoisted to the "front",
* or else the recursion will stop too early.
* If the overload does not include `() => never`, adding it in can cause
* infinite recursion issues with generic functions, so we omit it.
* Omitting `() => never`, will cause "compatible" signatures to collapse
* (eg. `() => void & (a: 1) => void` becomes `(a?: 1) => void`),
* which is acceptable for this library.
*/
type ToOverloads<TFunc extends AnyFunction> = Exclude<
OverloadUnion<(() => never) & TFunc>,
OverloadUnion<TFunc extends () => never ? (() => never) & TFunc : TFunc>,
TFunc extends () => never ? never : () => never
>;

/** Recursively extract functions from an overload into a union. */
/**
* Recursively extract functions from an overload into a union.
*
* The tricks to this working are:
*
* 1. Inferring the parameter and return types of an overloaded function will
* use the last signature, which is apparently an explicit design choice.
* 2. Intersecting a single signature with the original intersection
* can reorder the intersection (possibly an undocumented side effect?).
* 3. Intersections can only be re-ordered, not narrowed (reduced),
* So, the intersection is rebuilt in the "TPartialOverload" generic,
* then recursion can be stopped when the full intersection has been rebuilt.
* Otherwise, this would result in an infinite recursion.
*/
type OverloadUnion<TFunc, TPartialOverload = unknown> = TFunc extends (
...args: infer TArgs
) => infer TReturn
Expand All @@ -49,12 +72,7 @@ type OverloadUnion<TFunc, TPartialOverload = unknown> = TFunc extends (
:
| OverloadUnion<
TPartialOverload & TFunc,
TPartialOverload &
((...args: TArgs) => TReturn) &
OverloadProps<TFunc>
TPartialOverload & ((...args: TArgs) => TReturn)
>
| ((...args: TArgs) => TReturn)
: never;

/** Properties attached to a function. */
type OverloadProps<TFunc> = Pick<TFunc, keyof TFunc>;
26 changes: 19 additions & 7 deletions test/typing.test-d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,12 @@ describe('vitest-when type signatures', () => {
assertType<subject.Stub<[1], string>>(stub);
});

it('should reject invalid usage of a simple function', () => {
// @ts-expect-error: args missing
subject.when(simple).calledWith();
it('should handle a generic function', () => {
const stub = subject.when(generic).calledWith(1);

// @ts-expect-error: args wrong type
subject.when(simple).calledWith('hello');
stub.thenReturn('hello');

// @ts-expect-error: return wrong type
subject.when(simple).calledWith(1).thenReturn(42);
assertType<subject.Stub<[number], string>>(stub);
});

it('should handle an overloaded function using its last overload', () => {
Expand Down Expand Up @@ -84,6 +81,17 @@ describe('vitest-when type signatures', () => {
// @ts-expect-error: return wrong type
subject.when(simple).calledWith(1).thenReturn(42);
});

it('should reject invalid usage of a generic function', () => {
// @ts-expect-error: args missing
subject.when(generic).calledWith();

// @ts-expect-error: args wrong type
subject.when(generic<string>).calledWith(42);

// @ts-expect-error: return wrong type
subject.when(generic).calledWith(1).thenReturn(42);
});
});

function untyped(...args: any[]): any {
Expand All @@ -94,6 +102,10 @@ function simple(input: number): string {
throw new Error(`simple(${input})`);
}

function generic<T>(input: T): string {
throw new Error(`generic(${input})`);
}

function overloaded(): null;
function overloaded(input: number): string;
function overloaded(input?: number): string | null {
Expand Down

0 comments on commit 1456a1e

Please sign in to comment.