Skip to content

Simplified client progress_callback support #632

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/mcp/client/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import mcp.types as types
from mcp.shared.context import RequestContext
from mcp.shared.message import SessionMessage
from mcp.shared.session import BaseSession, RequestResponder
from mcp.shared.session import BaseSession, ProgressFnT, RequestResponder
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS

DEFAULT_CLIENT_INFO = types.Implementation(name="mcp", version="0.1.0")
Expand Down Expand Up @@ -259,9 +259,9 @@ async def call_tool(
name: str,
arguments: dict[str, Any] | None = None,
read_timeout_seconds: timedelta | None = None,
progress_callback: ProgressFnT | None = None,
) -> types.CallToolResult:
"""Send a tools/call request."""

return await self.send_request(
types.ClientRequest(
types.CallToolRequest(
Expand All @@ -271,6 +271,7 @@ async def call_tool(
),
types.CallToolResult,
request_read_timeout_seconds=read_timeout_seconds,
progress_callback=progress_callback,
)

async def list_prompts(self) -> types.ListPromptsResult:
Expand Down
75 changes: 71 additions & 4 deletions src/mcp/shared/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from contextlib import AsyncExitStack
from datetime import timedelta
from types import TracebackType
from typing import Any, Generic, TypeVar
from typing import Any, Generic, Protocol, TypeVar

import anyio
import httpx
Expand All @@ -24,6 +24,9 @@
JSONRPCNotification,
JSONRPCRequest,
JSONRPCResponse,
ProgressNotification,
ProgressNotificationParams,
ProgressToken,
RequestParams,
ServerNotification,
ServerRequest,
Expand All @@ -39,6 +42,14 @@
"ReceiveNotificationT", ClientNotification, ServerNotification
)


class ProgressFnT(Protocol):
async def __call__(
self,
params: ProgressNotificationParams,
) -> None: ...


RequestId = str | int


Expand Down Expand Up @@ -168,7 +179,9 @@ class BaseSession(
RequestId, MemoryObjectSendStream[JSONRPCResponse | JSONRPCError]
]
_request_id: int
_progress_id: int
_in_flight: dict[RequestId, RequestResponder[ReceiveRequestT, SendResultT]]
_in_progress: dict[ProgressToken, ProgressFnT]

def __init__(
self,
Expand All @@ -187,6 +200,8 @@ def __init__(
self._receive_notification_type = receive_notification_type
self._session_read_timeout_seconds = read_timeout_seconds
self._in_flight = {}
self._progress_id = 0
self._in_progress = {}
self._exit_stack = AsyncExitStack()

async def __aenter__(self) -> Self:
Expand Down Expand Up @@ -214,19 +229,55 @@ async def send_request(
result_type: type[ReceiveResultT],
request_read_timeout_seconds: timedelta | None = None,
metadata: MessageMetadata = None,
progress_callback: ProgressFnT | None = None,
) -> ReceiveResultT:
"""
Sends a request and wait for a response. Raises an McpError if the
response contains an error. If a request read timeout is provided, it
will take precedence over the session read timeout.

If progress_callback is provided any progress notifications sent from the
receiver will be passed back to the sender

Do not use this method to emit notifications! Use send_notification()
instead.
"""

request_id = self._request_id
self._request_id = request_id + 1

progress_id = None
send_request = None

if progress_callback is not None:
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nasty if/else branching, ugly but functional main challenge is how to inject meta into params when the type of the request is unknown. Open to ideas on how to make this cleaner

if request.root.params is None:
progress_id = self._progress_id
new_params = RequestParams(
_meta=RequestParams.Meta(progressToken=progress_id)
)
else:
if (
request.root.params.meta is None
or request.root.params.meta.progressToken is None
):
progress_id = self._progress_id
new_params = request.root.params.model_copy(
update={"meta": RequestParams.Meta(progressToken=progress_id)}
)
else:
raise ValueError(
"Request has progressToken and progress_callback provided "
"via send_request method only one or other is supported"
)

new_root = request.root.model_copy(update={"params": new_params})
send_request = request.model_copy(update={"root": new_root})
self._progress_id = progress_id + 1
self._in_progress[progress_id] = progress_callback

if send_request is None:
send_request = request

response_stream, response_stream_reader = anyio.create_memory_object_stream[
JSONRPCResponse | JSONRPCError
](1)
Expand All @@ -236,11 +287,11 @@ async def send_request(
jsonrpc_request = JSONRPCRequest(
jsonrpc="2.0",
id=request_id,
**request.model_dump(by_alias=True, mode="json", exclude_none=True),
**send_request.model_dump(
by_alias=True, mode="json", exclude_none=True
),
)

# TODO: Support progress callbacks

await self._write_stream.send(
SessionMessage(
message=JSONRPCMessage(jsonrpc_request), metadata=metadata
Expand Down Expand Up @@ -276,6 +327,8 @@ async def send_request(

finally:
self._response_streams.pop(request_id, None)
if progress_id is not None:
self._in_progress.pop(progress_id, None)
await response_stream.aclose()
await response_stream_reader.aclose()

Expand Down Expand Up @@ -364,6 +417,20 @@ async def _receive_loop(self) -> None:
if cancelled_id in self._in_flight:
await self._in_flight[cancelled_id].cancel()
else:
match notification.root:
case ProgressNotification(params=params):
if params.progressToken in self._in_progress:
progress_callback = self._in_progress[
params.progressToken
]
await progress_callback(params)
else:
logging.warning(
"Unknown progress token %s",
params.progressToken,
)
case _:
pass
await self._received_notification(notification)
await self._handle_incoming(notification)
except Exception as e:
Expand Down
103 changes: 103 additions & 0 deletions tests/client/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,3 +250,106 @@ async def mock_server():

# Assert that the default client info was sent
assert received_client_info == DEFAULT_CLIENT_INFO


@pytest.mark.anyio
async def test_client_session_progress():
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
SessionMessage
](1)
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
SessionMessage
](1)

send_notification_count = 10

async def mock_server():
session_message = await client_to_server_receive.receive()
jsonrpc_request = session_message.message
assert isinstance(jsonrpc_request.root, JSONRPCRequest)
request = ClientRequest.model_validate(
jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True)
)
assert isinstance(request.root, types.CallToolRequest)
assert request.root.params.meta
assert request.root.params.meta.progressToken is not None

progress_token = request.root.params.meta.progressToken
notifications = [
types.ServerNotification(
root=types.ProgressNotification(
params=types.ProgressNotificationParams(
progressToken=progress_token, progress=i
),
method="notifications/progress",
)
)
for i in range(send_notification_count)
]
result = ServerResult(types.CallToolResult(content=[]))

async with server_to_client_send:
for notification in notifications:
await server_to_client_send.send(
SessionMessage(
JSONRPCMessage(
types.JSONRPCNotification(
jsonrpc="2.0",
**notification.model_dump(
by_alias=True, mode="json", exclude_none=True
),
)
)
)
)
await server_to_client_send.send(
SessionMessage(
JSONRPCMessage(
JSONRPCResponse(
jsonrpc="2.0",
id=jsonrpc_request.root.id,
result=result.model_dump(
by_alias=True, mode="json", exclude_none=True
),
)
)
)
)

# Create a message handler to catch exceptions
async def message_handler(
message: RequestResponder[types.ServerRequest, types.ClientResult]
| types.ServerNotification
| Exception,
) -> None:
if isinstance(message, Exception):
raise message

progress_count = 0

async with (
ClientSession(
server_to_client_receive,
client_to_server_send,
message_handler=message_handler,
) as session,
anyio.create_task_group() as tg,
client_to_server_send,
client_to_server_receive,
server_to_client_send,
server_to_client_receive,
):
tg.start_soon(mock_server)

async def progress_callback(params: types.ProgressNotificationParams):
nonlocal progress_count
progress_count = progress_count + 1

result = await session.call_tool(
"tool_with_progress", progress_callback=progress_callback
)

# Assert the result
assert isinstance(result, types.CallToolResult)
assert len(result.content) == 0
assert progress_count == send_notification_count
Loading