diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 13d4fd91f..3d9172260 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -453,10 +453,15 @@ async def run( logger.debug(f"Received message: {message}") match message: - case RequestResponder(request=types.ClientRequest(root=req)): - await self._handle_request( - message, req, session, raise_exceptions - ) + case ( + RequestResponder( + request=types.ClientRequest(root=req) + ) as responder + ): + with responder: + await self._handle_request( + message, req, session, raise_exceptions + ) case types.ClientNotification(root=notify): await self._handle_notification(notify) diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 5e114ecf5..ddfa90903 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -1,6 +1,6 @@ from contextlib import AbstractAsyncContextManager from datetime import timedelta -from typing import Generic, TypeVar +from typing import Any, Callable, Generic, TypeVar import anyio import anyio.lowlevel @@ -10,6 +10,7 @@ from mcp.shared.exceptions import McpError from mcp.types import ( + CancelledNotification, ClientNotification, ClientRequest, ClientResult, @@ -44,21 +45,55 @@ def __init__( request_meta: RequestParams.Meta | None, request: ReceiveRequestT, session: "BaseSession", + on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any], ) -> None: self.request_id = request_id self.request_meta = request_meta self.request = request self._session = session - self._responded = False + self._completed = False + self._cancel_scope = anyio.CancelScope() + self._on_complete = on_complete + + def __enter__(self) -> "RequestResponder[ReceiveRequestT, SendResultT]": + self._cancel_scope.__enter__() + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + try: + if self._completed: + self._on_complete(self) + finally: + self._cancel_scope.__exit__(exc_type, exc_val, exc_tb) async def respond(self, response: SendResultT | ErrorData) -> None: - assert not self._responded, "Request already responded to" - self._responded = True + assert not self._completed, "Request already responded to" + if not self.cancelled: + self._completed = True + + await self._session._send_response( + request_id=self.request_id, response=response + ) + + async def cancel(self) -> None: + """Cancel this request and mark it as completed.""" + self._cancel_scope.cancel() + self._completed = True # Mark as completed so it's removed from in_flight + # Send an error response to indicate cancellation await self._session._send_response( - request_id=self.request_id, response=response + request_id=self.request_id, + response=ErrorData(code=0, message="Request cancelled", data=None), ) + @property + def in_flight(self) -> bool: + return not self._completed and not self.cancelled + + @property + def cancelled(self) -> bool: + return self._cancel_scope is not None and self._cancel_scope.cancel_called + class BaseSession( AbstractAsyncContextManager, @@ -82,6 +117,7 @@ class BaseSession( RequestId, MemoryObjectSendStream[JSONRPCResponse | JSONRPCError] ] _request_id: int + _in_flight: dict[RequestId, RequestResponder[ReceiveRequestT, SendResultT]] def __init__( self, @@ -99,6 +135,7 @@ def __init__( self._receive_request_type = receive_request_type self._receive_notification_type = receive_notification_type self._read_timeout_seconds = read_timeout_seconds + self._in_flight = {} self._incoming_message_stream_writer, self._incoming_message_stream_reader = ( anyio.create_memory_object_stream[ @@ -219,6 +256,7 @@ async def _receive_loop(self) -> None: by_alias=True, mode="json", exclude_none=True ) ) + responder = RequestResponder( request_id=message.root.id, request_meta=validated_request.root.params.meta @@ -226,20 +264,28 @@ async def _receive_loop(self) -> None: else None, request=validated_request, session=self, + on_complete=lambda r: self._in_flight.pop(r.request_id, None), ) + self._in_flight[responder.request_id] = responder await self._received_request(responder) - if not responder._responded: + if not responder._completed: await self._incoming_message_stream_writer.send(responder) + elif isinstance(message.root, JSONRPCNotification): notification = self._receive_notification_type.model_validate( message.root.model_dump( by_alias=True, mode="json", exclude_none=True ) ) - - await self._received_notification(notification) - await self._incoming_message_stream_writer.send(notification) + # Handle cancellation notifications + if isinstance(notification.root, CancelledNotification): + cancelled_id = notification.root.params.requestId + if cancelled_id in self._in_flight: + await self._in_flight[cancelled_id].cancel() + else: + await self._received_notification(notification) + await self._incoming_message_stream_writer.send(notification) else: # Response or error stream = self._response_streams.pop(message.root.id, None) if stream: diff --git a/tests/shared/test_session.py b/tests/shared/test_session.py new file mode 100644 index 000000000..65cf061ec --- /dev/null +++ b/tests/shared/test_session.py @@ -0,0 +1,126 @@ +from typing import AsyncGenerator + +import anyio +import pytest + +import mcp.types as types +from mcp.client.session import ClientSession +from mcp.server.lowlevel.server import Server +from mcp.shared.exceptions import McpError +from mcp.shared.memory import create_connected_server_and_client_session +from mcp.types import ( + CancelledNotification, + CancelledNotificationParams, + ClientNotification, + ClientRequest, + EmptyResult, +) + + +@pytest.fixture +def mcp_server() -> Server: + return Server(name="test server") + + +@pytest.fixture +async def client_connected_to_server( + mcp_server: Server, +) -> AsyncGenerator[ClientSession, None]: + async with create_connected_server_and_client_session(mcp_server) as client_session: + yield client_session + + +@pytest.mark.anyio +async def test_in_flight_requests_cleared_after_completion( + client_connected_to_server: ClientSession, +): + """Verify that _in_flight is empty after all requests complete.""" + # Send a request and wait for response + response = await client_connected_to_server.send_ping() + assert isinstance(response, EmptyResult) + + # Verify _in_flight is empty + assert len(client_connected_to_server._in_flight) == 0 + + +@pytest.mark.anyio +async def test_request_cancellation(): + """Test that requests can be cancelled while in-flight.""" + # The tool is already registered in the fixture + + ev_tool_called = anyio.Event() + ev_cancelled = anyio.Event() + request_id = None + + # Start the request in a separate task so we can cancel it + def make_server() -> Server: + server = Server(name="TestSessionServer") + + # Register the tool handler + @server.call_tool() + async def handle_call_tool(name: str, arguments: dict | None) -> list: + nonlocal request_id, ev_tool_called + if name == "slow_tool": + request_id = server.request_context.request_id + ev_tool_called.set() + await anyio.sleep(10) # Long enough to ensure we can cancel + return [] + raise ValueError(f"Unknown tool: {name}") + + # Register the tool so it shows up in list_tools + @server.list_tools() + async def handle_list_tools() -> list[types.Tool]: + return [ + types.Tool( + name="slow_tool", + description="A slow tool that takes 10 seconds to complete", + inputSchema={}, + ) + ] + + return server + + async def make_request(client_session): + nonlocal ev_cancelled + try: + await client_session.send_request( + ClientRequest( + types.CallToolRequest( + method="tools/call", + params=types.CallToolRequestParams( + name="slow_tool", arguments={} + ), + ) + ), + types.CallToolResult, + ) + pytest.fail("Request should have been cancelled") + except McpError as e: + # Expected - request was cancelled + assert "Request cancelled" in str(e) + ev_cancelled.set() + + async with create_connected_server_and_client_session( + make_server() + ) as client_session: + async with anyio.create_task_group() as tg: + tg.start_soon(make_request, client_session) + + # Wait for the request to be in-flight + with anyio.fail_after(1): # Timeout after 1 second + await ev_tool_called.wait() + + # Send cancellation notification + assert request_id is not None + await client_session.send_notification( + ClientNotification( + CancelledNotification( + method="notifications/cancelled", + params=CancelledNotificationParams(requestId=request_id), + ) + ) + ) + + # Give cancellation time to process + with anyio.fail_after(1): + await ev_cancelled.wait()