-
Notifications
You must be signed in to change notification settings - Fork 172
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add request cancellation and in-flight request tracking
This commit adds support for request cancellation and tracking of in-flight requests in the MCP protocol implementation. The key architectural changes are: 1. Request Lifecycle Management: - Added _in_flight dictionary to BaseSession to track active requests - Requests are tracked from receipt until completion/cancellation - Added proper cleanup via on_complete callback 2. Cancellation Support: - Added CancelledNotification handling in _receive_loop - Implemented cancel() method in RequestResponder - Uses anyio.CancelScope for robust cancellation - Sends error response on cancellation 3. Request Context: - Added request_ctx ContextVar for request context - Ensures proper cleanup after request handling - Maintains request state throughout lifecycle 4. Error Handling: - Improved error propagation for cancelled requests - Added proper cleanup of cancelled requests - Maintains consistency of in-flight tracking This change enables clients to cancel long-running requests and servers to properly clean up resources when requests are cancelled. Github-Issue:#88
- Loading branch information
Showing
3 changed files
with
190 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
from typing import AsyncGenerator | ||
|
||
import anyio | ||
import pytest | ||
|
||
import mcp.types as types | ||
from mcp.client.session import ClientSession | ||
from mcp.server.lowlevel.server import Server | ||
from mcp.shared.exceptions import McpError | ||
from mcp.shared.memory import create_connected_server_and_client_session | ||
from mcp.types import ( | ||
CancelledNotification, | ||
CancelledNotificationParams, | ||
ClientNotification, | ||
ClientRequest, | ||
EmptyResult, | ||
) | ||
|
||
|
||
@pytest.fixture | ||
def mcp_server() -> Server: | ||
return Server(name="test server") | ||
|
||
|
||
@pytest.fixture | ||
async def client_connected_to_server( | ||
mcp_server: Server, | ||
) -> AsyncGenerator[ClientSession, None]: | ||
async with create_connected_server_and_client_session(mcp_server) as client_session: | ||
yield client_session | ||
|
||
|
||
@pytest.mark.anyio | ||
async def test_in_flight_requests_cleared_after_completion( | ||
client_connected_to_server: ClientSession, | ||
): | ||
"""Verify that _in_flight is empty after all requests complete.""" | ||
# Send a request and wait for response | ||
response = await client_connected_to_server.send_ping() | ||
assert isinstance(response, EmptyResult) | ||
|
||
# Verify _in_flight is empty | ||
assert len(client_connected_to_server._in_flight) == 0 | ||
|
||
|
||
@pytest.mark.anyio | ||
async def test_request_cancellation(): | ||
"""Test that requests can be cancelled while in-flight.""" | ||
# The tool is already registered in the fixture | ||
|
||
ev_tool_called = anyio.Event() | ||
ev_cancelled = anyio.Event() | ||
request_id = None | ||
|
||
# Start the request in a separate task so we can cancel it | ||
def make_server() -> Server: | ||
server = Server(name="TestSessionServer") | ||
|
||
# Register the tool handler | ||
@server.call_tool() | ||
async def handle_call_tool(name: str, arguments: dict | None) -> list: | ||
nonlocal request_id, ev_tool_called | ||
if name == "slow_tool": | ||
request_id = server.request_context.request_id | ||
ev_tool_called.set() | ||
await anyio.sleep(10) # Long enough to ensure we can cancel | ||
return [] | ||
raise ValueError(f"Unknown tool: {name}") | ||
|
||
# Register the tool so it shows up in list_tools | ||
@server.list_tools() | ||
async def handle_list_tools() -> list[types.Tool]: | ||
return [ | ||
types.Tool( | ||
name="slow_tool", | ||
description="A slow tool that takes 10 seconds to complete", | ||
inputSchema={}, | ||
) | ||
] | ||
|
||
return server | ||
|
||
async def make_request(client_session): | ||
nonlocal ev_cancelled | ||
try: | ||
await client_session.send_request( | ||
ClientRequest( | ||
types.CallToolRequest( | ||
method="tools/call", | ||
params=types.CallToolRequestParams( | ||
name="slow_tool", arguments={} | ||
), | ||
) | ||
), | ||
types.CallToolResult, | ||
) | ||
pytest.fail("Request should have been cancelled") | ||
except McpError as e: | ||
# Expected - request was cancelled | ||
assert "Request cancelled" in str(e) | ||
ev_cancelled.set() | ||
|
||
async with create_connected_server_and_client_session( | ||
make_server() | ||
) as client_session: | ||
async with anyio.create_task_group() as tg: | ||
tg.start_soon(make_request, client_session) | ||
|
||
# Wait for the request to be in-flight | ||
with anyio.fail_after(1): # Timeout after 1 second | ||
await ev_tool_called.wait() | ||
|
||
# Send cancellation notification | ||
assert request_id is not None | ||
await client_session.send_notification( | ||
ClientNotification( | ||
CancelledNotification( | ||
method="notifications/cancelled", | ||
params=CancelledNotificationParams(requestId=request_id), | ||
) | ||
) | ||
) | ||
|
||
# Give cancellation time to process | ||
with anyio.fail_after(1): | ||
await ev_cancelled.wait() |