Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add PoC of stateless HTTP transport #24

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,7 @@ build-backend = "hatchling.build"

[tool.pytest.ini_options]
asyncio_default_fixture_loop_scope = "session"

[tool.ruff]
line-length = 88
target-version = "py310"
65 changes: 64 additions & 1 deletion src/mcp_grafana/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,70 @@
import enum
from typing import Literal

import anyio
import uvicorn
from mcp.server import FastMCP

from .tools import add_tools


class Transport(enum.StrEnum):
http = "http"
stdio = "stdio"
sse = "sse"


class GrafanaMCP(FastMCP):
async def run_http_async(self) -> None:
from starlette.applications import Starlette
from starlette.routing import Mount

from .transports.http import handle_message

async def handle_http(scope, receive, send):
if scope["type"] != "http":
raise ValueError("Expected HTTP request")
async with handle_message(scope, receive, send) as (
read_stream,
write_stream,
):
await self._mcp_server.run(
read_stream,
write_stream,
self._mcp_server.create_initialization_options(),
)

starlette_app = Starlette(
debug=self.settings.debug,
routes=[Mount("/", app=handle_http)],
)

config = uvicorn.Config(
starlette_app,
host=self.settings.host,
port=self.settings.port,
log_level=self.settings.log_level.lower(),
)
server = uvicorn.Server(config)
await server.serve()

def run(self, transport: Literal["http", "stdio", "sse"] = "stdio") -> None:
"""Run the FastMCP server. Note this is a synchronous function.

Args:
transport: Transport protocol to use ("stdio" or "sse")
"""
if transport not in Transport:
raise ValueError(f"Unknown transport: {transport}")

if transport == "stdio":
anyio.run(self.run_stdio_async)
elif transport == "sse":
anyio.run(self.run_sse_async)
else:
anyio.run(self.run_http_async)


# Create an MCP server
mcp = FastMCP("Grafana", log_level="DEBUG")
mcp = GrafanaMCP("Grafana", log_level="DEBUG")
add_tools(mcp)
9 changes: 1 addition & 8 deletions src/mcp_grafana/cli.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,10 @@
import enum

import typer

from . import mcp
from . import mcp, Transport

app = typer.Typer()


class Transport(enum.StrEnum):
stdio = "stdio"
sse = "sse"


@app.command()
def run(transport: Transport = Transport.stdio):
mcp.run(transport.value)
198 changes: 198 additions & 0 deletions src/mcp_grafana/transports/http.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
import logging
from contextlib import asynccontextmanager
from json import JSONDecodeError
from typing import Any, Tuple

import anyio
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
import httpx
from mcp import types
from pydantic import ValidationError
from starlette.requests import Request
from starlette.responses import JSONResponse, Response
from starlette.types import Receive, Scope, Send


logger = logging.getLogger(__name__)


ReadStream = MemoryObjectReceiveStream[types.JSONRPCMessage | Exception]
ReadStreamWriter = MemoryObjectSendStream[types.JSONRPCMessage | Exception]
WriteStream = MemoryObjectSendStream[types.JSONRPCMessage]
WriteStreamReader = MemoryObjectReceiveStream[types.JSONRPCMessage]


def make_streams() -> Tuple[
ReadStream, ReadStreamWriter, WriteStream, WriteStreamReader
]:
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception]
read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception]

write_stream: MemoryObjectSendStream[types.JSONRPCMessage]
write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage]

read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
return read_stream, read_stream_writer, write_stream, write_stream_reader


async def initialize(
read_stream_writer: ReadStreamWriter,
write_stream_reader: WriteStreamReader,
):
"""
Initialize the MCP server for this request.

In a stateful transport (e.g. stdio or sse) the client would
send an initialize request to the server, and the server would send
an 'initialized' response back to the client.

In the HTTP transport we're trying to be stateless, so we'll have to
handle the initialization ourselves.

This function handles that initialization by sending the required
messages to the server and consuming the response.
"""
# First construct the initialize request.
initialize_request = types.InitializeRequest(
method="initialize",
params=types.InitializeRequestParams(
protocolVersion=types.LATEST_PROTOCOL_VERSION,
capabilities=types.ClientCapabilities(
experimental=None,
roots=None,
sampling=None,
),
# TODO: get the name and version from the package metadata.
clientInfo=types.Implementation(name="mcp-grafana", version="0.1.2"),
),
)
initialize_request = types.JSONRPCRequest(
jsonrpc="2.0",
id=1,
**initialize_request.model_dump(by_alias=True, mode="json"),
)
# Send it to the server.
await read_stream_writer.send(types.JSONRPCMessage(initialize_request))
# We can ignore the response since we're not sending it back to the client.
await write_stream_reader.receive()

