From 05eb05eeae9a6a282c781896bd975e9bcaf071bf Mon Sep 17 00:00:00 2001 From: Inna Date: Mon, 24 Feb 2025 22:17:55 +0000 Subject: [PATCH 1/4] add request id to RequestParams --- src/mcp/shared/session.py | 20 ++++++++++++-------- src/mcp/types.py | 5 ++++- 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 3d3988ce..042dd17e 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -25,6 +25,7 @@ ServerNotification, ServerRequest, ServerResult, + RequestId, ) SendRequestT = TypeVar("SendRequestT", ClientRequest, ServerRequest) @@ -36,8 +37,6 @@ "ReceiveNotificationT", ClientNotification, ServerNotification ) -RequestId = str | int - class RequestResponder(Generic[ReceiveRequestT, SendResultT]): """Handles responding to MCP requests and manages request lifecycle. @@ -208,9 +207,12 @@ async def send_request( Do not use this method to emit notifications! Use send_notification() instead. """ - - request_id = self._request_id - self._request_id = request_id + 1 + params = request.root.params + if params is not None and params.request_id is not None: + request_id = params.request_id + else: + request_id = self._request_id + self._request_id = request_id + 1 response_stream, response_stream_reader = anyio.create_memory_object_stream[ JSONRPCResponse | JSONRPCError @@ -297,9 +299,11 @@ async def _receive_loop(self) -> None: responder = RequestResponder( request_id=message.root.id, - request_meta=validated_request.root.params.meta - if validated_request.root.params - else None, + request_meta=( + validated_request.root.params.meta + if validated_request.root.params + else None + ), request=validated_request, session=self, on_complete=lambda r: self._in_flight.pop(r.request_id, None), diff --git a/src/mcp/types.py b/src/mcp/types.py index 7d867bd3..c3e1aa44 100644 --- a/src/mcp/types.py +++ b/src/mcp/types.py @@ -34,7 +34,9 @@ ProgressToken = str | int Cursor = str Role = Literal["user", "assistant"] -RequestId = str | int +ClientInitiatedRequestId = str +ServerInitiatedRequestId = int +RequestId = ClientInitiatedRequestId | ServerInitiatedRequestId AnyFunction: TypeAlias = Callable[..., Any] @@ -51,6 +53,7 @@ class Meta(BaseModel): model_config = ConfigDict(extra="allow") meta: Meta | None = Field(alias="_meta", default=None) + request_id: ClientInitiatedRequestId | None = Field(default=None) class NotificationParams(BaseModel): From 1e81458a62804360b579670543a572a3c913e30e Mon Sep 17 00:00:00 2001 From: Inna Date: Mon, 24 Feb 2025 22:21:27 +0000 Subject: [PATCH 2/4] add request id to cleint session call_tool and get_prompt --- src/mcp/client/session.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index c1cc5b5f..8ea80f90 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -219,14 +219,19 @@ async def unsubscribe_resource(self, uri: AnyUrl) -> types.EmptyResult: ) async def call_tool( - self, name: str, arguments: dict | None = None + self, + name: str, + arguments: dict | None = None, + request_id: types.ClientInitiatedRequestId | None = None, ) -> types.CallToolResult: """Send a tools/call request.""" return await self.send_request( types.ClientRequest( types.CallToolRequest( method="tools/call", - params=types.CallToolRequestParams(name=name, arguments=arguments), + params=types.CallToolRequestParams( + name=name, arguments=arguments, request_id=request_id + ), ) ), types.CallToolResult, @@ -244,14 +249,19 @@ async def list_prompts(self) -> types.ListPromptsResult: ) async def get_prompt( - self, name: str, arguments: dict[str, str] | None = None + self, + name: str, + arguments: dict[str, str] | None = None, + request_id: types.ClientInitiatedRequestId | None = None, ) -> types.GetPromptResult: """Send a prompts/get request.""" return await self.send_request( types.ClientRequest( types.GetPromptRequest( method="prompts/get", - params=types.GetPromptRequestParams(name=name, arguments=arguments), + params=types.GetPromptRequestParams( + name=name, arguments=arguments, request_id=request_id + ), ) ), types.GetPromptResult, From 3b40aa614e0f86c6ee37674a84837fdf5c9fcf8c Mon Sep 17 00:00:00 2001 From: Inna Date: Mon, 24 Feb 2025 22:22:23 +0000 Subject: [PATCH 3/4] fix lint --- src/mcp/shared/session.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 042dd17e..2afb2433 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -21,11 +21,11 @@ JSONRPCNotification, JSONRPCRequest, JSONRPCResponse, + RequestId, RequestParams, ServerNotification, ServerRequest, ServerResult, - RequestId, ) SendRequestT = TypeVar("SendRequestT", ClientRequest, ServerRequest) From ef4041d39f6b7496c45868c9ef7657e8e2d81be7 Mon Sep 17 00:00:00 2001 From: Inna Date: Mon, 24 Feb 2025 22:26:28 +0000 Subject: [PATCH 4/4] rename to custom and automatic request id --- src/mcp/client/session.py | 4 ++-- src/mcp/types.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 8ea80f90..b3041333 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -222,7 +222,7 @@ async def call_tool( self, name: str, arguments: dict | None = None, - request_id: types.ClientInitiatedRequestId | None = None, + request_id: types.CustomRequestId | None = None, ) -> types.CallToolResult: """Send a tools/call request.""" return await self.send_request( @@ -252,7 +252,7 @@ async def get_prompt( self, name: str, arguments: dict[str, str] | None = None, - request_id: types.ClientInitiatedRequestId | None = None, + request_id: types.CustomRequestId | None = None, ) -> types.GetPromptResult: """Send a prompts/get request.""" return await self.send_request( diff --git a/src/mcp/types.py b/src/mcp/types.py index c3e1aa44..5cd01b83 100644 --- a/src/mcp/types.py +++ b/src/mcp/types.py @@ -34,9 +34,9 @@ ProgressToken = str | int Cursor = str Role = Literal["user", "assistant"] -ClientInitiatedRequestId = str -ServerInitiatedRequestId = int -RequestId = ClientInitiatedRequestId | ServerInitiatedRequestId +CustomRequestId = str +AutomaticRequestId = int +RequestId = CustomRequestId | AutomaticRequestId AnyFunction: TypeAlias = Callable[..., Any] @@ -53,7 +53,7 @@ class Meta(BaseModel): model_config = ConfigDict(extra="allow") meta: Meta | None = Field(alias="_meta", default=None) - request_id: ClientInitiatedRequestId | None = Field(default=None) + request_id: CustomRequestId | None = Field(default=None) class NotificationParams(BaseModel):