From 616e8dad800f81d41ad045504d2e8e655fbb2463 Mon Sep 17 00:00:00 2001
From: Michael Cousins <michael@cousins.io>
Date: Mon, 22 May 2023 15:33:17 -0500
Subject: [PATCH] fix: support typing overloaded functions (#1)

---
 .lintignore                     |   1 +
 README.md                       | 144 +++++++++++++++++++++++++-------
 example/meaning-of-life.test.ts |   4 +-
 package.json                    |   5 +-
 src/behaviors.ts                | 135 ++++++++++++++++++++++++------
 src/stubs.ts                    |  62 ++++++++------
 src/types.ts                    |  60 +++++++++++++
 src/vitest-when.ts              | 112 +++++++------------------
 test/typing.test-d.ts           | 115 +++++++++++++++++++++++++
 test/vitest-when.test.ts        |  11 +++
 10 files changed, 477 insertions(+), 172 deletions(-)
 create mode 100644 src/types.ts
 create mode 100644 test/typing.test-d.ts

diff --git a/.lintignore b/.lintignore
index 758cab9..b184608 100644
--- a/.lintignore
+++ b/.lintignore
@@ -2,3 +2,4 @@ coverage
 dist
 node_modules
 pnpm-lock.yaml
+tsconfig.vitest-temp.json
diff --git a/README.md b/README.md
index 77717ec..fae227a 100644
--- a/README.md
+++ b/README.md
@@ -4,7 +4,7 @@
 [![ci badge][]][ci]
 [![coverage badge][]][coverage]
 
-Stub behaviors of [vitest][] mocks based on how they are called with a small, readable, and opinionated API. Inspired by [testdouble.js][] and [jest-when][].
+Stub behaviors of [Vitest][] mock functions with a small, readable API. Inspired by [testdouble.js][] and [jest-when][].
 
 ```shell
 npm install --save-dev vitest-when
@@ -20,28 +20,100 @@ npm install --save-dev vitest-when
 [coverage]: https://coveralls.io/github/mcous/vitest-when
 [coverage badge]: https://img.shields.io/coverallsCoverage/github/mcous/vitest-when?style=flat-square
 
-## Why?
+## Usage
 
-[Vitest mock functions][] are powerful, but have an overly permissive API, inherited from Jest. This API makes it hard to use mocks to their full potential of providing meaningful design feedback while writing tests.
+Create [stubs][] - fake objects that have pre-configured responses to matching arguments - from [Vitest's mock functions][]. With vitest-when, your stubs are:
 
-- It's easy to make silly mistakes, like mocking a return value without checking the arguments.
-- Mock usage requires calls in both the [arrange and assert][] phases a test (e.g. configure return value, assert called with proper arguments), which harms test readability and maintainability.
+- Easy to read
+- Hard to misconfigure, especially when using TypeScript
 
-To avoid these issues, vitest-when wraps vitest mocks in a focused, opinionated API that allows you to configure mock behaviors if and only if they are called as you expect.
+Wrap your `vi.fn()` mock - or a function imported from a `vi.mock`'d module - in [`when`][when], match on a set of arguments using [`calledWith`][called-with], and configure a behavior
 
-[vitest mock functions]: https://vitest.dev/api/mock.html#mockreset
-[arrange and assert]: https://github.com/testdouble/contributing-tests/wiki/Arrange-Act-Assert
+- [`.thenReturn()`][then-return] - Return a value
+- [`.thenResolve()`][then-resolve] - Resolve a `Promise`
+- [`.thenThrow()`][then-throw] - Throw an error
+- [`.thenReject()`][then-reject] - Reject a `Promise`
+- [`.thenDo()`][then-do] - Trigger a function
 
-## Usage
+If the stub is called with arguments that match `calledWith`, the configured behavior will occur. If the arguments do not match, the stub will no-op and return `undefined`.
+
+```ts
+import { vi, test, afterEach } from 'vitest';
+import { when } from '';
+
+afterEach(() => {
+  vi.resetAllMocks();
+});
+
+test('stubbing with vitest-when', () => {
+  const stub = vi.fn();
+
+  when(stub).calledWith(1, 2, 3).thenReturn(4);
+  when(stub).calledWith(4, 5, 6).thenReturn(7);
+
+  const result123 = stub(1, 2, 3);
+  expect(result).toBe(4);
+
+  const result456 = stub(4, 5, 6);
+  expect(result).toBe(7);
+
+  const result789 = stub(7, 8, 9);
+  expect(result).toBe(undefined);
+});
+```
+
+You should call `vi.resetAllMocks()` in your suite's `afterEach` hook to remove the implementation added by `when`. You can also set Vitest's [`mockReset`](https://vitest.dev/config/#mockreset) config to `true` instead of using `afterEach`.
+
+[vitest's mock functions]: https://vitest.dev/api/mock.html
+[stubs]: https://en.wikipedia.org/wiki/Test_stub
+[when]: #whenspy-tfunc-stubwrappertfunc
+[called-with]: #calledwithargs-targs-stubtargs-treturn
+[then-return]: #thenreturnvalue-treturn
+[then-resolve]: #thenresolvevalue-treturn
+[then-throw]: #thenthrowerror-unknown
+[then-reject]: #thenrejecterror-unknown
+[then-do]: #thendocallback-args-targs--treturn
+
+### Why not vanilla Vitest mocks?
+
+Vitest's mock functions are powerful, but have an overly permissive API, inherited from Jest. Vanilla `vi.fn()` mock functions are difficult to use well and easy to use poorly.
+
+- Mock usage is spread across the [arrange and assert][] phases of your test, with "act" in between, making the test harder to read.
+- If you forget the `expect(...).toHaveBeenCalledWith(...)` step, the test will pass even if the mock is called incorrectly.
+- `expect(...).toHaveBeenCalledWith(...)` is not type-checked, as of Vitest `0.31.0`.
+
+```ts
+// arrange
+const stub = vi.fn();
+stub.mockReturnValue('world');
+
+// act
+const result = stub('hello');
+
+// assert
+expect(stub).toHaveBeenCalledWith('hello');
+expect(result).toBe('world');
+```
+
+In contrast, when using vitest-when stubs:
 
-0. Add `vi.resetAllMocks` to your suite's `afterEach` hook
-1. Use `when(mock).calledWith(...)` to specify matching arguments
-2. Configure a behavior with a stub method:
-   - Return a value: `.thenReturn(...)`
-   - Resolve a `Promise`: `.thenResolve(...)`
-   - Throw an error: `.thenThrow(...)`
-   - Reject a `Promise`: `.thenReject(...)`
-   - Trigger a callback: `.thenDo(...)`
+- All stub configuration happens in the "arrange" phase of your test.
+- You cannot forget `calledWith`.
+- `calledWith` and `thenReturn` (et. al.) are fully type-checked.
+
+```ts
+// arrange
+const stub = vi.fn();
+when(stub).calledWith('hello').thenReturn('world');
+
+// act
+const result = stub('hello');
+
+// assert
+expect(result).toBe('world');
+```
+
+[arrange and assert]: https://github.com/testdouble/contributing-tests/wiki/Arrange-Act-Assert
 
 ### Example
 
@@ -59,12 +131,12 @@ import * as subject from './meaning-of-life.ts';
 vi.mock('./deep-thought.ts');
 vi.mock('./earth.ts');
 
-describe('subject under test', () => {
+describe('get the meaning of life', () => {
   afterEach(() => {
     vi.resetAllMocks();
   });
 
-  it('should delegate work to dependency', async () => {
+  it('should get the answer and the question', async () => {
     when(deepThought.calculateAnswer).calledWith().thenResolve(42);
     when(earth.calculateQuestion).calledWith(42).thenResolve("What's 6 by 9?");
 
@@ -73,7 +145,9 @@ describe('subject under test', () => {
     expect(result).toEqual({ question: "What's 6 by 9?", answer: 42 });
   });
 });
+```
 
+```ts
 // meaning-of-life.ts
 import { calculateAnswer } from './deep-thought.ts';
 import { calculateQuestion } from './earth.ts';
@@ -89,12 +163,16 @@ export const createMeaning = async (): Promise<Meaning> => {
 
   return { question, answer };
 };
+```
 
+```ts
 // deep-thought.ts
 export const calculateAnswer = async (): Promise<number> => {
   throw new Error(`calculateAnswer() not implemented`);
 };
+```
 
+```ts
 // earth.ts
 export const calculateQuestion = async (answer: number): Promise<string> => {
   throw new Error(`calculateQuestion(${answer}) not implemented`);
@@ -103,19 +181,32 @@ export const calculateQuestion = async (answer: number): Promise<string> => {
 
 ## API
 
-### `when(spy: Mock<TArgs, TReturn>).calledWith(...args: TArgs): Stub<TArgs, TReturn>`
+### `when(spy: TFunc): StubWrapper<TFunc>`
 
-Create's a stub for a given set of arguments that you can then configure with different behaviors.
+Configures a `vi.fn()` mock function to act as a vitest-when stub. Adds an implementation to the function that initially no-ops, and returns an API to configure behaviors for given arguments using [`.calledWith(...)`][called-with]
 
 ```ts
+import { vi } from 'vitest';
+import { when } from 'vitest-when';
+
 const spy = vi.fn();
+const stubWrapper = when(spy);
 
-when(spy).calledWith('hello').thenReturn('world');
+expect(spy()).toBe(undefined);
+```
+
+### `.calledWith(...args: TArgs): Stub<TArgs, TReturn>`
+
+Create a stub that matches a given set of arguments which you can configure with different behaviors using methods like [`.thenReturn(...)`][then-return].
+
+```ts
+const spy = vi.fn();
+const stub = when(spy).calledWith('hello').thenReturn('world');
 
 expect(spy('hello')).toEqual('world');
 ```
 
-When a call to a mock uses arguments that match those given to `calledWith`, a configured behavior will be triggered. All arguments must match, though you can use vitest's [asymmetric matchers][] to loosen the stubbing:
+When a call to a mock uses arguments that match those given to `calledWith`, a configured behavior will be triggered. All arguments must match, but you can use Vitest's [asymmetric matchers][] to loosen the stubbing:
 
 ```ts
 const spy = vi.fn();
@@ -338,10 +429,3 @@ when(spy)
 expect(spy('hello')).toEqual('world');
 expect(spy('hello')).toEqual('solar system');
 ```
-
-## See also
-
-- [testdouble-vitest][] - Use [testdouble.js][] mocks with Vitest instead of the default [tinyspy][] mocks.
-
-[testdouble-vitest]: https://github.com/mcous/testdouble-vitest
-[tinyspy]: https://github.com/tinylibs/tinyspy
diff --git a/example/meaning-of-life.test.ts b/example/meaning-of-life.test.ts
index a64cd25..ca5ecb4 100644
--- a/example/meaning-of-life.test.ts
+++ b/example/meaning-of-life.test.ts
@@ -8,12 +8,12 @@ import * as subject from './meaning-of-life.ts';
 vi.mock('./deep-thought.ts');
 vi.mock('./earth.ts');
 
-describe('subject under test', () => {
+describe('get the meaning of life', () => {
   afterEach(() => {
     vi.resetAllMocks();
   });
 
-  it('should delegate work to dependency', async () => {
+  it('should get the answer and the question', async () => {
     when(deepThought.calculateAnswer).calledWith().thenResolve(42);
     when(earth.calculateQuestion).calledWith(42).thenResolve("What's 6 by 9?");
 
diff --git a/package.json b/package.json
index 647abb0..de265c6 100644
--- a/package.json
+++ b/package.json
@@ -1,7 +1,7 @@
 {
   "name": "vitest-when",
   "version": "0.1.1",
-  "description": "Stub behaviors of vitest mocks based on how they are called",
+  "description": "Stub behaviors of Vitest mock functions with a small, readable API.",
   "type": "module",
   "exports": {
     ".": {
@@ -19,7 +19,7 @@
     "access": "public",
     "provenance": true
   },
-  "packageManager": "pnpm@8.5.0",
+  "packageManager": "pnpm@8.5.1",
   "author": "Michael Cousins <michael@cousins.io> (https://mike.cousins.io)",
   "license": "MIT",
   "repository": {
@@ -43,6 +43,7 @@
     "coverage": "vitest run --coverage",
     "check:format": "pnpm run _prettier --check",
     "check:lint": "pnpm run _eslint",
+    "check:types": "vitest typecheck --run",
     "format": "pnpm run _prettier --write && pnpm run _eslint --fix",
     "_eslint": "eslint --ignore-path .lintignore \"**/*.ts\"",
     "_prettier": "prettier --ignore-path .lintignore \"**/*.@(ts|json|yaml)\""
diff --git a/src/behaviors.ts b/src/behaviors.ts
index 5de1362..91ebcb9 100644
--- a/src/behaviors.ts
+++ b/src/behaviors.ts
@@ -1,28 +1,53 @@
 import { equals } from '@vitest/expect';
+import type {
+  AnyFunction,
+  AllParameters,
+  ReturnTypeFromArgs,
+} from './types.ts';
 
-export interface BehaviorEntry<TArgs extends unknown[], TReturn> {
+export const ONCE = Symbol('ONCE');
+
+export type StubValue<TValue> = TValue | typeof ONCE;
+
+export interface BehaviorStack<TFunc extends AnyFunction> {
+  use: (
+    args: AllParameters<TFunc>
+  ) => BehaviorEntry<AllParameters<TFunc>> | undefined;
+
+  bindArgs: <TArgs extends AllParameters<TFunc>>(
+    args: TArgs
+  ) => BoundBehaviorStack<ReturnTypeFromArgs<TFunc, TArgs>>;
+}
+
+export interface BoundBehaviorStack<TReturn> {
+  addReturn: (values: StubValue<TReturn>[]) => void;
+  addResolve: (values: StubValue<Awaited<TReturn>>[]) => void;
+  addThrow: (values: StubValue<unknown>[]) => void;
+  addReject: (values: StubValue<unknown>[]) => void;
+  addDo: (values: StubValue<AnyFunction>[]) => void;
+}
+
+export interface BehaviorEntry<TArgs extends unknown[]> {
   args: TArgs;
-  returnValue?: TReturn;
+  returnValue?: unknown;
   throwError?: unknown | undefined;
-  doCallback?: ((...args: TArgs) => TReturn) | undefined;
+  doCallback?: AnyFunction | undefined;
   times?: number | undefined;
 }
 
-export interface Behaviors<TArgs extends unknown[], TReturn> {
-  add: (behaviors: BehaviorEntry<TArgs, TReturn>[]) => void;
-  execute: (args: TArgs) => TReturn;
+export interface BehaviorOptions<TValue> {
+  value: TValue;
+  times: number | undefined;
 }
 
-export const createBehaviors = <TArgs extends unknown[], TReturn>(): Behaviors<
-  TArgs,
-  TReturn
-> => {
-  const behaviorStack: BehaviorEntry<TArgs, TReturn>[] = [];
+export const createBehaviorStack = <
+  TFunc extends AnyFunction
+>(): BehaviorStack<TFunc> => {
+  const behaviors: BehaviorEntry<AllParameters<TFunc>>[] = [];
 
   return {
-    add: (behaviors) => behaviorStack.unshift(...behaviors),
-    execute: (args) => {
-      const behavior = behaviorStack
+    use: (args) => {
+      const behavior = behaviors
         .filter((b) => behaviorAvailable(b))
         .find(behaviorHasArgs(args));
 
@@ -30,27 +55,83 @@ export const createBehaviors = <TArgs extends unknown[], TReturn>(): Behaviors<
         behavior.times -= 1;
       }
 
-      if (behavior?.throwError) {
-        throw behavior.throwError as Error;
-      }
-
-      if (behavior?.doCallback) {
-        return behavior.doCallback(...args);
-      }
-
-      return behavior?.returnValue as TReturn;
+      return behavior;
     },
+
+    bindArgs: (args) => ({
+      addReturn: (values) => {
+        behaviors.unshift(
+          ...getBehaviorOptions(values).map(({ value, times }) => ({
+            args,
+            times,
+            returnValue: value,
+          }))
+        );
+      },
+      addResolve: (values) => {
+        behaviors.unshift(
+          ...getBehaviorOptions(values).map(({ value, times }) => ({
+            args,
+            times,
+            returnValue: Promise.resolve(value),
+          }))
+        );
+      },
+      addThrow: (values) => {
+        behaviors.unshift(
+          ...getBehaviorOptions(values).map(({ value, times }) => ({
+            args,
+            times,
+            throwError: value,
+          }))
+        );
+      },
+      addReject: (values) => {
+        behaviors.unshift(
+          ...getBehaviorOptions(values).map(({ value, times }) => ({
+            args,
+            times,
+            returnValue: Promise.reject(value),
+          }))
+        );
+      },
+      addDo: (values) => {
+        behaviors.unshift(
+          ...getBehaviorOptions(values).map(({ value, times }) => ({
+            args,
+            times,
+            doCallback: value,
+          }))
+        );
+      },
+    }),
   };
 };
 
-const behaviorAvailable = <TArgs extends unknown[], TReturn>(
-  behavior: BehaviorEntry<TArgs, TReturn>
+const getBehaviorOptions = <TValue>(
+  valuesAndOptions: StubValue<TValue>[]
+): BehaviorOptions<TValue>[] => {
+  const once = valuesAndOptions.includes(ONCE);
+  let values = valuesAndOptions.filter((value) => value !== ONCE) as TValue[];
+
+  if (values.length === 0) {
+    values = [undefined as TValue];
+  }
+
+  return values.map((value, i) => ({
+    value,
+    times: once || i < values.length - 1 ? 1 : undefined,
+  }));
+};
+
+const behaviorAvailable = <TArgs extends unknown[]>(
+  behavior: BehaviorEntry<TArgs>
 ): boolean => {
   return behavior.times === undefined || behavior.times > 0;
 };
 
-const behaviorHasArgs = <TArgs extends unknown[], TReturn>(args: TArgs) => {
-  return (behavior: BehaviorEntry<TArgs, TReturn>): boolean => {
+const behaviorHasArgs = <TArgs extends unknown[]>(args: TArgs) => {
+  return (behavior: BehaviorEntry<TArgs>): boolean => {
     let i = 0;
 
     while (i < args.length || i < behavior.args.length) {
diff --git a/src/stubs.ts b/src/stubs.ts
index 5e6987c..d3e0298 100644
--- a/src/stubs.ts
+++ b/src/stubs.ts
@@ -1,46 +1,54 @@
 import type { Mock as Spy } from 'vitest';
-import { createBehaviors, type Behaviors } from './behaviors.ts';
+import { createBehaviorStack, type BehaviorStack } from './behaviors.ts';
 import { NotAMockFunctionError } from './errors.ts';
+import type { AnyFunction, AllParameters } from './types.ts';
 
 const BEHAVIORS_KEY = Symbol('behaviors');
 
-type BaseSpyImplementation<TArgs extends unknown[], TReturn> = (
-  ...args: TArgs
-) => TReturn;
-
-interface WhenStubImplementation<TArgs extends unknown[], TReturn>
-  extends BaseSpyImplementation<TArgs, TReturn> {
-  [BEHAVIORS_KEY]: Behaviors<TArgs, TReturn>;
+interface WhenStubImplementation<TFunc extends AnyFunction> {
+  (...args: AllParameters<TFunc>): unknown;
+  [BEHAVIORS_KEY]: BehaviorStack<TFunc>;
 }
 
-export const configureStub = <TArgs extends unknown[], TReturn>(
+export const configureStub = <TFunc extends AnyFunction>(
   maybeSpy: unknown
-): Behaviors<TArgs, TReturn> => {
-  const spy = validateSpy<TArgs, TReturn>(maybeSpy);
-  let implementation = spy.getMockImplementation() as
-    | BaseSpyImplementation<TArgs, TReturn>
-    | WhenStubImplementation<TArgs, TReturn>
+): BehaviorStack<TFunc> => {
+  const spy = validateSpy<TFunc>(maybeSpy);
+  const existingImplementation = spy.getMockImplementation() as
+    | WhenStubImplementation<TFunc>
+    | TFunc
     | undefined;
 
-  if (!implementation || !(BEHAVIORS_KEY in implementation)) {
-    const behaviors = createBehaviors<TArgs, TReturn>();
+  if (existingImplementation && BEHAVIORS_KEY in existingImplementation) {
+    return existingImplementation[BEHAVIORS_KEY];
+  }
 
-    implementation = Object.assign(
-      (...args: TArgs) => behaviors.execute(args),
-      { [BEHAVIORS_KEY]: behaviors }
-    );
+  const behaviors = createBehaviorStack<TFunc>();
 
-    spy.mockImplementation(implementation);
+  const implementation = (...args: AllParameters<TFunc>): unknown => {
+    const behavior = behaviors.use(args);
 
-    return behaviors;
-  }
+    if (behavior?.throwError) {
+      throw behavior.throwError as Error;
+    }
+
+    if (behavior?.doCallback) {
+      return behavior.doCallback(...args);
+    }
+
+    return behavior?.returnValue;
+  };
+
+  spy.mockImplementation(
+    Object.assign(implementation, { [BEHAVIORS_KEY]: behaviors })
+  );
 
-  return implementation[BEHAVIORS_KEY];
+  return behaviors;
 };
 
-const validateSpy = <TArgs extends unknown[], TReturn>(
+const validateSpy = <TFunc extends AnyFunction>(
   maybeSpy: unknown
-): Spy<TArgs, TReturn> => {
+): Spy<AllParameters<TFunc>, unknown> => {
   if (
     typeof maybeSpy === 'function' &&
     'mockImplementation' in maybeSpy &&
@@ -48,7 +56,7 @@ const validateSpy = <TArgs extends unknown[], TReturn>(
     'getMockImplementation' in maybeSpy &&
     typeof maybeSpy.getMockImplementation === 'function'
   ) {
-    return maybeSpy as Spy<TArgs, TReturn>;
+    return maybeSpy as Spy<AllParameters<TFunc>, unknown>;
   }
 
   throw new NotAMockFunctionError(maybeSpy);
diff --git a/src/types.ts b/src/types.ts
new file mode 100644
index 0000000..7fd9319
--- /dev/null
+++ b/src/types.ts
@@ -0,0 +1,60 @@
+/**
+ * Get function arguments and return value types.
+ *
+ * Support for overloaded functions, thanks to @Shakeskeyboarde
+ * https://github.com/microsoft/TypeScript/issues/14107#issuecomment-1146738780
+ */
+
+import type { SpyInstance } from 'vitest';
+
+/** Any function, for use in `extends` */
+export type AnyFunction = (...args: never[]) => unknown;
+
+/** Acceptable arguments for a function.*/
+export type AllParameters<TFunc extends AnyFunction> =
+  TFunc extends SpyInstance<infer TArgs, unknown>
+    ? TArgs
+    : Parameters<ToOverloads<TFunc>>;
+
+/** The return type of a function, given the actual arguments used.*/
+export type ReturnTypeFromArgs<
+  TFunc extends AnyFunction,
+  TArgs extends unknown[]
+> = TFunc extends SpyInstance<unknown[], infer TReturn>
+  ? TReturn
+  : ExtractReturn<ToOverloads<TFunc>, TArgs>;
+
+/** Given a functions and actual arguments used, extract the return type. */
+type ExtractReturn<
+  TFunc extends AnyFunction,
+  TArgs extends unknown[]
+> = TFunc extends (...args: infer TFuncArgs) => infer TFuncReturn
+  ? TArgs extends TFuncArgs
+    ? TFuncReturn
+    : never
+  : never;
+
+/** Transform an overloaded function into a union of functions. */
+type ToOverloads<TFunc extends AnyFunction> = Exclude<
+  OverloadUnion<(() => never) & TFunc>,
+  TFunc extends () => never ? never : () => never
+>;
+
+/** Recursively extract functions from an overload into a union. */
+type OverloadUnion<TFunc, TPartialOverload = unknown> = TFunc extends (
+  ...args: infer TArgs
+) => infer TReturn
+  ? TPartialOverload extends TFunc
+    ? never
+    :
+        | OverloadUnion<
+            TPartialOverload & TFunc,
+            TPartialOverload &
+              ((...args: TArgs) => TReturn) &
+              OverloadProps<TFunc>
+          >
+        | ((...args: TArgs) => TReturn)
+  : never;
+
+/** Properties attached to a function. */
+type OverloadProps<TFunc> = Pick<TFunc, keyof TFunc>;
diff --git a/src/vitest-when.ts b/src/vitest-when.ts
index 03f0ebd..b0acb60 100644
--- a/src/vitest-when.ts
+++ b/src/vitest-when.ts
@@ -1,9 +1,18 @@
 import { configureStub } from './stubs.ts';
-
+import type { StubValue } from './behaviors.ts';
+import type {
+  AnyFunction,
+  AllParameters,
+  ReturnTypeFromArgs,
+} from './types.ts';
+
+export { ONCE, type StubValue } from './behaviors.ts';
 export * from './errors.ts';
 
-export interface StubWrapper<TArgs extends unknown[], TReturn> {
-  calledWith: (...args: TArgs) => Stub<TArgs, TReturn>;
+export interface StubWrapper<TFunc extends AnyFunction> {
+  calledWith<TArgs extends AllParameters<TFunc>>(
+    ...args: TArgs
+  ): Stub<TArgs, ReturnTypeFromArgs<TFunc, TArgs>>;
 }
 
 export interface Stub<TArgs extends unknown[], TReturn> {
@@ -11,90 +20,25 @@ export interface Stub<TArgs extends unknown[], TReturn> {
   thenResolve: (...values: StubValue<Awaited<TReturn>>[]) => void;
   thenThrow: (...errors: StubValue<unknown>[]) => void;
   thenReject: (...errors: StubValue<unknown>[]) => void;
-  thenDo: (...callbacks: StubValue<Callback<TArgs, TReturn>>[]) => void;
+  thenDo: (...callbacks: StubValue<(...args: TArgs) => TReturn>[]) => void;
 }
 
-export type Callback<TArgs extends unknown[], TReturn> = (
-  ...args: TArgs
-) => TReturn;
-
-export type StubValue<TValue> = TValue | typeof ONCE;
-
-export const ONCE = Symbol('ONCE');
-
-export const when = <TArgs extends unknown[], TReturn>(
-  spy: (...args: TArgs) => TReturn
-): StubWrapper<TArgs, TReturn> => {
-  const behaviors = configureStub<TArgs, TReturn>(spy);
+export const when = <TFunc extends AnyFunction>(
+  spy: TFunc
+): StubWrapper<TFunc> => {
+  const behaviorStack = configureStub(spy);
 
   return {
-    calledWith: (...args: TArgs) => ({
-      thenReturn: (...values: StubValue<TReturn>[]) => {
-        behaviors.add(
-          getBehaviorOptions(values).map(({ value, times }) => ({
-            args,
-            times,
-            returnValue: value,
-          }))
-        );
-      },
-      thenResolve: (...values: StubValue<Awaited<TReturn>>[]) => {
-        behaviors.add(
-          getBehaviorOptions(values).map(({ value, times }) => ({
-            args,
-            times,
-            returnValue: Promise.resolve(value) as TReturn,
-          }))
-        );
-      },
-      thenReject: (...errors: StubValue<unknown>[]) => {
-        behaviors.add(
-          getBehaviorOptions(errors).map(({ value, times }) => ({
-            args,
-            times,
-            returnValue: Promise.reject(value) as TReturn,
-          }))
-        );
-      },
-      thenThrow: (...errors: StubValue<unknown>[]) => {
-        behaviors.add(
-          getBehaviorOptions(errors).map(({ value, times }) => ({
-            args,
-            times,
-            throwError: value,
-          }))
-        );
-      },
-      thenDo: (...callbacks: StubValue<Callback<TArgs, TReturn>>[]) => {
-        behaviors.add(
-          getBehaviorOptions(callbacks).map(({ value, times }) => ({
-            args,
-            times,
-            doCallback: value,
-          }))
-        );
-      },
-    }),
+    calledWith: (...args) => {
+      const boundBehaviors = behaviorStack.bindArgs(args);
+
+      return {
+        thenReturn: (...values) => boundBehaviors.addReturn(values),
+        thenResolve: (...values) => boundBehaviors.addResolve(values),
+        thenThrow: (...errors) => boundBehaviors.addThrow(errors),
+        thenReject: (...errors) => boundBehaviors.addReject(errors),
+        thenDo: (...callbacks) => boundBehaviors.addDo(callbacks),
+      };
+    },
   };
 };
-
-interface BehaviorOptions<TValue> {
-  value: TValue;
-  times: number | undefined;
-}
-
-const getBehaviorOptions = <TValue>(
-  valuesAndOptions: StubValue<TValue>[]
-): BehaviorOptions<TValue>[] => {
-  const once = valuesAndOptions.includes(ONCE);
-  let values = valuesAndOptions.filter((value) => value !== ONCE) as TValue[];
-
-  if (values.length === 0) {
-    values = [undefined as TValue];
-  }
-
-  return values.map((value, i) => ({
-    value,
-    times: once || i < values.length - 1 ? 1 : undefined,
-  }));
-};
diff --git a/test/typing.test-d.ts b/test/typing.test-d.ts
new file mode 100644
index 0000000..e477bce
--- /dev/null
+++ b/test/typing.test-d.ts
@@ -0,0 +1,115 @@
+/* eslint-disable
+  @typescript-eslint/no-explicit-any,
+  @typescript-eslint/restrict-template-expressions,
+  func-style
+*/
+
+import { vi, describe, it, assertType } from 'vitest';
+import * as subject from '../src/vitest-when.ts';
+
+describe('vitest-when type signatures', () => {
+  it('should handle an anonymous mock', () => {
+    const spy = vi.fn();
+    const stub = subject.when(spy).calledWith(1, 2, 3);
+
+    assertType<subject.Stub<[number, number, number], any>>(stub);
+  });
+
+  it('should handle an untyped function', () => {
+    const stub = subject.when(untyped).calledWith(1);
+
+    stub.thenReturn('hello');
+
+    assertType<subject.Stub<[number], any>>(stub);
+  });
+
+  it('should handle a simple function', () => {
+    const stub = subject.when(simple).calledWith(1);
+
+    stub.thenReturn('hello');
+
+    assertType<subject.Stub<[1], string>>(stub);
+  });
+
+  it('should reject invalid usage of a simple function', () => {
+    // @ts-expect-error: args missing
+    subject.when(simple).calledWith();
+
+    // @ts-expect-error: args wrong type
+    subject.when(simple).calledWith('hello');
+
+    // @ts-expect-error: return wrong type
+    subject.when(simple).calledWith(1).thenReturn(42);
+  });
+
+  it('should handle an overloaded function using its last overload', () => {
+    const stub = subject.when(overloaded).calledWith(1);
+
+    stub.thenReturn('hello');
+
+    assertType<subject.Stub<[1], string>>(stub);
+  });
+
+  it('should handle an overloaded function using its first overload', () => {
+    const stub = subject.when(overloaded).calledWith();
+
+    stub.thenReturn(null);
+
+    assertType<subject.Stub<[], null>>(stub);
+  });
+
+  it('should handle an very overloaded function using its first overload', () => {
+    const stub = subject.when(veryOverloaded).calledWith();
+
+    stub.thenReturn(null);
+
+    assertType<subject.Stub<[], null>>(stub);
+  });
+
+  it('should handle an overloaded function using its last overload', () => {
+    const stub = subject.when(veryOverloaded).calledWith(1, 2, 3, 4);
+
+    stub.thenReturn(42);
+
+    assertType<subject.Stub<[1, 2, 3, 4], number>>(stub);
+  });
+
+  it('should reject invalid usage of a simple function', () => {
+    // @ts-expect-error: args missing
+    subject.when(simple).calledWith();
+
+    // @ts-expect-error: args wrong type
+    subject.when(simple).calledWith('hello');
+
+    // @ts-expect-error: return wrong type
+    subject.when(simple).calledWith(1).thenReturn(42);
+  });
+});
+
+function untyped(...args: any[]): any {
+  throw new Error(`untyped(...${args})`);
+}
+
+function simple(input: number): string {
+  throw new Error(`simple(${input})`);
+}
+
+function overloaded(): null;
+function overloaded(input: number): string;
+function overloaded(input?: number): string | null {
+  throw new Error(`overloaded(${input})`);
+}
+
+function veryOverloaded(): null;
+function veryOverloaded(i1: number): string;
+function veryOverloaded(i1: number, i2: number): boolean;
+function veryOverloaded(i1: number, i2: number, i3: number): null;
+function veryOverloaded(i1: number, i2: number, i3: number, i4: number): number;
+function veryOverloaded(
+  i1?: number,
+  i2?: number,
+  i3?: number,
+  i4?: number
+): string | boolean | number | null {
+  throw new Error(`veryOverloaded(${i1}, ${i2}, ${i3}, ${i4})`);
+}
diff --git a/test/vitest-when.test.ts b/test/vitest-when.test.ts
index 7121b01..7b7ffba 100644
--- a/test/vitest-when.test.ts
+++ b/test/vitest-when.test.ts
@@ -243,4 +243,15 @@ describe('vitest-when', () => {
 
     expect(spy('foo')).toEqual(1000);
   });
+
+  it('should deeply check object arguments', () => {
+    const spy = vi.fn();
+
+    subject
+      .when(spy)
+      .calledWith({ foo: { bar: { baz: 0 } } })
+      .thenReturn(100);
+
+    expect(spy({ foo: { bar: { baz: 0 } } })).toEqual(100);
+  });
 });