From 04af01a08251f37882b6c581a546631080cf4623 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Fri, 18 Apr 2025 20:00:39 +0100 Subject: [PATCH 01/10] MCP function calling via createMessage working --- mcp-run-python/src/main.ts | 51 ++++++++++++++++-------- mcp-run-python/src/prepare_env.py | 64 ++++++++++++++++++++++++++----- mcp-run-python/src/runCode.ts | 60 ++++++++++++++++++++++++----- 3 files changed, 139 insertions(+), 36 deletions(-) diff --git a/mcp-run-python/src/main.ts b/mcp-run-python/src/main.ts index 6eb051f93..c1e80a4e3 100644 --- a/mcp-run-python/src/main.ts +++ b/mcp-run-python/src/main.ts @@ -15,15 +15,16 @@ const VERSION = '0.0.13' export async function main() { const { args } = Deno - if (args.length === 1 && args[0] === 'stdio') { - await runStdio() + const flags = parseArgs(args, { + string: ['port', 'functions'], + default: { port: '3001' }, + }) + const functions = flags.functions ? flags.functions.split(',') : [] + if (args.length >= 1 && args[0] === 'stdio') { + await runStdio(functions) } else if (args.length >= 1 && args[0] === 'sse') { - const flags = parseArgs(Deno.args, { - string: ['port'], - default: { port: '3001' }, - }) const port = parseInt(flags.port) - runSse(port) + runSse(functions, port) } else if (args.length === 1 && args[0] === 'warmup') { await warmup() } else { @@ -34,7 +35,8 @@ Invalid arguments. Usage: deno run -N -R=node_modules -W=node_modules --node-modules-dir=auto jsr:@pydantic/mcp-run-python [stdio|sse|warmup] options: - --port Port to run the SSE server on (default: 3001)`, + --port Port to run the SSE server on (default: 3001) + --functions Comma separated list of client functions which the server can call`, ) Deno.exit(1) } @@ -43,7 +45,7 @@ options: /* * Create an MCP server with the `run_python_code` tool registered. */ -function createServer(): McpServer { +function createServer(functions: string[]): McpServer { const server = new McpServer( { name: 'MCP Run Python', @@ -85,29 +87,46 @@ print('python code here') { python_code: z.string().describe('Python code to run') }, async ({ python_code }: { python_code: string }) => { const logPromises: Promise[] = [] - const result = await runCode([{ + const mainPy = { name: 'main.py', content: python_code, active: true, - }], (level, data) => { + } + const codeLog = (level: LoggingLevel, data: string) => { if (LogLevels.indexOf(level) >= LogLevels.indexOf(setLogLevel)) { logPromises.push(server.server.sendLoggingMessage({ level, data })) } - }) + } + async function clientCallback(func: string, args?: string, kwargs?: string) { + const { content } = await server.server.createMessage({ + messages: [], + maxTokens: 0, + systemPrompt: '__python_function_call__', + metadata: { func, args, kwargs }, + }) + if (content.type !== 'text') { + throw new Error('Expected return content type to be "text"') + } else { + return content.text + } + } + + const result = await runCode([mainPy], codeLog, functions, clientCallback) await Promise.all(logPromises) return { content: [{ type: 'text', text: asXml(result) }], } }, ) + return server } /* * Run the MCP server using the SSE transport, e.g. over HTTP. */ -function runSse(port: number) { - const mcpServer = createServer() +function runSse(functions: string[], port: number) { + const mcpServer = createServer(functions) const transports: { [sessionId: string]: SSEServerTransport } = {} const server = http.createServer(async (req, res) => { @@ -162,8 +181,8 @@ function runSse(port: number) { /* * Run the MCP server using the Stdio transport. */ -async function runStdio() { - const mcpServer = createServer() +async function runStdio(functions: string[]) { + const mcpServer = createServer(functions) const transport = new StdioServerTransport() await mcpServer.connect(transport) } diff --git a/mcp-run-python/src/prepare_env.py b/mcp-run-python/src/prepare_env.py index e22db9ca7..ccb8c02ea 100644 --- a/mcp-run-python/src/prepare_env.py +++ b/mcp-run-python/src/prepare_env.py @@ -10,15 +10,17 @@ import re import sys import traceback -from collections.abc import Iterable, Iterator +from collections.abc import Awaitable, Iterable, Iterator from contextlib import contextmanager from dataclasses import dataclass from pathlib import Path -from typing import Any, Literal, TypedDict +from typing import Any, Callable, Literal, TypedDict import micropip import pyodide_js import tomllib +from pydantic import ConfigDict, TypeAdapter +from pydantic_core import to_json from pyodide.code import find_imports __all__ = 'prepare_env', 'dump_json' @@ -31,18 +33,18 @@ class File(TypedDict): @dataclass -class Success: +class PrepSuccess: dependencies: list[str] | None kind: Literal['success'] = 'success' @dataclass -class Error: +class PrepError: message: str kind: Literal['error'] = 'error' -async def prepare_env(files: list[File]) -> Success | Error: +async def prepare_env(files: list[File]) -> PrepSuccess | PrepError: sys.setrecursionlimit(400) cwd = Path.cwd() @@ -68,14 +70,12 @@ async def prepare_env(files: list[File]) -> Success | Error: except Exception: with open(logs_filename) as f: logs = f.read() - return Error(message=f'{logs} {traceback.format_exc()}') + return PrepError(message=f'{logs} {traceback.format_exc()}') - return Success(dependencies=dependencies) + return PrepSuccess(dependencies=dependencies) -def dump_json(value: Any) -> str | None: - from pydantic_core import to_json - +def pretty_result(value: Any) -> str | None: if value is None: return None if isinstance(value, str): @@ -84,6 +84,50 @@ def dump_json(value: Any) -> str | None: return to_json(value, indent=2, fallback=_json_fallback).decode() +def dump_json(value: Any) -> str | None: + if value: + return to_json(value, fallback=_json_fallback).decode() + + +class CallSuccess(TypedDict): + return_value: Any + kind: Literal['success'] + + +class CallError(TypedDict): + exc_type: str + message: str + kind: Literal['error'] + + +call_result_ta: TypeAdapter[CallSuccess | CallError] = TypeAdapter( + CallSuccess | CallError, config=ConfigDict(defer_build=True) +) + + +@dataclass(slots=True) +class ClientCallback: + _func_name: str + _callback: Callable[[tuple[Any, ...], dict[str, Any]], Awaitable[str]] + + async def __call__(self, *args: Any, **kwargs: Any) -> Any: + result_json = await self._callback(args, kwargs) + result = call_result_ta.validate_json(result_json) + if result['kind'] == 'success': + return result['return_value'] + + exc_type, message = result['exc_type'], result['message'] + try: + exc_type_ = __builtins__[exc_type] + except KeyError: + raise Exception(f'{message}\n(Raised exception type: {exc_type})') + else: + raise exc_type_(message) + + def __repr__(self) -> str: + return f'' + + def _json_fallback(value: Any) -> Any: tp: Any = type(value) module = tp.__module__ diff --git a/mcp-run-python/src/runCode.ts b/mcp-run-python/src/runCode.ts index fdb1d7084..6cbb772a0 100644 --- a/mcp-run-python/src/runCode.ts +++ b/mcp-run-python/src/runCode.ts @@ -1,4 +1,4 @@ -/* eslint @typescript-eslint/no-explicit-any: off */ +// deno-lint-ignore-file no-explicit-any import { loadPyodide } from 'pyodide' import { preparePythonCode } from './prepareEnvCode.ts' import type { LoggingLevel } from '@modelcontextprotocol/sdk/types.js' @@ -12,10 +12,11 @@ export interface CodeFile { export async function runCode( files: CodeFile[], log: (level: LoggingLevel, data: string) => void, + functionNames?: string[], + clientCallback?: (func: string, args?: string, kwargs?: string) => Promise, ): Promise { // remove once https://github.com/pyodide/pyodide/pull/5514 is released const realConsoleLog = console.log - // deno-lint-ignore no-explicit-any console.log = (...args: any[]) => log('debug', args.join(' ')) const output: string[] = [] @@ -54,9 +55,20 @@ export async function runCode( pathlib.Path(`${dirPath}/${moduleName}.py`).write_text(preparePythonCode) - const preparePyEnv: PreparePyEnv = pyodide.pyimport(moduleName) + const { prepare_env, ClientCallback, dump_json, pretty_result }: PreparePyEnv = pyodide.pyimport(moduleName) - const prepareStatus = await preparePyEnv.prepare_env(pyodide.toPy(files)) + const prepareStatus = await prepare_env(pyodide.toPy(files)) + + const globals: { [key: string]: string | ClientCallbackType } = { __name__: '__main__' } + + if (functionNames && clientCallback) { + for (const functionName of functionNames) { + globals[functionName] = ClientCallback( + functionName, + async (args, kwargs) => pyodide.toPy(await clientCallback(functionName, dump_json(args), dump_json(kwargs))), + ) + } + } let runResult: RunSuccess | RunError if (prepareStatus.kind == 'error') { @@ -70,14 +82,14 @@ export async function runCode( const activeFile = files.find((f) => f.active)! || files[0] try { const rawValue = await pyodide.runPythonAsync(activeFile.content, { - globals: pyodide.toPy({ __name__: '__main__' }), + globals: pyodide.toPy(globals), filename: activeFile.name, }) runResult = { status: 'success', dependencies, output, - returnValueJson: preparePyEnv.dump_json(rawValue), + returnValueJson: pretty_result(rawValue), } } catch (err) { runResult = { @@ -99,7 +111,7 @@ interface RunSuccess { // we could record stdout and stderr separately, but I suspect simplicity is more important output: string[] dependencies: string[] - returnValueJson: string | null + returnValueJson: string | undefined } interface RunError { @@ -144,7 +156,6 @@ function escapeClosing(closingTag: string): (str: string) => string { return (str) => str.replace(regex, onMatch) } -// deno-lint-ignore no-explicit-any function formatError(err: any): string { let errStr = err.toString() errStr = errStr.replace(/^PythonError: +/, '') @@ -153,6 +164,13 @@ function formatError(err: any): string { / {2}File "\/lib\/python\d+\.zip\/_pyodide\/.*\n {4}.*\n(?: {4,}\^+\n)?/g, '', ) + // remove frames from _prepare_env.py + errStr = errStr.replace( + / {2}File "\/tmp\/mcp_run_python\/_prepare_env.py".*\n {4,}.+\n/g, + '', + ) + // remove trailing newlines + errStr = errStr.replace(/\n+$/, '') return errStr } @@ -164,8 +182,30 @@ interface PrepareError { kind: 'error' message: string } + +interface PyObject { + toJs(): any +} + +interface CallSuccess { + kind: 'success' + return_value: any +} + +interface CallError { + kind: 'error' + exc_type: string + message: string +} + +type ClientCallbackType = ( + func_name: string, + callback: (args: PyObject, kwargs: PyObject) => Promise, +) => any + interface PreparePyEnv { prepare_env: (files: CodeFile[]) => Promise - // deno-lint-ignore no-explicit-any - dump_json: (value: any) => string | null + ClientCallback: ClientCallbackType + pretty_result: (value: any) => string | undefined + dump_json: (value: any) => string | undefined } From fd23fac951ec7bcc439f2a3b84fc476fc07b073d Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Fri, 18 Apr 2025 20:12:39 +0100 Subject: [PATCH 02/10] cleanup --- mcp-run-python/src/prepare_env.py | 18 +++++++++--------- mcp-run-python/src/runCode.ts | 29 +++++++---------------------- 2 files changed, 16 insertions(+), 31 deletions(-) diff --git a/mcp-run-python/src/prepare_env.py b/mcp-run-python/src/prepare_env.py index ccb8c02ea..c4b1bc80a 100644 --- a/mcp-run-python/src/prepare_env.py +++ b/mcp-run-python/src/prepare_env.py @@ -75,7 +75,7 @@ async def prepare_env(files: list[File]) -> PrepSuccess | PrepError: return PrepSuccess(dependencies=dependencies) -def pretty_result(value: Any) -> str | None: +def dump_json(value: Any) -> str | None: if value is None: return None if isinstance(value, str): @@ -84,11 +84,6 @@ def pretty_result(value: Any) -> str | None: return to_json(value, indent=2, fallback=_json_fallback).decode() -def dump_json(value: Any) -> str | None: - if value: - return to_json(value, fallback=_json_fallback).decode() - - class CallSuccess(TypedDict): return_value: Any kind: Literal['success'] @@ -106,12 +101,12 @@ class CallError(TypedDict): @dataclass(slots=True) -class ClientCallback: +class RegisterFunction: _func_name: str - _callback: Callable[[tuple[Any, ...], dict[str, Any]], Awaitable[str]] + _callback: Callable[[str, str | None, str | None], Awaitable[str]] async def __call__(self, *args: Any, **kwargs: Any) -> Any: - result_json = await self._callback(args, kwargs) + result_json = await self._callback(self._func_name, _dump_args(args), _dump_args(kwargs)) result = call_result_ta.validate_json(result_json) if result['kind'] == 'success': return result['return_value'] @@ -128,6 +123,11 @@ def __repr__(self) -> str: return f'' +def _dump_args(value: Any) -> str | None: + if value: + return to_json(value, fallback=_json_fallback).decode() + + def _json_fallback(value: Any) -> Any: tp: Any = type(value) module = tp.__module__ diff --git a/mcp-run-python/src/runCode.ts b/mcp-run-python/src/runCode.ts index 6cbb772a0..d7599c5f3 100644 --- a/mcp-run-python/src/runCode.ts +++ b/mcp-run-python/src/runCode.ts @@ -55,18 +55,15 @@ export async function runCode( pathlib.Path(`${dirPath}/${moduleName}.py`).write_text(preparePythonCode) - const { prepare_env, ClientCallback, dump_json, pretty_result }: PreparePyEnv = pyodide.pyimport(moduleName) + const { prepare_env, RegisterFunction, dump_json }: PreparePyEnv = pyodide.pyimport(moduleName) const prepareStatus = await prepare_env(pyodide.toPy(files)) - const globals: { [key: string]: string | ClientCallbackType } = { __name__: '__main__' } + const globals: { [key: string]: string | RegisterFunctionType } = { __name__: '__main__' } if (functionNames && clientCallback) { for (const functionName of functionNames) { - globals[functionName] = ClientCallback( - functionName, - async (args, kwargs) => pyodide.toPy(await clientCallback(functionName, dump_json(args), dump_json(kwargs))), - ) + globals[functionName] = RegisterFunction(functionName, clientCallback) } } @@ -89,7 +86,7 @@ export async function runCode( status: 'success', dependencies, output, - returnValueJson: pretty_result(rawValue), + returnValueJson: dump_json(rawValue), } } catch (err) { runResult = { @@ -187,25 +184,13 @@ interface PyObject { toJs(): any } -interface CallSuccess { - kind: 'success' - return_value: any -} - -interface CallError { - kind: 'error' - exc_type: string - message: string -} - -type ClientCallbackType = ( +type RegisterFunctionType = ( func_name: string, - callback: (args: PyObject, kwargs: PyObject) => Promise, + callback: (func_name: string, args?: string, kwargs?: string) => Promise, ) => any interface PreparePyEnv { prepare_env: (files: CodeFile[]) => Promise - ClientCallback: ClientCallbackType - pretty_result: (value: any) => string | undefined + RegisterFunction: RegisterFunctionType dump_json: (value: any) => string | undefined } From a729094730885972bdf56ce76000d578fd6d1228 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Fri, 18 Apr 2025 20:22:11 +0100 Subject: [PATCH 03/10] cleanup more --- mcp-run-python/src/runCode.ts | 29 ++++++++++++----------------- 1 file changed, 12 insertions(+), 17 deletions(-) diff --git a/mcp-run-python/src/runCode.ts b/mcp-run-python/src/runCode.ts index d7599c5f3..92cc5248e 100644 --- a/mcp-run-python/src/runCode.ts +++ b/mcp-run-python/src/runCode.ts @@ -1,4 +1,3 @@ -// deno-lint-ignore-file no-explicit-any import { loadPyodide } from 'pyodide' import { preparePythonCode } from './prepareEnvCode.ts' import type { LoggingLevel } from '@modelcontextprotocol/sdk/types.js' @@ -17,6 +16,7 @@ export async function runCode( ): Promise { // remove once https://github.com/pyodide/pyodide/pull/5514 is released const realConsoleLog = console.log + // deno-lint-ignore no-explicit-any console.log = (...args: any[]) => log('debug', args.join(' ')) const output: string[] = [] @@ -55,15 +55,15 @@ export async function runCode( pathlib.Path(`${dirPath}/${moduleName}.py`).write_text(preparePythonCode) - const { prepare_env, RegisterFunction, dump_json }: PreparePyEnv = pyodide.pyimport(moduleName) + const preparePyEnv: PreparePyEnv = pyodide.pyimport(moduleName) - const prepareStatus = await prepare_env(pyodide.toPy(files)) + const prepareStatus = await preparePyEnv.prepare_env(pyodide.toPy(files)) - const globals: { [key: string]: string | RegisterFunctionType } = { __name__: '__main__' } + const globals: Record = { __name__: '__main__' } if (functionNames && clientCallback) { for (const functionName of functionNames) { - globals[functionName] = RegisterFunction(functionName, clientCallback) + globals[functionName] = preparePyEnv.RegisterFunction(functionName, clientCallback) } } @@ -86,7 +86,7 @@ export async function runCode( status: 'success', dependencies, output, - returnValueJson: dump_json(rawValue), + returnValueJson: preparePyEnv.dump_json(rawValue), } } catch (err) { runResult = { @@ -153,6 +153,7 @@ function escapeClosing(closingTag: string): (str: string) => string { return (str) => str.replace(regex, onMatch) } +// deno-lint-ignore no-explicit-any function formatError(err: any): string { let errStr = err.toString() errStr = errStr.replace(/^PythonError: +/, '') @@ -180,17 +181,11 @@ interface PrepareError { message: string } -interface PyObject { - toJs(): any -} - -type RegisterFunctionType = ( - func_name: string, - callback: (func_name: string, args?: string, kwargs?: string) => Promise, -) => any - interface PreparePyEnv { prepare_env: (files: CodeFile[]) => Promise - RegisterFunction: RegisterFunctionType - dump_json: (value: any) => string | undefined + RegisterFunction: ( + func_name: string, + callback: (func_name: string, args?: string, kwargs?: string) => Promise, + ) => unknown + dump_json: (value: unknown) => string | undefined } From 1bbe51fdd1ef3b759dac3b4f61a83c7e83cc425a Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Sat, 19 Apr 2025 07:59:49 -0700 Subject: [PATCH 04/10] fix tests --- mcp-run-python/test_mcp_servers.py | 1 - 1 file changed, 1 deletion(-) diff --git a/mcp-run-python/test_mcp_servers.py b/mcp-run-python/test_mcp_servers.py index a11a6d5ba..a45adb764 100644 --- a/mcp-run-python/test_mcp_servers.py +++ b/mcp-run-python/test_mcp_servers.py @@ -147,7 +147,6 @@ async def test_list_tools(mcp_session: ClientSession) -> None: print(unknown) ^^^^^^^ NameError: name 'unknown' is not defined - \ """), id='undefined-variable', From f57365a3e5a79e3a41a2d91882ef06e1651b98e9 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Sat, 19 Apr 2025 08:14:18 -0700 Subject: [PATCH 05/10] tweam tool description --- mcp-run-python/src/main.ts | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/mcp-run-python/src/main.ts b/mcp-run-python/src/main.ts index c1e80a4e3..e38a2078a 100644 --- a/mcp-run-python/src/main.ts +++ b/mcp-run-python/src/main.ts @@ -65,13 +65,23 @@ The code may be async, and the value on the last line will be returned as the re The code will be executed with Python 3.12. -Dependencies may be defined via PEP 723 script metadata, e.g. to install "pydantic", the script should start -with a comment of the form: +Dependencies may be defined via PEP 723 script metadata. +To make HTTP requests, you must use the "httpx" library in async mode. + +For example: + +\`\`\`python # /// script -# dependencies = ['pydantic'] +# dependencies = ['httpx'] # /// -print('python code here') +import httpx + +async with httpx.AsyncClient() as client: + response = await client.get('https://example.com') +# return the text of the page +response.text +\`\`\` ` let setLogLevel: LoggingLevel = 'emergency' From 63a47797debd933e2118dc8e2af290fbd876902e Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Sun, 20 Apr 2025 08:56:26 -0700 Subject: [PATCH 06/10] simplify callback registration --- mcp-run-python/src/main.ts | 59 +++++++++++++++++++++++++------------- 1 file changed, 39 insertions(+), 20 deletions(-) diff --git a/mcp-run-python/src/main.ts b/mcp-run-python/src/main.ts index e38a2078a..9f8d7f6c2 100644 --- a/mcp-run-python/src/main.ts +++ b/mcp-run-python/src/main.ts @@ -16,27 +16,28 @@ const VERSION = '0.0.13' export async function main() { const { args } = Deno const flags = parseArgs(args, { - string: ['port', 'functions'], + string: ['port', 'callbacks'], default: { port: '3001' }, }) - const functions = flags.functions ? flags.functions.split(',') : [] - if (args.length >= 1 && args[0] === 'stdio') { - await runStdio(functions) - } else if (args.length >= 1 && args[0] === 'sse') { - const port = parseInt(flags.port) - runSse(functions, port) - } else if (args.length === 1 && args[0] === 'warmup') { - await warmup() + const { _: [task], callbacks, port } = flags + if (task === 'stdio') { + await runStdio(callbacks) + } else if (task === 'sse') { + runSse(parseInt(port), callbacks) + } else if (task === 'warmup') { + await warmup(callbacks) } else { console.error( `\ Invalid arguments. -Usage: deno run -N -R=node_modules -W=node_modules --node-modules-dir=auto jsr:@pydantic/mcp-run-python [stdio|sse|warmup] +Usage: + deno run -N -R=node_modules -W=node_modules --node-modules-dir=auto \\ + jsr:@pydantic/mcp-run-python [stdio|sse|warmup] options: - --port Port to run the SSE server on (default: 3001) - --functions Comma separated list of client functions which the server can call`, + --port Port to run the SSE server on (default: 3001). + --callbacks Python code representing the signatures of client functions the server can call.`, ) Deno.exit(1) } @@ -45,7 +46,8 @@ options: /* * Create an MCP server with the `run_python_code` tool registered. */ -function createServer(functions: string[]): McpServer { +function createServer(callbacks?: string): McpServer { + const functions = _extractFunctions(callbacks) const server = new McpServer( { name: 'MCP Run Python', @@ -59,9 +61,9 @@ function createServer(functions: string[]): McpServer { }, ) - const toolDescription = `Tool to execute Python code and return stdout, stderr, and return value. + let toolDescription = `Tool to execute Python code and return stdout, stderr, and return value. -The code may be async, and the value on the last line will be returned as the return value. +The code may be async, and the value on the last line will be returned as the return. The code will be executed with Python 3.12. @@ -83,6 +85,15 @@ async with httpx.AsyncClient() as client: response.text \`\`\` ` + if (callbacks) { + toolDescription += ` +The following functions are globally available to call: + +\`\`\`python +${callbacks} +\`\`\` + ` + } let setLogLevel: LoggingLevel = 'emergency' @@ -135,8 +146,8 @@ response.text /* * Run the MCP server using the SSE transport, e.g. over HTTP. */ -function runSse(functions: string[], port: number) { - const mcpServer = createServer(functions) +function runSse(port: number, callbacks?: string) { + const mcpServer = createServer(callbacks) const transports: { [sessionId: string]: SSEServerTransport } = {} const server = http.createServer(async (req, res) => { @@ -191,8 +202,8 @@ function runSse(functions: string[], port: number) { /* * Run the MCP server using the Stdio transport. */ -async function runStdio(functions: string[]) { - const mcpServer = createServer(functions) +async function runStdio(callbacks?: string) { + const mcpServer = createServer(callbacks) const transport = new StdioServerTransport() await mcpServer.connect(transport) } @@ -200,7 +211,11 @@ async function runStdio(functions: string[]) { /* * Run pyodide to download packages which can otherwise interrupt the server */ -async function warmup() { +async function warmup(callbacks?: string) { + if (callbacks) { + const functions = _extractFunctions(callbacks) + console.error(`Functions extracted from callbacks: ${JSON.stringify(functions)}`) + } console.error( `Running warmup script for MCP Run Python version ${VERSION}...`, ) @@ -222,6 +237,10 @@ a console.log('\nwarmup successful 🎉') } +function _extractFunctions(callbacks?: string): string[] { + return callbacks ? [...callbacks.matchAll(/^async def (\w+)/g).map(([, f]) => f)] : [] +} + // list of log levels to use for level comparison const LogLevels: LoggingLevel[] = [ 'debug', From ac8ddfcdacd503a3a5628a1b352cd6d3e80a139e Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Sun, 20 Apr 2025 09:14:09 -0700 Subject: [PATCH 07/10] uprev.py --- mcp-run-python/{deno.json => deno.jsonc} | 2 +- mcp-run-python/uprev.py | 31 ++++++++++++++++++++++++ 2 files changed, 32 insertions(+), 1 deletion(-) rename mcp-run-python/{deno.json => deno.jsonc} (98%) create mode 100644 mcp-run-python/uprev.py diff --git a/mcp-run-python/deno.json b/mcp-run-python/deno.jsonc similarity index 98% rename from mcp-run-python/deno.json rename to mcp-run-python/deno.jsonc index cbe71d74a..9e079919e 100644 --- a/mcp-run-python/deno.json +++ b/mcp-run-python/deno.jsonc @@ -32,7 +32,7 @@ "src/*.ts", "src/prepareEnvCode.ts", // required to override gitignore "README.md", - "deno.json" + "deno.jsonc" ] } } diff --git a/mcp-run-python/uprev.py b/mcp-run-python/uprev.py new file mode 100644 index 000000000..e653d59b5 --- /dev/null +++ b/mcp-run-python/uprev.py @@ -0,0 +1,31 @@ +import re +import sys + +from pathlib import Path + +if len(sys.argv) != 2: + print("Usage: python uprev.py ") + sys.exit(1) + +new_version = sys.argv[1] +this_dir = Path(__file__).parent + +path_regexes = [ + (this_dir / 'deno.jsonc', r'^\s+"version": "(.+?)"'), + (this_dir / 'src/main.ts', "^const VERSION = '(.+?)'"), + (this_dir / '../pydantic_ai_slim/pydantic_ai/mcp_run_python.py', "^MCP_RUN_PYTHON_VERSION = '(.+?)'") +] + +def replace_version(m: re.Match[str]) -> str: + version = m.group(1) + return m.group(0).replace(version, new_version) + +if __name__ == "__main__": + for path, regex in path_regexes: + path = path.resolve() + content = path.read_text() + content, count = re.subn(regex, replace_version, content, count=1, flags=re.M) + if count != 1: + raise ValueError(f"Failed to update version in {path}") + path.write_text(content) + print(f"Updated version to {new_version} in {path}") From 3e6b6938a22df957008566836774d3be03b36d22 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Sun, 20 Apr 2025 10:18:35 -0700 Subject: [PATCH 08/10] working on client tooling --- mcp-run-python/src/main.ts | 2 +- mcp-run-python/src/prepare_env.py | 2 +- mcp-run-python/test_mcp_servers.py | 2 +- mcp-run-python/uprev.py | 22 ++++--- .../pydantic_ai/mcp_run_python.py | 65 +++++++++++++++++++ 5 files changed, 80 insertions(+), 13 deletions(-) create mode 100644 pydantic_ai_slim/pydantic_ai/mcp_run_python.py diff --git a/mcp-run-python/src/main.ts b/mcp-run-python/src/main.ts index 9f8d7f6c2..3bb485238 100644 --- a/mcp-run-python/src/main.ts +++ b/mcp-run-python/src/main.ts @@ -87,7 +87,7 @@ response.text ` if (callbacks) { toolDescription += ` -The following functions are globally available to call: +The following functions are already defined and available to call: \`\`\`python ${callbacks} diff --git a/mcp-run-python/src/prepare_env.py b/mcp-run-python/src/prepare_env.py index c4b1bc80a..32422c296 100644 --- a/mcp-run-python/src/prepare_env.py +++ b/mcp-run-python/src/prepare_env.py @@ -139,7 +139,7 @@ def _json_fallback(value: Any) -> Any: elif module == 'pyodide.ffi': return value.to_py() else: - return repr(value) + return str(value) def _add_extra_dependencies(dependencies: list[str]) -> list[str]: diff --git a/mcp-run-python/test_mcp_servers.py b/mcp-run-python/test_mcp_servers.py index a45adb764..8886bc08e 100644 --- a/mcp-run-python/test_mcp_servers.py +++ b/mcp-run-python/test_mcp_servers.py @@ -19,9 +19,9 @@ DENO_ARGS = [ 'run', '-N', + '--node-modules-dir=auto', '-R=mcp-run-python/node_modules', '-W=mcp-run-python/node_modules', - '--node-modules-dir=auto', 'mcp-run-python/src/main.ts', ] diff --git a/mcp-run-python/uprev.py b/mcp-run-python/uprev.py index e653d59b5..f66cb0d64 100644 --- a/mcp-run-python/uprev.py +++ b/mcp-run-python/uprev.py @@ -1,31 +1,33 @@ import re import sys - from pathlib import Path if len(sys.argv) != 2: - print("Usage: python uprev.py ") + print('Usage: python uprev.py ') sys.exit(1) new_version = sys.argv[1] this_dir = Path(__file__).parent +root_dir = (this_dir / '..').resolve() path_regexes = [ (this_dir / 'deno.jsonc', r'^\s+"version": "(.+?)"'), (this_dir / 'src/main.ts', "^const VERSION = '(.+?)'"), - (this_dir / '../pydantic_ai_slim/pydantic_ai/mcp_run_python.py', "^MCP_RUN_PYTHON_VERSION = '(.+?)'") + (root_dir / 'pydantic_ai_slim/pydantic_ai/mcp_run_python.py', "^MCP_RUN_PYTHON_VERSION = '(.+?)'"), ] -def replace_version(m: re.Match[str]) -> str: - version = m.group(1) - return m.group(0).replace(version, new_version) -if __name__ == "__main__": +if __name__ == '__main__': for path, regex in path_regexes: - path = path.resolve() + path_pretty = path.relative_to(root_dir) + + def replace_version(m: re.Match[str]) -> str: + version = m.group(1) + print(f'Updated version from {version} to {new_version} in {path_pretty}') + return m.group(0).replace(version, new_version) + content = path.read_text() content, count = re.subn(regex, replace_version, content, count=1, flags=re.M) if count != 1: - raise ValueError(f"Failed to update version in {path}") + raise ValueError(f'Failed to update version in {path}') path.write_text(content) - print(f"Updated version to {new_version} in {path}") diff --git a/pydantic_ai_slim/pydantic_ai/mcp_run_python.py b/pydantic_ai_slim/pydantic_ai/mcp_run_python.py new file mode 100644 index 000000000..339ae9171 --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/mcp_run_python.py @@ -0,0 +1,65 @@ +import ast +import inspect +from collections.abc import Awaitable, Sequence +from typing import Any, Callable, Literal + +from .mcp import MCPServerStdio + +__all__ = ('mcp_run_python_stdio',) + +MCP_RUN_PYTHON_VERSION = '0.0.13' +Callback = Callable[..., Awaitable[Any]] + + +def mcp_run_python_stdio(callbacks: Sequence[Callback] = (), *, local_code: bool = False) -> MCPServerStdio: + """Prepare a server server connection using `'stdio'` transport. + + Args: + callbacks: A sequence of callback functions to be register on the server. + local_code: Whether to run local `mcp-run-python` code. + + Returns: + A server connection definition. + """ + return MCPServerStdio('deno', args=_deno_args('stdio', callbacks, local_code)) + + +def _deno_args(mode: Literal['stdio', 'sse'], callbacks: Sequence[Callback], local_code: bool) -> list[str]: + path_prefix = 'mcp-run-python/' if local_code else '' + args = [ + 'run', + '-N', + f'-R={path_prefix}node_modules', + f'-W={path_prefix}node_modules', + '--node-modules-dir=auto', + 'mcp-run-python/src/main.ts' if local_code else f'jsr:@pydantic/mcp-run-python@{MCP_RUN_PYTHON_VERSION}', + mode, + ] + + if callbacks: + sigs = '\n\n'.join(_callback_signature(cb) for cb in callbacks) + args += ['--callbacks', sigs] + return args + + +def _callback_signature(func: Callback) -> str: + """Extract the signature of a function. + + This simply means getting the source code of the function, and removing the body of the function while keeping the docstring. + """ + source = inspect.getsource(func) + ast_mod = ast.parse(source) + assert isinstance(ast_mod, ast.Module), f'Expected Module, got {type(ast_mod)}' + assert len(ast_mod.body) == 1, f'Expected single function definition, got {len(ast_mod.body)}' + f = ast_mod.body[0] + assert isinstance(f, ast.AsyncFunctionDef), f'Expected an async function, got {type(func)}' + lines = lines = source.splitlines() + e = f.body[0] + # if the first expression is a docstring, keep it and no need for an ellipsis as the body + if isinstance(e, ast.Expr) and isinstance(e.value, ast.Constant) and isinstance(e.value.value, str): + e = f.body[1] + lines = lines[: e.lineno - 1] + else: + lines = lines[: e.lineno - 1] + lines.append(e.col_offset * ' ' + '...') + return '\n'.join(lines) From 786c1efb63cea0661e013851d0bf9d59e7ccd534 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Sun, 20 Apr 2025 16:21:32 -0700 Subject: [PATCH 09/10] cleanup server args --- pydantic_ai_slim/pydantic_ai/mcp_run_python.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/mcp_run_python.py b/pydantic_ai_slim/pydantic_ai/mcp_run_python.py index 339ae9171..5ca12af68 100644 --- a/pydantic_ai_slim/pydantic_ai/mcp_run_python.py +++ b/pydantic_ai_slim/pydantic_ai/mcp_run_python.py @@ -21,18 +21,21 @@ def mcp_run_python_stdio(callbacks: Sequence[Callback] = (), *, local_code: bool Returns: A server connection definition. """ - return MCPServerStdio('deno', args=_deno_args('stdio', callbacks, local_code)) + return MCPServerStdio( + 'deno', + args=_deno_args('stdio', callbacks, local_code), + cwd='mcp-run-python' if local_code else None, + ) def _deno_args(mode: Literal['stdio', 'sse'], callbacks: Sequence[Callback], local_code: bool) -> list[str]: - path_prefix = 'mcp-run-python/' if local_code else '' args = [ 'run', '-N', - f'-R={path_prefix}node_modules', - f'-W={path_prefix}node_modules', + '-R=node_modules', + '-W=node_modules', '--node-modules-dir=auto', - 'mcp-run-python/src/main.ts' if local_code else f'jsr:@pydantic/mcp-run-python@{MCP_RUN_PYTHON_VERSION}', + 'src/main.ts' if local_code else f'jsr:@pydantic/mcp-run-python@{MCP_RUN_PYTHON_VERSION}', mode, ] From c6b2955ffdea1b973cb8f38f4f2579e07598016c Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Sun, 20 Apr 2025 18:49:33 -0700 Subject: [PATCH 10/10] support sampling in pydantic-ai --- mcp-run-python/src/main.ts | 8 +- pydantic_ai_slim/pydantic_ai/mcp.py | 44 ++++- .../pydantic_ai/mcp_run_python.py | 178 +++++++++++++++++- 3 files changed, 214 insertions(+), 16 deletions(-) diff --git a/mcp-run-python/src/main.ts b/mcp-run-python/src/main.ts index 3bb485238..ba46f7d72 100644 --- a/mcp-run-python/src/main.ts +++ b/mcp-run-python/src/main.ts @@ -22,7 +22,7 @@ export async function main() { const { _: [task], callbacks, port } = flags if (task === 'stdio') { await runStdio(callbacks) - } else if (task === 'sse') { + } else if (task === 'sse' || task === 'http') { runSse(parseInt(port), callbacks) } else if (task === 'warmup') { await warmup(callbacks) @@ -87,7 +87,7 @@ response.text ` if (callbacks) { toolDescription += ` -The following functions are already defined and available to call: +The following functions are already defined globally and available to call from within your code: \`\`\`python ${callbacks} @@ -122,8 +122,8 @@ ${callbacks} const { content } = await server.server.createMessage({ messages: [], maxTokens: 0, - systemPrompt: '__python_function_call__', - metadata: { func, args, kwargs }, + systemPrompt: '', + metadata: { pydantic_custom_use: '__python_function_call__', func, args, kwargs }, }) if (content.type !== 'text') { throw new Error('Expected return content type to be "text"') diff --git a/pydantic_ai_slim/pydantic_ai/mcp.py b/pydantic_ai_slim/pydantic_ai/mcp.py index c5dd95f2c..1fff18ea8 100644 --- a/pydantic_ai_slim/pydantic_ai/mcp.py +++ b/pydantic_ai_slim/pydantic_ai/mcp.py @@ -1,12 +1,12 @@ from __future__ import annotations from abc import ABC, abstractmethod -from collections.abc import AsyncIterator, Sequence +from collections.abc import AsyncIterator, Awaitable, Sequence from contextlib import AsyncExitStack, asynccontextmanager from dataclasses import dataclass from pathlib import Path from types import TracebackType -from typing import Any +from typing import Any, Callable from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from mcp.types import JSONRPCMessage, LoggingLevel @@ -15,10 +15,11 @@ from pydantic_ai.tools import ToolDefinition try: + from mcp import types as mcp_types from mcp.client.session import ClientSession from mcp.client.sse import sse_client from mcp.client.stdio import StdioServerParameters, stdio_client - from mcp.types import CallToolResult + from mcp.shared.context import RequestContext except ImportError as _import_error: raise ImportError( 'Please install the `mcp` package to use the MCP server, ' @@ -27,6 +28,11 @@ __all__ = 'MCPServer', 'MCPServerStdio', 'MCPServerHTTP' +OptionalSamplingFunction = Callable[ + [RequestContext['ClientSession', Any], mcp_types.CreateMessageRequestParams], + Awaitable[mcp_types.CreateMessageResult | mcp_types.ErrorData | None], +] + class MCPServer(ABC): """Base class for attaching agents to MCP servers. @@ -55,8 +61,22 @@ async def client_streams( @abstractmethod def _get_log_level(self) -> LoggingLevel | None: """Get the log level for the MCP server.""" + + @abstractmethod + def _custom_sampling_callback(self) -> OptionalSamplingFunction | None: + """Maybe get a sampling callback function for this server definition.""" raise NotImplementedError('MCP Server subclasses must implement this method.') + async def _sampling_callback( + self, context: RequestContext[ClientSession, Any], params: mcp_types.CreateMessageRequestParams + ) -> mcp_types.CreateMessageResult | mcp_types.ErrorData: + """MCP sampling callback.""" + if custom_sampling_callback := self._custom_sampling_callback(): + if result := await custom_sampling_callback(context, params): + return result + + raise NotImplementedError('MCP Sampling not yet implemented, except for custom sampling callbacks') + async def list_tools(self) -> list[ToolDefinition]: """Retrieve tools that are currently active on the server. @@ -74,7 +94,7 @@ async def list_tools(self) -> list[ToolDefinition]: for tool in tools.tools ] - async def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> CallToolResult: + async def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> mcp_types.CallToolResult: """Call a tool on the server. Args: @@ -90,7 +110,9 @@ async def __aenter__(self) -> Self: self._exit_stack = AsyncExitStack() self._read_stream, self._write_stream = await self._exit_stack.enter_async_context(self.client_streams()) - client = ClientSession(read_stream=self._read_stream, write_stream=self._write_stream) + client = ClientSession( + read_stream=self._read_stream, write_stream=self._write_stream, sampling_callback=self._sampling_callback + ) self._client = await self._exit_stack.enter_async_context(client) await self._client.initialize() @@ -168,6 +190,9 @@ async def main(): cwd: str | Path | None = None """The working directory to use when spawning the process.""" + custom_sampling_callback: OptionalSamplingFunction | None = None + """Optional callback function for sampling.""" + @asynccontextmanager async def client_streams( self, @@ -181,6 +206,9 @@ async def client_streams( def _get_log_level(self) -> LoggingLevel | None: return self.log_level + def _custom_sampling_callback(self) -> OptionalSamplingFunction | None: + return self.custom_sampling_callback + @dataclass class MCPServerHTTP(MCPServer): @@ -248,6 +276,9 @@ async def main(): If `None`, no log level will be set. """ + custom_sampling_callback: OptionalSamplingFunction | None = None + """Optional callback function for sampling.""" + @asynccontextmanager async def client_streams( self, @@ -261,3 +292,6 @@ async def client_streams( def _get_log_level(self) -> LoggingLevel | None: return self.log_level + + def _custom_sampling_callback(self) -> OptionalSamplingFunction | None: + return self.custom_sampling_callback diff --git a/pydantic_ai_slim/pydantic_ai/mcp_run_python.py b/pydantic_ai_slim/pydantic_ai/mcp_run_python.py index 5ca12af68..6b46b7904 100644 --- a/pydantic_ai_slim/pydantic_ai/mcp_run_python.py +++ b/pydantic_ai_slim/pydantic_ai/mcp_run_python.py @@ -1,11 +1,24 @@ import ast import inspect -from collections.abc import Awaitable, Sequence -from typing import Any, Callable, Literal +import subprocess +from collections.abc import AsyncIterator, Awaitable, Sequence +from contextlib import asynccontextmanager +from dataclasses import dataclass +from time import time +from typing import Any, Callable, Literal, cast, override -from .mcp import MCPServerStdio +import anyio +import httpx +import pydantic_core +from mcp import ClientSession, types as mcp_types +from mcp.shared.context import RequestContext +from pydantic import BaseModel, Json +from pydantic._internal._validate_call import ValidateCallWrapper # pyright: ignore[reportPrivateImportUsage] +from typing_extensions import TypedDict -__all__ = ('mcp_run_python_stdio',) +from .mcp import MCPServerHTTP, MCPServerStdio + +__all__ = 'mcp_run_python_stdio', 'MCPRunPythonHTTP' MCP_RUN_PYTHON_VERSION = '0.0.13' Callback = Callable[..., Awaitable[Any]] @@ -16,7 +29,7 @@ def mcp_run_python_stdio(callbacks: Sequence[Callback] = (), *, local_code: bool Args: callbacks: A sequence of callback functions to be register on the server. - local_code: Whether to run local `mcp-run-python` code. + local_code: Whether to run local `mcp-run-python` code, this is mostly used for development and testing. Returns: A server connection definition. @@ -25,10 +38,76 @@ def mcp_run_python_stdio(callbacks: Sequence[Callback] = (), *, local_code: bool 'deno', args=_deno_args('stdio', callbacks, local_code), cwd='mcp-run-python' if local_code else None, + custom_sampling_callback=_PythonSamplingCallback(callbacks) if callbacks else None, ) -def _deno_args(mode: Literal['stdio', 'sse'], callbacks: Sequence[Callback], local_code: bool) -> list[str]: +@dataclass +class MCPRunPythonHTTP: + """Setup for `mcp-run-python` running with HTTP transport.""" + + callbacks: Sequence[Callback] = () + """Callbacks to be registered on the server.""" + port: int = 3001 + """Port to run the server on.""" + local_code: bool = False + """Whether to run local `mcp-run-python` code, this is mostly used for development and testing.""" + + @property + def url(self) -> str: + """URL the server will be run on.""" + return f'http://localhost:{self.port}/sse' + + def server_def(self, url: str | None = None) -> MCPServerHTTP: + """Create a server definition to pass to a pydantic-ai [`Agent`][pydantic_ai.Agent].""" + return MCPServerHTTP( + url or self.url, + custom_sampling_callback=_PythonSamplingCallback(self.callbacks) if self.callbacks else None, + ) + + def run(self) -> None: + """Run the server and block until it is terminated.""" + try: + subprocess.run(self._args(), cwd=self._cwd(), check=True) + except KeyboardInterrupt: + pass + + @asynccontextmanager + async def run_context(self, server_wait_timeout: float | None = 2) -> AsyncIterator[None]: + """Run the server as an async context manager. + + Args: + server_wait_timeout: The timeout in seconds to wait for the server to start, or `None` to not wait. + """ + p = await anyio.open_process(self._args(), cwd=self._cwd(), stdout=None, stderr=None) + async with p: + if server_wait_timeout: + await self.wait_for_server(server_wait_timeout) + yield + p.terminate() + + async def wait_for_server(self, timeout: float = 2): + """Wait for the server to be ready.""" + async with httpx.AsyncClient(timeout=0.01) as client: + start = time() + while True: + try: + await client.head(self.url) + except httpx.RequestError: + if time() - start > timeout: + raise TimeoutError(f'Server did not start within {timeout} seconds') + await anyio.sleep(0.1) + else: + break + + def _args(self) -> list[str]: + return ['deno'] + _deno_args('http', self.callbacks, self.local_code) + ['--port', str(self.port)] + + def _cwd(self) -> str | None: + return 'mcp-run-python' if self.local_code else None + + +def _deno_args(mode: Literal['stdio', 'http'], callbacks: Sequence[Callback], local_code: bool) -> list[str]: args = [ 'run', '-N', @@ -56,7 +135,7 @@ def _callback_signature(func: Callback) -> str: assert len(ast_mod.body) == 1, f'Expected single function definition, got {len(ast_mod.body)}' f = ast_mod.body[0] assert isinstance(f, ast.AsyncFunctionDef), f'Expected an async function, got {type(func)}' - lines = lines = source.splitlines() + lines = source.splitlines() e = f.body[0] # if the first expression is a docstring, keep it and no need for an ellipsis as the body if isinstance(e, ast.Expr) and isinstance(e.value, ast.Constant) and isinstance(e.value.value, str): @@ -65,4 +144,89 @@ def _callback_signature(func: Callback) -> str: else: lines = lines[: e.lineno - 1] lines.append(e.col_offset * ' ' + '...') + + # if the function has any decorators, this will remove them. + if f.lineno != 1: + lines = lines[f.lineno - 1 :] + return '\n'.join(lines) + + +class _PythonSamplingCallback: + def __init__(self, callbacks: Sequence[Callback]): + self.function_lookup: dict[str, ValidateCallWrapper] = {} + for callback in callbacks: + name = callback.__name__ + if name in self.function_lookup: + raise ValueError(f'Duplicate callback name: {name}') + self.function_lookup[name] = ValidateCallWrapper( + callback, # pyright: ignore[reportArgumentType] + None, + False, + None, + ) + + async def __call__( + self, + context: RequestContext[ClientSession, Any], + params: mcp_types.CreateMessageRequestParams, + ) -> mcp_types.CreateMessageResult | mcp_types.ErrorData | None: + if not params.metadata or params.metadata.get('pydantic_custom_use') != '__python_function_call__': + return None + + call_metadata = _PythonCallMetadata.model_validate(params.metadata) + if function_wrapper := self.function_lookup.get(call_metadata.func): + content: _CallSuccess | _CallError + try: + return_value = await function_wrapper.__pydantic_validator__.validate_python(call_metadata.args_kwargs) + except ValueError as e: + # special support for ValueError since it's commonly subclassed, and it's the parent of ValidationError + # TODO we should probably have specific support for other common errors + content = _CallError(exc_type='ValueError', message=str(e), kind='error') + except Exception as e: + content = _CallError(exc_type=e.__class__.__name__, message=str(e), kind='error') + else: + content = _CallSuccess(return_value=return_value, kind='success') + + content_text = pydantic_core.to_json(content, fallback=_json_fallback).decode() + return mcp_types.CreateMessageResult( + role='assistant', content=mcp_types.TextContent(type='text', text=content_text), model='python' + ) + else: + raise LookupError(f'Function `{call_metadata.func}` not found') + + @override + def __repr__(self) -> str: + return f'<_PythonSamplingCallback: {", ".join(map(repr, self.function_lookup))}>' + + +class _PythonCallMetadata(BaseModel): + func: str + args: Json[list[Any]] | None = None # JSON + kwargs: Json[dict[str, Any]] | None = None # JSON + + @property + def args_kwargs(self) -> pydantic_core.ArgsKwargs: + return pydantic_core.ArgsKwargs(tuple(self.args or ()), self.kwargs) + + +class _CallSuccess(TypedDict): + return_value: Any + kind: Literal['success'] + + +class _CallError(TypedDict): + exc_type: str + message: str + kind: Literal['error'] + + +def _json_fallback(value: Any) -> Any: + tp = cast(Any, type(value)) + if tp.__module__ == 'numpy': + if tp.__name__ in {'ndarray', 'matrix'}: + return value.tolist() + else: + return value.item() + else: + return repr(value)