diff --git a/src/fastmcp/client/transports.py b/src/fastmcp/client/transports.py index 81afc9c88..628f2ec16 100644 --- a/src/fastmcp/client/transports.py +++ b/src/fastmcp/client/transports.py @@ -6,7 +6,7 @@ import shutil import sys import warnings -from collections.abc import AsyncIterator +from collections.abc import AsyncIterator, Callable from pathlib import Path from typing import Any, Literal, TextIO, TypeVar, cast, overload @@ -254,6 +254,8 @@ def __init__( sse_read_timeout = datetime.timedelta(seconds=float(sse_read_timeout)) self.sse_read_timeout = sse_read_timeout + self._get_session_id_cb: Callable[[], str | None] | None = None + def _set_auth(self, auth: httpx.Auth | Literal["oauth"] | str | None): if auth == "oauth": auth = OAuth(self.url, httpx_client_factory=self.httpx_client_factory) @@ -287,12 +289,25 @@ async def connect_session( auth=self.auth, **client_kwargs, ) as transport: - read_stream, write_stream, _ = transport + read_stream, write_stream, get_session_id = transport + self._get_session_id_cb = get_session_id async with ClientSession( read_stream, write_stream, **session_kwargs ) as session: yield session + def get_session_id(self) -> str | None: + if self._get_session_id_cb: + try: + return self._get_session_id_cb() + except Exception: + return None + return None + + async def close(self): + # Reset the session id callback + self._get_session_id_cb = None + def __repr__(self) -> str: return f"" diff --git a/tests/client/test_streamable_http.py b/tests/client/test_streamable_http.py index f0378e91c..5c1753d9a 100644 --- a/tests/client/test_streamable_http.py +++ b/tests/client/test_streamable_http.py @@ -175,6 +175,15 @@ async def test_http_headers(streamable_http_server: str): assert json_result["x-demo-header"] == "ABC" +async def test_session_id_callback(streamable_http_server: str): + """Test getting mcp-session-id from the transport.""" + transport = StreamableHttpTransport(streamable_http_server) + assert transport.get_session_id() is None + async with Client(transport=transport): + session_id = transport.get_session_id() + assert session_id is not None + + @pytest.mark.parametrize("streamable_http_server", [True, False], indirect=True) async def test_greet_with_progress_tool(streamable_http_server: str): """Test calling the greet tool."""