Skip to content
Closed
3 changes: 2 additions & 1 deletion src/mcp/client/sse.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,9 @@ async def sse_reader(
logger.info(
f"Received endpoint URL: {endpoint_url}"
)

url_parsed = urlparse(url)

endpoint_parsed = urlparse(endpoint_url)
if (
url_parsed.netloc != endpoint_parsed.netloc
Expand Down
10 changes: 9 additions & 1 deletion src/mcp/server/sse.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ async def handle_sse(request):
"""

import logging
import re
from contextlib import asynccontextmanager
from typing import Any
from urllib.parse import quote
Expand Down Expand Up @@ -95,7 +96,14 @@ async def connect_sse(self, scope: Scope, receive: Receive, send: Send):
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)

session_id = uuid4()
session_uri = f"{quote(self._endpoint)}?session_id={session_id.hex}"
request_path = scope["path"]

match = re.match(r"^/([^/]+(?:/mcp)?)/sse$", request_path)
mount_prefix = match.group(1) if match else ""

session_uri = f"/{quote(mount_prefix)}{quote(self._endpoint)}"
session_uri += f"?session_id={session_id.hex}"

self._read_stream_writers[session_id] = read_stream_writer
logger.debug(f"Created new session with ID: {session_id}")

Expand Down
Loading