Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/fastmcp/client/auth/oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from key_value.aio.protocols import AsyncKeyValue
from key_value.aio.stores.memory import MemoryStore
from mcp.client.auth import OAuthClientProvider, TokenStorage
from mcp.shared._httpx_utils import McpHttpClientFactory
from mcp.shared.auth import (
OAuthClientInformationFull,
OAuthClientMetadata,
Expand Down Expand Up @@ -147,6 +148,7 @@ def __init__(
token_storage: AsyncKeyValue | None = None,
additional_client_metadata: dict[str, Any] | None = None,
callback_port: int | None = None,
httpx_client_factory: McpHttpClientFactory | None = None,
):
"""
Initialize OAuth client provider for an MCP server.
Expand All @@ -164,6 +166,7 @@ def __init__(
server_base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"

# Setup OAuth client
self.httpx_client_factory = httpx_client_factory or httpx.AsyncClient
self.redirect_port = callback_port or find_available_port()
redirect_uri = f"http://localhost:{self.redirect_port}/callback"

Expand Down Expand Up @@ -226,7 +229,7 @@ async def _initialize(self) -> None:
async def redirect_handler(self, authorization_url: str) -> None:
"""Open browser for authorization, with pre-flight check for invalid client."""
# Pre-flight check to detect invalid client_id before opening browser
async with httpx.AsyncClient() as client:
async with self.httpx_client_factory() as client:
response = await client.get(authorization_url, follow_redirects=False)

# Check for client not found error (400 typically means bad client_id)
Expand Down
8 changes: 4 additions & 4 deletions src/fastmcp/client/transports.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,16 +177,16 @@ def __init__(

self.url = url
self.headers = headers or {}
self._set_auth(auth)
self.httpx_client_factory = httpx_client_factory
self._set_auth(auth)

if isinstance(sse_read_timeout, int | float):
sse_read_timeout = datetime.timedelta(seconds=float(sse_read_timeout))
self.sse_read_timeout = sse_read_timeout

def _set_auth(self, auth: httpx.Auth | Literal["oauth"] | str | None):
if auth == "oauth":
auth = OAuth(self.url)
auth = OAuth(self.url, httpx_client_factory=self.httpx_client_factory)
elif isinstance(auth, str):
auth = BearerAuth(auth)
self.auth = auth
Expand Down Expand Up @@ -247,16 +247,16 @@ def __init__(

self.url = url
self.headers = headers or {}
self._set_auth(auth)
self.httpx_client_factory = httpx_client_factory
self._set_auth(auth)

if isinstance(sse_read_timeout, int | float):
sse_read_timeout = datetime.timedelta(seconds=float(sse_read_timeout))
self.sse_read_timeout = sse_read_timeout

def _set_auth(self, auth: httpx.Auth | Literal["oauth"] | str | None):
if auth == "oauth":
auth = OAuth(self.url)
auth = OAuth(self.url, httpx_client_factory=self.httpx_client_factory)
elif isinstance(auth, str):
auth = BearerAuth(auth)
self.auth = auth
Expand Down
37 changes: 37 additions & 0 deletions tests/client/transports/test_transports.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from ssl import VerifyMode

import httpx

from fastmcp.client.transports import SSETransport, StreamableHttpTransport


async def test_oauth_uses_same_client_as_transport_streamable_http():
transport = StreamableHttpTransport(
"https://some.fake.url/",
httpx_client_factory=lambda *args, **kwargs: httpx.AsyncClient(
verify=False, *args, **kwargs
),
auth="oauth",
)

async with transport.auth.httpx_client_factory() as httpx_client: # type: ignore[attr-defined]
assert (
httpx_client._transport._pool._ssl_context.verify_mode
== VerifyMode.CERT_NONE
)


async def test_oauth_uses_same_client_as_transport_sse():
transport = SSETransport(
"https://some.fake.url/",
httpx_client_factory=lambda *args, **kwargs: httpx.AsyncClient(
verify=False, *args, **kwargs
),
auth="oauth",
)

async with transport.auth.httpx_client_factory() as httpx_client: # type: ignore[attr-defined]
assert (
httpx_client._transport._pool._ssl_context.verify_mode
== VerifyMode.CERT_NONE
)