From 3599de86273d0df2d19b03e1e25ad1c033af5bee Mon Sep 17 00:00:00 2001 From: ihrpr Date: Sun, 4 May 2025 13:34:08 +0100 Subject: [PATCH 1/3] elicitation --- src/mcp/client/session.py | 30 +++++++++++++ src/mcp/server/fastmcp/server.py | 33 +++++++++++++++ src/mcp/server/session.py | 35 ++++++++++++++- src/mcp/types.py | 45 ++++++++++++++++++-- tests/server/fastmcp/test_integration.py | 54 +++++++++++++++++++++++- 5 files changed, 191 insertions(+), 6 deletions(-) diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 7bb8821f..fdd25958 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -22,6 +22,14 @@ async def __call__( ) -> types.CreateMessageResult | types.ErrorData: ... +class ElicitationFnT(Protocol): + async def __call__( + self, + context: RequestContext["ClientSession", Any], + params: types.ElicitRequestParams, + ) -> types.ElicitResult | types.ErrorData: ... + + class ListRootsFnT(Protocol): async def __call__( self, context: RequestContext["ClientSession", Any] @@ -62,6 +70,16 @@ async def _default_sampling_callback( ) +async def _default_elicitation_callback( + context: RequestContext["ClientSession", Any], + params: types.ElicitRequestParams, +) -> types.ElicitResult | types.ErrorData: + return types.ErrorData( + code=types.INVALID_REQUEST, + message="Elicitation not supported", + ) + + async def _default_list_roots_callback( context: RequestContext["ClientSession", Any], ) -> types.ListRootsResult | types.ErrorData: @@ -97,6 +115,7 @@ def __init__( write_stream: MemoryObjectSendStream[SessionMessage], read_timeout_seconds: timedelta | None = None, sampling_callback: SamplingFnT | None = None, + elicitation_callback: ElicitationFnT | None = None, list_roots_callback: ListRootsFnT | None = None, logging_callback: LoggingFnT | None = None, message_handler: MessageHandlerFnT | None = None, @@ -111,12 +130,16 @@ def __init__( ) self._client_info = client_info or DEFAULT_CLIENT_INFO self._sampling_callback = sampling_callback or _default_sampling_callback + self._elicitation_callback = ( + elicitation_callback or _default_elicitation_callback + ) self._list_roots_callback = list_roots_callback or _default_list_roots_callback self._logging_callback = logging_callback or _default_logging_callback self._message_handler = message_handler or _default_message_handler async def initialize(self) -> types.InitializeResult: sampling = types.SamplingCapability() + elicitation = types.ElicitationCapability() roots = types.RootsCapability( # TODO: Should this be based on whether we # _will_ send notifications, or only whether @@ -132,6 +155,7 @@ async def initialize(self) -> types.InitializeResult: protocolVersion=types.LATEST_PROTOCOL_VERSION, capabilities=types.ClientCapabilities( sampling=sampling, + elicitation=elicitation, experimental=None, roots=roots, ), @@ -355,6 +379,12 @@ async def _received_request( client_response = ClientResponse.validate_python(response) await responder.respond(client_response) + case types.ElicitRequest(params=params): + with responder: + response = await self._elicitation_callback(ctx, params) + client_response = ClientResponse.validate_python(response) + await responder.respond(client_response) + case types.ListRootsRequest(): with responder: response = await self._list_roots_callback(ctx) diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index 0e0b565c..bb8712fa 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -822,6 +822,39 @@ async def read_resource(self, uri: str | AnyUrl) -> Iterable[ReadResourceContent ), "Context is not available outside of a request" return await self._fastmcp.read_resource(uri) + async def elicit( + self, + message: str, + requestedSchema: dict[str, Any], + ) -> dict[str, Any]: + """Elicit information from the client/user. + + This method can be used to interactively ask for additional information from the + client within a tool's execution. + The client might display the message to the user and collect a response + according to the provided schema. Or in case a client is an agent, it might + decide how to handle the elicitation -- either by asking the user or + automatically generating a response. + + Args: + message: The message to present to the user + requestedSchema: JSON Schema defining the expected response structure + + Returns: + The user's response as a dict matching the request schema structure + + Raises: + ValueError: If elicitation is not supported by the client or fails + """ + + result = await self.request_context.session.elicit( + message=message, + requestedSchema=requestedSchema, + related_request_id=self.request_id, + ) + + return result.response + async def log( self, level: Literal["debug", "info", "warning", "error"], diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index c769d1aa..648d5ff1 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -47,7 +47,7 @@ async def handle_list_prompts(ctx: RequestContext) -> list[types.Prompt]: import mcp.types as types from mcp.server.models import InitializationOptions -from mcp.shared.message import SessionMessage +from mcp.shared.message import ServerMessageMetadata, SessionMessage from mcp.shared.session import ( BaseSession, RequestResponder, @@ -128,6 +128,10 @@ def check_client_capability(self, capability: types.ClientCapabilities) -> bool: if client_caps.sampling is None: return False + if capability.elicitation is not None: + if client_caps.elicitation is None: + return False + if capability.experimental is not None: if client_caps.experimental is None: return False @@ -262,6 +266,35 @@ async def list_roots(self) -> types.ListRootsResult: types.ListRootsResult, ) + async def elicit( + self, + message: str, + requestedSchema: dict[str, Any], + related_request_id: types.RequestId | None = None, + ) -> types.ElicitResult: + """Send an elicitation/create request. + + Args: + message: The message to present to the user + requestedSchema: JSON Schema defining the expected response structure + + Returns: + The client's response + """ + return await self.send_request( + types.ServerRequest( + types.ElicitRequest( + method="elicitation/create", + params=types.ElicitRequestParams( + message=message, + requestedSchema=requestedSchema, + ), + ) + ), + types.ElicitResult, + metadata=ServerMessageMetadata(related_request_id=related_request_id), + ) + async def send_ping(self) -> types.EmptyResult: """Send a ping request.""" return await self.send_request( diff --git a/src/mcp/types.py b/src/mcp/types.py index 6ab7fba5..3ef13cfe 100644 --- a/src/mcp/types.py +++ b/src/mcp/types.py @@ -205,7 +205,13 @@ class RootsCapability(BaseModel): class SamplingCapability(BaseModel): - """Capability for logging operations.""" + """Capability for sampling operations.""" + + model_config = ConfigDict(extra="allow") + + +class ElicitationCapability(BaseModel): + """Capability for elicitation operations.""" model_config = ConfigDict(extra="allow") @@ -217,6 +223,8 @@ class ClientCapabilities(BaseModel): """Experimental, non-standard capabilities that the client supports.""" sampling: SamplingCapability | None = None """Present if the client supports sampling from an LLM.""" + elicitation: ElicitationCapability | None = None + """Present if the client supports elicitation from the user.""" roots: RootsCapability | None = None """Present if the client supports listing roots.""" model_config = ConfigDict(extra="allow") @@ -1141,11 +1149,42 @@ class ClientNotification( pass -class ClientResult(RootModel[EmptyResult | CreateMessageResult | ListRootsResult]): +class ElicitRequestParams(RequestParams): + """Parameters for elicitation requests.""" + + message: str + """The message to present to the user.""" + + requestedSchema: dict[str, Any] + """ + A JSON Schema object defining the expected structure of the response. + """ + model_config = ConfigDict(extra="allow") + + +class ElicitRequest(Request[ElicitRequestParams, Literal["elicitation/create"]]): + """A request from the server to elicit information from the client.""" + + method: Literal["elicitation/create"] + params: ElicitRequestParams + + +class ElicitResult(Result): + """The client's response to an elicitation/create request from the server.""" + + response: dict[str, Any] + """The response from the client, matching the structure of requestedSchema.""" + + +class ClientResult( + RootModel[EmptyResult | CreateMessageResult | ListRootsResult | ElicitResult] +): pass -class ServerRequest(RootModel[PingRequest | CreateMessageRequest | ListRootsRequest]): +class ServerRequest( + RootModel[PingRequest | CreateMessageRequest | ListRootsRequest | ElicitRequest] +): pass diff --git a/tests/server/fastmcp/test_integration.py b/tests/server/fastmcp/test_integration.py index 281db2db..ae28d8a6 100644 --- a/tests/server/fastmcp/test_integration.py +++ b/tests/server/fastmcp/test_integration.py @@ -15,8 +15,8 @@ from mcp.client.session import ClientSession from mcp.client.sse import sse_client -from mcp.server.fastmcp import FastMCP -from mcp.types import InitializeResult, TextContent +from mcp.server.fastmcp import Context, FastMCP +from mcp.types import InitializeResult, TextContent, ElicitResult @pytest.fixture @@ -45,6 +45,23 @@ def make_fastmcp_app(): def echo(message: str) -> str: return f"Echo: {message}" + # Add a tool that uses elicitation + @mcp.tool(description="A tool that uses elicitation") + async def ask_user(prompt: str, ctx: Context) -> str: + schema = { + "type": "object", + "properties": { + "answer": {"type": "string"}, + }, + "required": ["answer"], + } + + response = await ctx.elicit( + message=f"Tool wants to ask: {prompt}", + requestedSchema=schema, + ) + return f"User answered: {response['answer']}" + # Create the SSE app app: Starlette = mcp.sse_app() @@ -110,3 +127,36 @@ async def test_fastmcp_without_auth(server: None, server_url: str) -> None: assert len(tool_result.content) == 1 assert isinstance(tool_result.content[0], TextContent) assert tool_result.content[0].text == "Echo: hello" + + +@pytest.mark.anyio +async def test_elicitation_feature(server: None, server_url: str) -> None: + """Test the elicitation feature.""" + + # Create a custom handler for elicitation requests + async def elicitation_callback(context, params): + # Verify the elicitation parameters + if params.message == "Tool wants to ask: What is your name?": + return ElicitResult(response={"answer": "Test User"}) + else: + raise ValueError("Unexpected elicitation message") + + # Connect to the server with our custom elicitation handler + async with sse_client(server_url + "/sse") as streams: + async with ClientSession( + *streams, elicitation_callback=elicitation_callback + ) as session: + # First initialize the session + result = await session.initialize() + assert isinstance(result, InitializeResult) + assert result.serverInfo.name == "NoAuthServer" + + # Call the tool that uses elicitation + tool_result = await session.call_tool( + "ask_user", {"prompt": "What is your name?"} + ) + # Verify the result + assert len(tool_result.content) == 1 + assert isinstance(tool_result.content[0], TextContent) + # # The test should only succeed with the successful elicitation response + assert tool_result.content[0].text == "User answered: Test User" From 032b630961b79c6f4f2aa75095d15c26b4b4ff8d Mon Sep 17 00:00:00 2001 From: ihrpr Date: Sun, 4 May 2025 13:50:05 +0100 Subject: [PATCH 2/3] add elicitation test using create_client_server_memory_streams --- src/mcp/shared/memory.py | 3 + tests/server/fastmcp/test_integration.py | 2 +- .../server/fastmcp/test_stdio_elicitation.py | 59 +++++++++++++++++++ 3 files changed, 63 insertions(+), 1 deletion(-) create mode 100644 tests/server/fastmcp/test_stdio_elicitation.py diff --git a/src/mcp/shared/memory.py b/src/mcp/shared/memory.py index b53f8dd6..615c88b7 100644 --- a/src/mcp/shared/memory.py +++ b/src/mcp/shared/memory.py @@ -13,6 +13,7 @@ import mcp.types as types from mcp.client.session import ( ClientSession, + ElicitationFnT, ListRootsFnT, LoggingFnT, MessageHandlerFnT, @@ -68,6 +69,7 @@ async def create_connected_server_and_client_session( message_handler: MessageHandlerFnT | None = None, client_info: types.Implementation | None = None, raise_exceptions: bool = False, + elicitation_callback: ElicitationFnT | None = None, ) -> AsyncGenerator[ClientSession, None]: """Creates a ClientSession that is connected to a running MCP server.""" async with create_client_server_memory_streams() as ( @@ -98,6 +100,7 @@ async def create_connected_server_and_client_session( logging_callback=logging_callback, message_handler=message_handler, client_info=client_info, + elicitation_callback=elicitation_callback, ) as client_session: await client_session.initialize() yield client_session diff --git a/tests/server/fastmcp/test_integration.py b/tests/server/fastmcp/test_integration.py index ae28d8a6..9fac6226 100644 --- a/tests/server/fastmcp/test_integration.py +++ b/tests/server/fastmcp/test_integration.py @@ -16,7 +16,7 @@ from mcp.client.session import ClientSession from mcp.client.sse import sse_client from mcp.server.fastmcp import Context, FastMCP -from mcp.types import InitializeResult, TextContent, ElicitResult +from mcp.types import ElicitResult, InitializeResult, TextContent @pytest.fixture diff --git a/tests/server/fastmcp/test_stdio_elicitation.py b/tests/server/fastmcp/test_stdio_elicitation.py new file mode 100644 index 00000000..e41f704f --- /dev/null +++ b/tests/server/fastmcp/test_stdio_elicitation.py @@ -0,0 +1,59 @@ +""" +Test the elicitation feature using stdio transport. +""" + +import pytest + +from mcp.server.fastmcp import Context, FastMCP +from mcp.shared.memory import create_connected_server_and_client_session +from mcp.types import ElicitResult, TextContent + + +@pytest.mark.anyio +async def test_stdio_elicitation(): + """Test the elicitation feature using stdio transport.""" + + # Create a FastMCP server with a tool that uses elicitation + mcp = FastMCP(name="StdioElicitationServer") + + @mcp.tool(description="A tool that uses elicitation") + async def ask_user(prompt: str, ctx: Context) -> str: + schema = { + "type": "object", + "properties": { + "answer": {"type": "string"}, + }, + "required": ["answer"], + } + + response = await ctx.elicit( + message=f"Tool wants to ask: {prompt}", + requestedSchema=schema, + ) + return f"User answered: {response['answer']}" + + # Create a custom handler for elicitation requests + async def elicitation_callback(context, params): + # Verify the elicitation parameters + if params.message == "Tool wants to ask: What is your name?": + return ElicitResult(response={"answer": "Test User"}) + else: + raise ValueError(f"Unexpected elicitation message: {params.message}") + + # Use memory-based session to test with stdio transport + async with create_connected_server_and_client_session( + mcp._mcp_server, elicitation_callback=elicitation_callback + ) as client_session: + # First initialize the session + result = await client_session.initialize() + assert result.serverInfo.name == "StdioElicitationServer" + + # Call the tool that uses elicitation + tool_result = await client_session.call_tool( + "ask_user", {"prompt": "What is your name?"} + ) + + # Verify the result + assert len(tool_result.content) == 1 + assert isinstance(tool_result.content[0], TextContent) + assert tool_result.content[0].text == "User answered: Test User" From 547d516a08ec7ae378325f91a8ef69bd33e39577 Mon Sep 17 00:00:00 2001 From: ihrpr Date: Tue, 6 May 2025 11:39:16 +0100 Subject: [PATCH 3/3] field rename --- src/mcp/server/fastmcp/server.py | 2 +- src/mcp/types.py | 2 +- tests/server/fastmcp/test_integration.py | 2 +- tests/server/fastmcp/test_stdio_elicitation.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index bb8712fa..757b0abc 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -853,7 +853,7 @@ async def elicit( related_request_id=self.request_id, ) - return result.response + return result.content async def log( self, diff --git a/src/mcp/types.py b/src/mcp/types.py index 3ef13cfe..cb1267d6 100644 --- a/src/mcp/types.py +++ b/src/mcp/types.py @@ -1172,7 +1172,7 @@ class ElicitRequest(Request[ElicitRequestParams, Literal["elicitation/create"]]) class ElicitResult(Result): """The client's response to an elicitation/create request from the server.""" - response: dict[str, Any] + content: dict[str, Any] """The response from the client, matching the structure of requestedSchema.""" diff --git a/tests/server/fastmcp/test_integration.py b/tests/server/fastmcp/test_integration.py index 9fac6226..c0cc1f83 100644 --- a/tests/server/fastmcp/test_integration.py +++ b/tests/server/fastmcp/test_integration.py @@ -137,7 +137,7 @@ async def test_elicitation_feature(server: None, server_url: str) -> None: async def elicitation_callback(context, params): # Verify the elicitation parameters if params.message == "Tool wants to ask: What is your name?": - return ElicitResult(response={"answer": "Test User"}) + return ElicitResult(content={"answer": "Test User"}) else: raise ValueError("Unexpected elicitation message") diff --git a/tests/server/fastmcp/test_stdio_elicitation.py b/tests/server/fastmcp/test_stdio_elicitation.py index e41f704f..52555a28 100644 --- a/tests/server/fastmcp/test_stdio_elicitation.py +++ b/tests/server/fastmcp/test_stdio_elicitation.py @@ -36,7 +36,7 @@ async def ask_user(prompt: str, ctx: Context) -> str: async def elicitation_callback(context, params): # Verify the elicitation parameters if params.message == "Tool wants to ask: What is your name?": - return ElicitResult(response={"answer": "Test User"}) + return ElicitResult(content={"answer": "Test User"}) else: raise ValueError(f"Unexpected elicitation message: {params.message}")