Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions src/fastmcp/client/auth/oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from key_value.aio.protocols import AsyncKeyValue
from key_value.aio.stores.memory import MemoryStore
from mcp.client.auth import OAuthClientProvider, TokenStorage
from mcp.shared._httpx_utils import McpHttpClientFactory
from mcp.shared.auth import (
OAuthClientInformationFull,
OAuthClientMetadata,
Expand Down Expand Up @@ -46,7 +47,7 @@ async def check_if_auth_required(
Returns:
True if auth appears to be required, False otherwise
"""
async with httpx.AsyncClient(**(httpx_kwargs or {})) as client:
async with self.httpx_client_factory(**(httpx_kwargs or {})) as client:
try:
# Try a simple request to the endpoint
response = await client.get(mcp_url, timeout=5.0)
Expand Down Expand Up @@ -147,6 +148,7 @@ def __init__(
token_storage: AsyncKeyValue | None = None,
additional_client_metadata: dict[str, Any] | None = None,
callback_port: int | None = None,
httpx_client_factory: McpHttpClientFactory | None = None,
):
"""
Initialize OAuth client provider for an MCP server.
Expand All @@ -164,6 +166,7 @@ def __init__(
server_base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"

# Setup OAuth client
self.httpx_client_factory = httpx_client_factory or httpx.AsyncClient
self.redirect_port = callback_port or find_available_port()
redirect_uri = f"http://localhost:{self.redirect_port}/callback"

Expand Down Expand Up @@ -226,7 +229,7 @@ async def _initialize(self) -> None:
async def redirect_handler(self, authorization_url: str) -> None:
"""Open browser for authorization, with pre-flight check for invalid client."""
# Pre-flight check to detect invalid client_id before opening browser
async with httpx.AsyncClient() as client:
async with self.httpx_client_factory() as client:
response = await client.get(authorization_url, follow_redirects=False)

# Check for client not found error (400 typically means bad client_id)
Expand Down
4 changes: 2 additions & 2 deletions src/fastmcp/client/transports.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,16 +247,16 @@ def __init__(

self.url = url
self.headers = headers or {}
self._set_auth(auth)
self.httpx_client_factory = httpx_client_factory
self._set_auth(auth)

if isinstance(sse_read_timeout, int | float):
sse_read_timeout = datetime.timedelta(seconds=float(sse_read_timeout))
self.sse_read_timeout = sse_read_timeout

def _set_auth(self, auth: httpx.Auth | Literal["oauth"] | str | None):
if auth == "oauth":
auth = OAuth(self.url)
auth = OAuth(self.url, httpx_client_factory=self.httpx_client_factory)
elif isinstance(auth, str):
auth = BearerAuth(auth)
self.auth = auth
Expand Down
17 changes: 17 additions & 0 deletions tests/client/transports/test_transports.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from ssl import VerifyMode

from fastmcp.client import Client
from fastmcp.client.transports import StreamableHttpTransport
import httpx

async def test_oauth_uses_same_client_as_transport():
transport = StreamableHttpTransport(
"https://some.fake.url/",
httpx_client_factory=lambda *args, **kwargs: httpx.AsyncClient(verify=False, *args, **kwargs),
auth="oauth",
)
client = Client(transport=transport)

httpx_client = transport.auth.httpx_client_factory()

assert httpx_client._transport._pool._ssl_context.verify_mode == ssl.VerifyMode.CERT_NONE
Loading