Skip to content

Commit ddc6022

Browse files
committed
Add test for middleware
This commit adds a test for the new GrafanaMiddleware class. It is a bit heavy-handed due to the fact that it requires quite a lot of moving parts. There are quite a few comments in the test file which hopefully explains what is going on.
1 parent 0fae385 commit ddc6022

File tree

5 files changed

+349
-49
lines changed

5 files changed

+349
-49
lines changed

pyproject.toml

+2
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,10 @@ mcp-grafana = "mcp_grafana.cli:app"
2121

2222
[dependency-groups]
2323
dev = [
24+
"httpx-sse>=0.4.0",
2425
"pytest>=8.3.4",
2526
"pytest-asyncio>=0.25.2",
27+
"pytest-httpserver>=1.1.1",
2628
]
2729
lint = [
2830
"ruff>=0.8.5",

src/mcp_grafana/cli.py

+3-49
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import enum
22
from types import MethodType
33

4-
from mcp.server import FastMCP
54
import typer
65

76
from . import mcp
@@ -17,57 +16,12 @@ class Transport(enum.StrEnum):
1716
@app.command()
1817
def run(transport: Transport = Transport.stdio, header_auth: bool = False):
1918
if transport == Transport.sse and header_auth:
19+
from .middleware import run_sse_async_with_middleware
20+
2021
# Monkeypatch the run_sse_async method to inject a Grafana middleware.
2122
# This is a bit of a hack, but fastmcp doesn't have a way of adding
2223
# middleware. It's not unreasonable to do this really, since fastmcp
2324
# is just a thin wrapper around the low level mcp server.
24-
mcp.run_sse_async = MethodType(run_sse_async, mcp)
25+
mcp.run_sse_async = MethodType(run_sse_async_with_middleware, mcp)
2526

2627
mcp.run(transport.value)
27-
28-
29-
async def run_sse_async(self: FastMCP) -> None:
30-
"""
31-
Run the server using SSE transport, with a middleware that extracts
32-
Grafana authentication information from the request headers.
33-
34-
The vast majority of this code is the same as the original run_sse_async
35-
method (see https://github.com/modelcontextprotocol/python-sdk/blob/44c0004e6c69e336811bb6793b7176e1eda50015/src/mcp/server/fastmcp/server.py#L436-L468).
36-
"""
37-
38-
from mcp.server.sse import SseServerTransport
39-
from starlette.applications import Starlette
40-
from starlette.routing import Mount, Route
41-
import uvicorn
42-
43-
from .middleware import GrafanaMiddleware
44-
45-
sse = SseServerTransport("/messages/")
46-
47-
async def handle_sse(request):
48-
async with GrafanaMiddleware(request):
49-
async with sse.connect_sse(
50-
request.scope, request.receive, request._send
51-
) as streams:
52-
await self._mcp_server.run(
53-
streams[0],
54-
streams[1],
55-
self._mcp_server.create_initialization_options(),
56-
)
57-
58-
starlette_app = Starlette(
59-
debug=self.settings.debug,
60-
routes=[
61-
Route("/sse", endpoint=handle_sse),
62-
Mount("/messages/", app=sse.handle_post_message),
63-
],
64-
)
65-
66-
config = uvicorn.Config(
67-
starlette_app,
68-
host=self.settings.host,
69-
port=self.settings.port,
70-
log_level=self.settings.log_level.lower(),
71-
)
72-
server = uvicorn.Server(config)
73-
await server.serve()

src/mcp_grafana/middleware.py

+46
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from dataclasses import dataclass
22

3+
from mcp.server import FastMCP
34
from starlette.datastructures import Headers
45

56
from .client import GrafanaClient, grafana_client
@@ -55,3 +56,48 @@ async def __aenter__(self):
5556
async def __aexit__(self, exc_type, exc_val, exc_tb):
5657
if self.token is not None:
5758
grafana_settings.reset(self.token)
59+
60+
61+
async def run_sse_async_with_middleware(self: FastMCP) -> None:
62+
"""
63+
Run the server using SSE transport, with a middleware that extracts
64+
Grafana authentication information from the request headers.
65+
66+
The vast majority of this code is the same as the original run_sse_async
67+
method (see https://github.com/modelcontextprotocol/python-sdk/blob/44c0004e6c69e336811bb6793b7176e1eda50015/src/mcp/server/fastmcp/server.py#L436-L468).
68+
"""
69+
70+
from mcp.server.sse import SseServerTransport
71+
from starlette.applications import Starlette
72+
from starlette.routing import Mount, Route
73+
import uvicorn
74+
75+
sse = SseServerTransport("/messages/")
76+
77+
async def handle_sse(request):
78+
async with GrafanaMiddleware(request):
79+
async with sse.connect_sse(
80+
request.scope, request.receive, request._send
81+
) as streams:
82+
await self._mcp_server.run(
83+
streams[0],
84+
streams[1],
85+
self._mcp_server.create_initialization_options(),
86+
)
87+
88+
starlette_app = Starlette(
89+
debug=self.settings.debug,
90+
routes=[
91+
Route("/sse", endpoint=handle_sse),
92+
Mount("/messages/", app=sse.handle_post_message),
93+
],
94+
)
95+
96+
config = uvicorn.Config(
97+
starlette_app,
98+
host=self.settings.host,
99+
port=self.settings.port,
100+
log_level=self.settings.log_level.lower(),
101+
)
102+
server = uvicorn.Server(config)
103+
await server.serve()

tests/middleware_test.py

+222
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,222 @@
1+
import asyncio
2+
import json
3+
from types import MethodType
4+
from typing import AsyncIterator
5+
6+
import anyio
7+
import httpx
8+
from mcp.types import (
9+
LATEST_PROTOCOL_VERSION,
10+
CallToolResult,
11+
ClientCapabilities,
12+
ClientNotification,
13+
Implementation,
14+
InitializeRequest,
15+
InitializeRequestParams,
16+
InitializedNotification,
17+
JSONRPCNotification,
18+
JSONRPCRequest,
19+
JSONRPCResponse,
20+
)
21+
import pytest
22+
from httpx_sse import aconnect_sse
23+
24+
from mcp_grafana import mcp
25+
from mcp_grafana.middleware import run_sse_async_with_middleware
26+
27+
from pytest_httpserver import HTTPServer
28+
29+
30+
class TestMiddleware:
31+
"""
32+
Test that our injected starlette middleware extracts headers and
33+
overrides settings per-request, as expected.
34+
35+
Also ensure that the contextvars do not leak across requests.
36+
"""
37+
38+
@pytest.mark.asyncio
39+
async def test_multiple_requests(self):
40+
"""
41+
Ensure that the contextvars do not leak across requests.
42+
43+
This is a bit of a tricky test, since we're not actually
44+
testing the middleware itself, but rather the contextvars
45+
that it uses.
46+
47+
We need to:
48+
1. Start a couple of mock Grafana servers
49+
2. Start our MCP server
50+
3. Send a request to the MCP server pointing to the first
51+
Grafana server (using the X-Grafana-Url header)
52+
4. Send a different request to the MCP server pointing to
53+
the second Grafana server (using the X-Grafana-Url header)
54+
5. Ensure that the right request goes to the right server
55+
"""
56+
57+
# Start a couple of mock Grafana servers.
58+
with HTTPServer(port=10000) as g1, HTTPServer(port=10001) as g2:
59+
# Set up some responses from those servers.
60+
61+
g1.expect_oneshot_request("/api/datasources").respond_with_json([{"id": 1}])
62+
g1.expect_oneshot_request(
63+
"/api/plugins/grafana-incident-app/resources/api/IncidentsService.CreateIncident",
64+
method="POST",
65+
# TODO: add proper request body.
66+
).respond_with_json({}) # TODO: add response body
67+
68+
g2.expect_oneshot_request(
69+
"/api/datasources/proxy/uid/foo/api/v1/label/__name__/values"
70+
).respond_with_json({
71+
"status": "success",
72+
"data": [
73+
"metric1",
74+
"metric2",
75+
],
76+
})
77+
78+
# Hardcode a port for the MCP server.
79+
mcp.settings.host = "127.0.0.1"
80+
mcp.settings.port = 10002
81+
82+
# Create clients for each server.
83+
# Note these clients send requests to the MCP server, not the Grafana server.
84+
# The initial SSE request includes headers that tell the server which
85+
# Grafana server to send tool requests to.
86+
g1_client = httpx.AsyncClient(
87+
base_url=f"http://{mcp.settings.host}:{mcp.settings.port}",
88+
)
89+
g2_client = httpx.AsyncClient(
90+
base_url=f"http://{mcp.settings.host}:{mcp.settings.port}"
91+
)
92+
93+
# Monkeypatch the MCP server to use our middleware.
94+
mcp.run_sse_async = MethodType(run_sse_async_with_middleware, mcp)
95+
96+
async with anyio.create_task_group() as tg:
97+
tg.start_soon(mcp.run_sse_async, name="mcp")
98+
# Wait for the server to start.
99+
await asyncio.sleep(0.1)
100+
101+
# Send SSE requests to the MCP server, one for each Grafana server.
102+
# We can access tool call results over the SSE stream.
103+
async with (
104+
aconnect_sse(
105+
g1_client,
106+
"GET",
107+
"/sse",
108+
headers={
109+
"X-Grafana-Url": f"http://{g1.host}:{g1.port}",
110+
"X-Grafana-Api-Key": "abcd123",
111+
},
112+
) as g1_source,
113+
aconnect_sse(
114+
g2_client,
115+
"GET",
116+
"/sse",
117+
headers={
118+
"X-Grafana-Url": f"http://{g2.host}:{g2.port}",
119+
"X-Grafana-Api-Key": "efgh456",
120+
},
121+
) as g2_source,
122+
):
123+
g1_iter = g1_source.aiter_sse()
124+
g2_iter = g2_source.aiter_sse()
125+
# The URL to use is in the first SSE message.
126+
g1_url = (await g1_iter.__anext__()).data
127+
g2_url = (await g2_iter.__anext__()).data
128+
129+
# The MCP protocol requires us to send an initialize request
130+
# before we can send any other requests.
131+
await initialize(g1_client, g1_url, g1_iter)
132+
await initialize(g2_client, g2_url, g2_iter)
133+
134+
# Send a tool call request using the first URL.
135+
await g1_client.post(
136+
g1_url,
137+
json={
138+
"jsonrpc": "2.0",
139+
"id": 2,
140+
"method": "tools/call",
141+
"params": {"name": "list_datasources"},
142+
},
143+
)
144+
result = await jsonrpc_result(g1_iter)
145+
# This must have come from the first Grafana server.
146+
assert json.loads(result.content[0].text) == json.dumps( # type: ignore
147+
[{"id": 1}], indent=4
148+
)
149+
150+
# Send a tool call request using the second URL.
151+
await g2_client.post(
152+
g2_url,
153+
json={
154+
"jsonrpc": "2.0",
155+
"id": 2,
156+
"method": "tools/call",
157+
"params": {
158+
"name": "list_prometheus_metric_names",
159+
"arguments": {"datasource_uid": "foo", "regex": ".*"},
160+
},
161+
},
162+
)
163+
result = await jsonrpc_result(g2_iter)
164+
metrics = [x.text for x in result.content] # type: ignore
165+
# This must have come from the second Grafana server.
166+
assert metrics == ["metric1", "metric2"]
167+
168+
# As ridiculous as it sounds, there is no way to stop the uvicorn
169+
# server other than raising a signal (sigint or sigterm), which would
170+
# also cause the test to fail. Instead, we just cancel the task group
171+
# and let the test finish.
172+
# The annoying part of this is that there are tons of extra logs emitted
173+
# by uvicorn which can't be captured by pytest...
174+
tg.cancel_scope.cancel()
175+
176+
177+
async def initialize(client: httpx.AsyncClient, url: str, stream: AsyncIterator):
178+
"""
179+
Handle the initialization handshake with the MCP server.
180+
"""
181+
req = InitializeRequest(
182+
method="initialize",
183+
params=InitializeRequestParams(
184+
protocolVersion=LATEST_PROTOCOL_VERSION,
185+
capabilities=ClientCapabilities(
186+
sampling=None,
187+
experimental=None,
188+
),
189+
clientInfo=Implementation(name="mcp-grafana", version="0.1.2"),
190+
),
191+
)
192+
jdoc = JSONRPCRequest(
193+
jsonrpc="2.0",
194+
id=1,
195+
**req.model_dump(by_alias=True, mode="json"),
196+
)
197+
resp = await client.post(url, json=jdoc.model_dump(by_alias=True))
198+
resp.raise_for_status()
199+
200+
req = ClientNotification(
201+
InitializedNotification(method="notifications/initialized")
202+
)
203+
jdoc = JSONRPCNotification(
204+
jsonrpc="2.0",
205+
**req.model_dump(by_alias=True, mode="json"),
206+
)
207+
await client.post(url, json=jdoc.model_dump(by_alias=True))
208+
209+
# Consume the stream to ensure that the initialization handshake
210+
# is complete.
211+
sse = await stream.__anext__()
212+
data = json.loads(sse.data)
213+
assert "result" in data
214+
215+
216+
async def jsonrpc_result(stream: AsyncIterator) -> CallToolResult:
217+
"""
218+
Extract the result of a 'call tool' JSONRPC request from the SSE stream.
219+
"""
220+
jdoc = (await stream.__anext__()).data
221+
resp = JSONRPCResponse.model_validate_json(jdoc)
222+
return CallToolResult.model_validate(resp.result)

0 commit comments

Comments
 (0)