Skip to content

Commit

Permalink
fix: improve error handling and request cancellation for issue #88
Browse files Browse the repository at this point in the history
  • Loading branch information
dsp-ant committed Feb 4, 2025
1 parent 827e494 commit 08cfbe5
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 20 deletions.
34 changes: 22 additions & 12 deletions src/mcp/shared/session.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from contextlib import AbstractAsyncContextManager
from datetime import timedelta
from typing import Any, Callable, Generic, TypeVar
Expand Down Expand Up @@ -273,19 +274,28 @@ async def _receive_loop(self) -> None:
await self._incoming_message_stream_writer.send(responder)

elif isinstance(message.root, JSONRPCNotification):
notification = self._receive_notification_type.model_validate(
message.root.model_dump(
by_alias=True, mode="json", exclude_none=True
try:
notification = self._receive_notification_type.model_validate(
message.root.model_dump(
by_alias=True, mode="json", exclude_none=True
)
)
# Handle cancellation notifications
if isinstance(notification.root, CancelledNotification):
cancelled_id = notification.root.params.requestId
if cancelled_id in self._in_flight:
await self._in_flight[cancelled_id].cancel()
else:
await self._received_notification(notification)
await self._incoming_message_stream_writer.send(
notification
)
except Exception as e:
# For other validation errors, log and continue
logging.warning(
f"Failed to validate notification: {e}. "
f"Message was: {message.root}"
)
)
# Handle cancellation notifications
if isinstance(notification.root, CancelledNotification):
cancelled_id = notification.root.params.requestId
if cancelled_id in self._in_flight:
await self._in_flight[cancelled_id].cancel()
else:
await self._received_notification(notification)
await self._incoming_message_stream_writer.send(notification)
else: # Response or error
stream = self._response_streams.pop(message.root.id, None)
if stream:
Expand Down
27 changes: 19 additions & 8 deletions tests/issues/test_88_random_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,19 +30,23 @@ async def test_notification_validation_error(tmp_path: Path):

server = Server(name="test")
request_count = 0
slow_request_complete = False
slow_request_started = anyio.Event()
slow_request_complete = anyio.Event()

@server.call_tool()
async def slow_tool(
name: str, arg
) -> Sequence[TextContent | ImageContent | EmbeddedResource]:
nonlocal request_count, slow_request_complete
nonlocal request_count
request_count += 1

if name == "slow":
# Signal that slow request has started
slow_request_started.set()
# Long enough to ensure timeout
await anyio.sleep(0.2)
slow_request_complete = True
# Signal completion
slow_request_complete.set()
return [TextContent(type="text", text=f"slow {request_count}")]
elif name == "fast":
# Fast enough to complete before timeout
Expand Down Expand Up @@ -71,16 +75,16 @@ async def client(read_stream, write_stream):
# First call should work (fast operation)
result = await session.call_tool("fast")
assert result.content == [TextContent(type="text", text="fast 1")]
assert not slow_request_complete
assert not slow_request_complete.is_set()

# Second call should timeout (slow operation)
with pytest.raises(McpError) as exc_info:
await session.call_tool("slow")
assert "Timed out while waiting" in str(exc_info.value)

# Wait for slow request to complete in the background
await anyio.sleep(0.3)
assert slow_request_complete
with anyio.fail_after(1): # Timeout after 1 second
await slow_request_complete.wait()

# Third call should work (fast operation),
# proving server is still responsive
Expand All @@ -91,10 +95,17 @@ async def client(read_stream, write_stream):
server_writer, server_reader = anyio.create_memory_object_stream(1)
client_writer, client_reader = anyio.create_memory_object_stream(1)

server_ready = anyio.Event()

async def wrapped_server_handler(read_stream, write_stream):
server_ready.set()
await server_handler(read_stream, write_stream)

async with anyio.create_task_group() as tg:
tg.start_soon(server_handler, server_reader, client_writer)
tg.start_soon(wrapped_server_handler, server_reader, client_writer)
# Wait for server to start and initialize
await anyio.sleep(0.1)
with anyio.fail_after(1): # Timeout after 1 second
await server_ready.wait()
# Run client in a separate task to avoid cancellation
async with anyio.create_task_group() as client_tg:
client_tg.start_soon(client, client_reader, server_writer)

0 comments on commit 08cfbe5

Please sign in to comment.