-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Add FastMCPToolset #2784
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add FastMCPToolset #2784
Changes from all commits
cffc38f
456900e
9bac437
1cf320e
edd89f2
9c4fe38
a46222f
27592c7
0362fd7
eaa45c8
4bd0334
533e879
f2be96d
0fd6929
881f306
dfcad61
8776a67
a171272
6ea1dd3
81d004d
880f355
aade6d5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,231 @@ | ||
from __future__ import annotations | ||
|
||
import base64 | ||
from asyncio import Lock | ||
from contextlib import AsyncExitStack | ||
from dataclasses import KW_ONLY, dataclass, field | ||
from pathlib import Path | ||
from typing import TYPE_CHECKING, Any, Literal, overload | ||
|
||
from fastmcp.client.transports import ClientTransport | ||
from fastmcp.mcp_config import MCPConfig | ||
from fastmcp.server import FastMCP | ||
from mcp.server.fastmcp import FastMCP as FastMCP1Server | ||
from pydantic import AnyUrl | ||
from typing_extensions import Self | ||
|
||
from pydantic_ai import messages | ||
from pydantic_ai.exceptions import ModelRetry | ||
from pydantic_ai.tools import AgentDepsT, RunContext, ToolDefinition | ||
from pydantic_ai.toolsets import AbstractToolset | ||
from pydantic_ai.toolsets.abstract import ToolsetTool | ||
|
||
try: | ||
from fastmcp.client import Client | ||
from fastmcp.exceptions import ToolError | ||
from mcp.types import ( | ||
AudioContent, | ||
ContentBlock, | ||
ImageContent, | ||
TextContent, | ||
Tool as MCPTool, | ||
) | ||
|
||
from pydantic_ai.mcp import TOOL_SCHEMA_VALIDATOR | ||
|
||
except ImportError as _import_error: | ||
raise ImportError( | ||
'Please install the `fastmcp` package to use the FastMCP server, ' | ||
'you can use the `fastmcp` optional group — `pip install "pydantic-ai-slim[fastmcp]"`' | ||
) from _import_error | ||
|
||
|
||
if TYPE_CHECKING: | ||
from fastmcp.client.client import CallToolResult | ||
|
||
|
||
FastMCPToolResult = messages.BinaryContent | dict[str, Any] | str | None | ||
|
||
ToolErrorBehavior = Literal['model_retry', 'error'] | ||
|
||
|
||
@dataclass | ||
class FastMCPToolset(AbstractToolset[AgentDepsT]): | ||
strawgate marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""A FastMCP Toolset that uses the FastMCP Client to call tools from a local or remote MCP Server. | ||
The Toolset can accept a FastMCP Client, a FastMCP Transport, or any other object which a FastMCP Transport can be created from. | ||
See https://gofastmcp.com/clients/transports for a full list of transports available. | ||
""" | ||
|
||
client: Client[Any] | ||
"""The FastMCP transport to use. This can be a local or remote MCP Server configuration, a transport string, or a FastMCP Client.""" | ||
|
||
_: KW_ONLY | ||
|
||
tool_error_behavior: Literal['model_retry', 'error'] = field(default='error') | ||
"""The behavior to take when a tool error occurs.""" | ||
|
||
max_retries: int = field(default=2) | ||
"""The maximum number of retries to attempt if a tool call fails.""" | ||
|
||
_id: str | None = field(default=None) | ||
|
||
@overload | ||
def __init__( | ||
self, | ||
*, | ||
client: Client[Any], | ||
max_retries: int = 2, | ||
tool_error_behavior: Literal['model_retry', 'error'] = 'error', | ||
id: str | None = None, | ||
) -> None: ... | ||
|
||
@overload | ||
def __init__( | ||
self, | ||
transport: ClientTransport | ||
| FastMCP | ||
| FastMCP1Server | ||
| AnyUrl | ||
| Path | ||
| MCPConfig | ||
| dict[str, Any] | ||
| str | ||
| None = None, | ||
*, | ||
max_retries: int = 2, | ||
tool_error_behavior: Literal['model_retry', 'error'] = 'error', | ||
id: str | None = None, | ||
) -> None: ... | ||
|
||
def __init__( | ||
self, | ||
transport: ClientTransport | ||
| FastMCP | ||
| FastMCP1Server | ||
| AnyUrl | ||
| Path | ||
| MCPConfig | ||
| dict[str, Any] | ||
| str | ||
| None = None, | ||
*, | ||
client: Client[Any] | None = None, | ||
max_retries: int = 2, | ||
tool_error_behavior: Literal['model_retry', 'error'] = 'error', | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd prefer for the default to be |
||
id: str | None = None, | ||
) -> None: | ||
if not client and not transport: | ||
raise ValueError('Either client or transport must be provided') | ||
|
||
if client and transport: | ||
raise ValueError('Either client or transport must be provided, not both') | ||
|
||
if client: | ||
self.client = client | ||
else: | ||
self.client = Client[Any](transport=transport) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm guessing this will raise an error if the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, I can add tests for this, most common error raised will be a |
||
|
||
self._id = id | ||
self.max_retries = max_retries | ||
self.tool_error_behavior = tool_error_behavior | ||
|
||
self._enter_lock: Lock = Lock() | ||
self._running_count: int = 0 | ||
self._exit_stack: AsyncExitStack | None = None | ||
|
||
@property | ||
def id(self) -> str | None: | ||
DouweM marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return self._id | ||
|
||
async def __aenter__(self) -> Self: | ||
async with self._enter_lock: | ||
if self._running_count == 0 and self.client: | ||
self._exit_stack = AsyncExitStack() | ||
await self._exit_stack.enter_async_context(self.client) | ||
|
||
self._running_count += 1 | ||
|
||
return self | ||
|
||
async def __aexit__(self, *args: Any) -> bool | None: | ||
async with self._enter_lock: | ||
self._running_count -= 1 | ||
if self._running_count == 0 and self._exit_stack: | ||
await self._exit_stack.aclose() | ||
self._exit_stack = None | ||
|
||
return None | ||
|
||
async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]: | ||
async with self: | ||
mcp_tools: list[MCPTool] = await self.client.list_tools() | ||
|
||
return { | ||
tool.name: _convert_mcp_tool_to_toolset_tool(toolset=self, mcp_tool=tool, retries=self.max_retries) | ||
for tool in mcp_tools | ||
} | ||
|
||
async def call_tool( | ||
self, name: str, tool_args: dict[str, Any], ctx: RunContext[AgentDepsT], tool: ToolsetTool[AgentDepsT] | ||
) -> Any: | ||
async with self: | ||
try: | ||
call_tool_result: CallToolResult = await self.client.call_tool(name=name, arguments=tool_args) | ||
except ToolError as e: | ||
if self.tool_error_behavior == 'model_retry': | ||
raise ModelRetry(message=str(e)) from e | ||
else: | ||
raise e | ||
|
||
# If we have structured content, return that | ||
if call_tool_result.structured_content: | ||
strawgate marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return call_tool_result.structured_content | ||
|
||
# Otherwise, return the content | ||
return _map_fastmcp_tool_results(parts=call_tool_result.content) | ||
|
||
|
||
def _convert_mcp_tool_to_toolset_tool( | ||
toolset: FastMCPToolset[AgentDepsT], | ||
mcp_tool: MCPTool, | ||
retries: int, | ||
) -> ToolsetTool[AgentDepsT]: | ||
"""Convert an MCP tool to a toolset tool.""" | ||
return ToolsetTool[AgentDepsT]( | ||
tool_def=ToolDefinition( | ||
strawgate marked this conversation as resolved.
Show resolved
Hide resolved
|
||
name=mcp_tool.name, | ||
description=mcp_tool.description, | ||
parameters_json_schema=mcp_tool.inputSchema, | ||
metadata={ | ||
'meta': mcp_tool.meta, | ||
'annotations': mcp_tool.annotations.model_dump() if mcp_tool.annotations else None, | ||
'output_schema': mcp_tool.outputSchema or None, | ||
}, | ||
), | ||
toolset=toolset, | ||
max_retries=retries, | ||
args_validator=TOOL_SCHEMA_VALIDATOR, | ||
) | ||
|
||
|
||
def _map_fastmcp_tool_results(parts: list[ContentBlock]) -> list[FastMCPToolResult] | FastMCPToolResult: | ||
"""Map FastMCP tool results to toolset tool results.""" | ||
mapped_results = [_map_fastmcp_tool_result(part) for part in parts] | ||
|
||
if len(mapped_results) == 1: | ||
return mapped_results[0] | ||
|
||
return mapped_results | ||
|
||
|
||
def _map_fastmcp_tool_result(part: ContentBlock) -> FastMCPToolResult: | ||
if isinstance(part, TextContent): | ||
return part.text | ||
|
||
if isinstance(part, ImageContent | AudioContent): | ||
return messages.BinaryContent(data=base64.b64decode(part.data), media_type=part.mimeType) | ||
|
||
msg = f'Unsupported/Unknown content block type: {type(part)}' # pragma: no cover | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What are the other types? In There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Probably just those two, will take a look |
||
raise ValueError(msg) # pragma: no cover) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -46,7 +46,7 @@ requires-python = ">=3.10" | |
|
||
[tool.hatch.metadata.hooks.uv-dynamic-versioning] | ||
dependencies = [ | ||
"pydantic-ai-slim[openai,vertexai,google,groq,anthropic,mistral,cohere,bedrock,huggingface,cli,mcp,evals,ag-ui,retries,temporal,logfire]=={{ version }}", | ||
"pydantic-ai-slim[openai,vertexai,google,groq,anthropic,mistral,cohere,bedrock,huggingface,cli,mcp,fastmcp,evals,ag-ui,retries,temporal,logfire]=={{ version }}", | ||
] | ||
|
||
[tool.hatch.metadata.hooks.uv-dynamic-versioning.optional-dependencies] | ||
|
@@ -90,6 +90,7 @@ dev = [ | |
"coverage[toml]>=7.10.3", | ||
"dirty-equals>=0.9.0", | ||
"duckduckgo-search>=7.0.0", | ||
"fastmcp>=2.12.0", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think we need it here if it's already coming in through |
||
"inline-snapshot>=0.19.3", | ||
"pytest>=8.3.3", | ||
"pytest-examples>=0.0.18", | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note to self: before merging this, see if it makes sense to move this to a separate doc that can be listed in the "MCP" section in the sidebar.