diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index c1cc5b5f..b3041333 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -219,14 +219,19 @@ async def unsubscribe_resource(self, uri: AnyUrl) -> types.EmptyResult: ) async def call_tool( - self, name: str, arguments: dict | None = None + self, + name: str, + arguments: dict | None = None, + request_id: types.CustomRequestId | None = None, ) -> types.CallToolResult: """Send a tools/call request.""" return await self.send_request( types.ClientRequest( types.CallToolRequest( method="tools/call", - params=types.CallToolRequestParams(name=name, arguments=arguments), + params=types.CallToolRequestParams( + name=name, arguments=arguments, request_id=request_id + ), ) ), types.CallToolResult, @@ -244,14 +249,19 @@ async def list_prompts(self) -> types.ListPromptsResult: ) async def get_prompt( - self, name: str, arguments: dict[str, str] | None = None + self, + name: str, + arguments: dict[str, str] | None = None, + request_id: types.CustomRequestId | None = None, ) -> types.GetPromptResult: """Send a prompts/get request.""" return await self.send_request( types.ClientRequest( types.GetPromptRequest( method="prompts/get", - params=types.GetPromptRequestParams(name=name, arguments=arguments), + params=types.GetPromptRequestParams( + name=name, arguments=arguments, request_id=request_id + ), ) ), types.GetPromptResult, diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 3d3988ce..2afb2433 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -21,6 +21,7 @@ JSONRPCNotification, JSONRPCRequest, JSONRPCResponse, + RequestId, RequestParams, ServerNotification, ServerRequest, @@ -36,8 +37,6 @@ "ReceiveNotificationT", ClientNotification, ServerNotification ) -RequestId = str | int - class RequestResponder(Generic[ReceiveRequestT, SendResultT]): """Handles responding to MCP requests and manages request lifecycle. @@ -208,9 +207,12 @@ async def send_request( Do not use this method to emit notifications! Use send_notification() instead. """ - - request_id = self._request_id - self._request_id = request_id + 1 + params = request.root.params + if params is not None and params.request_id is not None: + request_id = params.request_id + else: + request_id = self._request_id + self._request_id = request_id + 1 response_stream, response_stream_reader = anyio.create_memory_object_stream[ JSONRPCResponse | JSONRPCError @@ -297,9 +299,11 @@ async def _receive_loop(self) -> None: responder = RequestResponder( request_id=message.root.id, - request_meta=validated_request.root.params.meta - if validated_request.root.params - else None, + request_meta=( + validated_request.root.params.meta + if validated_request.root.params + else None + ), request=validated_request, session=self, on_complete=lambda r: self._in_flight.pop(r.request_id, None), diff --git a/src/mcp/types.py b/src/mcp/types.py index 7d867bd3..5cd01b83 100644 --- a/src/mcp/types.py +++ b/src/mcp/types.py @@ -34,7 +34,9 @@ ProgressToken = str | int Cursor = str Role = Literal["user", "assistant"] -RequestId = str | int +CustomRequestId = str +AutomaticRequestId = int +RequestId = CustomRequestId | AutomaticRequestId AnyFunction: TypeAlias = Callable[..., Any] @@ -51,6 +53,7 @@ class Meta(BaseModel): model_config = ConfigDict(extra="allow") meta: Meta | None = Field(alias="_meta", default=None) + request_id: CustomRequestId | None = Field(default=None) class NotificationParams(BaseModel):