Skip to content

Commit a1c90b0

Browse files
feat: add handle_tool_error and handle_validation_error to load_mcp_tools
1 parent 219b60c commit a1c90b0

File tree

2 files changed

+36
-4
lines changed

2 files changed

+36
-4
lines changed

langchain_mcp_adapters/client.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,16 @@
55
"""
66

77
import asyncio
8-
from collections.abc import AsyncIterator
8+
from collections.abc import AsyncIterator, Callable
99
from contextlib import asynccontextmanager
1010
from types import TracebackType
1111
from typing import Any
1212

1313
from langchain_core.documents.base import Blob
1414
from langchain_core.messages import AIMessage, HumanMessage
15-
from langchain_core.tools import BaseTool
15+
from langchain_core.tools import BaseTool, ToolException
1616
from mcp import ClientSession
17+
from pydantic import ValidationError
1718

1819
from langchain_mcp_adapters.callbacks import CallbackContext, Callbacks
1920
from langchain_mcp_adapters.prompts import load_mcp_prompt
@@ -138,12 +139,22 @@ async def session(
138139
await session.initialize()
139140
yield session
140141

141-
async def get_tools(self, *, server_name: str | None = None) -> list[BaseTool]:
142+
async def get_tools(
143+
self,
144+
*,
145+
server_name: str | None = None,
146+
handle_tool_error: bool | str | Callable[[ToolException], str] | None = False,
147+
handle_validation_error: (
148+
bool | str | Callable[[ValidationError], str] | None
149+
) = False,
150+
) -> list[BaseTool]:
142151
"""Get a list of all tools from all connected servers.
143152
144153
Args:
145154
server_name: Optional name of the server to get tools from.
146155
If None, all tools from all servers will be returned (default).
156+
handle_tool_error: Optional error handler for tool execution errors.
157+
handle_validation_error: Optional error handler for validation errors.
147158
148159
NOTE: a new session will be created for each tool call
149160
@@ -163,6 +174,8 @@ async def get_tools(self, *, server_name: str | None = None) -> list[BaseTool]:
163174
connection=self.connections[server_name],
164175
callbacks=self.callbacks,
165176
server_name=server_name,
177+
handle_tool_error=handle_tool_error,
178+
handle_validation_error=handle_validation_error,
166179
)
167180

168181
all_tools: list[BaseTool] = []
@@ -174,6 +187,8 @@ async def get_tools(self, *, server_name: str | None = None) -> list[BaseTool]:
174187
connection=connection,
175188
callbacks=self.callbacks,
176189
server_name=name,
190+
handle_tool_error=handle_tool_error,
191+
handle_validation_error=handle_validation_error,
177192
)
178193
)
179194
load_mcp_tool_tasks.append(load_mcp_tool_task)

langchain_mcp_adapters/tools.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
tools, handle tool execution, and manage tool conversion between the two formats.
55
"""
66

7+
from collections.abc import Callable
78
from typing import Any, cast, get_args
89

910
from langchain_core.tools import (
@@ -25,7 +26,7 @@
2526
TextContent,
2627
)
2728
from mcp.types import Tool as MCPTool
28-
from pydantic import BaseModel, create_model
29+
from pydantic import BaseModel, ValidationError, create_model
2930

3031
from langchain_mcp_adapters.callbacks import CallbackContext, Callbacks, _MCPCallbacks
3132
from langchain_mcp_adapters.sessions import Connection, create_session
@@ -112,6 +113,10 @@ def convert_mcp_tool_to_langchain_tool(
112113
connection: Connection | None = None,
113114
callbacks: Callbacks | None = None,
114115
server_name: str | None = None,
116+
handle_tool_error: bool | str | Callable[[ToolException], str] | None = False,
117+
handle_validation_error: (
118+
bool | str | Callable[[ValidationError], str] | None
119+
) = False,
115120
) -> BaseTool:
116121
"""Convert an MCP tool to a LangChain tool.
117122
@@ -124,6 +129,8 @@ def convert_mcp_tool_to_langchain_tool(
124129
if a `session` is not provided
125130
callbacks: Optional callbacks for handling notifications and events
126131
server_name: Name of the server this tool belongs to
132+
handle_tool_error: Optional error handler for tool execution errors.
133+
handle_validation_error: Optional error handler for validation errors.
127134
128135
Returns:
129136
a LangChain tool
@@ -192,6 +199,8 @@ async def call_tool(
192199
coroutine=call_tool,
193200
response_format="content_and_artifact",
194201
metadata=metadata,
202+
handle_tool_error=handle_tool_error,
203+
handle_validation_error=handle_validation_error,
195204
)
196205

197206

@@ -201,6 +210,10 @@ async def load_mcp_tools(
201210
connection: Connection | None = None,
202211
callbacks: Callbacks | None = None,
203212
server_name: str | None = None,
213+
handle_tool_error: bool | str | Callable[[ToolException], str] | None = False,
214+
handle_validation_error: (
215+
bool | str | Callable[[ValidationError], str] | None
216+
) = False,
204217
) -> list[BaseTool]:
205218
"""Load all available MCP tools and convert them to LangChain tools.
206219
@@ -209,6 +222,8 @@ async def load_mcp_tools(
209222
connection: Connection config to create a new session if session is None.
210223
callbacks: Optional callbacks for handling notifications and events.
211224
server_name: Name of the server these tools belong to.
225+
handle_tool_error: Optional error handler for tool execution errors.
226+
handle_validation_error: Optional error handler for validation errors.
212227
213228
Returns:
214229
List of LangChain tools. Tool annotations are returned as part
@@ -247,6 +262,8 @@ async def load_mcp_tools(
247262
connection=connection,
248263
callbacks=callbacks,
249264
server_name=server_name,
265+
handle_tool_error=handle_tool_error,
266+
handle_validation_error=handle_validation_error,
250267
)
251268
for tool in tools
252269
]

0 commit comments

Comments
 (0)