|
| 1 | +import logging |
| 2 | +from contextlib import asynccontextmanager |
| 3 | +from json import JSONDecodeError |
| 4 | +from typing import Any, Tuple |
| 5 | + |
| 6 | +import anyio |
| 7 | +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream |
| 8 | +import httpx |
| 9 | +from mcp import types |
| 10 | +from pydantic import ValidationError |
| 11 | +from starlette.requests import Request |
| 12 | +from starlette.responses import JSONResponse, Response |
| 13 | +from starlette.types import Receive, Scope, Send |
| 14 | + |
| 15 | + |
| 16 | +logger = logging.getLogger(__name__) |
| 17 | + |
| 18 | + |
| 19 | +ReadStream = MemoryObjectReceiveStream[types.JSONRPCMessage | Exception] |
| 20 | +ReadStreamWriter = MemoryObjectSendStream[types.JSONRPCMessage | Exception] |
| 21 | +WriteStream = MemoryObjectSendStream[types.JSONRPCMessage] |
| 22 | +WriteStreamReader = MemoryObjectReceiveStream[types.JSONRPCMessage] |
| 23 | + |
| 24 | + |
| 25 | +def make_streams() -> Tuple[ |
| 26 | + ReadStream, ReadStreamWriter, WriteStream, WriteStreamReader |
| 27 | +]: |
| 28 | + read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception] |
| 29 | + read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception] |
| 30 | + |
| 31 | + write_stream: MemoryObjectSendStream[types.JSONRPCMessage] |
| 32 | + write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage] |
| 33 | + |
| 34 | + read_stream_writer, read_stream = anyio.create_memory_object_stream(0) |
| 35 | + write_stream, write_stream_reader = anyio.create_memory_object_stream(0) |
| 36 | + return read_stream, read_stream_writer, write_stream, write_stream_reader |
| 37 | + |
| 38 | + |
| 39 | +async def initialize( |
| 40 | + read_stream_writer: ReadStreamWriter, |
| 41 | + write_stream_reader: WriteStreamReader, |
| 42 | +): |
| 43 | + """ |
| 44 | + Initialize the MCP server for this request. |
| 45 | +
|
| 46 | + In a stateful transport (e.g. stdio or sse) the client would |
| 47 | + send an initialize request to the server, and the server would send |
| 48 | + an 'initialized' response back to the client. |
| 49 | +
|
| 50 | + In the HTTP transport we're trying to be stateless, so we'll have to |
| 51 | + handle the initialization ourselves. |
| 52 | +
|
| 53 | + This function handles that initialization by sending the required |
| 54 | + messages to the server and consuming the response. |
| 55 | + """ |
| 56 | + # First construct the initialize request. |
| 57 | + initialize_request = types.InitializeRequest( |
| 58 | + method="initialize", |
| 59 | + params=types.InitializeRequestParams( |
| 60 | + protocolVersion=types.LATEST_PROTOCOL_VERSION, |
| 61 | + capabilities=types.ClientCapabilities( |
| 62 | + experimental=None, |
| 63 | + roots=None, |
| 64 | + sampling=None, |
| 65 | + ), |
| 66 | + # TODO: get the name and version from the package metadata. |
| 67 | + clientInfo=types.Implementation(name="mcp-grafana", version="0.1.2"), |
| 68 | + ), |
| 69 | + ) |
| 70 | + initialize_request = types.JSONRPCRequest( |
| 71 | + jsonrpc="2.0", |
| 72 | + id=1, |
| 73 | + **initialize_request.model_dump(by_alias=True, mode="json"), |
| 74 | + ) |
| 75 | + # Send it to the server. |
| 76 | + await read_stream_writer.send(types.JSONRPCMessage(initialize_request)) |
| 77 | + # We can ignore the response since we're not sending it back to the client. |
| 78 | + await write_stream_reader.receive() |
| 79 | + |
| 80 | + # Next we need to notify the server that we're initialized. |
| 81 | + initialize_notification = types.JSONRPCNotification( |
| 82 | + jsonrpc="2.0", |
| 83 | + **types.ClientNotification( |
| 84 | + types.InitializedNotification(method="notifications/initialized"), |
| 85 | + ).model_dump(by_alias=True, mode="json"), |
| 86 | + ) |
| 87 | + await read_stream_writer.send(types.JSONRPCMessage(initialize_notification)) |
| 88 | + # Notifications don't have a response, so we don't need to await the |
| 89 | + # write stream reader. |
| 90 | + |
| 91 | + |
| 92 | +@asynccontextmanager |
| 93 | +async def handle_message(scope: Scope, receive: Receive, send: Send): |
| 94 | + """ |
| 95 | + ASGI application for handling MCP messages using the stateless HTTP transport. |
| 96 | +
|
| 97 | + This function is called for each incoming message. It creates a new |
| 98 | + stream for reading and writing messages, which will be used by the |
| 99 | + MCP server, and handles: |
| 100 | +
|
| 101 | + - decoding the client message from JSON into internal types |
| 102 | + - validating the client message |
| 103 | + - initializing the MCP server, which must be done on every request |
| 104 | + (since this is a stateless transport) |
| 105 | + - sending the client message to the MCP server |
| 106 | + - receiving the server's response |
| 107 | + - encoding the server's response into JSON and sending it back to the client |
| 108 | +
|
| 109 | + The returned read and write streams are intended to be passed to |
| 110 | + `mcp.server.lowlevel.Server.run()` as the `read_stream` and `write_stream` |
| 111 | + arguments. |
| 112 | + """ |
| 113 | + read_stream, read_stream_writer, write_stream, write_stream_reader = make_streams() |
| 114 | + |
| 115 | + async def handle_post_message(): |
| 116 | + request = Request(scope, receive) |
| 117 | + try: |
| 118 | + json = await request.json() |
| 119 | + except JSONDecodeError as err: |
| 120 | + logger.error(f"Failed to parse message: {err}") |
| 121 | + response = Response("Could not parse message", status_code=400) |
| 122 | + await response(scope, receive, send) |
| 123 | + return |
| 124 | + try: |
| 125 | + client_message = types.JSONRPCMessage.model_validate(json) |
| 126 | + logger.debug(f"Validated client message: {client_message}") |
| 127 | + except ValidationError as err: |
| 128 | + logger.error(f"Failed to parse message: {err}") |
| 129 | + response = Response("Could not parse message", status_code=400) |
| 130 | + await response(scope, receive, send) |
| 131 | + return |
| 132 | + |
| 133 | + # As part of the MCP spec we need to initialize first. |
| 134 | + # In a stateful flow (e.g. stdio or sse transports) the client would |
| 135 | + # send an initialize request to the server, and the server would send |
| 136 | + # a response back to the client. In this case we're trying to be stateless, |
| 137 | + # so we'll handle the initialization ourselves. |
| 138 | + logger.debug("Initializing server") |
| 139 | + await initialize(read_stream_writer, write_stream_reader) |
| 140 | + |
| 141 | + # Alright, now we can send the client message. |
| 142 | + logger.debug("Sending client message") |
| 143 | + await read_stream_writer.send(client_message) |
| 144 | + # Wait for the server's response, and forward it to the client. |
| 145 | + server_message = await write_stream_reader.receive() |
| 146 | + obj = server_message.model_dump(by_alias=True, mode="json", exclude_none=True) |
| 147 | + response = JSONResponse(obj) |
| 148 | + await response(scope, receive, send) |
| 149 | + |
| 150 | + async with anyio.create_task_group() as tg: |
| 151 | + tg.start_soon(handle_post_message) |
| 152 | + yield (read_stream, write_stream) |
| 153 | + |
| 154 | + |
| 155 | +@asynccontextmanager |
| 156 | +async def http_client(url: str, headers: dict[str, Any] | None = None): |
| 157 | + read_stream, read_stream_writer, write_stream, write_stream_reader = make_streams() |
| 158 | + |
| 159 | + async with anyio.create_task_group() as tg: |
| 160 | + try: |
| 161 | + |
| 162 | + async def http_rw(): |
| 163 | + logger.debug("Waiting for request body") |
| 164 | + body = await write_stream_reader.receive() |
| 165 | + |
| 166 | + logger.debug(f"Connecting to HTTP endpoint: {url}") |
| 167 | + async with httpx.AsyncClient(headers=headers) as client: |
| 168 | + response = await client.post( |
| 169 | + url, content=body.model_dump_json(by_alias=True) |
| 170 | + ) |
| 171 | + logger.debug(f"Received response: {response.status_code}") |
| 172 | + message = types.JSONRPCMessage.model_validate_json(response.content) |
| 173 | + await read_stream_writer.send(message) |
| 174 | + |
| 175 | + tg.start_soon(http_rw) |
| 176 | + try: |
| 177 | + yield read_stream, write_stream |
| 178 | + finally: |
| 179 | + tg.cancel_scope.cancel() |
| 180 | + finally: |
| 181 | + await read_stream_writer.aclose() |
| 182 | + await write_stream_reader.aclose() |
0 commit comments