From 888bdd3c34313061dd1d1d69ad565569e3ca0f9b Mon Sep 17 00:00:00 2001 From: David Soria Parra Date: Fri, 3 Jan 2025 23:44:43 +0000 Subject: [PATCH 1/4] tests for issue 88 --- tests/issues/test_88_random_error.py | 100 +++++++++++++++++++++++++++ 1 file changed, 100 insertions(+) create mode 100644 tests/issues/test_88_random_error.py diff --git a/tests/issues/test_88_random_error.py b/tests/issues/test_88_random_error.py new file mode 100644 index 0000000..8b979ab --- /dev/null +++ b/tests/issues/test_88_random_error.py @@ -0,0 +1,100 @@ +"""Test to reproduce issue #88: Random error thrown on response.""" + +from datetime import timedelta +from pathlib import Path +from typing import Sequence + +import anyio +import pytest + +from mcp.client.session import ClientSession +from mcp.server.lowlevel import Server +from mcp.shared.exceptions import McpError +from mcp.types import ( + EmbeddedResource, + ImageContent, + TextContent, +) + + +@pytest.mark.anyio +async def test_notification_validation_error(tmp_path: Path): + """Test that timeouts are handled gracefully and don't break the server. + + This test verifies that when a client request times out: + 1. The server task stays alive + 2. The server can still handle new requests + 3. The client can make new requests + 4. No resources are leaked + """ + + server = Server(name="test") + request_count = 0 + slow_request_complete = False + + @server.call_tool() + async def slow_tool( + name: str, arg + ) -> Sequence[TextContent | ImageContent | EmbeddedResource]: + nonlocal request_count, slow_request_complete + request_count += 1 + + if name == "slow": + # Long enough to ensure timeout + await anyio.sleep(0.2) + slow_request_complete = True + return [TextContent(type="text", text=f"slow {request_count}")] + elif name == "fast": + # Fast enough to complete before timeout + await anyio.sleep(0.01) + return [TextContent(type="text", text=f"fast {request_count}")] + return [TextContent(type="text", text=f"unknown {request_count}")] + + async def server_handler(read_stream, write_stream): + await server.run( + read_stream, + write_stream, + server.create_initialization_options(), + raise_exceptions=True, + ) + + async def client(read_stream, write_stream): + # Use a timeout that's: + # - Long enough for fast operations (>10ms) + # - Short enough for slow operations (<200ms) + # - Not too short to avoid flakiness + async with ClientSession( + read_stream, write_stream, read_timeout_seconds=timedelta(milliseconds=50) + ) as session: + await session.initialize() + + # First call should work (fast operation) + result = await session.call_tool("fast") + assert result.content == [TextContent(type="text", text="fast 1")] + assert not slow_request_complete + + # Second call should timeout (slow operation) + with pytest.raises(McpError) as exc_info: + await session.call_tool("slow") + assert "Timed out while waiting" in str(exc_info.value) + + # Wait for slow request to complete in the background + await anyio.sleep(0.3) + assert slow_request_complete + + # Third call should work (fast operation), + # proving server is still responsive + result = await session.call_tool("fast") + assert result.content == [TextContent(type="text", text="fast 3")] + + # Run server and client in separate task groups to avoid cancellation + server_writer, server_reader = anyio.create_memory_object_stream(1) + client_writer, client_reader = anyio.create_memory_object_stream(1) + + async with anyio.create_task_group() as tg: + tg.start_soon(server_handler, server_reader, client_writer) + # Wait for server to start and initialize + await anyio.sleep(0.1) + # Run client in a separate task to avoid cancellation + async with anyio.create_task_group() as client_tg: + client_tg.start_soon(client, client_reader, server_writer) From 827e494df493024bc63020920ba6d006127227cf Mon Sep 17 00:00:00 2001 From: David Soria Parra Date: Thu, 23 Jan 2025 20:10:02 +0000 Subject: [PATCH 2/4] feat: add request cancellation and in-flight request tracking This commit adds support for request cancellation and tracking of in-flight requests in the MCP protocol implementation. The key architectural changes are: 1. Request Lifecycle Management: - Added _in_flight dictionary to BaseSession to track active requests - Requests are tracked from receipt until completion/cancellation - Added proper cleanup via on_complete callback 2. Cancellation Support: - Added CancelledNotification handling in _receive_loop - Implemented cancel() method in RequestResponder - Uses anyio.CancelScope for robust cancellation - Sends error response on cancellation 3. Request Context: - Added request_ctx ContextVar for request context - Ensures proper cleanup after request handling - Maintains request state throughout lifecycle 4. Error Handling: - Improved error propagation for cancelled requests - Added proper cleanup of cancelled requests - Maintains consistency of in-flight tracking This change enables clients to cancel long-running requests and servers to properly clean up resources when requests are cancelled. Github-Issue:#88 --- src/mcp/server/lowlevel/server.py | 13 ++- src/mcp/shared/session.py | 64 ++++++++++++--- tests/shared/test_session.py | 126 ++++++++++++++++++++++++++++++ 3 files changed, 190 insertions(+), 13 deletions(-) create mode 100644 tests/shared/test_session.py diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 13d4fd9..3d91722 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -453,10 +453,15 @@ async def run( logger.debug(f"Received message: {message}") match message: - case RequestResponder(request=types.ClientRequest(root=req)): - await self._handle_request( - message, req, session, raise_exceptions - ) + case ( + RequestResponder( + request=types.ClientRequest(root=req) + ) as responder + ): + with responder: + await self._handle_request( + message, req, session, raise_exceptions + ) case types.ClientNotification(root=notify): await self._handle_notification(notify) diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 5e114ec..ddfa909 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -1,6 +1,6 @@ from contextlib import AbstractAsyncContextManager from datetime import timedelta -from typing import Generic, TypeVar +from typing import Any, Callable, Generic, TypeVar import anyio import anyio.lowlevel @@ -10,6 +10,7 @@ from mcp.shared.exceptions import McpError from mcp.types import ( + CancelledNotification, ClientNotification, ClientRequest, ClientResult, @@ -44,21 +45,55 @@ def __init__( request_meta: RequestParams.Meta | None, request: ReceiveRequestT, session: "BaseSession", + on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any], ) -> None: self.request_id = request_id self.request_meta = request_meta self.request = request self._session = session - self._responded = False + self._completed = False + self._cancel_scope = anyio.CancelScope() + self._on_complete = on_complete + + def __enter__(self) -> "RequestResponder[ReceiveRequestT, SendResultT]": + self._cancel_scope.__enter__() + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + try: + if self._completed: + self._on_complete(self) + finally: + self._cancel_scope.__exit__(exc_type, exc_val, exc_tb) async def respond(self, response: SendResultT | ErrorData) -> None: - assert not self._responded, "Request already responded to" - self._responded = True + assert not self._completed, "Request already responded to" + if not self.cancelled: + self._completed = True + + await self._session._send_response( + request_id=self.request_id, response=response + ) + + async def cancel(self) -> None: + """Cancel this request and mark it as completed.""" + self._cancel_scope.cancel() + self._completed = True # Mark as completed so it's removed from in_flight + # Send an error response to indicate cancellation await self._session._send_response( - request_id=self.request_id, response=response + request_id=self.request_id, + response=ErrorData(code=0, message="Request cancelled", data=None), ) + @property + def in_flight(self) -> bool: + return not self._completed and not self.cancelled + + @property + def cancelled(self) -> bool: + return self._cancel_scope is not None and self._cancel_scope.cancel_called + class BaseSession( AbstractAsyncContextManager, @@ -82,6 +117,7 @@ class BaseSession( RequestId, MemoryObjectSendStream[JSONRPCResponse | JSONRPCError] ] _request_id: int + _in_flight: dict[RequestId, RequestResponder[ReceiveRequestT, SendResultT]] def __init__( self, @@ -99,6 +135,7 @@ def __init__( self._receive_request_type = receive_request_type self._receive_notification_type = receive_notification_type self._read_timeout_seconds = read_timeout_seconds + self._in_flight = {} self._incoming_message_stream_writer, self._incoming_message_stream_reader = ( anyio.create_memory_object_stream[ @@ -219,6 +256,7 @@ async def _receive_loop(self) -> None: by_alias=True, mode="json", exclude_none=True ) ) + responder = RequestResponder( request_id=message.root.id, request_meta=validated_request.root.params.meta @@ -226,20 +264,28 @@ async def _receive_loop(self) -> None: else None, request=validated_request, session=self, + on_complete=lambda r: self._in_flight.pop(r.request_id, None), ) + self._in_flight[responder.request_id] = responder await self._received_request(responder) - if not responder._responded: + if not responder._completed: await self._incoming_message_stream_writer.send(responder) + elif isinstance(message.root, JSONRPCNotification): notification = self._receive_notification_type.model_validate( message.root.model_dump( by_alias=True, mode="json", exclude_none=True ) ) - - await self._received_notification(notification) - await self._incoming_message_stream_writer.send(notification) + # Handle cancellation notifications + if isinstance(notification.root, CancelledNotification): + cancelled_id = notification.root.params.requestId + if cancelled_id in self._in_flight: + await self._in_flight[cancelled_id].cancel() + else: + await self._received_notification(notification) + await self._incoming_message_stream_writer.send(notification) else: # Response or error stream = self._response_streams.pop(message.root.id, None) if stream: diff --git a/tests/shared/test_session.py b/tests/shared/test_session.py new file mode 100644 index 0000000..65cf061 --- /dev/null +++ b/tests/shared/test_session.py @@ -0,0 +1,126 @@ +from typing import AsyncGenerator + +import anyio +import pytest + +import mcp.types as types +from mcp.client.session import ClientSession +from mcp.server.lowlevel.server import Server +from mcp.shared.exceptions import McpError +from mcp.shared.memory import create_connected_server_and_client_session +from mcp.types import ( + CancelledNotification, + CancelledNotificationParams, + ClientNotification, + ClientRequest, + EmptyResult, +) + + +@pytest.fixture +def mcp_server() -> Server: + return Server(name="test server") + + +@pytest.fixture +async def client_connected_to_server( + mcp_server: Server, +) -> AsyncGenerator[ClientSession, None]: + async with create_connected_server_and_client_session(mcp_server) as client_session: + yield client_session + + +@pytest.mark.anyio +async def test_in_flight_requests_cleared_after_completion( + client_connected_to_server: ClientSession, +): + """Verify that _in_flight is empty after all requests complete.""" + # Send a request and wait for response + response = await client_connected_to_server.send_ping() + assert isinstance(response, EmptyResult) + + # Verify _in_flight is empty + assert len(client_connected_to_server._in_flight) == 0 + + +@pytest.mark.anyio +async def test_request_cancellation(): + """Test that requests can be cancelled while in-flight.""" + # The tool is already registered in the fixture + + ev_tool_called = anyio.Event() + ev_cancelled = anyio.Event() + request_id = None + + # Start the request in a separate task so we can cancel it + def make_server() -> Server: + server = Server(name="TestSessionServer") + + # Register the tool handler + @server.call_tool() + async def handle_call_tool(name: str, arguments: dict | None) -> list: + nonlocal request_id, ev_tool_called + if name == "slow_tool": + request_id = server.request_context.request_id + ev_tool_called.set() + await anyio.sleep(10) # Long enough to ensure we can cancel + return [] + raise ValueError(f"Unknown tool: {name}") + + # Register the tool so it shows up in list_tools + @server.list_tools() + async def handle_list_tools() -> list[types.Tool]: + return [ + types.Tool( + name="slow_tool", + description="A slow tool that takes 10 seconds to complete", + inputSchema={}, + ) + ] + + return server + + async def make_request(client_session): + nonlocal ev_cancelled + try: + await client_session.send_request( + ClientRequest( + types.CallToolRequest( + method="tools/call", + params=types.CallToolRequestParams( + name="slow_tool", arguments={} + ), + ) + ), + types.CallToolResult, + ) + pytest.fail("Request should have been cancelled") + except McpError as e: + # Expected - request was cancelled + assert "Request cancelled" in str(e) + ev_cancelled.set() + + async with create_connected_server_and_client_session( + make_server() + ) as client_session: + async with anyio.create_task_group() as tg: + tg.start_soon(make_request, client_session) + + # Wait for the request to be in-flight + with anyio.fail_after(1): # Timeout after 1 second + await ev_tool_called.wait() + + # Send cancellation notification + assert request_id is not None + await client_session.send_notification( + ClientNotification( + CancelledNotification( + method="notifications/cancelled", + params=CancelledNotificationParams(requestId=request_id), + ) + ) + ) + + # Give cancellation time to process + with anyio.fail_after(1): + await ev_cancelled.wait() From 08cfbe522aae48365f74147b20636e8bd715174d Mon Sep 17 00:00:00 2001 From: David Soria Parra Date: Tue, 4 Feb 2025 13:58:44 +0000 Subject: [PATCH 3/4] fix: improve error handling and request cancellation for issue #88 --- src/mcp/shared/session.py | 34 ++++++++++++++++++---------- tests/issues/test_88_random_error.py | 27 +++++++++++++++------- 2 files changed, 41 insertions(+), 20 deletions(-) diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index ddfa909..e21bcbc 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -1,3 +1,4 @@ +import logging from contextlib import AbstractAsyncContextManager from datetime import timedelta from typing import Any, Callable, Generic, TypeVar @@ -273,19 +274,28 @@ async def _receive_loop(self) -> None: await self._incoming_message_stream_writer.send(responder) elif isinstance(message.root, JSONRPCNotification): - notification = self._receive_notification_type.model_validate( - message.root.model_dump( - by_alias=True, mode="json", exclude_none=True + try: + notification = self._receive_notification_type.model_validate( + message.root.model_dump( + by_alias=True, mode="json", exclude_none=True + ) + ) + # Handle cancellation notifications + if isinstance(notification.root, CancelledNotification): + cancelled_id = notification.root.params.requestId + if cancelled_id in self._in_flight: + await self._in_flight[cancelled_id].cancel() + else: + await self._received_notification(notification) + await self._incoming_message_stream_writer.send( + notification + ) + except Exception as e: + # For other validation errors, log and continue + logging.warning( + f"Failed to validate notification: {e}. " + f"Message was: {message.root}" ) - ) - # Handle cancellation notifications - if isinstance(notification.root, CancelledNotification): - cancelled_id = notification.root.params.requestId - if cancelled_id in self._in_flight: - await self._in_flight[cancelled_id].cancel() - else: - await self._received_notification(notification) - await self._incoming_message_stream_writer.send(notification) else: # Response or error stream = self._response_streams.pop(message.root.id, None) if stream: diff --git a/tests/issues/test_88_random_error.py b/tests/issues/test_88_random_error.py index 8b979ab..8609c20 100644 --- a/tests/issues/test_88_random_error.py +++ b/tests/issues/test_88_random_error.py @@ -30,19 +30,23 @@ async def test_notification_validation_error(tmp_path: Path): server = Server(name="test") request_count = 0 - slow_request_complete = False + slow_request_started = anyio.Event() + slow_request_complete = anyio.Event() @server.call_tool() async def slow_tool( name: str, arg ) -> Sequence[TextContent | ImageContent | EmbeddedResource]: - nonlocal request_count, slow_request_complete + nonlocal request_count request_count += 1 if name == "slow": + # Signal that slow request has started + slow_request_started.set() # Long enough to ensure timeout await anyio.sleep(0.2) - slow_request_complete = True + # Signal completion + slow_request_complete.set() return [TextContent(type="text", text=f"slow {request_count}")] elif name == "fast": # Fast enough to complete before timeout @@ -71,7 +75,7 @@ async def client(read_stream, write_stream): # First call should work (fast operation) result = await session.call_tool("fast") assert result.content == [TextContent(type="text", text="fast 1")] - assert not slow_request_complete + assert not slow_request_complete.is_set() # Second call should timeout (slow operation) with pytest.raises(McpError) as exc_info: @@ -79,8 +83,8 @@ async def client(read_stream, write_stream): assert "Timed out while waiting" in str(exc_info.value) # Wait for slow request to complete in the background - await anyio.sleep(0.3) - assert slow_request_complete + with anyio.fail_after(1): # Timeout after 1 second + await slow_request_complete.wait() # Third call should work (fast operation), # proving server is still responsive @@ -91,10 +95,17 @@ async def client(read_stream, write_stream): server_writer, server_reader = anyio.create_memory_object_stream(1) client_writer, client_reader = anyio.create_memory_object_stream(1) + server_ready = anyio.Event() + + async def wrapped_server_handler(read_stream, write_stream): + server_ready.set() + await server_handler(read_stream, write_stream) + async with anyio.create_task_group() as tg: - tg.start_soon(server_handler, server_reader, client_writer) + tg.start_soon(wrapped_server_handler, server_reader, client_writer) # Wait for server to start and initialize - await anyio.sleep(0.1) + with anyio.fail_after(1): # Timeout after 1 second + await server_ready.wait() # Run client in a separate task to avoid cancellation async with anyio.create_task_group() as client_tg: client_tg.start_soon(client, client_reader, server_writer) From 733db0c9cfd2bef6b125c1962b0137a510233759 Mon Sep 17 00:00:00 2001 From: David Soria Parra Date: Tue, 4 Feb 2025 19:29:12 +0000 Subject: [PATCH 4/4] fix: enforce context manager usage for RequestResponder --- src/mcp/server/session.py | 23 ++++++++++++----------- src/mcp/shared/session.py | 37 +++++++++++++++++++++++++++++++++++++ 2 files changed, 49 insertions(+), 11 deletions(-) diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index b71b372..d918b98 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -126,19 +126,20 @@ async def _received_request( case types.InitializeRequest(params=params): self._initialization_state = InitializationState.Initializing self._client_params = params - await responder.respond( - types.ServerResult( - types.InitializeResult( - protocolVersion=types.LATEST_PROTOCOL_VERSION, - capabilities=self._init_options.capabilities, - serverInfo=types.Implementation( - name=self._init_options.server_name, - version=self._init_options.server_version, - ), - instructions=self._init_options.instructions, + with responder: + await responder.respond( + types.ServerResult( + types.InitializeResult( + protocolVersion=types.LATEST_PROTOCOL_VERSION, + capabilities=self._init_options.capabilities, + serverInfo=types.Implementation( + name=self._init_options.server_name, + version=self._init_options.server_version, + ), + instructions=self._init_options.instructions, + ) ) ) - ) case _: if self._initialization_state != InitializationState.Initialized: raise RuntimeError( diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index e21bcbc..3d3988c 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -40,6 +40,21 @@ class RequestResponder(Generic[ReceiveRequestT, SendResultT]): + """Handles responding to MCP requests and manages request lifecycle. + + This class MUST be used as a context manager to ensure proper cleanup and + cancellation handling: + + Example: + with request_responder as resp: + await resp.respond(result) + + The context manager ensures: + 1. Proper cancellation scope setup and cleanup + 2. Request completion tracking + 3. Cleanup of in-flight requests + """ + def __init__( self, request_id: RequestId, @@ -55,19 +70,36 @@ def __init__( self._completed = False self._cancel_scope = anyio.CancelScope() self._on_complete = on_complete + self._entered = False # Track if we're in a context manager def __enter__(self) -> "RequestResponder[ReceiveRequestT, SendResultT]": + """Enter the context manager, enabling request cancellation tracking.""" + self._entered = True + self._cancel_scope = anyio.CancelScope() self._cancel_scope.__enter__() return self def __exit__(self, exc_type, exc_val, exc_tb) -> None: + """Exit the context manager, performing cleanup and notifying completion.""" try: if self._completed: self._on_complete(self) finally: + self._entered = False + if not self._cancel_scope: + raise RuntimeError("No active cancel scope") self._cancel_scope.__exit__(exc_type, exc_val, exc_tb) async def respond(self, response: SendResultT | ErrorData) -> None: + """Send a response for this request. + + Must be called within a context manager block. + Raises: + RuntimeError: If not used within a context manager + AssertionError: If request was already responded to + """ + if not self._entered: + raise RuntimeError("RequestResponder must be used as a context manager") assert not self._completed, "Request already responded to" if not self.cancelled: @@ -79,6 +111,11 @@ async def respond(self, response: SendResultT | ErrorData) -> None: async def cancel(self) -> None: """Cancel this request and mark it as completed.""" + if not self._entered: + raise RuntimeError("RequestResponder must be used as a context manager") + if not self._cancel_scope: + raise RuntimeError("No active cancel scope") + self._cancel_scope.cancel() self._completed = True # Mark as completed so it's removed from in_flight # Send an error response to indicate cancellation