Skip to content

Commit

Permalink
Merge pull request #16 from modelcontextprotocol/davidsp/init-options
Browse files Browse the repository at this point in the history
Introduce Initialization options that are passed to ServerSession
  • Loading branch information
dsp-ant authored Oct 11, 2024
2 parents 047b5d8 + cc342a0 commit 9475815
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 14 deletions.
32 changes: 30 additions & 2 deletions mcp_python/server/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
ReadResourceResult,
Resource,
ResourceReference,
ServerCapabilities,
ServerResult,
SetLevelRequest,
SubscribeRequest,
Expand All @@ -40,7 +41,6 @@

logger = logging.getLogger(__name__)


request_ctx: contextvars.ContextVar[RequestContext] = contextvars.ContextVar(
"request_ctx"
)
Expand All @@ -53,6 +53,33 @@ def __init__(self, name: str):
self.notification_handlers: dict[type, Callable[..., Awaitable[None]]] = {}
logger.info(f"Initializing server '{name}'")

def create_initialization_options(self) -> types.InitializationOptions:
"""Create initialization options from this server instance."""
def pkg_version(package: str) -> str:
try:
from importlib.metadata import version
return version(package)
except Exception:
return "unknown"

return types.InitializationOptions(
server_name=self.name,
server_version=pkg_version("mcp_python"),
capabilities=self.get_capabilities(),
)

def get_capabilities(self) -> ServerCapabilities:
"""Convert existing handlers to a ServerCapabilities object."""
def get_capability(req_type: type) -> dict[str, Any] | None:
return {} if req_type in self.request_handlers else None

return ServerCapabilities(
prompts=get_capability(ListPromptsRequest),
resources=get_capability(ListResourcesRequest),
tools=get_capability(ListPromptsRequest),
logging=get_capability(SetLevelRequest)
)

@property
def request_context(self) -> RequestContext:
"""If called outside of a request context, this will raise a LookupError."""
Expand Down Expand Up @@ -280,9 +307,10 @@ async def run(
self,
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception],
write_stream: MemoryObjectSendStream[JSONRPCMessage],
initialization_options: types.InitializationOptions
):
with warnings.catch_warnings(record=True) as w:
async with ServerSession(read_stream, write_stream) as session:
async with ServerSession(read_stream, write_stream, initialization_options) as session:
async for message in session.incoming_messages:
logger.debug(f"Received message: {message}")

Expand Down
7 changes: 5 additions & 2 deletions mcp_python/server/__main__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import logging
import sys

import importlib.metadata
import anyio

from mcp_python.server.session import ServerSession
from mcp_python.server.types import InitializationOptions
from mcp_python.server.stdio import stdio_server
from mcp_python.types import ServerCapabilities

if not sys.warnoptions:
import warnings
Expand All @@ -26,8 +28,9 @@ async def receive_loop(session: ServerSession):


async def main():
version = importlib.metadata.version("mcp_python")
async with stdio_server() as (read_stream, write_stream):
async with ServerSession(read_stream, write_stream) as session, write_stream:
async with ServerSession(read_stream, write_stream, InitializationOptions(server_name="mcp_python", server_version=version, capabilities=ServerCapabilities())) as session, write_stream:
await receive_loop(session)


Expand Down
14 changes: 6 additions & 8 deletions mcp_python/server/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
BaseSession,
RequestResponder,
)
from mcp_python.server.types import InitializationOptions
from mcp_python.shared.version import SUPPORTED_PROTOCOL_VERSION
from mcp_python.types import (
ClientNotification,
Expand Down Expand Up @@ -52,9 +53,11 @@ def __init__(
self,
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception],
write_stream: MemoryObjectSendStream[JSONRPCMessage],
init_options: InitializationOptions
) -> None:
super().__init__(read_stream, write_stream, ClientRequest, ClientNotification)
self._initialization_state = InitializationState.NotInitialized
self._init_options = init_options

async def _received_request(
self, responder: RequestResponder[ClientRequest, ServerResult]
Expand All @@ -66,15 +69,10 @@ async def _received_request(
ServerResult(
InitializeResult(
protocolVersion=SUPPORTED_PROTOCOL_VERSION,
capabilities=ServerCapabilities(
logging=None,
resources=None,
tools=None,
experimental=None,
prompts={},
),
capabilities=self._init_options.capabilities,
serverInfo=Implementation(
name="mcp_python", version="0.1.0"
name=self._init_options.server_name,
version=self._init_options.server_version
),
)
)
Expand Down
9 changes: 8 additions & 1 deletion mcp_python/server/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from dataclasses import dataclass
from typing import Literal

from mcp_python.types import Role
from pydantic import BaseModel
from mcp_python.types import Role, ServerCapabilities


@dataclass
Expand All @@ -25,3 +26,9 @@ class Message:
class PromptResponse:
messages: list[Message]
desc: str | None = None


class InitializationOptions(BaseModel):
server_name: str
server_version: str
capabilities: ServerCapabilities
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,8 @@ target-version = "py38"

[tool.ruff.lint.per-file-ignores]
"__init__.py" = ["F401"]

[tool.uv]
dev-dependencies = [
"trio>=0.26.2",
]
33 changes: 32 additions & 1 deletion tests/server/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@
import pytest

from mcp_python.client.session import ClientSession
from mcp_python.server import Server
from mcp_python.server.session import ServerSession
from mcp_python.server.types import InitializationOptions
from mcp_python.types import (
ClientNotification,
InitializedNotification,
JSONRPCMessage,
ServerCapabilities,
)


Expand All @@ -30,7 +33,7 @@ async def run_server():
nonlocal received_initialized

async with ServerSession(
client_to_server_receive, server_to_client_send
client_to_server_receive, server_to_client_send, InitializationOptions(server_name='mcp_python', server_version='0.1.0', capabilities=ServerCapabilities())
) as server_session:
async for message in server_session.incoming_messages:
if isinstance(message, Exception):
Expand All @@ -57,3 +60,31 @@ async def run_server():
pass

assert received_initialized


@pytest.mark.anyio
async def test_server_capabilities():
server = Server("test")

# Initially no capabilities
caps = server.get_capabilities()
assert caps.prompts is None
assert caps.resources is None

# Add a prompts handler
@server.list_prompts()
async def list_prompts():
return []

caps = server.get_capabilities()
assert caps.prompts == {}
assert caps.resources is None

# Add a resources handler
@server.list_resources()
async def list_resources():
return []

caps = server.get_capabilities()
assert caps.prompts == {}
assert caps.resources == {}

0 comments on commit 9475815

Please sign in to comment.