diff --git a/pyproject.toml b/pyproject.toml index 8f3d395..df76136 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,3 +34,7 @@ build-backend = "hatchling.build" [tool.pytest.ini_options] asyncio_default_fixture_loop_scope = "session" + +[tool.ruff] +line-length = 88 +target-version = "py310" diff --git a/src/mcp_grafana/__init__.py b/src/mcp_grafana/__init__.py index 761713f..fc29c25 100644 --- a/src/mcp_grafana/__init__.py +++ b/src/mcp_grafana/__init__.py @@ -1,7 +1,70 @@ +import enum +from typing import Literal + +import anyio +import uvicorn from mcp.server import FastMCP from .tools import add_tools + +class Transport(enum.StrEnum): + http = "http" + stdio = "stdio" + sse = "sse" + + +class GrafanaMCP(FastMCP): + async def run_http_async(self) -> None: + from starlette.applications import Starlette + from starlette.routing import Mount + + from .transports.http import handle_message + + async def handle_http(scope, receive, send): + if scope["type"] != "http": + raise ValueError("Expected HTTP request") + async with handle_message(scope, receive, send) as ( + read_stream, + write_stream, + ): + await self._mcp_server.run( + read_stream, + write_stream, + self._mcp_server.create_initialization_options(), + ) + + starlette_app = Starlette( + debug=self.settings.debug, + routes=[Mount("/", app=handle_http)], + ) + + config = uvicorn.Config( + starlette_app, + host=self.settings.host, + port=self.settings.port, + log_level=self.settings.log_level.lower(), + ) + server = uvicorn.Server(config) + await server.serve() + + def run(self, transport: Literal["http", "stdio", "sse"] = "stdio") -> None: + """Run the FastMCP server. Note this is a synchronous function. + + Args: + transport: Transport protocol to use ("stdio" or "sse") + """ + if transport not in Transport.__members__: + raise ValueError(f"Unknown transport: {transport}") + + if transport == "stdio": + anyio.run(self.run_stdio_async) + elif transport == "sse": + anyio.run(self.run_sse_async) + else: + anyio.run(self.run_http_async) + + # Create an MCP server -mcp = FastMCP("Grafana", log_level="DEBUG") +mcp = GrafanaMCP("Grafana", log_level="DEBUG") add_tools(mcp) diff --git a/src/mcp_grafana/cli.py b/src/mcp_grafana/cli.py index db6e846..0f598b4 100644 --- a/src/mcp_grafana/cli.py +++ b/src/mcp_grafana/cli.py @@ -1,17 +1,10 @@ -import enum - import typer -from . import mcp +from . import mcp, Transport app = typer.Typer() -class Transport(enum.StrEnum): - stdio = "stdio" - sse = "sse" - - @app.command() def run(transport: Transport = Transport.stdio): mcp.run(transport.value) diff --git a/src/mcp_grafana/transports/http.py b/src/mcp_grafana/transports/http.py new file mode 100644 index 0000000..b6b8da1 --- /dev/null +++ b/src/mcp_grafana/transports/http.py @@ -0,0 +1,204 @@ +import logging +from contextlib import asynccontextmanager +from json import JSONDecodeError +from typing import Any, Tuple + +import anyio +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream +import httpx +from mcp import types +from pydantic import ValidationError +from starlette.requests import Request +from starlette.responses import JSONResponse, Response +from starlette.types import Receive, Scope, Send + + +logger = logging.getLogger(__name__) + + +ReadStream = MemoryObjectReceiveStream[types.JSONRPCMessage | Exception] +ReadStreamWriter = MemoryObjectSendStream[types.JSONRPCMessage | Exception] +WriteStream = MemoryObjectSendStream[types.JSONRPCMessage] +WriteStreamReader = MemoryObjectReceiveStream[types.JSONRPCMessage] + + +def make_streams() -> Tuple[ + ReadStream, ReadStreamWriter, WriteStream, WriteStreamReader +]: + read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception] + read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception] + + write_stream: MemoryObjectSendStream[types.JSONRPCMessage] + write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage] + + read_stream_writer, read_stream = anyio.create_memory_object_stream(0) + write_stream, write_stream_reader = anyio.create_memory_object_stream(0) + return read_stream, read_stream_writer, write_stream, write_stream_reader + + +async def initialize( + read_stream_writer: ReadStreamWriter, + write_stream_reader: WriteStreamReader, +): + """ + Initialize the MCP server for this request. + + In a stateful transport (e.g. stdio or sse) the client would + send an initialize request to the server, and the server would send + an 'initialized' response back to the client. + + In the HTTP transport we're trying to be stateless, so we'll have to + handle the initialization ourselves. + + This function handles that initialization by sending the required + messages to the server and consuming the response. + """ + # First construct the initialize request. + initialize_request = types.InitializeRequest( + method="initialize", + params=types.InitializeRequestParams( + protocolVersion=types.LATEST_PROTOCOL_VERSION, + capabilities=types.ClientCapabilities( + experimental=None, + roots=None, + sampling=None, + ), + # TODO: get the name and version from the package metadata. + clientInfo=types.Implementation(name="mcp-grafana", version="0.1.2"), + ), + ) + initialize_request = types.JSONRPCRequest( + jsonrpc="2.0", + id=1, + **initialize_request.model_dump(by_alias=True, mode="json"), + ) + # Send it to the server. + await read_stream_writer.send(types.JSONRPCMessage(initialize_request)) + # We can ignore the response since we're not sending it back to the client. + await write_stream_reader.receive() + + # Next we need to notify the server that we're initialized. + initialize_notification = types.JSONRPCNotification( + jsonrpc="2.0", + **types.ClientNotification( + types.InitializedNotification(method="notifications/initialized"), + ).model_dump(by_alias=True, mode="json"), + ) + await read_stream_writer.send(types.JSONRPCMessage(initialize_notification)) + # Notifications don't have a response, so we don't need to await the + # write stream reader. + + +@asynccontextmanager +async def handle_message(scope: Scope, receive: Receive, send: Send): + """ + ASGI application for handling MCP messages using the stateless HTTP transport. + + This function is called for each incoming message. It creates a new + stream for reading and writing messages, which will be used by the + MCP server, and handles: + + - decoding the client message from JSON into internal types + - validating the client message + - initializing the MCP server, which must be done on every request + (since this is a stateless transport) + - sending the client message to the MCP server + - receiving the server's response + - encoding the server's response into JSON and sending it back to the client + + The returned read and write streams are intended to be passed to + `mcp.server.lowlevel.Server.run()` as the `read_stream` and `write_stream` + arguments. + """ + read_stream, read_stream_writer, write_stream, write_stream_reader = make_streams() + + async def handle_post_message(): + try: + request = Request(scope, receive) + if request.method != "POST": + response = Response("Method not allowed", status_code=405) + await response(scope, receive, send) + return + if scope["path"] != "/mcp": + response = Response("Not found", status_code=404) + await response(scope, receive, send) + return + try: + json = await request.json() + except JSONDecodeError as err: + logger.error(f"Failed to parse message: {err}") + response = Response("Could not parse message", status_code=400) + await response(scope, receive, send) + return + + try: + client_message = types.JSONRPCMessage.model_validate(json) + logger.debug(f"Validated client message: {client_message}") + except ValidationError as err: + logger.error(f"Failed to parse message: {err}") + response = Response("Could not parse message", status_code=400) + await response(scope, receive, send) + return + + # As part of the MCP spec we need to initialize first. + # In a stateful flow (e.g. stdio or sse transports) the client would + # send an initialize request to the server, and the server would send + # a response back to the client. In this case we're trying to be stateless, + # so we'll handle the initialization ourselves. + logger.debug("Initializing server") + await initialize(read_stream_writer, write_stream_reader) + + # Alright, now we can send the client message. + logger.debug("Sending client message") + await read_stream_writer.send(client_message) + + # Wait for the server's response, and forward it to the client. + server_message = await write_stream_reader.receive() + obj = server_message.model_dump( + by_alias=True, mode="json", exclude_none=True + ) + response = JSONResponse(obj) + await response(scope, receive, send) + finally: + await read_stream_writer.aclose() + await write_stream_reader.aclose() + + async with anyio.create_task_group() as tg: + tg.start_soon(handle_post_message) + yield (read_stream, write_stream) + + +@asynccontextmanager +async def http_client(url: str, headers: dict[str, Any] | None = None): + read_stream, read_stream_writer, write_stream, write_stream_reader = make_streams() + + async with anyio.create_task_group() as tg: + try: + + async def http_rw(): + logger.debug("Waiting for request body") + body = await write_stream_reader.receive() + + u = httpx.URL(url) + u_cleaned = ( + u + if u.userinfo == b"" + else u.copy_with(username="", password="") + ) + logger.debug(f"Connecting to HTTP endpoint: {u_cleaned}") + async with httpx.AsyncClient(headers=headers) as client: + response = await client.post( + u, content=body.model_dump_json(by_alias=True) + ) + logger.debug(f"Received response: {response.status_code}") + message = types.JSONRPCMessage.model_validate_json(response.content) + await read_stream_writer.send(message) + + tg.start_soon(http_rw) + try: + yield read_stream, write_stream + finally: + tg.cancel_scope.cancel() + finally: + await read_stream_writer.aclose() + await write_stream_reader.aclose() diff --git a/tests/conftest.py b/tests/conftest.py index 1d20d27..63e9249 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,6 @@ +import pytest + + def pytest_addoption(parser): """ Add command line options for integration and cloud tests. @@ -29,3 +32,8 @@ def pytest_addoption(parser): default=False, help="enable cloud integration tests", ) + + +@pytest.fixture +def anyio_backend(): + return "asyncio" diff --git a/tests/http_test.py b/tests/http_test.py new file mode 100644 index 0000000..cd19562 --- /dev/null +++ b/tests/http_test.py @@ -0,0 +1,140 @@ +import multiprocessing +import socket +import time +from typing import AsyncGenerator, Generator + +import pytest +import uvicorn +from mcp.server.lowlevel.server import Server +from mcp.types import Tool +from mcp.client.session import ClientSession +from starlette.applications import Starlette +from starlette.requests import Request +from starlette.routing import Route + +from mcp_grafana.transports.http import handle_message, http_client + + +@pytest.fixture +def server_port() -> int: + with socket.socket() as s: + s.bind(("127.0.0.1", 0)) + return s.getsockname()[1] + + +@pytest.fixture +def server_url(server_port: int) -> str: + return f"http://127.0.0.1:{server_port}" + + +# A test server implementation. +class ServerTest(Server): + def __init__(self): + super().__init__("test_server_for_http") + + +# Test fixtures +def make_server_app() -> Starlette: + """Create test Starlette app with SSE transport""" + server = ServerTest() + + @server.list_tools() + async def handle_list_tools() -> list[Tool]: + return [ + Tool( + name="test_tool", + description="A test tool", + inputSchema={"type": "object", "properties": {}}, + ), + Tool( + name="test_tool2", + description="A second test tool", + inputSchema={"type": "object", "properties": {}}, + ), + ] + + async def handle_http(request: Request): + async with handle_message(request.scope, request.receive, request._send) as ( + read_stream, + write_stream, + ): + await server.run( + read_stream, + write_stream, + server.create_initialization_options(), + ) + + app = Starlette(routes=[Route("/mcp", endpoint=handle_http, methods=["POST"])]) + return app + + +def run_server(server_port: int) -> None: + app = make_server_app() + server = uvicorn.Server( + config=uvicorn.Config( + app=app, host="127.0.0.1", port=server_port, log_level="error" + ) + ) + print(f"starting server on {server_port}") + server.run() + + # Give server time to start + while not server.started: + print("waiting for server to start") + time.sleep(0.5) + + +@pytest.fixture() +def server(server_port: int) -> Generator[None, None, None]: + proc = multiprocessing.Process( + target=run_server, kwargs={"server_port": server_port}, daemon=True + ) + print("starting process") + proc.start() + + # Wait for server to be running + max_attempts = 20 + attempt = 0 + print("waiting for server to start") + while attempt < max_attempts: + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.connect(("127.0.0.1", server_port)) + break + except ConnectionRefusedError: + time.sleep(0.1) + attempt += 1 + else: + raise RuntimeError( + "Server failed to start after {} attempts".format(max_attempts) + ) + + yield + + print("killing server") + # Signal the server to stop + proc.kill() + proc.join(timeout=2) + if proc.is_alive(): + print("server process failed to terminate") + + +@pytest.fixture +async def http_client_session( + server, + server_url: str, +) -> AsyncGenerator[ClientSession, None]: + async with http_client(url=server_url + "/mcp") as (read_stream, write_stream): + async with ClientSession(read_stream, write_stream) as session: + yield session + + +@pytest.mark.anyio +async def test_http_client_list_tools( + http_client_session: ClientSession, +) -> None: + session = http_client_session + response = await session.list_tools() + assert len(response.tools) == 2 + assert response.tools[0].name == "test_tool" + assert response.tools[1].name == "test_tool2"