Skip to content

Commit

Permalink
feat: add request cancellation and in-flight request tracking
Browse files Browse the repository at this point in the history
This commit adds support for request cancellation and tracking of
in-flight requests in the MCP protocol implementation. The key
architectural changes are:

1. Request Lifecycle Management:
   - Added _in_flight dictionary to BaseSession to track active requests
   - Requests are tracked from receipt until completion/cancellation
   - Added proper cleanup via on_complete callback

2. Cancellation Support:
   - Added CancelledNotification handling in _receive_loop
   - Implemented cancel() method in RequestResponder
   - Uses anyio.CancelScope for robust cancellation
   - Sends error response on cancellation

3. Request Context:
   - Added request_ctx ContextVar for request context
   - Ensures proper cleanup after request handling
   - Maintains request state throughout lifecycle

4. Error Handling:
   - Improved error propagation for cancelled requests
   - Added proper cleanup of cancelled requests
   - Maintains consistency of in-flight tracking

This change enables clients to cancel long-running requests and
servers to properly clean up resources when requests are cancelled.

Github-Issue:#88
  • Loading branch information
dsp-ant committed Feb 3, 2025
1 parent 888bdd3 commit 827e494
Show file tree
Hide file tree
Showing 3 changed files with 190 additions and 13 deletions.
13 changes: 9 additions & 4 deletions src/mcp/server/lowlevel/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,10 +453,15 @@ async def run(
logger.debug(f"Received message: {message}")

match message:
case RequestResponder(request=types.ClientRequest(root=req)):
await self._handle_request(
message, req, session, raise_exceptions
)
case (
RequestResponder(
request=types.ClientRequest(root=req)
) as responder
):
with responder:
await self._handle_request(
message, req, session, raise_exceptions
)
case types.ClientNotification(root=notify):
await self._handle_notification(notify)

Expand Down
64 changes: 55 additions & 9 deletions src/mcp/shared/session.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from contextlib import AbstractAsyncContextManager
from datetime import timedelta
from typing import Generic, TypeVar
from typing import Any, Callable, Generic, TypeVar

import anyio
import anyio.lowlevel
Expand All @@ -10,6 +10,7 @@

from mcp.shared.exceptions import McpError
from mcp.types import (
CancelledNotification,
ClientNotification,
ClientRequest,
ClientResult,
Expand Down Expand Up @@ -44,21 +45,55 @@ def __init__(
request_meta: RequestParams.Meta | None,
request: ReceiveRequestT,
session: "BaseSession",
on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any],
) -> None:
self.request_id = request_id
self.request_meta = request_meta
self.request = request
self._session = session
self._responded = False
self._completed = False
self._cancel_scope = anyio.CancelScope()
self._on_complete = on_complete

def __enter__(self) -> "RequestResponder[ReceiveRequestT, SendResultT]":
self._cancel_scope.__enter__()
return self

def __exit__(self, exc_type, exc_val, exc_tb) -> None:
try:
if self._completed:
self._on_complete(self)
finally:
self._cancel_scope.__exit__(exc_type, exc_val, exc_tb)

async def respond(self, response: SendResultT | ErrorData) -> None:
assert not self._responded, "Request already responded to"
self._responded = True
assert not self._completed, "Request already responded to"

if not self.cancelled:
self._completed = True

await self._session._send_response(
request_id=self.request_id, response=response
)

async def cancel(self) -> None:
"""Cancel this request and mark it as completed."""
self._cancel_scope.cancel()
self._completed = True # Mark as completed so it's removed from in_flight
# Send an error response to indicate cancellation
await self._session._send_response(
request_id=self.request_id, response=response
request_id=self.request_id,
response=ErrorData(code=0, message="Request cancelled", data=None),
)

@property
def in_flight(self) -> bool:
return not self._completed and not self.cancelled

@property
def cancelled(self) -> bool:
return self._cancel_scope is not None and self._cancel_scope.cancel_called


class BaseSession(
AbstractAsyncContextManager,
Expand All @@ -82,6 +117,7 @@ class BaseSession(
RequestId, MemoryObjectSendStream[JSONRPCResponse | JSONRPCError]
]
_request_id: int
_in_flight: dict[RequestId, RequestResponder[ReceiveRequestT, SendResultT]]

def __init__(
self,
Expand All @@ -99,6 +135,7 @@ def __init__(
self._receive_request_type = receive_request_type
self._receive_notification_type = receive_notification_type
self._read_timeout_seconds = read_timeout_seconds
self._in_flight = {}

self._incoming_message_stream_writer, self._incoming_message_stream_reader = (
anyio.create_memory_object_stream[
Expand Down Expand Up @@ -219,27 +256,36 @@ async def _receive_loop(self) -> None:
by_alias=True, mode="json", exclude_none=True
)
)

responder = RequestResponder(
request_id=message.root.id,
request_meta=validated_request.root.params.meta
if validated_request.root.params
else None,
request=validated_request,
session=self,
on_complete=lambda r: self._in_flight.pop(r.request_id, None),
)

self._in_flight[responder.request_id] = responder
await self._received_request(responder)
if not responder._responded:
if not responder._completed:
await self._incoming_message_stream_writer.send(responder)

elif isinstance(message.root, JSONRPCNotification):
notification = self._receive_notification_type.model_validate(
message.root.model_dump(
by_alias=True, mode="json", exclude_none=True
)
)

await self._received_notification(notification)
await self._incoming_message_stream_writer.send(notification)
# Handle cancellation notifications
if isinstance(notification.root, CancelledNotification):
cancelled_id = notification.root.params.requestId
if cancelled_id in self._in_flight:
await self._in_flight[cancelled_id].cancel()
else:
await self._received_notification(notification)
await self._incoming_message_stream_writer.send(notification)
else: # Response or error
stream = self._response_streams.pop(message.root.id, None)
if stream:
Expand Down
126 changes: 126 additions & 0 deletions tests/shared/test_session.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
from typing import AsyncGenerator

import anyio
import pytest

import mcp.types as types
from mcp.client.session import ClientSession
from mcp.server.lowlevel.server import Server
from mcp.shared.exceptions import McpError
from mcp.shared.memory import create_connected_server_and_client_session
from mcp.types import (
CancelledNotification,
CancelledNotificationParams,
ClientNotification,
ClientRequest,
EmptyResult,
)


@pytest.fixture
def mcp_server() -> Server:
return Server(name="test server")


@pytest.fixture
async def client_connected_to_server(
mcp_server: Server,
) -> AsyncGenerator[ClientSession, None]:
async with create_connected_server_and_client_session(mcp_server) as client_session:
yield client_session


@pytest.mark.anyio
async def test_in_flight_requests_cleared_after_completion(
client_connected_to_server: ClientSession,
):
"""Verify that _in_flight is empty after all requests complete."""
# Send a request and wait for response
response = await client_connected_to_server.send_ping()
assert isinstance(response, EmptyResult)

# Verify _in_flight is empty
assert len(client_connected_to_server._in_flight) == 0


@pytest.mark.anyio
async def test_request_cancellation():
"""Test that requests can be cancelled while in-flight."""
# The tool is already registered in the fixture

ev_tool_called = anyio.Event()
ev_cancelled = anyio.Event()
request_id = None

# Start the request in a separate task so we can cancel it
def make_server() -> Server:
server = Server(name="TestSessionServer")

# Register the tool handler
@server.call_tool()
async def handle_call_tool(name: str, arguments: dict | None) -> list:
nonlocal request_id, ev_tool_called
if name == "slow_tool":
request_id = server.request_context.request_id
ev_tool_called.set()
await anyio.sleep(10) # Long enough to ensure we can cancel
return []
raise ValueError(f"Unknown tool: {name}")

# Register the tool so it shows up in list_tools
@server.list_tools()
async def handle_list_tools() -> list[types.Tool]:
return [
types.Tool(
name="slow_tool",
description="A slow tool that takes 10 seconds to complete",
inputSchema={},
)
]

return server

async def make_request(client_session):
nonlocal ev_cancelled
try:
await client_session.send_request(
ClientRequest(
types.CallToolRequest(
method="tools/call",
params=types.CallToolRequestParams(
name="slow_tool", arguments={}
),
)
),
types.CallToolResult,
)
pytest.fail("Request should have been cancelled")
except McpError as e:
# Expected - request was cancelled
assert "Request cancelled" in str(e)
ev_cancelled.set()

async with create_connected_server_and_client_session(
make_server()
) as client_session:
async with anyio.create_task_group() as tg:
tg.start_soon(make_request, client_session)

# Wait for the request to be in-flight
with anyio.fail_after(1): # Timeout after 1 second
await ev_tool_called.wait()

# Send cancellation notification
assert request_id is not None
await client_session.send_notification(
ClientNotification(
CancelledNotification(
method="notifications/cancelled",
params=CancelledNotificationParams(requestId=request_id),
)
)
)

# Give cancellation time to process
with anyio.fail_after(1):
await ev_cancelled.wait()

0 comments on commit 827e494

Please sign in to comment.