diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 7bb8821f7..5239a249b 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -8,7 +8,7 @@ import mcp.types as types from mcp.shared.context import RequestContext from mcp.shared.message import SessionMessage -from mcp.shared.session import BaseSession, RequestResponder +from mcp.shared.session import BaseSession, ProgressFnT, RequestResponder from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS DEFAULT_CLIENT_INFO = types.Implementation(name="mcp", version="0.1.0") @@ -259,9 +259,9 @@ async def call_tool( name: str, arguments: dict[str, Any] | None = None, read_timeout_seconds: timedelta | None = None, + progress_callback: ProgressFnT | None = None, ) -> types.CallToolResult: """Send a tools/call request.""" - return await self.send_request( types.ClientRequest( types.CallToolRequest( @@ -271,6 +271,7 @@ async def call_tool( ), types.CallToolResult, request_read_timeout_seconds=read_timeout_seconds, + progress_callback=progress_callback, ) async def list_prompts(self) -> types.ListPromptsResult: diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index cce8b1184..449b489f5 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -3,7 +3,7 @@ from contextlib import AsyncExitStack from datetime import timedelta from types import TracebackType -from typing import Any, Generic, TypeVar +from typing import Any, Generic, Protocol, TypeVar import anyio import httpx @@ -24,6 +24,9 @@ JSONRPCNotification, JSONRPCRequest, JSONRPCResponse, + ProgressNotification, + ProgressNotificationParams, + ProgressToken, RequestParams, ServerNotification, ServerRequest, @@ -39,6 +42,14 @@ "ReceiveNotificationT", ClientNotification, ServerNotification ) + +class ProgressFnT(Protocol): + async def __call__( + self, + params: ProgressNotificationParams, + ) -> None: ... + + RequestId = str | int @@ -168,7 +179,9 @@ class BaseSession( RequestId, MemoryObjectSendStream[JSONRPCResponse | JSONRPCError] ] _request_id: int + _progress_id: int _in_flight: dict[RequestId, RequestResponder[ReceiveRequestT, SendResultT]] + _in_progress: dict[ProgressToken, ProgressFnT] def __init__( self, @@ -187,6 +200,8 @@ def __init__( self._receive_notification_type = receive_notification_type self._session_read_timeout_seconds = read_timeout_seconds self._in_flight = {} + self._progress_id = 0 + self._in_progress = {} self._exit_stack = AsyncExitStack() async def __aenter__(self) -> Self: @@ -214,12 +229,16 @@ async def send_request( result_type: type[ReceiveResultT], request_read_timeout_seconds: timedelta | None = None, metadata: MessageMetadata = None, + progress_callback: ProgressFnT | None = None, ) -> ReceiveResultT: """ Sends a request and wait for a response. Raises an McpError if the response contains an error. If a request read timeout is provided, it will take precedence over the session read timeout. + If progress_callback is provided any progress notifications sent from the + receiver will be passed back to the sender + Do not use this method to emit notifications! Use send_notification() instead. """ @@ -227,6 +246,38 @@ async def send_request( request_id = self._request_id self._request_id = request_id + 1 + progress_id = None + send_request = None + + if progress_callback is not None: + if request.root.params is None: + progress_id = self._progress_id + new_params = RequestParams( + _meta=RequestParams.Meta(progressToken=progress_id) + ) + else: + if ( + request.root.params.meta is None + or request.root.params.meta.progressToken is None + ): + progress_id = self._progress_id + new_params = request.root.params.model_copy( + update={"meta": RequestParams.Meta(progressToken=progress_id)} + ) + else: + raise ValueError( + "Request has progressToken and progress_callback provided " + "via send_request method only one or other is supported" + ) + + new_root = request.root.model_copy(update={"params": new_params}) + send_request = request.model_copy(update={"root": new_root}) + self._progress_id = progress_id + 1 + self._in_progress[progress_id] = progress_callback + + if send_request is None: + send_request = request + response_stream, response_stream_reader = anyio.create_memory_object_stream[ JSONRPCResponse | JSONRPCError ](1) @@ -236,11 +287,11 @@ async def send_request( jsonrpc_request = JSONRPCRequest( jsonrpc="2.0", id=request_id, - **request.model_dump(by_alias=True, mode="json", exclude_none=True), + **send_request.model_dump( + by_alias=True, mode="json", exclude_none=True + ), ) - # TODO: Support progress callbacks - await self._write_stream.send( SessionMessage( message=JSONRPCMessage(jsonrpc_request), metadata=metadata @@ -276,6 +327,8 @@ async def send_request( finally: self._response_streams.pop(request_id, None) + if progress_id is not None: + self._in_progress.pop(progress_id, None) await response_stream.aclose() await response_stream_reader.aclose() @@ -364,6 +417,20 @@ async def _receive_loop(self) -> None: if cancelled_id in self._in_flight: await self._in_flight[cancelled_id].cancel() else: + match notification.root: + case ProgressNotification(params=params): + if params.progressToken in self._in_progress: + progress_callback = self._in_progress[ + params.progressToken + ] + await progress_callback(params) + else: + logging.warning( + "Unknown progress token %s", + params.progressToken, + ) + case _: + pass await self._received_notification(notification) await self._handle_incoming(notification) except Exception as e: diff --git a/tests/client/test_session.py b/tests/client/test_session.py index 6abcf70cb..08f228149 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -250,3 +250,106 @@ async def mock_server(): # Assert that the default client info was sent assert received_client_info == DEFAULT_CLIENT_INFO + + +@pytest.mark.anyio +async def test_client_session_progress(): + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ + SessionMessage + ](1) + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[ + SessionMessage + ](1) + + send_notification_count = 10 + + async def mock_server(): + session_message = await client_to_server_receive.receive() + jsonrpc_request = session_message.message + assert isinstance(jsonrpc_request.root, JSONRPCRequest) + request = ClientRequest.model_validate( + jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) + ) + assert isinstance(request.root, types.CallToolRequest) + assert request.root.params.meta + assert request.root.params.meta.progressToken is not None + + progress_token = request.root.params.meta.progressToken + notifications = [ + types.ServerNotification( + root=types.ProgressNotification( + params=types.ProgressNotificationParams( + progressToken=progress_token, progress=i + ), + method="notifications/progress", + ) + ) + for i in range(send_notification_count) + ] + result = ServerResult(types.CallToolResult(content=[])) + + async with server_to_client_send: + for notification in notifications: + await server_to_client_send.send( + SessionMessage( + JSONRPCMessage( + types.JSONRPCNotification( + jsonrpc="2.0", + **notification.model_dump( + by_alias=True, mode="json", exclude_none=True + ), + ) + ) + ) + ) + await server_to_client_send.send( + SessionMessage( + JSONRPCMessage( + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.root.id, + result=result.model_dump( + by_alias=True, mode="json", exclude_none=True + ), + ) + ) + ) + ) + + # Create a message handler to catch exceptions + async def message_handler( + message: RequestResponder[types.ServerRequest, types.ClientResult] + | types.ServerNotification + | Exception, + ) -> None: + if isinstance(message, Exception): + raise message + + progress_count = 0 + + async with ( + ClientSession( + server_to_client_receive, + client_to_server_send, + message_handler=message_handler, + ) as session, + anyio.create_task_group() as tg, + client_to_server_send, + client_to_server_receive, + server_to_client_send, + server_to_client_receive, + ): + tg.start_soon(mock_server) + + async def progress_callback(params: types.ProgressNotificationParams): + nonlocal progress_count + progress_count = progress_count + 1 + + result = await session.call_tool( + "tool_with_progress", progress_callback=progress_callback + ) + + # Assert the result + assert isinstance(result, types.CallToolResult) + assert len(result.content) == 0 + assert progress_count == send_notification_count