Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add request cancellation and cleanup #167

Merged
merged 4 commits into from
Feb 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions src/mcp/server/lowlevel/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,10 +453,15 @@ async def run(
logger.debug(f"Received message: {message}")

match message:
case RequestResponder(request=types.ClientRequest(root=req)):
await self._handle_request(
message, req, session, raise_exceptions
)
case (
RequestResponder(
request=types.ClientRequest(root=req)
) as responder
):
with responder:
await self._handle_request(
message, req, session, raise_exceptions
)
case types.ClientNotification(root=notify):
await self._handle_notification(notify)

Expand Down
23 changes: 12 additions & 11 deletions src/mcp/server/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,19 +126,20 @@ async def _received_request(
case types.InitializeRequest(params=params):
self._initialization_state = InitializationState.Initializing
self._client_params = params
await responder.respond(
types.ServerResult(
types.InitializeResult(
protocolVersion=types.LATEST_PROTOCOL_VERSION,
capabilities=self._init_options.capabilities,
serverInfo=types.Implementation(
name=self._init_options.server_name,
version=self._init_options.server_version,
),
instructions=self._init_options.instructions,
with responder:
await responder.respond(
types.ServerResult(
types.InitializeResult(
protocolVersion=types.LATEST_PROTOCOL_VERSION,
capabilities=self._init_options.capabilities,
serverInfo=types.Implementation(
name=self._init_options.server_name,
version=self._init_options.server_version,
),
instructions=self._init_options.instructions,
)
)
)
)
case _:
if self._initialization_state != InitializationState.Initialized:
raise RuntimeError(
Expand Down
119 changes: 106 additions & 13 deletions src/mcp/shared/session.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
from contextlib import AbstractAsyncContextManager
from datetime import timedelta
from typing import Generic, TypeVar
from typing import Any, Callable, Generic, TypeVar

import anyio
import anyio.lowlevel
Expand All @@ -10,6 +11,7 @@

from mcp.shared.exceptions import McpError
from mcp.types import (
CancelledNotification,
ClientNotification,
ClientRequest,
ClientResult,
Expand Down Expand Up @@ -38,27 +40,98 @@


class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
"""Handles responding to MCP requests and manages request lifecycle.

This class MUST be used as a context manager to ensure proper cleanup and
cancellation handling:

Example:
with request_responder as resp:
await resp.respond(result)

The context manager ensures:
1. Proper cancellation scope setup and cleanup
2. Request completion tracking
3. Cleanup of in-flight requests
"""

def __init__(
self,
request_id: RequestId,
request_meta: RequestParams.Meta | None,
request: ReceiveRequestT,
session: "BaseSession",
on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any],
) -> None:
self.request_id = request_id
self.request_meta = request_meta
self.request = request
self._session = session
self._responded = False
self._completed = False
self._cancel_scope = anyio.CancelScope()
self._on_complete = on_complete
self._entered = False # Track if we're in a context manager

def __enter__(self) -> "RequestResponder[ReceiveRequestT, SendResultT]":
"""Enter the context manager, enabling request cancellation tracking."""
self._entered = True
self._cancel_scope = anyio.CancelScope()
self._cancel_scope.__enter__()
return self

def __exit__(self, exc_type, exc_val, exc_tb) -> None:
"""Exit the context manager, performing cleanup and notifying completion."""
try:
if self._completed:
self._on_complete(self)
finally:
self._entered = False
if not self._cancel_scope:
raise RuntimeError("No active cancel scope")
self._cancel_scope.__exit__(exc_type, exc_val, exc_tb)

async def respond(self, response: SendResultT | ErrorData) -> None:
assert not self._responded, "Request already responded to"
self._responded = True
"""Send a response for this request.

Must be called within a context manager block.
Raises:
RuntimeError: If not used within a context manager
AssertionError: If request was already responded to
"""
if not self._entered:
raise RuntimeError("RequestResponder must be used as a context manager")
assert not self._completed, "Request already responded to"

if not self.cancelled:
self._completed = True

await self._session._send_response(
request_id=self.request_id, response=response
)

async def cancel(self) -> None:
"""Cancel this request and mark it as completed."""
if not self._entered:
raise RuntimeError("RequestResponder must be used as a context manager")
if not self._cancel_scope:
raise RuntimeError("No active cancel scope")

self._cancel_scope.cancel()
self._completed = True # Mark as completed so it's removed from in_flight
# Send an error response to indicate cancellation
await self._session._send_response(
request_id=self.request_id, response=response
request_id=self.request_id,
response=ErrorData(code=0, message="Request cancelled", data=None),
)

@property
def in_flight(self) -> bool:
return not self._completed and not self.cancelled

@property
def cancelled(self) -> bool:
return self._cancel_scope is not None and self._cancel_scope.cancel_called


class BaseSession(
AbstractAsyncContextManager,
Expand All @@ -82,6 +155,7 @@ class BaseSession(
RequestId, MemoryObjectSendStream[JSONRPCResponse | JSONRPCError]
]
_request_id: int
_in_flight: dict[RequestId, RequestResponder[ReceiveRequestT, SendResultT]]

def __init__(
self,
Expand All @@ -99,6 +173,7 @@ def __init__(
self._receive_request_type = receive_request_type
self._receive_notification_type = receive_notification_type
self._read_timeout_seconds = read_timeout_seconds
self._in_flight = {}

self._incoming_message_stream_writer, self._incoming_message_stream_reader = (
anyio.create_memory_object_stream[
Expand Down Expand Up @@ -219,27 +294,45 @@ async def _receive_loop(self) -> None:
by_alias=True, mode="json", exclude_none=True
)
)

responder = RequestResponder(
request_id=message.root.id,
request_meta=validated_request.root.params.meta
if validated_request.root.params
else None,
request=validated_request,
session=self,
on_complete=lambda r: self._in_flight.pop(r.request_id, None),
)

self._in_flight[responder.request_id] = responder
await self._received_request(responder)
if not responder._responded:
if not responder._completed:
await self._incoming_message_stream_writer.send(responder)

elif isinstance(message.root, JSONRPCNotification):
notification = self._receive_notification_type.model_validate(
message.root.model_dump(
by_alias=True, mode="json", exclude_none=True
try:
notification = self._receive_notification_type.model_validate(
message.root.model_dump(
by_alias=True, mode="json", exclude_none=True
)
)
# Handle cancellation notifications
if isinstance(notification.root, CancelledNotification):
cancelled_id = notification.root.params.requestId
if cancelled_id in self._in_flight:
await self._in_flight[cancelled_id].cancel()
else:
await self._received_notification(notification)
await self._incoming_message_stream_writer.send(
notification
)
except Exception as e:
# For other validation errors, log and continue
logging.warning(
f"Failed to validate notification: {e}. "
f"Message was: {message.root}"
)
)

await self._received_notification(notification)
await self._incoming_message_stream_writer.send(notification)
else: # Response or error
stream = self._response_streams.pop(message.root.id, None)
if stream:
Expand Down
111 changes: 111 additions & 0 deletions tests/issues/test_88_random_error.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
"""Test to reproduce issue #88: Random error thrown on response."""

from datetime import timedelta
from pathlib import Path
from typing import Sequence

import anyio
import pytest

from mcp.client.session import ClientSession
from mcp.server.lowlevel import Server
from mcp.shared.exceptions import McpError
from mcp.types import (
EmbeddedResource,
ImageContent,
TextContent,
)


@pytest.mark.anyio
async def test_notification_validation_error(tmp_path: Path):
"""Test that timeouts are handled gracefully and don't break the server.

This test verifies that when a client request times out:
1. The server task stays alive
2. The server can still handle new requests
3. The client can make new requests
4. No resources are leaked
"""

server = Server(name="test")
request_count = 0
slow_request_started = anyio.Event()
slow_request_complete = anyio.Event()

@server.call_tool()
async def slow_tool(
name: str, arg
) -> Sequence[TextContent | ImageContent | EmbeddedResource]:
nonlocal request_count
request_count += 1

if name == "slow":
# Signal that slow request has started
slow_request_started.set()
# Long enough to ensure timeout
await anyio.sleep(0.2)
Comment on lines +46 to +47
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm quite allergic to actually sleeping in tests, due to the flakiness and slowness. Is there any other way we can exercise these? Would a timeout of 0 work? Or a timeout in the past vs. a timeout in the far future?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me fix this. This was Claude written. After a lot of back and forth, the right way is to use events and wait for them to trigger.

# Signal completion
slow_request_complete.set()
return [TextContent(type="text", text=f"slow {request_count}")]
elif name == "fast":
# Fast enough to complete before timeout
await anyio.sleep(0.01)
return [TextContent(type="text", text=f"fast {request_count}")]
return [TextContent(type="text", text=f"unknown {request_count}")]

async def server_handler(read_stream, write_stream):
await server.run(
read_stream,
write_stream,
server.create_initialization_options(),
raise_exceptions=True,
)

async def client(read_stream, write_stream):
# Use a timeout that's:
# - Long enough for fast operations (>10ms)
# - Short enough for slow operations (<200ms)
# - Not too short to avoid flakiness
async with ClientSession(
read_stream, write_stream, read_timeout_seconds=timedelta(milliseconds=50)
) as session:
await session.initialize()

# First call should work (fast operation)
result = await session.call_tool("fast")
assert result.content == [TextContent(type="text", text="fast 1")]
assert not slow_request_complete.is_set()

# Second call should timeout (slow operation)
with pytest.raises(McpError) as exc_info:
await session.call_tool("slow")
assert "Timed out while waiting" in str(exc_info.value)

# Wait for slow request to complete in the background
with anyio.fail_after(1): # Timeout after 1 second
await slow_request_complete.wait()

# Third call should work (fast operation),
# proving server is still responsive
result = await session.call_tool("fast")
assert result.content == [TextContent(type="text", text="fast 3")]

# Run server and client in separate task groups to avoid cancellation
server_writer, server_reader = anyio.create_memory_object_stream(1)
client_writer, client_reader = anyio.create_memory_object_stream(1)

server_ready = anyio.Event()

async def wrapped_server_handler(read_stream, write_stream):
server_ready.set()
await server_handler(read_stream, write_stream)

async with anyio.create_task_group() as tg:
tg.start_soon(wrapped_server_handler, server_reader, client_writer)
# Wait for server to start and initialize
with anyio.fail_after(1): # Timeout after 1 second
await server_ready.wait()
# Run client in a separate task to avoid cancellation
async with anyio.create_task_group() as client_tg:
client_tg.start_soon(client, client_reader, server_writer)
Loading