Skip to content

Add support for Elicitation #625

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Jun 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 47 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
- [Images](#images)
- [Context](#context)
- [Completions](#completions)
- [Elicitation](#elicitation)
- [Authentication](#authentication)
- [Running Your Server](#running-your-server)
- [Development Mode](#development-mode)
- [Claude Desktop Integration](#claude-desktop-integration)
Expand Down Expand Up @@ -74,7 +76,7 @@ The Model Context Protocol allows applications to provide context for LLMs in a

### Adding MCP to your python project

We recommend using [uv](https://docs.astral.sh/uv/) to manage your Python projects.
We recommend using [uv](https://docs.astral.sh/uv/) to manage your Python projects.

If you haven't created a uv-managed project yet, create one:

Expand Down Expand Up @@ -372,6 +374,50 @@ async def handle_completion(
return Completion(values=filtered)
return None
```
### Elicitation

Request additional information from users during tool execution:

```python
from mcp.server.fastmcp import FastMCP, Context
from mcp.server.elicitation import (
AcceptedElicitation,
DeclinedElicitation,
CancelledElicitation,
)
from pydantic import BaseModel, Field

mcp = FastMCP("Booking System")


@mcp.tool()
async def book_table(date: str, party_size: int, ctx: Context) -> str:
"""Book a table with confirmation"""

# Schema must only contain primitive types (str, int, float, bool)
class ConfirmBooking(BaseModel):
confirm: bool = Field(description="Confirm booking?")
notes: str = Field(default="", description="Special requests")

result = await ctx.elicit(
message=f"Confirm booking for {party_size} on {date}?", schema=ConfirmBooking
)
Comment on lines +402 to +404
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it gives us a better API and user experience if ctx.elicit(schema=SchemaT) always return an instance of SchemaT, or an exception.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My comment implies that an exception would be raised if user rejects.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

then the problem will be how to distinguish between cancel and reject?

Copy link
Member

@Kludex Kludex Jun 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
result = await ctx.elicit(
message=f"Confirm booking for {party_size} on {date}?", schema=ConfirmBooking
)
try:
result = await ctx.elicit(
message=f"Confirm booking for {party_size} on {date}?",
schema=ConfirmBooking
)
except ElicitError as exc:
print(exc.reason)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't this imply that we use exceptions for control flow? We should reserve exceptions for handling exceptional circumstances. I don't think decline fits into this.

I do prefer having a return value that indicates if it was accepted,declined,etc.


match result:
case AcceptedElicitation(data=data):
if data.confirm:
return f"Booked! Notes: {data.notes or 'None'}"
return "Booking cancelled"
case DeclinedElicitation():
return "Booking declined"
case CancelledElicitation():
return "Booking cancelled"
```

The `elicit()` method returns an `ElicitationResult` with:
- `action`: "accept", "decline", or "cancel"
- `data`: The validated response (only when accepted)
- `validation_error`: Any validation error message
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should just bubble.


### Authentication

Expand Down
30 changes: 30 additions & 0 deletions src/mcp/client/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,14 @@ async def __call__(
) -> types.CreateMessageResult | types.ErrorData: ...


class ElicitationFnT(Protocol):
async def __call__(
self,
context: RequestContext["ClientSession", Any],
params: types.ElicitRequestParams,
) -> types.ElicitResult | types.ErrorData: ...


class ListRootsFnT(Protocol):
async def __call__(
self, context: RequestContext["ClientSession", Any]
Expand Down Expand Up @@ -58,6 +66,16 @@ async def _default_sampling_callback(
)


async def _default_elicitation_callback(
context: RequestContext["ClientSession", Any],
params: types.ElicitRequestParams,
) -> types.ElicitResult | types.ErrorData:
return types.ErrorData(
code=types.INVALID_REQUEST,
message="Elicitation not supported",
)


async def _default_list_roots_callback(
context: RequestContext["ClientSession", Any],
) -> types.ListRootsResult | types.ErrorData:
Expand Down Expand Up @@ -91,6 +109,7 @@ def __init__(
write_stream: MemoryObjectSendStream[SessionMessage],
read_timeout_seconds: timedelta | None = None,
sampling_callback: SamplingFnT | None = None,
elicitation_callback: ElicitationFnT | None = None,
list_roots_callback: ListRootsFnT | None = None,
logging_callback: LoggingFnT | None = None,
message_handler: MessageHandlerFnT | None = None,
Expand All @@ -105,12 +124,16 @@ def __init__(
)
self._client_info = client_info or DEFAULT_CLIENT_INFO
self._sampling_callback = sampling_callback or _default_sampling_callback
self._elicitation_callback = elicitation_callback or _default_elicitation_callback
self._list_roots_callback = list_roots_callback or _default_list_roots_callback
self._logging_callback = logging_callback or _default_logging_callback
self._message_handler = message_handler or _default_message_handler

async def initialize(self) -> types.InitializeResult:
sampling = types.SamplingCapability() if self._sampling_callback is not _default_sampling_callback else None
elicitation = (
types.ElicitationCapability() if self._elicitation_callback is not _default_elicitation_callback else None
)
roots = (
# TODO: Should this be based on whether we
# _will_ send notifications, or only whether
Expand All @@ -128,6 +151,7 @@ async def initialize(self) -> types.InitializeResult:
protocolVersion=types.LATEST_PROTOCOL_VERSION,
capabilities=types.ClientCapabilities(
sampling=sampling,
elicitation=elicitation,
experimental=None,
roots=roots,
),
Expand Down Expand Up @@ -362,6 +386,12 @@ async def _received_request(self, responder: RequestResponder[types.ServerReques
client_response = ClientResponse.validate_python(response)
await responder.respond(client_response)

case types.ElicitRequest(params=params):
with responder:
response = await self._elicitation_callback(ctx, params)
client_response = ClientResponse.validate_python(response)
await responder.respond(client_response)

case types.ListRootsRequest():
with responder:
response = await self._list_roots_callback(ctx)
Expand Down
111 changes: 111 additions & 0 deletions src/mcp/server/elicitation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
"""Elicitation utilities for MCP servers."""

from __future__ import annotations

import types
from typing import Generic, Literal, TypeVar, Union, get_args, get_origin

from pydantic import BaseModel
from pydantic.fields import FieldInfo

from mcp.server.session import ServerSession
from mcp.types import RequestId

ElicitSchemaModelT = TypeVar("ElicitSchemaModelT", bound=BaseModel)


class AcceptedElicitation(BaseModel, Generic[ElicitSchemaModelT]):
"""Result when user accepts the elicitation."""

action: Literal["accept"] = "accept"
data: ElicitSchemaModelT


class DeclinedElicitation(BaseModel):
"""Result when user declines the elicitation."""

action: Literal["decline"] = "decline"


class CancelledElicitation(BaseModel):
"""Result when user cancels the elicitation."""

action: Literal["cancel"] = "cancel"


ElicitationResult = AcceptedElicitation[ElicitSchemaModelT] | DeclinedElicitation | CancelledElicitation


# Primitive types allowed in elicitation schemas
_ELICITATION_PRIMITIVE_TYPES = (str, int, float, bool)


def _validate_elicitation_schema(schema: type[BaseModel]) -> None:
"""Validate that a Pydantic model only contains primitive field types."""
for field_name, field_info in schema.model_fields.items():
if not _is_primitive_field(field_info):
raise TypeError(
f"Elicitation schema field '{field_name}' must be a primitive type "
f"{_ELICITATION_PRIMITIVE_TYPES} or Optional of these types. "
f"Complex types like lists, dicts, or nested models are not allowed."
)


def _is_primitive_field(field_info: FieldInfo) -> bool:
"""Check if a field is a primitive type allowed in elicitation schemas."""
annotation = field_info.annotation

# Handle None type
if annotation is types.NoneType:
return True

# Handle basic primitive types
if annotation in _ELICITATION_PRIMITIVE_TYPES:
return True

# Handle Union types
origin = get_origin(annotation)
if origin is Union or origin is types.UnionType:
args = get_args(annotation)
# All args must be primitive types or None
return all(arg is types.NoneType or arg in _ELICITATION_PRIMITIVE_TYPES for arg in args)

return False


async def elicit_with_validation(
session: ServerSession,
message: str,
schema: type[ElicitSchemaModelT],
related_request_id: RequestId | None = None,
) -> ElicitationResult[ElicitSchemaModelT]:
"""Elicit information from the client/user with schema validation.

This method can be used to interactively ask for additional information from the
client within a tool's execution. The client might display the message to the
user and collect a response according to the provided schema. Or in case a
client is an agent, it might decide how to handle the elicitation -- either by asking
the user or automatically generating a response.
"""
# Validate that schema only contains primitive types and fail loudly if not
_validate_elicitation_schema(schema)

json_schema = schema.model_json_schema()

result = await session.elicit(
message=message,
requestedSchema=json_schema,
related_request_id=related_request_id,
)

if result.action == "accept" and result.content:
# Validate and parse the content using the schema
validated_data = schema.model_validate(result.content)
return AcceptedElicitation(data=validated_data)
elif result.action == "decline":
return DeclinedElicitation()
elif result.action == "cancel":
return CancelledElicitation()
else:
# This should never happen, but handle it just in case
raise ValueError(f"Unexpected elicitation action: {result.action}")
32 changes: 32 additions & 0 deletions src/mcp/server/fastmcp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from mcp.server.auth.settings import (
AuthSettings,
)
from mcp.server.elicitation import ElicitationResult, ElicitSchemaModelT, elicit_with_validation
from mcp.server.fastmcp.exceptions import ResourceError
from mcp.server.fastmcp.prompts import Prompt, PromptManager
from mcp.server.fastmcp.resources import FunctionResource, Resource, ResourceManager
Expand Down Expand Up @@ -972,6 +973,37 @@ async def read_resource(self, uri: str | AnyUrl) -> Iterable[ReadResourceContent
assert self._fastmcp is not None, "Context is not available outside of a request"
return await self._fastmcp.read_resource(uri)

async def elicit(
self,
message: str,
schema: type[ElicitSchemaModelT],
) -> ElicitationResult[ElicitSchemaModelT]:
"""Elicit information from the client/user.

This method can be used to interactively ask for additional information from the
client within a tool's execution. The client might display the message to the
user and collect a response according to the provided schema. Or in case a
client is an agent, it might decide how to handle the elicitation -- either by asking
the user or automatically generating a response.

Args:
schema: A Pydantic model class defining the expected response structure, according to the specification,
only primive types are allowed.
message: Optional message to present to the user. If not provided, will use
a default message based on the schema

Returns:
An ElicitationResult containing the action taken and the data if accepted

Note:
Check the result.action to determine if the user accepted, declined, or cancelled.
The result.data will only be populated if action is "accept" and validation succeeded.
"""

return await elicit_with_validation(
session=self.request_context.session, message=message, schema=schema, related_request_id=self.request_id
)

async def log(
self,
level: Literal["debug", "info", "warning", "error"],
Expand Down
33 changes: 33 additions & 0 deletions src/mcp/server/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,10 @@ def check_client_capability(self, capability: types.ClientCapabilities) -> bool:
if client_caps.sampling is None:
return False

if capability.elicitation is not None:
if client_caps.elicitation is None:
return False

if capability.experimental is not None:
if client_caps.experimental is None:
return False
Expand Down Expand Up @@ -251,6 +255,35 @@ async def list_roots(self) -> types.ListRootsResult:
types.ListRootsResult,
)

async def elicit(
self,
message: str,
requestedSchema: types.ElicitRequestedSchema,
related_request_id: types.RequestId | None = None,
) -> types.ElicitResult:
"""Send an elicitation/create request.

Args:
message: The message to present to the user
requestedSchema: Schema defining the expected response structure

Returns:
The client's response
"""
return await self.send_request(
types.ServerRequest(
types.ElicitRequest(
method="elicitation/create",
params=types.ElicitRequestParams(
message=message,
requestedSchema=requestedSchema,
),
)
),
types.ElicitResult,
metadata=ServerMessageMetadata(related_request_id=related_request_id),
)

async def send_ping(self) -> types.EmptyResult:
"""Send a ping request."""
return await self.send_request(
Expand Down
11 changes: 10 additions & 1 deletion src/mcp/shared/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,14 @@
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream

import mcp.types as types
from mcp.client.session import ClientSession, ListRootsFnT, LoggingFnT, MessageHandlerFnT, SamplingFnT
from mcp.client.session import (
ClientSession,
ElicitationFnT,
ListRootsFnT,
LoggingFnT,
MessageHandlerFnT,
SamplingFnT,
)
from mcp.server import Server
from mcp.shared.message import SessionMessage

Expand Down Expand Up @@ -53,6 +60,7 @@ async def create_connected_server_and_client_session(
message_handler: MessageHandlerFnT | None = None,
client_info: types.Implementation | None = None,
raise_exceptions: bool = False,
elicitation_callback: ElicitationFnT | None = None,
) -> AsyncGenerator[ClientSession, None]:
"""Creates a ClientSession that is connected to a running MCP server."""
async with create_client_server_memory_streams() as (
Expand Down Expand Up @@ -83,6 +91,7 @@ async def create_connected_server_and_client_session(
logging_callback=logging_callback,
message_handler=message_handler,
client_info=client_info,
elicitation_callback=elicitation_callback,
) as client_session:
await client_session.initialize()
yield client_session
Expand Down
Loading
Loading