Skip to content

Commit 762c60c

Browse files
committed
Fail fast if Grafana override settings are not provided in sse request
1 parent 87d4168 commit 762c60c

File tree

2 files changed

+36
-3
lines changed

2 files changed

+36
-3
lines changed

src/mcp_grafana/middleware.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from mcp.server import FastMCP
44
from starlette.datastructures import Headers
5+
from starlette.exceptions import HTTPException
56

67
from .client import GrafanaClient, grafana_client
78
from .settings import GrafanaSettings, grafana_settings
@@ -36,8 +37,9 @@ class GrafanaMiddleware:
3637
This should be used as a context manager before handling the /sse request.
3738
"""
3839

39-
def __init__(self, request):
40+
def __init__(self, request, fail_if_unset=True):
4041
self.request = request
42+
self.fail_if_unset = fail_if_unset
4143
self.settings_token = None
4244
self.client_token = None
4345

@@ -53,6 +55,8 @@ async def __aenter__(self):
5355
self.client_token = grafana_client.set(
5456
GrafanaClient.from_settings(new_settings)
5557
)
58+
elif self.fail_if_unset:
59+
raise HTTPException(status_code=403, detail="No Grafana settings found.")
5660

5761
async def __aexit__(self, exc_type, exc_val, exc_tb):
5862
if self.settings_token is not None:

tests/middleware_test.py

+31-2
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import anyio
77
import httpx
8+
from mcp.server import FastMCP
89
from mcp.types import (
910
LATEST_PROTOCOL_VERSION,
1011
CallToolResult,
@@ -21,12 +22,19 @@
2122
import pytest
2223
from httpx_sse import aconnect_sse
2324

24-
from mcp_grafana import mcp
25+
from mcp_grafana.tools import add_tools
2526
from mcp_grafana.middleware import run_sse_async_with_middleware
2627

2728
from pytest_httpserver import HTTPServer
2829

2930

31+
@pytest.fixture
32+
def mcp():
33+
mcp = FastMCP("grafana")
34+
add_tools(mcp)
35+
return mcp
36+
37+
3038
class TestMiddleware:
3139
"""
3240
Test that our injected starlette middleware extracts headers and
@@ -36,7 +44,28 @@ class TestMiddleware:
3644
"""
3745

3846
@pytest.mark.asyncio
39-
async def test_multiple_requests(self):
47+
async def test_no_headers_provided(self, mcp: FastMCP):
48+
"""
49+
Ensure that the middleware fails if no headers are provided.
50+
"""
51+
52+
# Monkeypatch the MCP server to use our middleware.
53+
mcp.run_sse_async = MethodType(run_sse_async_with_middleware, mcp)
54+
mcp.settings.host = "127.0.0.1"
55+
mcp.settings.port = 9500
56+
async with anyio.create_task_group() as tg:
57+
tg.start_soon(mcp.run_sse_async, name="mcp")
58+
# Wait for the server to start.
59+
await asyncio.sleep(0.1)
60+
client = httpx.AsyncClient(
61+
base_url=f"http://{mcp.settings.host}:{mcp.settings.port}"
62+
)
63+
resp = await client.get("/sse")
64+
assert resp.status_code == httpx.codes.FORBIDDEN
65+
tg.cancel_scope.cancel()
66+
67+
@pytest.mark.asyncio
68+
async def test_multiple_requests(self, mcp: FastMCP):
4069
"""
4170
Ensure that the contextvars do not leak across requests.
4271

0 commit comments

Comments
 (0)