Skip to content

Commit a6ddde2

Browse files
guschnwgjlowin
andauthored
Allow OAuth instance to use the same httpx factory as the Transport (#2324)
* Allow OAuth instance to use the same httpx factory as the Transport * Fix test * Update SSL verification mode assertion in tests * This is actually not needed * Creating a Client instance is not needed for this test * Fix test * Apply httpx_client_factory fix to SSETransport Extends the OAuth httpx_client_factory changes to SSETransport. SSETransport had the same issues as StreamableHttpTransport where it wasn't passing the custom httpx client factory to OAuth, causing certificate verification settings to be ignored during OAuth flows. Changes: - Set httpx_client_factory before calling _set_auth() - Pass httpx_client_factory to OAuth constructor - Add test for SSETransport OAuth client factory propagation --------- Co-authored-by: Jeremiah Lowin <153965+jlowin@users.noreply.github.com>
1 parent a7563a3 commit a6ddde2

File tree

3 files changed

+45
-5
lines changed

3 files changed

+45
-5
lines changed

src/fastmcp/client/auth/oauth.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from key_value.aio.protocols import AsyncKeyValue
1313
from key_value.aio.stores.memory import MemoryStore
1414
from mcp.client.auth import OAuthClientProvider, TokenStorage
15+
from mcp.shared._httpx_utils import McpHttpClientFactory
1516
from mcp.shared.auth import (
1617
OAuthClientInformationFull,
1718
OAuthClientMetadata,
@@ -147,6 +148,7 @@ def __init__(
147148
token_storage: AsyncKeyValue | None = None,
148149
additional_client_metadata: dict[str, Any] | None = None,
149150
callback_port: int | None = None,
151+
httpx_client_factory: McpHttpClientFactory | None = None,
150152
):
151153
"""
152154
Initialize OAuth client provider for an MCP server.
@@ -164,6 +166,7 @@ def __init__(
164166
server_base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
165167

166168
# Setup OAuth client
169+
self.httpx_client_factory = httpx_client_factory or httpx.AsyncClient
167170
self.redirect_port = callback_port or find_available_port()
168171
redirect_uri = f"http://localhost:{self.redirect_port}/callback"
169172

@@ -226,7 +229,7 @@ async def _initialize(self) -> None:
226229
async def redirect_handler(self, authorization_url: str) -> None:
227230
"""Open browser for authorization, with pre-flight check for invalid client."""
228231
# Pre-flight check to detect invalid client_id before opening browser
229-
async with httpx.AsyncClient() as client:
232+
async with self.httpx_client_factory() as client:
230233
response = await client.get(authorization_url, follow_redirects=False)
231234

232235
# Check for client not found error (400 typically means bad client_id)

src/fastmcp/client/transports.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -177,16 +177,16 @@ def __init__(
177177

178178
self.url = url
179179
self.headers = headers or {}
180-
self._set_auth(auth)
181180
self.httpx_client_factory = httpx_client_factory
181+
self._set_auth(auth)
182182

183183
if isinstance(sse_read_timeout, int | float):
184184
sse_read_timeout = datetime.timedelta(seconds=float(sse_read_timeout))
185185
self.sse_read_timeout = sse_read_timeout
186186

187187
def _set_auth(self, auth: httpx.Auth | Literal["oauth"] | str | None):
188188
if auth == "oauth":
189-
auth = OAuth(self.url)
189+
auth = OAuth(self.url, httpx_client_factory=self.httpx_client_factory)
190190
elif isinstance(auth, str):
191191
auth = BearerAuth(auth)
192192
self.auth = auth
@@ -247,16 +247,16 @@ def __init__(
247247

248248
self.url = url
249249
self.headers = headers or {}
250-
self._set_auth(auth)
251250
self.httpx_client_factory = httpx_client_factory
251+
self._set_auth(auth)
252252

253253
if isinstance(sse_read_timeout, int | float):
254254
sse_read_timeout = datetime.timedelta(seconds=float(sse_read_timeout))
255255
self.sse_read_timeout = sse_read_timeout
256256

257257
def _set_auth(self, auth: httpx.Auth | Literal["oauth"] | str | None):
258258
if auth == "oauth":
259-
auth = OAuth(self.url)
259+
auth = OAuth(self.url, httpx_client_factory=self.httpx_client_factory)
260260
elif isinstance(auth, str):
261261
auth = BearerAuth(auth)
262262
self.auth = auth
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
from ssl import VerifyMode
2+
3+
import httpx
4+
5+
from fastmcp.client.transports import SSETransport, StreamableHttpTransport
6+
7+
8+
async def test_oauth_uses_same_client_as_transport_streamable_http():
9+
transport = StreamableHttpTransport(
10+
"https://some.fake.url/",
11+
httpx_client_factory=lambda *args, **kwargs: httpx.AsyncClient(
12+
verify=False, *args, **kwargs
13+
),
14+
auth="oauth",
15+
)
16+
17+
async with transport.auth.httpx_client_factory() as httpx_client: # type: ignore[attr-defined]
18+
assert (
19+
httpx_client._transport._pool._ssl_context.verify_mode
20+
== VerifyMode.CERT_NONE
21+
)
22+
23+
24+
async def test_oauth_uses_same_client_as_transport_sse():
25+
transport = SSETransport(
26+
"https://some.fake.url/",
27+
httpx_client_factory=lambda *args, **kwargs: httpx.AsyncClient(
28+
verify=False, *args, **kwargs
29+
),
30+
auth="oauth",
31+
)
32+
33+
async with transport.auth.httpx_client_factory() as httpx_client: # type: ignore[attr-defined]
34+
assert (
35+
httpx_client._transport._pool._ssl_context.verify_mode
36+
== VerifyMode.CERT_NONE
37+
)

0 commit comments

Comments
 (0)