Skip to content

MCP function calling #1550

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mcp-run-python/deno.json → mcp-run-python/deno.jsonc
Original file line number Diff line number Diff line change
@@ -32,7 +32,7 @@
"src/*.ts",
"src/prepareEnvCode.ts", // required to override gitignore
"README.md",
"deno.json"
"deno.jsonc"
]
}
}
104 changes: 76 additions & 28 deletions mcp-run-python/src/main.ts
Original file line number Diff line number Diff line change
@@ -15,26 +15,29 @@ const VERSION = '0.0.13'

export async function main() {
const { args } = Deno
if (args.length === 1 && args[0] === 'stdio') {
await runStdio()
} else if (args.length >= 1 && args[0] === 'sse') {
const flags = parseArgs(Deno.args, {
string: ['port'],
default: { port: '3001' },
})
const port = parseInt(flags.port)
runSse(port)
} else if (args.length === 1 && args[0] === 'warmup') {
await warmup()
const flags = parseArgs(args, {
string: ['port', 'callbacks'],
default: { port: '3001' },
})
const { _: [task], callbacks, port } = flags
if (task === 'stdio') {
await runStdio(callbacks)
} else if (task === 'sse' || task === 'http') {
runSse(parseInt(port), callbacks)
} else if (task === 'warmup') {
await warmup(callbacks)
} else {
console.error(
`\
Invalid arguments.

Usage: deno run -N -R=node_modules -W=node_modules --node-modules-dir=auto jsr:@pydantic/mcp-run-python [stdio|sse|warmup]
Usage:
deno run -N -R=node_modules -W=node_modules --node-modules-dir=auto \\
jsr:@pydantic/mcp-run-python [stdio|sse|warmup]

options:
--port <port> Port to run the SSE server on (default: 3001)`,
--port <port> Port to run the SSE server on (default: 3001).
--callbacks <python-signatures> Python code representing the signatures of client functions the server can call.`,
)
Deno.exit(1)
}
@@ -43,7 +46,8 @@ options:
/*
* Create an MCP server with the `run_python_code` tool registered.
*/
function createServer(): McpServer {
function createServer(callbacks?: string): McpServer {
const functions = _extractFunctions(callbacks)
const server = new McpServer(
{
name: 'MCP Run Python',
@@ -57,20 +61,39 @@ function createServer(): McpServer {
},
)

const toolDescription = `Tool to execute Python code and return stdout, stderr, and return value.
let toolDescription = `Tool to execute Python code and return stdout, stderr, and return value.

The code may be async, and the value on the last line will be returned as the return value.
The code may be async, and the value on the last line will be returned as the return.

The code will be executed with Python 3.12.

Dependencies may be defined via PEP 723 script metadata, e.g. to install "pydantic", the script should start
with a comment of the form:
Dependencies may be defined via PEP 723 script metadata.

To make HTTP requests, you must use the "httpx" library in async mode.

For example:

\`\`\`python
# /// script
# dependencies = ['pydantic']
# dependencies = ['httpx']
# ///
print('python code here')
import httpx

async with httpx.AsyncClient() as client:
response = await client.get('https://example.com')
# return the text of the page
response.text
\`\`\`
`
if (callbacks) {
toolDescription += `
The following functions are already defined globally and available to call from within your code:

\`\`\`python
${callbacks}
\`\`\`
`
}

let setLogLevel: LoggingLevel = 'emergency'

@@ -85,29 +108,46 @@ print('python code here')
{ python_code: z.string().describe('Python code to run') },
async ({ python_code }: { python_code: string }) => {
const logPromises: Promise<void>[] = []
const result = await runCode([{
const mainPy = {
name: 'main.py',
content: python_code,
active: true,
}], (level, data) => {
}
const codeLog = (level: LoggingLevel, data: string) => {
if (LogLevels.indexOf(level) >= LogLevels.indexOf(setLogLevel)) {
logPromises.push(server.server.sendLoggingMessage({ level, data }))
}
})
}
async function clientCallback(func: string, args?: string, kwargs?: string) {
const { content } = await server.server.createMessage({
messages: [],
maxTokens: 0,
systemPrompt: '',
metadata: { pydantic_custom_use: '__python_function_call__', func, args, kwargs },
})
if (content.type !== 'text') {
throw new Error('Expected return content type to be "text"')
} else {
return content.text
}
}

const result = await runCode([mainPy], codeLog, functions, clientCallback)
await Promise.all(logPromises)
return {
content: [{ type: 'text', text: asXml(result) }],
}
},
)

return server
}

/*
* Run the MCP server using the SSE transport, e.g. over HTTP.
*/
function runSse(port: number) {
const mcpServer = createServer()
function runSse(port: number, callbacks?: string) {
const mcpServer = createServer(callbacks)
const transports: { [sessionId: string]: SSEServerTransport } = {}

const server = http.createServer(async (req, res) => {
@@ -162,16 +202,20 @@ function runSse(port: number) {
/*
* Run the MCP server using the Stdio transport.
*/
async function runStdio() {
const mcpServer = createServer()
async function runStdio(callbacks?: string) {
const mcpServer = createServer(callbacks)
const transport = new StdioServerTransport()
await mcpServer.connect(transport)
}

/*
* Run pyodide to download packages which can otherwise interrupt the server
*/
async function warmup() {
async function warmup(callbacks?: string) {
if (callbacks) {
const functions = _extractFunctions(callbacks)
console.error(`Functions extracted from callbacks: ${JSON.stringify(functions)}`)
}
console.error(
`Running warmup script for MCP Run Python version ${VERSION}...`,
)
@@ -193,6 +237,10 @@ a
console.log('\nwarmup successful 🎉')
}

function _extractFunctions(callbacks?: string): string[] {
return callbacks ? [...callbacks.matchAll(/^async def (\w+)/g).map(([, f]) => f)] : []
}

// list of log levels to use for level comparison
const LogLevels: LoggingLevel[] = [
'debug',
64 changes: 54 additions & 10 deletions mcp-run-python/src/prepare_env.py
Original file line number Diff line number Diff line change
@@ -10,15 +10,17 @@
import re
import sys
import traceback
from collections.abc import Iterable, Iterator
from collections.abc import Awaitable, Iterable, Iterator
from contextlib import contextmanager
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Literal, TypedDict
from typing import Any, Callable, Literal, TypedDict

import micropip
import pyodide_js
import tomllib
from pydantic import ConfigDict, TypeAdapter
from pydantic_core import to_json
from pyodide.code import find_imports

__all__ = 'prepare_env', 'dump_json'
@@ -31,18 +33,18 @@ class File(TypedDict):


@dataclass
class Success:
class PrepSuccess:
dependencies: list[str] | None
kind: Literal['success'] = 'success'


@dataclass
class Error:
class PrepError:
message: str
kind: Literal['error'] = 'error'


async def prepare_env(files: list[File]) -> Success | Error:
async def prepare_env(files: list[File]) -> PrepSuccess | PrepError:
sys.setrecursionlimit(400)

cwd = Path.cwd()
@@ -68,14 +70,12 @@ async def prepare_env(files: list[File]) -> Success | Error:
except Exception:
with open(logs_filename) as f:
logs = f.read()
return Error(message=f'{logs} {traceback.format_exc()}')
return PrepError(message=f'{logs} {traceback.format_exc()}')

return Success(dependencies=dependencies)
return PrepSuccess(dependencies=dependencies)


def dump_json(value: Any) -> str | None:
from pydantic_core import to_json

if value is None:
return None
if isinstance(value, str):
@@ -84,6 +84,50 @@ def dump_json(value: Any) -> str | None:
return to_json(value, indent=2, fallback=_json_fallback).decode()


class CallSuccess(TypedDict):
return_value: Any
kind: Literal['success']


class CallError(TypedDict):
exc_type: str
message: str
kind: Literal['error']


call_result_ta: TypeAdapter[CallSuccess | CallError] = TypeAdapter(
CallSuccess | CallError, config=ConfigDict(defer_build=True)
)


@dataclass(slots=True)
class RegisterFunction:
_func_name: str
_callback: Callable[[str, str | None, str | None], Awaitable[str]]

async def __call__(self, *args: Any, **kwargs: Any) -> Any:
result_json = await self._callback(self._func_name, _dump_args(args), _dump_args(kwargs))
result = call_result_ta.validate_json(result_json)
if result['kind'] == 'success':
return result['return_value']

exc_type, message = result['exc_type'], result['message']
try:
exc_type_ = __builtins__[exc_type]
except KeyError:
raise Exception(f'{message}\n(Raised exception type: {exc_type})')
else:
raise exc_type_(message)

def __repr__(self) -> str:
return f'<client callback {self._func_name}>'


def _dump_args(value: Any) -> str | None:
if value:
return to_json(value, fallback=_json_fallback).decode()


def _json_fallback(value: Any) -> Any:
tp: Any = type(value)
module = tp.__module__
@@ -95,7 +139,7 @@ def _json_fallback(value: Any) -> Any:
elif module == 'pyodide.ffi':
return value.to_py()
else:
return repr(value)
return str(value)


def _add_extra_dependencies(dependencies: list[str]) -> list[str]:
30 changes: 25 additions & 5 deletions mcp-run-python/src/runCode.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
/* eslint @typescript-eslint/no-explicit-any: off */
import { loadPyodide } from 'pyodide'
import { preparePythonCode } from './prepareEnvCode.ts'
import type { LoggingLevel } from '@modelcontextprotocol/sdk/types.js'
@@ -12,6 +11,8 @@ export interface CodeFile {
export async function runCode(
files: CodeFile[],
log: (level: LoggingLevel, data: string) => void,
functionNames?: string[],
clientCallback?: (func: string, args?: string, kwargs?: string) => Promise<string>,
): Promise<RunSuccess | RunError> {
// remove once https://github.com/pyodide/pyodide/pull/5514 is released
const realConsoleLog = console.log
@@ -58,6 +59,14 @@ export async function runCode(

const prepareStatus = await preparePyEnv.prepare_env(pyodide.toPy(files))

const globals: Record<string, unknown> = { __name__: '__main__' }

if (functionNames && clientCallback) {
for (const functionName of functionNames) {
globals[functionName] = preparePyEnv.RegisterFunction(functionName, clientCallback)
}
}

let runResult: RunSuccess | RunError
if (prepareStatus.kind == 'error') {
runResult = {
@@ -70,7 +79,7 @@ export async function runCode(
const activeFile = files.find((f) => f.active)! || files[0]
try {
const rawValue = await pyodide.runPythonAsync(activeFile.content, {
globals: pyodide.toPy({ __name__: '__main__' }),
globals: pyodide.toPy(globals),
filename: activeFile.name,
})
runResult = {
@@ -99,7 +108,7 @@ interface RunSuccess {
// we could record stdout and stderr separately, but I suspect simplicity is more important
output: string[]
dependencies: string[]
returnValueJson: string | null
returnValueJson: string | undefined
}

interface RunError {
@@ -153,6 +162,13 @@ function formatError(err: any): string {
/ {2}File "\/lib\/python\d+\.zip\/_pyodide\/.*\n {4}.*\n(?: {4,}\^+\n)?/g,
'',
)
// remove frames from _prepare_env.py
errStr = errStr.replace(
/ {2}File "\/tmp\/mcp_run_python\/_prepare_env.py".*\n {4,}.+\n/g,
'',
)
// remove trailing newlines
errStr = errStr.replace(/\n+$/, '')
return errStr
}

@@ -164,8 +180,12 @@ interface PrepareError {
kind: 'error'
message: string
}

interface PreparePyEnv {
prepare_env: (files: CodeFile[]) => Promise<PrepareSuccess | PrepareError>
// deno-lint-ignore no-explicit-any
dump_json: (value: any) => string | null
RegisterFunction: (
func_name: string,
callback: (func_name: string, args?: string, kwargs?: string) => Promise<string>,
) => unknown
dump_json: (value: unknown) => string | undefined
}
3 changes: 1 addition & 2 deletions mcp-run-python/test_mcp_servers.py
Original file line number Diff line number Diff line change
@@ -19,9 +19,9 @@
DENO_ARGS = [
'run',
'-N',
'--node-modules-dir=auto',
'-R=mcp-run-python/node_modules',
'-W=mcp-run-python/node_modules',
'--node-modules-dir=auto',
'mcp-run-python/src/main.ts',
]

@@ -147,7 +147,6 @@ async def test_list_tools(mcp_session: ClientSession) -> None:
print(unknown)
^^^^^^^
NameError: name 'unknown' is not defined
</error>\
"""),
id='undefined-variable',
33 changes: 33 additions & 0 deletions mcp-run-python/uprev.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import re
import sys
from pathlib import Path

if len(sys.argv) != 2:
print('Usage: python uprev.py <new_version>')
sys.exit(1)

new_version = sys.argv[1]
this_dir = Path(__file__).parent
root_dir = (this_dir / '..').resolve()

path_regexes = [
(this_dir / 'deno.jsonc', r'^\s+"version": "(.+?)"'),
(this_dir / 'src/main.ts', "^const VERSION = '(.+?)'"),
(root_dir / 'pydantic_ai_slim/pydantic_ai/mcp_run_python.py', "^MCP_RUN_PYTHON_VERSION = '(.+?)'"),
]


if __name__ == '__main__':
for path, regex in path_regexes:
path_pretty = path.relative_to(root_dir)

def replace_version(m: re.Match[str]) -> str:
version = m.group(1)
print(f'Updated version from {version} to {new_version} in {path_pretty}')
return m.group(0).replace(version, new_version)

content = path.read_text()
content, count = re.subn(regex, replace_version, content, count=1, flags=re.M)
if count != 1:
raise ValueError(f'Failed to update version in {path}')
path.write_text(content)
44 changes: 39 additions & 5 deletions pydantic_ai_slim/pydantic_ai/mcp.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from collections.abc import AsyncIterator, Sequence
from collections.abc import AsyncIterator, Awaitable, Sequence
from contextlib import AsyncExitStack, asynccontextmanager
from dataclasses import dataclass
from pathlib import Path
from types import TracebackType
from typing import Any
from typing import Any, Callable

from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from mcp.types import JSONRPCMessage, LoggingLevel
@@ -15,10 +15,11 @@
from pydantic_ai.tools import ToolDefinition

try:
from mcp import types as mcp_types
from mcp.client.session import ClientSession
from mcp.client.sse import sse_client
from mcp.client.stdio import StdioServerParameters, stdio_client
from mcp.types import CallToolResult
from mcp.shared.context import RequestContext
except ImportError as _import_error:
raise ImportError(
'Please install the `mcp` package to use the MCP server, '
@@ -27,6 +28,11 @@

__all__ = 'MCPServer', 'MCPServerStdio', 'MCPServerHTTP'

OptionalSamplingFunction = Callable[
[RequestContext['ClientSession', Any], mcp_types.CreateMessageRequestParams],
Awaitable[mcp_types.CreateMessageResult | mcp_types.ErrorData | None],
]


class MCPServer(ABC):
"""Base class for attaching agents to MCP servers.
@@ -55,8 +61,22 @@ async def client_streams(
@abstractmethod
def _get_log_level(self) -> LoggingLevel | None:
"""Get the log level for the MCP server."""

@abstractmethod
def _custom_sampling_callback(self) -> OptionalSamplingFunction | None:
"""Maybe get a sampling callback function for this server definition."""
raise NotImplementedError('MCP Server subclasses must implement this method.')

async def _sampling_callback(
self, context: RequestContext[ClientSession, Any], params: mcp_types.CreateMessageRequestParams
) -> mcp_types.CreateMessageResult | mcp_types.ErrorData:
"""MCP sampling callback."""
if custom_sampling_callback := self._custom_sampling_callback():
if result := await custom_sampling_callback(context, params):
return result

raise NotImplementedError('MCP Sampling not yet implemented, except for custom sampling callbacks')

async def list_tools(self) -> list[ToolDefinition]:
"""Retrieve tools that are currently active on the server.
@@ -74,7 +94,7 @@ async def list_tools(self) -> list[ToolDefinition]:
for tool in tools.tools
]

async def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> CallToolResult:
async def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> mcp_types.CallToolResult:
"""Call a tool on the server.
Args:
@@ -90,7 +110,9 @@ async def __aenter__(self) -> Self:
self._exit_stack = AsyncExitStack()

self._read_stream, self._write_stream = await self._exit_stack.enter_async_context(self.client_streams())
client = ClientSession(read_stream=self._read_stream, write_stream=self._write_stream)
client = ClientSession(
read_stream=self._read_stream, write_stream=self._write_stream, sampling_callback=self._sampling_callback
)
self._client = await self._exit_stack.enter_async_context(client)

await self._client.initialize()
@@ -168,6 +190,9 @@ async def main():
cwd: str | Path | None = None
"""The working directory to use when spawning the process."""

custom_sampling_callback: OptionalSamplingFunction | None = None
"""Optional callback function for sampling."""

@asynccontextmanager
async def client_streams(
self,
@@ -181,6 +206,9 @@ async def client_streams(
def _get_log_level(self) -> LoggingLevel | None:
return self.log_level

def _custom_sampling_callback(self) -> OptionalSamplingFunction | None:
return self.custom_sampling_callback


@dataclass
class MCPServerHTTP(MCPServer):
@@ -248,6 +276,9 @@ async def main():
If `None`, no log level will be set.
"""

custom_sampling_callback: OptionalSamplingFunction | None = None
"""Optional callback function for sampling."""

@asynccontextmanager
async def client_streams(
self,
@@ -261,3 +292,6 @@ async def client_streams(

def _get_log_level(self) -> LoggingLevel | None:
return self.log_level

def _custom_sampling_callback(self) -> OptionalSamplingFunction | None:
return self.custom_sampling_callback
232 changes: 232 additions & 0 deletions pydantic_ai_slim/pydantic_ai/mcp_run_python.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,232 @@
import ast
import inspect
import subprocess
from collections.abc import AsyncIterator, Awaitable, Sequence
from contextlib import asynccontextmanager
from dataclasses import dataclass
from time import time
from typing import Any, Callable, Literal, cast, override

import anyio
import httpx
import pydantic_core
from mcp import ClientSession, types as mcp_types
from mcp.shared.context import RequestContext
from pydantic import BaseModel, Json
from pydantic._internal._validate_call import ValidateCallWrapper # pyright: ignore[reportPrivateImportUsage]
from typing_extensions import TypedDict

from .mcp import MCPServerHTTP, MCPServerStdio

__all__ = 'mcp_run_python_stdio', 'MCPRunPythonHTTP'

MCP_RUN_PYTHON_VERSION = '0.0.13'
Callback = Callable[..., Awaitable[Any]]


def mcp_run_python_stdio(callbacks: Sequence[Callback] = (), *, local_code: bool = False) -> MCPServerStdio:
"""Prepare a server server connection using `'stdio'` transport.
Args:
callbacks: A sequence of callback functions to be register on the server.
local_code: Whether to run local `mcp-run-python` code, this is mostly used for development and testing.
Returns:
A server connection definition.
"""
return MCPServerStdio(
'deno',
args=_deno_args('stdio', callbacks, local_code),
cwd='mcp-run-python' if local_code else None,
custom_sampling_callback=_PythonSamplingCallback(callbacks) if callbacks else None,
)


@dataclass
class MCPRunPythonHTTP:
"""Setup for `mcp-run-python` running with HTTP transport."""

callbacks: Sequence[Callback] = ()
"""Callbacks to be registered on the server."""
port: int = 3001
"""Port to run the server on."""
local_code: bool = False
"""Whether to run local `mcp-run-python` code, this is mostly used for development and testing."""

@property
def url(self) -> str:
"""URL the server will be run on."""
return f'http://localhost:{self.port}/sse'

def server_def(self, url: str | None = None) -> MCPServerHTTP:
"""Create a server definition to pass to a pydantic-ai [`Agent`][pydantic_ai.Agent]."""
return MCPServerHTTP(
url or self.url,
custom_sampling_callback=_PythonSamplingCallback(self.callbacks) if self.callbacks else None,
)

def run(self) -> None:
"""Run the server and block until it is terminated."""
try:
subprocess.run(self._args(), cwd=self._cwd(), check=True)
except KeyboardInterrupt:
pass

@asynccontextmanager
async def run_context(self, server_wait_timeout: float | None = 2) -> AsyncIterator[None]:
"""Run the server as an async context manager.
Args:
server_wait_timeout: The timeout in seconds to wait for the server to start, or `None` to not wait.
"""
p = await anyio.open_process(self._args(), cwd=self._cwd(), stdout=None, stderr=None)
async with p:
if server_wait_timeout:
await self.wait_for_server(server_wait_timeout)
yield
p.terminate()

async def wait_for_server(self, timeout: float = 2):
"""Wait for the server to be ready."""
async with httpx.AsyncClient(timeout=0.01) as client:
start = time()
while True:
try:
await client.head(self.url)
except httpx.RequestError:
if time() - start > timeout:
raise TimeoutError(f'Server did not start within {timeout} seconds')
await anyio.sleep(0.1)
else:
break

def _args(self) -> list[str]:
return ['deno'] + _deno_args('http', self.callbacks, self.local_code) + ['--port', str(self.port)]

def _cwd(self) -> str | None:
return 'mcp-run-python' if self.local_code else None


def _deno_args(mode: Literal['stdio', 'http'], callbacks: Sequence[Callback], local_code: bool) -> list[str]:
args = [
'run',
'-N',
'-R=node_modules',
'-W=node_modules',
'--node-modules-dir=auto',
'src/main.ts' if local_code else f'jsr:@pydantic/mcp-run-python@{MCP_RUN_PYTHON_VERSION}',
mode,
]

if callbacks:
sigs = '\n\n'.join(_callback_signature(cb) for cb in callbacks)
args += ['--callbacks', sigs]
return args


def _callback_signature(func: Callback) -> str:
"""Extract the signature of a function.
This simply means getting the source code of the function, and removing the body of the function while keeping the docstring.
"""
source = inspect.getsource(func)
ast_mod = ast.parse(source)
assert isinstance(ast_mod, ast.Module), f'Expected Module, got {type(ast_mod)}'
assert len(ast_mod.body) == 1, f'Expected single function definition, got {len(ast_mod.body)}'
f = ast_mod.body[0]
assert isinstance(f, ast.AsyncFunctionDef), f'Expected an async function, got {type(func)}'
lines = source.splitlines()
e = f.body[0]
# if the first expression is a docstring, keep it and no need for an ellipsis as the body
if isinstance(e, ast.Expr) and isinstance(e.value, ast.Constant) and isinstance(e.value.value, str):
e = f.body[1]
lines = lines[: e.lineno - 1]
else:
lines = lines[: e.lineno - 1]
lines.append(e.col_offset * ' ' + '...')

# if the function has any decorators, this will remove them.
if f.lineno != 1:
lines = lines[f.lineno - 1 :]

return '\n'.join(lines)


class _PythonSamplingCallback:
def __init__(self, callbacks: Sequence[Callback]):
self.function_lookup: dict[str, ValidateCallWrapper] = {}
for callback in callbacks:
name = callback.__name__
if name in self.function_lookup:
raise ValueError(f'Duplicate callback name: {name}')
self.function_lookup[name] = ValidateCallWrapper(
callback, # pyright: ignore[reportArgumentType]
None,
False,
None,
)

async def __call__(
self,
context: RequestContext[ClientSession, Any],
params: mcp_types.CreateMessageRequestParams,
) -> mcp_types.CreateMessageResult | mcp_types.ErrorData | None:
if not params.metadata or params.metadata.get('pydantic_custom_use') != '__python_function_call__':
return None

call_metadata = _PythonCallMetadata.model_validate(params.metadata)
if function_wrapper := self.function_lookup.get(call_metadata.func):
content: _CallSuccess | _CallError
try:
return_value = await function_wrapper.__pydantic_validator__.validate_python(call_metadata.args_kwargs)
except ValueError as e:
# special support for ValueError since it's commonly subclassed, and it's the parent of ValidationError
# TODO we should probably have specific support for other common errors
content = _CallError(exc_type='ValueError', message=str(e), kind='error')
except Exception as e:
content = _CallError(exc_type=e.__class__.__name__, message=str(e), kind='error')
else:
content = _CallSuccess(return_value=return_value, kind='success')

content_text = pydantic_core.to_json(content, fallback=_json_fallback).decode()
return mcp_types.CreateMessageResult(
role='assistant', content=mcp_types.TextContent(type='text', text=content_text), model='python'
)
else:
raise LookupError(f'Function `{call_metadata.func}` not found')

@override
def __repr__(self) -> str:
return f'<_PythonSamplingCallback: {", ".join(map(repr, self.function_lookup))}>'


class _PythonCallMetadata(BaseModel):
func: str
args: Json[list[Any]] | None = None # JSON
kwargs: Json[dict[str, Any]] | None = None # JSON

@property
def args_kwargs(self) -> pydantic_core.ArgsKwargs:
return pydantic_core.ArgsKwargs(tuple(self.args or ()), self.kwargs)


class _CallSuccess(TypedDict):
return_value: Any
kind: Literal['success']


class _CallError(TypedDict):
exc_type: str
message: str
kind: Literal['error']


def _json_fallback(value: Any) -> Any:
tp = cast(Any, type(value))
if tp.__module__ == 'numpy':
if tp.__name__ in {'ndarray', 'matrix'}:
return value.tolist()
else:
return value.item()
else:
return repr(value)