Skip to content

Propagate Custom Request ID to use in cancel notification for call tool and get prompt #231

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions src/mcp/client/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,14 +219,19 @@ async def unsubscribe_resource(self, uri: AnyUrl) -> types.EmptyResult:
)

async def call_tool(
self, name: str, arguments: dict | None = None
self,
name: str,
arguments: dict | None = None,
request_id: types.CustomRequestId | None = None,
) -> types.CallToolResult:
"""Send a tools/call request."""
return await self.send_request(
types.ClientRequest(
types.CallToolRequest(
method="tools/call",
params=types.CallToolRequestParams(name=name, arguments=arguments),
params=types.CallToolRequestParams(
name=name, arguments=arguments, request_id=request_id
),
)
),
types.CallToolResult,
Expand All @@ -244,14 +249,19 @@ async def list_prompts(self) -> types.ListPromptsResult:
)

async def get_prompt(
self, name: str, arguments: dict[str, str] | None = None
self,
name: str,
arguments: dict[str, str] | None = None,
request_id: types.CustomRequestId | None = None,
) -> types.GetPromptResult:
"""Send a prompts/get request."""
return await self.send_request(
types.ClientRequest(
types.GetPromptRequest(
method="prompts/get",
params=types.GetPromptRequestParams(name=name, arguments=arguments),
params=types.GetPromptRequestParams(
name=name, arguments=arguments, request_id=request_id
),
)
),
types.GetPromptResult,
Expand Down
20 changes: 12 additions & 8 deletions src/mcp/shared/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
JSONRPCNotification,
JSONRPCRequest,
JSONRPCResponse,
RequestId,
RequestParams,
ServerNotification,
ServerRequest,
Expand All @@ -36,8 +37,6 @@
"ReceiveNotificationT", ClientNotification, ServerNotification
)

RequestId = str | int


class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
"""Handles responding to MCP requests and manages request lifecycle.
Expand Down Expand Up @@ -208,9 +207,12 @@ async def send_request(
Do not use this method to emit notifications! Use send_notification()
instead.
"""

request_id = self._request_id
self._request_id = request_id + 1
params = request.root.params
if params is not None and params.request_id is not None:
request_id = params.request_id
else:
request_id = self._request_id
self._request_id = request_id + 1

response_stream, response_stream_reader = anyio.create_memory_object_stream[
JSONRPCResponse | JSONRPCError
Expand Down Expand Up @@ -297,9 +299,11 @@ async def _receive_loop(self) -> None:

responder = RequestResponder(
request_id=message.root.id,
request_meta=validated_request.root.params.meta
if validated_request.root.params
else None,
request_meta=(
validated_request.root.params.meta
if validated_request.root.params
else None
),
request=validated_request,
session=self,
on_complete=lambda r: self._in_flight.pop(r.request_id, None),
Expand Down
5 changes: 4 additions & 1 deletion src/mcp/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@
ProgressToken = str | int
Cursor = str
Role = Literal["user", "assistant"]
RequestId = str | int
CustomRequestId = str
AutomaticRequestId = int
RequestId = CustomRequestId | AutomaticRequestId
AnyFunction: TypeAlias = Callable[..., Any]


Expand All @@ -51,6 +53,7 @@ class Meta(BaseModel):
model_config = ConfigDict(extra="allow")

meta: Meta | None = Field(alias="_meta", default=None)
request_id: CustomRequestId | None = Field(default=None)


class NotificationParams(BaseModel):
Expand Down