# Next we need to notify the server that we're initialized.
initialize_notification = types.JSONRPCNotification(
jsonrpc="2.0",
**types.ClientNotification(
types.InitializedNotification(method="notifications/initialized"),
).model_dump(by_alias=True, mode="json"),
)
await read_stream_writer.send(types.JSONRPCMessage(initialize_notification))
# Notifications don't have a response, so we don't need to await the
# write stream reader.


@asynccontextmanager
async def handle_message(scope: Scope, receive: Receive, send: Send):
"""
ASGI application for handling MCP messages using the stateless HTTP transport.

This function is called for each incoming message. It creates a new
stream for reading and writing messages, which will be used by the
MCP server, and handles:

- decoding the client message from JSON into internal types
- validating the client message
- initializing the MCP server, which must be done on every request
(since this is a stateless transport)
- sending the client message to the MCP server
- receiving the server's response
- encoding the server's response into JSON and sending it back to the client

The returned read and write streams are intended to be passed to
`mcp.server.lowlevel.Server.run()` as the `read_stream` and `write_stream`
arguments.
"""
read_stream, read_stream_writer, write_stream, write_stream_reader = make_streams()

async def handle_post_message():
try:
request = Request(scope, receive)
if request.method != "POST":
response = Response("Method not allowed", status_code=405)
await response(scope, receive, send)
return
if scope["path"] != "/mcp":
response = Response("Not found", status_code=404)
await response(scope, receive, send)
return
try:
json = await request.json()
except JSONDecodeError as err:
logger.error(f"Failed to parse message: {err}")
response = Response("Could not parse message", status_code=400)
await response(scope, receive, send)
return

try:
client_message = types.JSONRPCMessage.model_validate(json)
logger.debug(f"Validated client message: {client_message}")
except ValidationError as err:
logger.error(f"Failed to parse message: {err}")
response = Response("Could not parse message", status_code=400)
await response(scope, receive, send)
return

# As part of the MCP spec we need to initialize first.
# In a stateful flow (e.g. stdio or sse transports) the client would
# send an initialize request to the server, and the server would send
# a response back to the client. In this case we're trying to be stateless,
# so we'll handle the initialization ourselves.
logger.debug("Initializing server")
await initialize(read_stream_writer, write_stream_reader)

# Alright, now we can send the client message.
logger.debug("Sending client message")
await read_stream_writer.send(client_message)

# Wait for the server's response, and forward it to the client.
server_message = await write_stream_reader.receive()
obj = server_message.model_dump(
by_alias=True, mode="json", exclude_none=True
)
response = JSONResponse(obj)
await response(scope, receive, send)
finally:
await read_stream_writer.aclose()
await write_stream_reader.aclose()

async with anyio.create_task_group() as tg:
tg.start_soon(handle_post_message)
yield (read_stream, write_stream)


@asynccontextmanager
async def http_client(url: str, headers: dict[str, Any] | None = None):
read_stream, read_stream_writer, write_stream, write_stream_reader = make_streams()

async with anyio.create_task_group() as tg:
try:

async def http_rw():
logger.debug("Waiting for request body")
body = await write_stream_reader.receive()

logger.debug(f"Connecting to HTTP endpoint: {url}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: This has bitten me before when the URLs contain usernames/passwords 😬 but its behind debug so might be ok!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wouldn't say that's ok, we commonly enable debug logging in prod as needed. We should make sure to clean the url before writing it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, fixed.

async with httpx.AsyncClient(headers=headers) as client:
response = await client.post(
url, content=body.model_dump_json(by_alias=True)
)
logger.debug(f"Received response: {response.status_code}")
message = types.JSONRPCMessage.model_validate_json(response.content)
await read_stream_writer.send(message)

tg.start_soon(http_rw)
try:
yield read_stream, write_stream
finally:
tg.cancel_scope.cancel()
finally:
await read_stream_writer.aclose()
await write_stream_reader.aclose()
8 changes: 8 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import pytest


def pytest_addoption(parser):
"""
Add command line options for integration and cloud tests.
Expand Down Expand Up @@ -29,3 +32,8 @@ def pytest_addoption(parser):
default=False,
help="enable cloud integration tests",
)


@pytest.fixture
def anyio_backend():
return "asyncio"
Loading
Loading