Skip to content

Commit b3ddf55

Browse files
committed
Correctly close streams, better handle edge cases
1 parent f36666d commit b3ddf55

File tree

2 files changed

+54
-40
lines changed

2 files changed

+54
-40
lines changed

src/mcp_grafana/__init__.py

+6-9
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import anyio
55
import uvicorn
66
from mcp.server import FastMCP
7-
from starlette.requests import Request
87

98
from .tools import add_tools
109

@@ -18,14 +17,14 @@ class Transport(enum.StrEnum):
1817
class GrafanaMCP(FastMCP):
1918
async def run_http_async(self) -> None:
2019
from starlette.applications import Starlette
21-
from starlette.routing import Route
20+
from starlette.routing import Mount
2221

2322
from .transports.http import handle_message
2423

25-
async def handle_http(request: Request):
26-
async with handle_message(
27-
request.scope, request.receive, request._send
28-
) as (
24+
async def handle_http(scope, receive, send):
25+
if scope["type"] != "http":
26+
raise ValueError("Expected HTTP request")
27+
async with handle_message(scope, receive, send) as (
2928
read_stream,
3029
write_stream,
3130
):
@@ -37,9 +36,7 @@ async def handle_http(request: Request):
3736

3837
starlette_app = Starlette(
3938
debug=self.settings.debug,
40-
routes=[
41-
Route("/mcp", endpoint=handle_http, methods=["POST"]),
42-
],
39+
routes=[Mount("/", app=handle_http)],
4340
)
4441

4542
config = uvicorn.Config(

src/mcp_grafana/transports/http.py

+48-31
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import httpx
99
from mcp import types
1010
from pydantic import ValidationError
11+
from starlette.exceptions import HTTPException
1112
from starlette.requests import Request
1213
from starlette.responses import JSONResponse, Response
1314
from starlette.types import Receive, Scope, Send
@@ -113,39 +114,55 @@ async def handle_message(scope: Scope, receive: Receive, send: Send):
113114
read_stream, read_stream_writer, write_stream, write_stream_reader = make_streams()
114115

115116
async def handle_post_message():
116-
request = Request(scope, receive)
117117
try:
118-
json = await request.json()
119-
except JSONDecodeError as err:
120-
logger.error(f"Failed to parse message: {err}")
121-
response = Response("Could not parse message", status_code=400)
122-
await response(scope, receive, send)
123-
return
124-
try:
125-
client_message = types.JSONRPCMessage.model_validate(json)
126-
logger.debug(f"Validated client message: {client_message}")
127-
except ValidationError as err:
128-
logger.error(f"Failed to parse message: {err}")
129-
response = Response("Could not parse message", status_code=400)
118+
request = Request(scope, receive)
119+
if request.method != "POST":
120+
response = Response("Method not allowed", status_code=405)
121+
await response(scope, receive, send)
122+
return
123+
if scope["path"] != "/mcp":
124+
response = Response("Not found", status_code=404)
125+
await response(scope, receive, send)
126+
return
127+
try:
128+
json = await request.json()
129+
except JSONDecodeError as err:
130+
logger.error(f"Failed to parse message: {err}")
131+
response = Response("Could not parse message", status_code=400)
132+
await response(scope, receive, send)
133+
return
134+
135+
try:
136+
client_message = types.JSONRPCMessage.model_validate(json)
137+
logger.debug(f"Validated client message: {client_message}")
138+
except ValidationError as err:
139+
logger.error(f"Failed to parse message: {err}")
140+
response = Response("Could not parse message", status_code=400)
141+
await response(scope, receive, send)
142+
return
143+
144+
# As part of the MCP spec we need to initialize first.
145+
# In a stateful flow (e.g. stdio or sse transports) the client would
146+
# send an initialize request to the server, and the server would send
147+
# a response back to the client. In this case we're trying to be stateless,
148+
# so we'll handle the initialization ourselves.
149+
logger.debug("Initializing server")
150+
await initialize(read_stream_writer, write_stream_reader)
151+
152+
# Alright, now we can send the client message.
153+
logger.debug("Sending client message")
154+
await read_stream_writer.send(client_message)
155+
156+
# Wait for the server's response, and forward it to the client.
157+
server_message = await write_stream_reader.receive()
158+
obj = server_message.model_dump(
159+
by_alias=True, mode="json", exclude_none=True
160+
)
161+
response = JSONResponse(obj)
130162
await response(scope, receive, send)
131-
return
132-
133-
# As part of the MCP spec we need to initialize first.
134-
# In a stateful flow (e.g. stdio or sse transports) the client would
135-
# send an initialize request to the server, and the server would send
136-
# a response back to the client. In this case we're trying to be stateless,
137-
# so we'll handle the initialization ourselves.
138-
logger.debug("Initializing server")
139-
await initialize(read_stream_writer, write_stream_reader)
140-
141-
# Alright, now we can send the client message.
142-
logger.debug("Sending client message")
143-
await read_stream_writer.send(client_message)
144-
# Wait for the server's response, and forward it to the client.
145-
server_message = await write_stream_reader.receive()
146-
obj = server_message.model_dump(by_alias=True, mode="json", exclude_none=True)
147-
response = JSONResponse(obj)
148-
await response(scope, receive, send)
163+
finally:
164+
await read_stream_writer.aclose()
165+
await write_stream_reader.aclose()
149166

150167
async with anyio.create_task_group() as tg:
151168
tg.start_soon(handle_post_message)

0 commit comments

Comments
 (0)