|
8 | 8 | import httpx
|
9 | 9 | from mcp import types
|
10 | 10 | from pydantic import ValidationError
|
| 11 | +from starlette.exceptions import HTTPException |
11 | 12 | from starlette.requests import Request
|
12 | 13 | from starlette.responses import JSONResponse, Response
|
13 | 14 | from starlette.types import Receive, Scope, Send
|
@@ -113,39 +114,55 @@ async def handle_message(scope: Scope, receive: Receive, send: Send):
|
113 | 114 | read_stream, read_stream_writer, write_stream, write_stream_reader = make_streams()
|
114 | 115 |
|
115 | 116 | async def handle_post_message():
|
116 |
| - request = Request(scope, receive) |
117 | 117 | try:
|
118 |
| - json = await request.json() |
119 |
| - except JSONDecodeError as err: |
120 |
| - logger.error(f"Failed to parse message: {err}") |
121 |
| - response = Response("Could not parse message", status_code=400) |
122 |
| - await response(scope, receive, send) |
123 |
| - return |
124 |
| - try: |
125 |
| - client_message = types.JSONRPCMessage.model_validate(json) |
126 |
| - logger.debug(f"Validated client message: {client_message}") |
127 |
| - except ValidationError as err: |
128 |
| - logger.error(f"Failed to parse message: {err}") |
129 |
| - response = Response("Could not parse message", status_code=400) |
| 118 | + request = Request(scope, receive) |
| 119 | + if request.method != "POST": |
| 120 | + response = Response("Method not allowed", status_code=405) |
| 121 | + await response(scope, receive, send) |
| 122 | + return |
| 123 | + if scope["path"] != "/mcp": |
| 124 | + response = Response("Not found", status_code=404) |
| 125 | + await response(scope, receive, send) |
| 126 | + return |
| 127 | + try: |
| 128 | + json = await request.json() |
| 129 | + except JSONDecodeError as err: |
| 130 | + logger.error(f"Failed to parse message: {err}") |
| 131 | + response = Response("Could not parse message", status_code=400) |
| 132 | + await response(scope, receive, send) |
| 133 | + return |
| 134 | + |
| 135 | + try: |
| 136 | + client_message = types.JSONRPCMessage.model_validate(json) |
| 137 | + logger.debug(f"Validated client message: {client_message}") |
| 138 | + except ValidationError as err: |
| 139 | + logger.error(f"Failed to parse message: {err}") |
| 140 | + response = Response("Could not parse message", status_code=400) |
| 141 | + await response(scope, receive, send) |
| 142 | + return |
| 143 | + |
| 144 | + # As part of the MCP spec we need to initialize first. |
| 145 | + # In a stateful flow (e.g. stdio or sse transports) the client would |
| 146 | + # send an initialize request to the server, and the server would send |
| 147 | + # a response back to the client. In this case we're trying to be stateless, |
| 148 | + # so we'll handle the initialization ourselves. |
| 149 | + logger.debug("Initializing server") |
| 150 | + await initialize(read_stream_writer, write_stream_reader) |
| 151 | + |
| 152 | + # Alright, now we can send the client message. |
| 153 | + logger.debug("Sending client message") |
| 154 | + await read_stream_writer.send(client_message) |
| 155 | + |
| 156 | + # Wait for the server's response, and forward it to the client. |
| 157 | + server_message = await write_stream_reader.receive() |
| 158 | + obj = server_message.model_dump( |
| 159 | + by_alias=True, mode="json", exclude_none=True |
| 160 | + ) |
| 161 | + response = JSONResponse(obj) |
130 | 162 | await response(scope, receive, send)
|
131 |
| - return |
132 |
| - |
133 |
| - # As part of the MCP spec we need to initialize first. |
134 |
| - # In a stateful flow (e.g. stdio or sse transports) the client would |
135 |
| - # send an initialize request to the server, and the server would send |
136 |
| - # a response back to the client. In this case we're trying to be stateless, |
137 |
| - # so we'll handle the initialization ourselves. |
138 |
| - logger.debug("Initializing server") |
139 |
| - await initialize(read_stream_writer, write_stream_reader) |
140 |
| - |
141 |
| - # Alright, now we can send the client message. |
142 |
| - logger.debug("Sending client message") |
143 |
| - await read_stream_writer.send(client_message) |
144 |
| - # Wait for the server's response, and forward it to the client. |
145 |
| - server_message = await write_stream_reader.receive() |
146 |
| - obj = server_message.model_dump(by_alias=True, mode="json", exclude_none=True) |
147 |
| - response = JSONResponse(obj) |
148 |
| - await response(scope, receive, send) |
| 163 | + finally: |
| 164 | + await read_stream_writer.aclose() |
| 165 | + await write_stream_reader.aclose() |
149 | 166 |
|
150 | 167 | async with anyio.create_task_group() as tg:
|
151 | 168 | tg.start_soon(handle_post_message)
|
|
0 commit comments