Skip to content

Elicitation prototype #625

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions src/mcp/client/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,14 @@ async def __call__(
) -> types.CreateMessageResult | types.ErrorData: ...


class ElicitationFnT(Protocol):
async def __call__(
self,
context: RequestContext["ClientSession", Any],
params: types.ElicitRequestParams,
) -> types.ElicitResult | types.ErrorData: ...


class ListRootsFnT(Protocol):
async def __call__(
self, context: RequestContext["ClientSession", Any]
Expand Down Expand Up @@ -62,6 +70,16 @@ async def _default_sampling_callback(
)


async def _default_elicitation_callback(
context: RequestContext["ClientSession", Any],
params: types.ElicitRequestParams,
) -> types.ElicitResult | types.ErrorData:
return types.ErrorData(
code=types.INVALID_REQUEST,
message="Elicitation not supported",
)


async def _default_list_roots_callback(
context: RequestContext["ClientSession", Any],
) -> types.ListRootsResult | types.ErrorData:
Expand Down Expand Up @@ -97,6 +115,7 @@ def __init__(
write_stream: MemoryObjectSendStream[SessionMessage],
read_timeout_seconds: timedelta | None = None,
sampling_callback: SamplingFnT | None = None,
elicitation_callback: ElicitationFnT | None = None,
list_roots_callback: ListRootsFnT | None = None,
logging_callback: LoggingFnT | None = None,
message_handler: MessageHandlerFnT | None = None,
Expand All @@ -111,12 +130,16 @@ def __init__(
)
self._client_info = client_info or DEFAULT_CLIENT_INFO
self._sampling_callback = sampling_callback or _default_sampling_callback
self._elicitation_callback = (
elicitation_callback or _default_elicitation_callback
)
self._list_roots_callback = list_roots_callback or _default_list_roots_callback
self._logging_callback = logging_callback or _default_logging_callback
self._message_handler = message_handler or _default_message_handler

async def initialize(self) -> types.InitializeResult:
sampling = types.SamplingCapability()
elicitation = types.ElicitationCapability()
roots = types.RootsCapability(
# TODO: Should this be based on whether we
# _will_ send notifications, or only whether
Expand All @@ -132,6 +155,7 @@ async def initialize(self) -> types.InitializeResult:
protocolVersion=types.LATEST_PROTOCOL_VERSION,
capabilities=types.ClientCapabilities(
sampling=sampling,
elicitation=elicitation,
experimental=None,
roots=roots,
),
Expand Down Expand Up @@ -355,6 +379,12 @@ async def _received_request(
client_response = ClientResponse.validate_python(response)
await responder.respond(client_response)

case types.ElicitRequest(params=params):
with responder:
response = await self._elicitation_callback(ctx, params)
client_response = ClientResponse.validate_python(response)
await responder.respond(client_response)

case types.ListRootsRequest():
with responder:
response = await self._list_roots_callback(ctx)
Expand Down
33 changes: 33 additions & 0 deletions src/mcp/server/fastmcp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -822,6 +822,39 @@ async def read_resource(self, uri: str | AnyUrl) -> Iterable[ReadResourceContent
), "Context is not available outside of a request"
return await self._fastmcp.read_resource(uri)

async def elicit(
self,
message: str,
requestedSchema: dict[str, Any],
) -> dict[str, Any]:
"""Elicit information from the client/user.

This method can be used to interactively ask for additional information from the
client within a tool's execution.
The client might display the message to the user and collect a response
according to the provided schema. Or in case a client is an agent, it might
decide how to handle the elicitation -- either by asking the user or
automatically generating a response.

Args:
message: The message to present to the user
requestedSchema: JSON Schema defining the expected response structure

Returns:
The user's response as a dict matching the request schema structure

Raises:
ValueError: If elicitation is not supported by the client or fails
"""

result = await self.request_context.session.elicit(
message=message,
requestedSchema=requestedSchema,
related_request_id=self.request_id,
)

return result.content

async def log(
self,
level: Literal["debug", "info", "warning", "error"],
Expand Down
35 changes: 34 additions & 1 deletion src/mcp/server/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ async def handle_list_prompts(ctx: RequestContext) -> list[types.Prompt]:

import mcp.types as types
from mcp.server.models import InitializationOptions
from mcp.shared.message import SessionMessage
from mcp.shared.message import ServerMessageMetadata, SessionMessage
from mcp.shared.session import (
BaseSession,
RequestResponder,
Expand Down Expand Up @@ -128,6 +128,10 @@ def check_client_capability(self, capability: types.ClientCapabilities) -> bool:
if client_caps.sampling is None:
return False

if capability.elicitation is not None:
if client_caps.elicitation is None:
return False

if capability.experimental is not None:
if client_caps.experimental is None:
return False
Expand Down Expand Up @@ -262,6 +266,35 @@ async def list_roots(self) -> types.ListRootsResult:
types.ListRootsResult,
)

async def elicit(
self,
message: str,
requestedSchema: dict[str, Any],
related_request_id: types.RequestId | None = None,
) -> types.ElicitResult:
"""Send an elicitation/create request.

Args:
message: The message to present to the user
requestedSchema: JSON Schema defining the expected response structure

Returns:
The client's response
"""
return await self.send_request(
types.ServerRequest(
types.ElicitRequest(
method="elicitation/create",
params=types.ElicitRequestParams(
message=message,
requestedSchema=requestedSchema,
),
)
),
types.ElicitResult,
metadata=ServerMessageMetadata(related_request_id=related_request_id),
)

async def send_ping(self) -> types.EmptyResult:
"""Send a ping request."""
return await self.send_request(
Expand Down
3 changes: 3 additions & 0 deletions src/mcp/shared/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import mcp.types as types
from mcp.client.session import (
ClientSession,
ElicitationFnT,
ListRootsFnT,
LoggingFnT,
MessageHandlerFnT,
Expand Down Expand Up @@ -68,6 +69,7 @@ async def create_connected_server_and_client_session(
message_handler: MessageHandlerFnT | None = None,
client_info: types.Implementation | None = None,
raise_exceptions: bool = False,
elicitation_callback: ElicitationFnT | None = None,
) -> AsyncGenerator[ClientSession, None]:
"""Creates a ClientSession that is connected to a running MCP server."""
async with create_client_server_memory_streams() as (
Expand Down Expand Up @@ -98,6 +100,7 @@ async def create_connected_server_and_client_session(
logging_callback=logging_callback,
message_handler=message_handler,
client_info=client_info,
elicitation_callback=elicitation_callback,
) as client_session:
await client_session.initialize()
yield client_session
Expand Down
45 changes: 42 additions & 3 deletions src/mcp/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,13 @@ class RootsCapability(BaseModel):


class SamplingCapability(BaseModel):
"""Capability for logging operations."""
"""Capability for sampling operations."""

model_config = ConfigDict(extra="allow")


class ElicitationCapability(BaseModel):
"""Capability for elicitation operations."""

model_config = ConfigDict(extra="allow")

Expand All @@ -217,6 +223,8 @@ class ClientCapabilities(BaseModel):
"""Experimental, non-standard capabilities that the client supports."""
sampling: SamplingCapability | None = None
"""Present if the client supports sampling from an LLM."""
elicitation: ElicitationCapability | None = None
"""Present if the client supports elicitation from the user."""
roots: RootsCapability | None = None
"""Present if the client supports listing roots."""
model_config = ConfigDict(extra="allow")
Expand Down Expand Up @@ -1141,11 +1149,42 @@ class ClientNotification(
pass


class ClientResult(RootModel[EmptyResult | CreateMessageResult | ListRootsResult]):
class ElicitRequestParams(RequestParams):
"""Parameters for elicitation requests."""

message: str
"""The message to present to the user."""

requestedSchema: dict[str, Any]
"""
A JSON Schema object defining the expected structure of the response.
"""
model_config = ConfigDict(extra="allow")


class ElicitRequest(Request[ElicitRequestParams, Literal["elicitation/create"]]):
"""A request from the server to elicit information from the client."""

method: Literal["elicitation/create"]
params: ElicitRequestParams


class ElicitResult(Result):
"""The client's response to an elicitation/create request from the server."""

content: dict[str, Any]
"""The response from the client, matching the structure of requestedSchema."""


class ClientResult(
RootModel[EmptyResult | CreateMessageResult | ListRootsResult | ElicitResult]
):
pass


class ServerRequest(RootModel[PingRequest | CreateMessageRequest | ListRootsRequest]):
class ServerRequest(
RootModel[PingRequest | CreateMessageRequest | ListRootsRequest | ElicitRequest]
):
pass


Expand Down
54 changes: 52 additions & 2 deletions tests/server/fastmcp/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@

from mcp.client.session import ClientSession
from mcp.client.sse import sse_client
from mcp.server.fastmcp import FastMCP
from mcp.types import InitializeResult, TextContent
from mcp.server.fastmcp import Context, FastMCP
from mcp.types import ElicitResult, InitializeResult, TextContent


@pytest.fixture
Expand Down Expand Up @@ -45,6 +45,23 @@ def make_fastmcp_app():
def echo(message: str) -> str:
return f"Echo: {message}"

# Add a tool that uses elicitation
@mcp.tool(description="A tool that uses elicitation")
async def ask_user(prompt: str, ctx: Context) -> str:
schema = {
"type": "object",
"properties": {
"answer": {"type": "string"},
},
"required": ["answer"],
}

response = await ctx.elicit(
message=f"Tool wants to ask: {prompt}",
requestedSchema=schema,
)
return f"User answered: {response['answer']}"

# Create the SSE app
app: Starlette = mcp.sse_app()

Expand Down Expand Up @@ -110,3 +127,36 @@ async def test_fastmcp_without_auth(server: None, server_url: str) -> None:
assert len(tool_result.content) == 1
assert isinstance(tool_result.content[0], TextContent)
assert tool_result.content[0].text == "Echo: hello"


@pytest.mark.anyio
async def test_elicitation_feature(server: None, server_url: str) -> None:
"""Test the elicitation feature."""

# Create a custom handler for elicitation requests
async def elicitation_callback(context, params):
# Verify the elicitation parameters
if params.message == "Tool wants to ask: What is your name?":
return ElicitResult(content={"answer": "Test User"})
else:
raise ValueError("Unexpected elicitation message")

# Connect to the server with our custom elicitation handler
async with sse_client(server_url + "/sse") as streams:
async with ClientSession(
*streams, elicitation_callback=elicitation_callback
) as session:
# First initialize the session
result = await session.initialize()
assert isinstance(result, InitializeResult)
assert result.serverInfo.name == "NoAuthServer"

# Call the tool that uses elicitation
tool_result = await session.call_tool(
"ask_user", {"prompt": "What is your name?"}
)
# Verify the result
assert len(tool_result.content) == 1
assert isinstance(tool_result.content[0], TextContent)
# # The test should only succeed with the successful elicitation response
assert tool_result.content[0].text == "User answered: Test User"
59 changes: 59 additions & 0 deletions tests/server/fastmcp/test_stdio_elicitation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
"""
Test the elicitation feature using stdio transport.
"""

import pytest

from mcp.server.fastmcp import Context, FastMCP
from mcp.shared.memory import create_connected_server_and_client_session
from mcp.types import ElicitResult, TextContent


@pytest.mark.anyio
async def test_stdio_elicitation():
"""Test the elicitation feature using stdio transport."""

# Create a FastMCP server with a tool that uses elicitation
mcp = FastMCP(name="StdioElicitationServer")

@mcp.tool(description="A tool that uses elicitation")
async def ask_user(prompt: str, ctx: Context) -> str:
schema = {
"type": "object",
"properties": {
"answer": {"type": "string"},
},
"required": ["answer"],
}

response = await ctx.elicit(
message=f"Tool wants to ask: {prompt}",
requestedSchema=schema,
)
return f"User answered: {response['answer']}"

# Create a custom handler for elicitation requests
async def elicitation_callback(context, params):
# Verify the elicitation parameters
if params.message == "Tool wants to ask: What is your name?":
return ElicitResult(content={"answer": "Test User"})
else:
raise ValueError(f"Unexpected elicitation message: {params.message}")

# Use memory-based session to test with stdio transport
async with create_connected_server_and_client_session(
mcp._mcp_server, elicitation_callback=elicitation_callback
) as client_session:
# First initialize the session
result = await client_session.initialize()
assert result.serverInfo.name == "StdioElicitationServer"

# Call the tool that uses elicitation
tool_result = await client_session.call_tool(
"ask_user", {"prompt": "What is your name?"}
)

# Verify the result
assert len(tool_result.content) == 1
assert isinstance(tool_result.content[0], TextContent)
assert tool_result.content[0].text == "User answered: Test User"
Loading