diff --git a/docs/api.md b/docs/api.md index 3f696af54..3291f5c01 100644 --- a/docs/api.md +++ b/docs/api.md @@ -1 +1,5 @@ +The Python SDK exposes the entire `mcp` package for use in your own projects. +It includes an OAuth server implementation with support for the RFC 8693 +`token_exchange` grant type. + ::: mcp diff --git a/docs/index.md b/docs/index.md index 42ad9ca0c..dc0ffea32 100644 --- a/docs/index.md +++ b/docs/index.md @@ -3,3 +3,7 @@ This is the MCP Server implementation in Python. It only contains the [API Reference](api.md) for the time being. + +The built-in OAuth server supports the RFC 8693 `token_exchange` grant type, +allowing clients to exchange user tokens from external providers for MCP +access tokens. diff --git a/examples/servers/simple-auth/mcp_simple_auth/simple_auth_provider.py b/examples/servers/simple-auth/mcp_simple_auth/simple_auth_provider.py index 0f1092d7d..30091566a 100644 --- a/examples/servers/simple-auth/mcp_simple_auth/simple_auth_provider.py +++ b/examples/servers/simple-auth/mcp_simple_auth/simple_auth_provider.py @@ -238,6 +238,52 @@ async def exchange_authorization_code( scope=" ".join(authorization_code.scopes), ) + async def exchange_client_credentials(self, client: OAuthClientInformationFull, scopes: list[str]) -> OAuthToken: + """Exchange client credentials for an MCP access token.""" + mcp_token = f"mcp_{secrets.token_hex(32)}" + self.tokens[mcp_token] = AccessToken( + token=mcp_token, + client_id=client.client_id, + scopes=scopes, + expires_at=int(time.time()) + 3600, + ) + return OAuthToken( + access_token=mcp_token, + token_type="Bearer", + expires_in=3600, + scope=" ".join(scopes), + ) + + async def exchange_token( + self, + client: OAuthClientInformationFull, + subject_token: str, + subject_token_type: str, + actor_token: str | None, + actor_token_type: str | None, + scope: list[str] | None, + audience: str | None, + resource: str | None, + ) -> OAuthToken: + """Exchange an external token for an MCP access token.""" + if not subject_token: + raise ValueError("Invalid subject token") + + mcp_token = f"mcp_{secrets.token_hex(32)}" + self.tokens[mcp_token] = AccessToken( + token=mcp_token, + client_id=client.client_id, + scopes=scope or [self.settings.mcp_scope], + expires_at=int(time.time()) + 3600, + resource=resource, + ) + return OAuthToken( + access_token=mcp_token, + token_type="Bearer", + expires_in=3600, + scope=" ".join(scope or [self.settings.mcp_scope]), + ) + async def load_access_token(self, token: str) -> AccessToken | None: """Load and validate an access token.""" access_token = self.tokens.get(token) diff --git a/src/mcp/client/auth.py b/src/mcp/client/auth.py index 775fb0f6c..fef506fb5 100644 --- a/src/mcp/client/auth.py +++ b/src/mcp/client/auth.py @@ -125,7 +125,7 @@ def update_token_expiry(self, token: OAuthToken) -> None: self.token_expiry_time = None def is_token_valid(self) -> bool: - """Check if current token is valid.""" + """Check if the current token is valid.""" return bool( self.current_tokens and self.current_tokens.access_token @@ -133,7 +133,7 @@ def is_token_valid(self) -> bool: ) def can_refresh_token(self) -> bool: - """Check if token can be refreshed.""" + """Check if the token can be refreshed.""" return bool(self.current_tokens and self.current_tokens.refresh_token and self.client_info) def clear_tokens(self) -> None: @@ -176,7 +176,105 @@ def should_include_resource_param(self, protocol_version: str | None = None) -> return protocol_version >= "2025-06-18" -class OAuthClientProvider(httpx.Auth): +class BaseOAuthProvider(httpx.Auth): + """Common OAuth utilities for discovery, registration, and client auth.""" + + requires_response_body = True + + def __init__( + self, + server_url: str, + client_metadata: OAuthClientMetadata, + storage: TokenStorage, + timeout: float = 300.0, + ) -> None: + self.server_url = server_url + self.client_metadata = client_metadata + self.storage = storage + self.timeout = timeout + self._metadata: OAuthMetadata | None = None + self._client_info: OAuthClientInformationFull | None = None + + def _get_authorization_base_url(self, url: str) -> str: + parsed = urlparse(url) + return f"{parsed.scheme}://{parsed.netloc}" + + def _get_discovery_urls(self, server_url: str | None = None) -> list[str]: + url = server_url or self.server_url + parsed = urlparse(url) + base_url = f"{parsed.scheme}://{parsed.netloc}" + urls: list[str] = [] + + if parsed.path and parsed.path != "/": + oauth_path = f"/.well-known/oauth-authorization-server{parsed.path.rstrip('/')}" + urls.append(urljoin(base_url, oauth_path)) + urls.append(urljoin(base_url, "/.well-known/oauth-authorization-server")) + if parsed.path and parsed.path != "/": + oidc_path = f"/.well-known/openid-configuration{parsed.path.rstrip('/')}" + urls.append(urljoin(base_url, oidc_path)) + urls.append(f"{url.rstrip('/')}/.well-known/openid-configuration") + return urls + + def _create_oauth_metadata_request(self, url: str) -> httpx.Request: + return httpx.Request("GET", url, headers={MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION}) + + async def _handle_oauth_metadata_response(self, response: httpx.Response) -> None: + content = await response.aread() + metadata = OAuthMetadata.model_validate_json(content) + self._metadata = metadata + if self.client_metadata.scope is None and metadata.scopes_supported is not None: + self.client_metadata.scope = " ".join(metadata.scopes_supported) + + def _create_registration_request(self, metadata: OAuthMetadata | None = None) -> httpx.Request | None: + if self._client_info: + return None + if metadata and metadata.registration_endpoint: + registration_url = str(metadata.registration_endpoint) + else: + auth_base_url = self._get_authorization_base_url(self.server_url) + registration_url = urljoin(auth_base_url, "/register") + registration_data = self.client_metadata.model_dump(by_alias=True, mode="json", exclude_none=True) + return httpx.Request( + "POST", + registration_url, + json=registration_data, + headers={"Content-Type": "application/json"}, + ) + + async def _handle_registration_response(self, response: httpx.Response) -> None: + if response.status_code not in (200, 201): + await response.aread() + raise OAuthRegistrationError(f"Registration failed: {response.status_code} {response.text}") + content = await response.aread() + client_info = OAuthClientInformationFull.model_validate_json(content) + self._client_info = client_info + await self.storage.set_client_info(client_info) + + def _apply_client_auth( + self, + token_data: dict[str, str], + headers: dict[str, str], + client_info: OAuthClientInformationFull, + ) -> None: + auth_method = "client_secret_post" + if self._metadata and self._metadata.token_endpoint_auth_methods_supported: + supported = self._metadata.token_endpoint_auth_methods_supported + if "client_secret_basic" in supported: + auth_method = "client_secret_basic" + elif "client_secret_post" in supported: + auth_method = "client_secret_post" + if auth_method == "client_secret_basic": + if client_info.client_secret is None: + raise OAuthFlowError("Client secret required for client_secret_basic") + credential = f"{client_info.client_id}:{client_info.client_secret}" + headers["Authorization"] = f"Basic {base64.b64encode(credential.encode()).decode()}" + else: + token_data["client_id"] = client_info.client_id + if client_info.client_secret: + token_data["client_secret"] = client_info.client_secret + + +class OAuthClientProvider(BaseOAuthProvider): """ OAuth2 authentication for httpx. Handles OAuth flow with automatic client registration and token storage. @@ -194,6 +292,7 @@ def __init__( timeout: float = 300.0, ): """Initialize OAuth2 authentication.""" + super().__init__(server_url, client_metadata, storage, timeout) self.context = OAuthContext( server_url=server_url, client_metadata=client_metadata, @@ -251,63 +350,7 @@ async def _handle_protected_resource_response(self, response: httpx.Response) -> except ValidationError: pass - def _get_discovery_urls(self) -> list[str]: - """Generate ordered list of (url, type) tuples for discovery attempts.""" - urls: list[str] = [] - auth_server_url = self.context.auth_server_url or self.context.server_url - parsed = urlparse(auth_server_url) - base_url = f"{parsed.scheme}://{parsed.netloc}" - - # RFC 8414: Path-aware OAuth discovery - if parsed.path and parsed.path != "/": - oauth_path = f"/.well-known/oauth-authorization-server{parsed.path.rstrip('/')}" - urls.append(urljoin(base_url, oauth_path)) - - # OAuth root fallback - urls.append(urljoin(base_url, "/.well-known/oauth-authorization-server")) - - # RFC 8414 section 5: Path-aware OIDC discovery - # See https://www.rfc-editor.org/rfc/rfc8414.html#section-5 - if parsed.path and parsed.path != "/": - oidc_path = f"/.well-known/openid-configuration{parsed.path.rstrip('/')}" - urls.append(urljoin(base_url, oidc_path)) - - # OIDC 1.0 fallback (appends to full URL per OIDC spec) - oidc_fallback = f"{auth_server_url.rstrip('/')}/.well-known/openid-configuration" - urls.append(oidc_fallback) - - return urls - - async def _register_client(self) -> httpx.Request | None: - """Build registration request or skip if already registered.""" - if self.context.client_info: - return None - - if self.context.oauth_metadata and self.context.oauth_metadata.registration_endpoint: - registration_url = str(self.context.oauth_metadata.registration_endpoint) - else: - auth_base_url = self.context.get_authorization_base_url(self.context.server_url) - registration_url = urljoin(auth_base_url, "/register") - - registration_data = self.context.client_metadata.model_dump(by_alias=True, mode="json", exclude_none=True) - - return httpx.Request( - "POST", registration_url, json=registration_data, headers={"Content-Type": "application/json"} - ) - - async def _handle_registration_response(self, response: httpx.Response) -> None: - """Handle registration response.""" - if response.status_code not in (200, 201): - await response.aread() - raise OAuthRegistrationError(f"Registration failed: {response.status_code} {response.text}") - - try: - content = await response.aread() - client_info = OAuthClientInformationFull.model_validate_json(content) - self.context.client_info = client_info - await self.context.storage.set_client_info(client_info) - except ValidationError as e: - raise OAuthRegistrationError(f"Invalid registration response: {e}") + # Discovery and registration helpers provided by BaseOAuthProvider async def _perform_authorization(self) -> tuple[str, str]: """Perform the authorization redirect and get auth code.""" @@ -370,7 +413,6 @@ async def _exchange_token(self, auth_code: str, code_verifier: str) -> httpx.Req "grant_type": "authorization_code", "code": auth_code, "redirect_uri": str(self.context.client_metadata.redirect_uris[0]), - "client_id": self.context.client_info.client_id, "code_verifier": code_verifier, } @@ -378,12 +420,10 @@ async def _exchange_token(self, auth_code: str, code_verifier: str) -> httpx.Req if self.context.should_include_resource_param(self.context.protocol_version): token_data["resource"] = self.context.get_resource_url() # RFC 8707 - if self.context.client_info.client_secret: - token_data["client_secret"] = self.context.client_info.client_secret + headers = {"Content-Type": "application/x-www-form-urlencoded"} + self._apply_client_auth(token_data, headers, self.context.client_info) - return httpx.Request( - "POST", token_url, data=token_data, headers={"Content-Type": "application/x-www-form-urlencoded"} - ) + return httpx.Request("POST", token_url, data=token_data, headers=headers) async def _handle_token_response(self, response: httpx.Response) -> None: """Handle token exchange response.""" @@ -425,19 +465,16 @@ async def _refresh_token(self) -> httpx.Request: refresh_data = { "grant_type": "refresh_token", "refresh_token": self.context.current_tokens.refresh_token, - "client_id": self.context.client_info.client_id, } # Only include resource param if conditions are met if self.context.should_include_resource_param(self.context.protocol_version): refresh_data["resource"] = self.context.get_resource_url() # RFC 8707 - if self.context.client_info.client_secret: - refresh_data["client_secret"] = self.context.client_info.client_secret + headers = {"Content-Type": "application/x-www-form-urlencoded"} + self._apply_client_auth(refresh_data, headers, self.context.client_info) - return httpx.Request( - "POST", token_url, data=refresh_data, headers={"Content-Type": "application/x-www-form-urlencoded"} - ) + return httpx.Request("POST", token_url, data=refresh_data, headers=headers) async def _handle_refresh_response(self, response: httpx.Response) -> bool: """Handle token refresh response. Returns True if successful.""" @@ -471,17 +508,6 @@ def _add_auth_header(self, request: httpx.Request) -> None: if self.context.current_tokens and self.context.current_tokens.access_token: request.headers["Authorization"] = f"Bearer {self.context.current_tokens.access_token}" - def _create_oauth_metadata_request(self, url: str) -> httpx.Request: - return httpx.Request("GET", url, headers={MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION}) - - async def _handle_oauth_metadata_response(self, response: httpx.Response) -> None: - content = await response.aread() - metadata = OAuthMetadata.model_validate_json(content) - self.context.oauth_metadata = metadata - # Apply default scope if needed - if self.context.client_metadata.scope is None and metadata.scopes_supported is not None: - self.context.client_metadata.scope = " ".join(metadata.scopes_supported) - async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.Request, httpx.Response]: """HTTPX auth flow integration.""" async with self.context.lock: @@ -515,7 +541,7 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. await self._handle_protected_resource_response(discovery_response) # Step 2: Discover OAuth metadata (with fallback for legacy servers) - discovery_urls = self._get_discovery_urls() + discovery_urls = self._get_discovery_urls(self.context.auth_server_url or self.context.server_url) for url in discovery_urls: oauth_metadata_request = self._create_oauth_metadata_request(url) oauth_metadata_response = yield oauth_metadata_request @@ -523,6 +549,7 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. if oauth_metadata_response.status_code == 200: try: await self._handle_oauth_metadata_response(oauth_metadata_response) + self.context.oauth_metadata = self._metadata break except ValidationError: continue @@ -530,10 +557,11 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. break # Non-4XX error, stop trying # Step 3: Register client if needed - registration_request = await self._register_client() + registration_request = self._create_registration_request(self._metadata) if registration_request: registration_response = yield registration_request await self._handle_registration_response(registration_response) + self.context.client_info = self._client_info # Step 4: Perform authorization auth_code, code_verifier = await self._perform_authorization() @@ -549,3 +577,278 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. # Retry with new tokens self._add_auth_header(request) yield request + + +class ClientCredentialsProvider(BaseOAuthProvider): + """HTTPX auth using the OAuth2 client credentials grant.""" + + def __init__( + self, + server_url: str, + client_metadata: OAuthClientMetadata, + storage: TokenStorage, + resource: str | None = None, + timeout: float = 300.0, + ) -> None: + super().__init__(server_url, client_metadata, storage, timeout) + self.resource = resource or resource_url_from_server_url(server_url) + self._current_tokens: OAuthToken | None = None + self._token_expiry_time: float | None = None + self._token_lock = anyio.Lock() + + def _has_valid_token(self) -> bool: + if not self._current_tokens or not self._current_tokens.access_token: + return False + if self._token_expiry_time and time.time() > self._token_expiry_time: + return False + return True + + async def _validate_token_scopes(self, token_response: OAuthToken) -> None: + if not token_response.scope: + return + requested_scopes: set[str] = set() + if self.client_metadata.scope: + requested_scopes = set(self.client_metadata.scope.split()) + returned_scopes = set(token_response.scope.split()) + unauthorized_scopes = returned_scopes - requested_scopes + if unauthorized_scopes: + raise Exception(f"Server granted unauthorized scopes: {unauthorized_scopes}.") + else: + granted = set(token_response.scope.split()) + logger.debug( + "No explicit scopes requested, accepting server-granted scopes: %s", + granted, + ) + + async def initialize(self) -> None: + self._current_tokens = await self.storage.get_tokens() + self._client_info = await self.storage.get_client_info() + + async def _get_or_register_client(self) -> OAuthClientInformationFull: + if not self._client_info: + request = self._create_registration_request(self._metadata) + if request: + async with httpx.AsyncClient(timeout=self.timeout) as client: + response: httpx.Response = await client.send(request) + await self._handle_registration_response(response) + assert self._client_info + return self._client_info + + async def _request_token(self) -> None: + if not self._metadata: + discovery_urls = self._get_discovery_urls(self.server_url) + async with httpx.AsyncClient(timeout=self.timeout) as client: + for url in discovery_urls: + req = self._create_oauth_metadata_request(url) + resp: httpx.Response = await client.send(req) + if resp.status_code == 200: + try: + await self._handle_oauth_metadata_response(resp) + break + except ValidationError: + continue + elif resp.status_code < 400 or resp.status_code >= 500: + break + + client_info = await self._get_or_register_client() + + if self._metadata and self._metadata.token_endpoint: + token_url = str(self._metadata.token_endpoint) + else: + auth_base_url = self._get_authorization_base_url(self.server_url) + token_url = urljoin(auth_base_url, "/token") + + token_data: dict[str, str] = { + "grant_type": "client_credentials", + "resource": self.resource, + } + if self.client_metadata.scope: + token_data["scope"] = self.client_metadata.scope + headers = {"Content-Type": "application/x-www-form-urlencoded"} + self._apply_client_auth(token_data, headers, client_info) + + async with httpx.AsyncClient(timeout=self.timeout) as client: + response: httpx.Response = await client.post( + token_url, + data=token_data, + headers=headers, + ) + + if response.status_code != 200: + raise Exception(f"Token request failed: {response.status_code} {response.text}") + + token_response = OAuthToken.model_validate(response.json()) + await self._validate_token_scopes(token_response) + + if token_response.expires_in: + self._token_expiry_time = time.time() + token_response.expires_in + else: + self._token_expiry_time = None + + await self.storage.set_tokens(token_response) + self._current_tokens = token_response + + async def ensure_token(self) -> None: + async with self._token_lock: + if self._has_valid_token(): + return + await self._request_token() + + async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.Request, httpx.Response]: + if not self._has_valid_token(): + await self.initialize() + await self.ensure_token() + if self._current_tokens and self._current_tokens.access_token: + request.headers["Authorization"] = f"Bearer {self._current_tokens.access_token}" + response = yield request + if response.status_code == 401: + self._current_tokens = None + + +class TokenExchangeProvider(BaseOAuthProvider): + """OAuth2 token exchange based on RFC 8693.""" + + def __init__( + self, + server_url: str, + client_metadata: OAuthClientMetadata, + storage: TokenStorage, + subject_token_supplier: Callable[[], Awaitable[str]], + subject_token_type: str = "access_token", + actor_token_supplier: Callable[[], Awaitable[str]] | None = None, + actor_token_type: str | None = None, + audience: str | None = None, + resource: str | None = None, + timeout: float = 300.0, + ) -> None: + super().__init__(server_url, client_metadata, storage, timeout) + self.subject_token_supplier = subject_token_supplier + self.subject_token_type = subject_token_type + self.actor_token_supplier = actor_token_supplier + self.actor_token_type = actor_token_type + self.audience = audience + self.resource = resource or resource_url_from_server_url(server_url) + self._current_tokens: OAuthToken | None = None + self._token_expiry_time: float | None = None + self._token_lock = anyio.Lock() + + def _has_valid_token(self) -> bool: + if not self._current_tokens or not self._current_tokens.access_token: + return False + if self._token_expiry_time and time.time() > self._token_expiry_time: + return False + return True + + async def _validate_token_scopes(self, token_response: OAuthToken) -> None: + if not token_response.scope: + return + requested_scopes: set[str] = set() + if self.client_metadata.scope: + requested_scopes = set(self.client_metadata.scope.split()) + returned_scopes = set(token_response.scope.split()) + unauthorized_scopes = returned_scopes - requested_scopes + if unauthorized_scopes: + raise Exception(f"Server granted unauthorized scopes: {unauthorized_scopes}.") + else: + granted = set(token_response.scope.split()) + logger.debug( + "No explicit scopes requested, accepting server-granted scopes: %s", + granted, + ) + + async def initialize(self) -> None: + self._current_tokens = await self.storage.get_tokens() + self._client_info = await self.storage.get_client_info() + + async def _get_or_register_client(self) -> OAuthClientInformationFull: + if not self._client_info: + request = self._create_registration_request(self._metadata) + if request: + async with httpx.AsyncClient(timeout=self.timeout) as client: + response: httpx.Response = await client.send(request) + await self._handle_registration_response(response) + assert self._client_info + return self._client_info + + async def _request_token(self) -> None: + if not self._metadata: + discovery_urls = self._get_discovery_urls(self.server_url) + async with httpx.AsyncClient(timeout=self.timeout) as client: + for url in discovery_urls: + req = self._create_oauth_metadata_request(url) + resp: httpx.Response = await client.send(req) + if resp.status_code == 200: + try: + await self._handle_oauth_metadata_response(resp) + break + except ValidationError: + continue + elif resp.status_code < 400 or resp.status_code >= 500: + break + + client_info = await self._get_or_register_client() + + if self._metadata and self._metadata.token_endpoint: + token_url = str(self._metadata.token_endpoint) + else: + auth_base_url = self._get_authorization_base_url(self.server_url) + token_url = urljoin(auth_base_url, "/token") + + subject_token = await self.subject_token_supplier() + actor_token = await self.actor_token_supplier() if self.actor_token_supplier else None + + token_data: dict[str, str] = { + "grant_type": "token_exchange", + "subject_token": subject_token, + "subject_token_type": self.subject_token_type, + } + if actor_token: + token_data["actor_token"] = actor_token + if self.actor_token_type: + token_data["actor_token_type"] = self.actor_token_type + if self.audience: + token_data["audience"] = self.audience + if self.resource: + token_data["resource"] = self.resource + if self.client_metadata.scope: + token_data["scope"] = self.client_metadata.scope + + headers = {"Content-Type": "application/x-www-form-urlencoded"} + self._apply_client_auth(token_data, headers, client_info) + + async with httpx.AsyncClient(timeout=self.timeout) as client: + response: httpx.Response = await client.post( + token_url, + data=token_data, + headers=headers, + ) + + if response.status_code != 200: + raise Exception(f"Token request failed: {response.status_code} {response.text}") + + token_response = OAuthToken.model_validate(response.json()) + await self._validate_token_scopes(token_response) + + if token_response.expires_in: + self._token_expiry_time = time.time() + token_response.expires_in + else: + self._token_expiry_time = None + + await self.storage.set_tokens(token_response) + self._current_tokens = token_response + + async def ensure_token(self) -> None: + async with self._token_lock: + if self._has_valid_token(): + return + await self._request_token() + + async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.Request, httpx.Response]: + if not self._has_valid_token(): + await self.initialize() + await self.ensure_token() + if self._current_tokens and self._current_tokens.access_token: + request.headers["Authorization"] = f"Bearer {self._current_tokens.access_token}" + response = yield request + if response.status_code == 401: + self._current_tokens = None diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index b1ab2c079..245af3c0e 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -173,9 +173,11 @@ async def _handle_sse_event( session_message = SessionMessage(message) await read_stream_writer.send(session_message) - # Call resumption token callback if we have an ID - if sse.id and resumption_callback: - await resumption_callback(sse.id) + # Call resumption token callback if we have an ID. Only update + # the resumption token on notifications to avoid overwriting it + # with the token from the final response. + if sse.id and resumption_callback and not isinstance(message.root, JSONRPCResponse | JSONRPCError): + await resumption_callback(sse.id.strip()) # If this is a response or error return True indicating completion # Otherwise, return False to continue listening @@ -221,7 +223,7 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None: """Handle a resumption request using GET with SSE.""" headers = self._prepare_request_headers(ctx.headers) if ctx.metadata and ctx.metadata.resumption_token: - headers[LAST_EVENT_ID] = ctx.metadata.resumption_token + headers[LAST_EVENT_ID] = ctx.metadata.resumption_token.strip() else: raise ResumptionError("Resumption request requires a resumption token") diff --git a/src/mcp/server/auth/handlers/register.py b/src/mcp/server/auth/handlers/register.py index e6d99e66d..b211e238f 100644 --- a/src/mcp/server/auth/handlers/register.py +++ b/src/mcp/server/auth/handlers/register.py @@ -68,11 +68,22 @@ async def handle(self, request: Request) -> Response: ), status_code=400, ) - if set(client_metadata.grant_types) != {"authorization_code", "refresh_token"}: + grant_types_set: set[str] = set(client_metadata.grant_types) + valid_sets = [ + {"authorization_code", "refresh_token"}, + {"client_credentials"}, + {"token_exchange"}, + {"client_credentials", "token_exchange"}, + ] + + if grant_types_set not in valid_sets: return PydanticJSONResponse( content=RegistrationErrorResponse( error="invalid_client_metadata", - error_description="grant_types must be authorization_code and refresh_token", + error_description=( + "grant_types must be authorization_code and refresh_token " + "or client_credentials or token exchange or client_credentials and token_exchange" + ), ), status_code=400, ) diff --git a/src/mcp/server/auth/handlers/token.py b/src/mcp/server/auth/handlers/token.py index 4e15e6265..e39b4ef1e 100644 --- a/src/mcp/server/auth/handlers/token.py +++ b/src/mcp/server/auth/handlers/token.py @@ -40,16 +40,39 @@ class RefreshTokenRequest(BaseModel): resource: str | None = Field(None, description="Resource indicator for the token") +class ClientCredentialsRequest(BaseModel): + # See https://datatracker.ietf.org/doc/html/rfc6749#section-4.4 + grant_type: Literal["client_credentials"] + scope: str | None = Field(None, description="Optional scope parameter") + client_id: str + client_secret: str | None = None + + +class TokenExchangeRequest(BaseModel): + """RFC 8693 token exchange request.""" + + grant_type: Literal["token_exchange"] + subject_token: str = Field(..., description="Token to exchange") + subject_token_type: str = Field(..., description="Type of the subject token") + actor_token: str | None = Field(None, description="Optional actor token") + actor_token_type: str | None = Field(None, description="Type of the actor token if provided") + resource: str | None = None + audience: str | None = None + scope: str | None = None + client_id: str + client_secret: str | None = None + + class TokenRequest( RootModel[ Annotated[ - AuthorizationCodeRequest | RefreshTokenRequest, + AuthorizationCodeRequest | RefreshTokenRequest | ClientCredentialsRequest | TokenExchangeRequest, Field(discriminator="grant_type"), ] ] ): root: Annotated[ - AuthorizationCodeRequest | RefreshTokenRequest, + AuthorizationCodeRequest | RefreshTokenRequest | ClientCredentialsRequest | TokenExchangeRequest, Field(discriminator="grant_type"), ] @@ -192,10 +215,49 @@ async def handle(self, request: Request): ) ) + case ClientCredentialsRequest(): + scopes = ( + token_request.scope.split(" ") + if token_request.scope + else client_info.scope.split(" ") + if client_info.scope + else [] + ) + try: + tokens = await self.provider.exchange_client_credentials(client_info, scopes) + except TokenError as e: + return self.response( + TokenErrorResponse( + error=e.error, + error_description=e.error_description, + ) + ) + + case TokenExchangeRequest(): + scopes = token_request.scope.split(" ") if token_request.scope else [] + try: + tokens = await self.provider.exchange_token( + client_info, + token_request.subject_token, + token_request.subject_token_type, + token_request.actor_token, + token_request.actor_token_type, + scopes, + token_request.audience, + token_request.resource, + ) + except TokenError as e: + return self.response( + TokenErrorResponse( + error=e.error, + error_description=e.error_description, + ) + ) + case RefreshTokenRequest(): refresh_token = await self.provider.load_refresh_token(client_info, token_request.refresh_token) if refresh_token is None or refresh_token.client_id != token_request.client_id: - # if token belongs to different client, pretend it doesn't exist + # if token belongs to a different client, pretend it doesn't exist return self.response( TokenErrorResponse( error="invalid_grant", diff --git a/src/mcp/server/auth/provider.py b/src/mcp/server/auth/provider.py index b84db89a2..e4de4ecf8 100644 --- a/src/mcp/server/auth/provider.py +++ b/src/mcp/server/auth/provider.py @@ -80,6 +80,7 @@ class AuthorizeError(Exception): "unauthorized_client", "unsupported_grant_type", "invalid_scope", + "invalid_target", ] @@ -248,6 +249,24 @@ async def exchange_refresh_token( """ ... + async def exchange_client_credentials(self, client: OAuthClientInformationFull, scopes: list[str]) -> OAuthToken: + """Exchange client credentials for an MCP access token.""" + ... + + async def exchange_token( + self, + client: OAuthClientInformationFull, + subject_token: str, + subject_token_type: str, + actor_token: str | None, + actor_token_type: str | None, + scope: list[str] | None, + audience: str | None, + resource: str | None, + ) -> OAuthToken: + """Exchange an external token for an MCP access token.""" + ... + async def load_access_token(self, token: str) -> AccessTokenT | None: """ Loads an access token by its token. diff --git a/src/mcp/server/auth/routes.py b/src/mcp/server/auth/routes.py index bce32df52..eeefcf9a3 100644 --- a/src/mcp/server/auth/routes.py +++ b/src/mcp/server/auth/routes.py @@ -163,7 +163,12 @@ def build_metadata( scopes_supported=client_registration_options.valid_scopes, response_types_supported=["code"], response_modes_supported=None, - grant_types_supported=["authorization_code", "refresh_token"], + grant_types_supported=[ + "authorization_code", + "refresh_token", + "client_credentials", + "token_exchange", + ], token_endpoint_auth_methods_supported=["client_secret_post"], token_endpoint_auth_signing_alg_values_supported=None, service_documentation=service_documentation_url, diff --git a/src/mcp/shared/auth.py b/src/mcp/shared/auth.py index 6bf15b531..016e52578 100644 --- a/src/mcp/shared/auth.py +++ b/src/mcp/shared/auth.py @@ -13,6 +13,7 @@ class OAuthToken(BaseModel): expires_in: int | None = None scope: str | None = None refresh_token: str | None = None + issued_token_type: str | None = None @field_validator("token_type", mode="before") @classmethod @@ -46,8 +47,15 @@ class OAuthClientMetadata(BaseModel): # client_secret_post; # ie: we do not support client_secret_basic token_endpoint_auth_method: Literal["none", "client_secret_post"] = "client_secret_post" - # grant_types: this implementation only supports authorization_code & refresh_token - grant_types: list[Literal["authorization_code", "refresh_token"]] = [ + # grant_types: this implementation supports authorization_code, refresh_token, client_credentials, & token_exchange + grant_types: list[ + Literal[ + "authorization_code", + "refresh_token", + "client_credentials", + "token_exchange", + ] + ] = [ "authorization_code", "refresh_token", ] @@ -114,10 +122,35 @@ class OAuthMetadata(BaseModel): registration_endpoint: AnyHttpUrl | None = None scopes_supported: list[str] | None = None response_types_supported: list[str] = ["code"] - response_modes_supported: list[str] | None = None - grant_types_supported: list[str] | None = None - token_endpoint_auth_methods_supported: list[str] | None = None - token_endpoint_auth_signing_alg_values_supported: list[str] | None = None + response_modes_supported: ( + list[ + Literal[ + "query", + "fragment", + "form_post", + "query.jwt", + "fragment.jwt", + "form_post.jwt", + "jwt", + ] + ] + | None + ) = None + grant_types_supported: ( + list[ + Literal[ + "authorization_code", + "refresh_token", + "client_credentials", + "token_exchange", + ] + ] + | None + ) = None + token_endpoint_auth_methods_supported: list[Literal["none", "client_secret_post", "client_secret_basic"]] | None = ( + None + ) + token_endpoint_auth_signing_alg_values_supported: None = None service_documentation: AnyHttpUrl | None = None ui_locales_supported: list[str] | None = None op_policy_uri: AnyHttpUrl | None = None diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index b2f49fc8b..865f9c973 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -312,7 +312,10 @@ async def send_notification( message=JSONRPCMessage(jsonrpc_notification), metadata=ServerMessageMetadata(related_request_id=related_request_id) if related_request_id else None, ) - await self._write_stream.send(session_message) + try: + await self._write_stream.send(session_message) + except (anyio.BrokenResourceError, anyio.ClosedResourceError): + logging.debug("Discarding notification due to closed stream") async def _send_response(self, request_id: RequestId, response: SendResultT | ErrorData) -> None: if isinstance(response, ErrorData): diff --git a/tests/client/test_auth.py b/tests/client/test_auth.py index 61d74df1e..7c48cad95 100644 --- a/tests/client/test_auth.py +++ b/tests/client/test_auth.py @@ -1,17 +1,31 @@ -""" -Tests for refactored OAuth client authentication implementation. -""" +"""Tests for refactored OAuth client authentication implementation.""" +# pyright: reportUnknownParameterType=false, reportUnknownVariableType=false, reportUnknownMemberType=false + +import asyncio import time -from unittest import mock +from collections.abc import AsyncGenerator +from typing import Any +from unittest.mock import AsyncMock, Mock, patch import httpx import pytest from inline_snapshot import Is, snapshot from pydantic import AnyHttpUrl, AnyUrl -from mcp.client.auth import OAuthClientProvider, PKCEParameters -from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken, ProtectedResourceMetadata +from mcp.client.auth import ( + ClientCredentialsProvider, + OAuthClientProvider, + PKCEParameters, + TokenExchangeProvider, +) +from mcp.shared.auth import ( + OAuthClientInformationFull, + OAuthClientMetadata, + OAuthMetadata, + OAuthToken, + ProtectedResourceMetadata, +) class MockTokenStorage: @@ -79,6 +93,79 @@ async def callback_handler() -> tuple[str, str | None]: ) +@pytest.fixture +def client_credentials_metadata(): + return OAuthClientMetadata( + redirect_uris=[AnyHttpUrl("http://localhost:3000/callback")], + client_name="CC Client", + grant_types=["client_credentials"], + response_types=["code"], + scope="read write", + token_endpoint_auth_method="client_secret_post", + ) + + +@pytest.fixture +def oauth_metadata(): + return OAuthMetadata( + issuer=AnyHttpUrl("https://auth.example.com"), + authorization_endpoint=AnyHttpUrl("https://auth.example.com/authorize"), + token_endpoint=AnyHttpUrl("https://auth.example.com/token"), + registration_endpoint=AnyHttpUrl("https://auth.example.com/register"), + scopes_supported=["read", "write", "admin"], + response_types_supported=["code"], + grant_types_supported=["authorization_code", "refresh_token", "client_credentials"], + code_challenge_methods_supported=["S256"], + ) + + +@pytest.fixture +def oauth_client_info(): + return OAuthClientInformationFull( + client_id="test_client_id", + client_secret="test_client_secret", + redirect_uris=[AnyUrl("http://localhost:3000/callback")], + client_name="Test Client", + grant_types=["authorization_code", "refresh_token"], + response_types=["code"], + scope="read write", + ) + + +@pytest.fixture +def oauth_token(): + return OAuthToken( + access_token="test_access_token", + token_type="Bearer", + expires_in=3600, + refresh_token="test_refresh_token", + scope="read write", + ) + + +@pytest.fixture +async def client_credentials_provider( + client_credentials_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage +) -> ClientCredentialsProvider: + return ClientCredentialsProvider( + server_url="https://api.example.com/v1/mcp", + client_metadata=client_credentials_metadata, + storage=mock_storage, + ) + + +@pytest.fixture +async def token_exchange_provider( + client_credentials_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage +) -> TokenExchangeProvider: + return TokenExchangeProvider( + server_url="https://api.example.com/v1/mcp", + client_metadata=client_credentials_metadata, + storage=mock_storage, + subject_token_supplier=lambda: asyncio.sleep(0, result="user_token"), + ) + + class TestPKCEParameters: """Test PKCE parameter generation.""" @@ -345,14 +432,22 @@ async def test_oauth_discovery_fallback_conditions(self, oauth_provider: OAuthCl ) # Mock the authorization process to minimize unnecessary state in this test - oauth_provider._perform_authorization = mock.AsyncMock(return_value=("test_auth_code", "test_code_verifier")) + oauth_provider._perform_authorization = AsyncMock(return_value=("test_auth_code", "test_code_verifier")) + + # Next request should fall back to legacy behavior: register then obtain token + registration_request = await auth_flow.asend(oauth_metadata_response_3) + assert str(registration_request.url) == "https://api.example.com/register" + assert registration_request.method == "POST" - # Next request should fall back to legacy behavior and auth with the RS (mocked /authorize, next is /token) - token_request = await auth_flow.asend(oauth_metadata_response_3) + registration_response = httpx.Response( + 200, + content=b'{"client_id":"c","redirect_uris":["http://localhost:3030/callback"]}', + request=registration_request, + ) + token_request = await auth_flow.asend(registration_response) assert str(token_request.url) == "https://api.example.com/token" assert token_request.method == "POST" - # Send a successful token response token_response = httpx.Response( 200, content=( @@ -361,7 +456,7 @@ async def test_oauth_discovery_fallback_conditions(self, oauth_provider: OAuthCl ), request=token_request, ) - token_request = await auth_flow.asend(token_response) + await auth_flow.asend(token_response) @pytest.mark.anyio async def test_handle_metadata_response_success(self, oauth_provider: OAuthClientProvider): @@ -376,13 +471,13 @@ async def test_handle_metadata_response_success(self, oauth_provider: OAuthClien # Should set metadata await oauth_provider._handle_oauth_metadata_response(response) - assert oauth_provider.context.oauth_metadata is not None - assert str(oauth_provider.context.oauth_metadata.issuer) == "https://auth.example.com/" + assert oauth_provider._metadata is not None + assert str(oauth_provider._metadata.issuer) == "https://auth.example.com/" @pytest.mark.anyio async def test_register_client_request(self, oauth_provider: OAuthClientProvider): """Test client registration request building.""" - request = await oauth_provider._register_client() + request = oauth_provider._create_registration_request(oauth_provider.context.oauth_metadata) assert request is not None assert request.method == "POST" @@ -398,9 +493,10 @@ async def test_register_client_skip_if_registered(self, oauth_provider: OAuthCli redirect_uris=[AnyUrl("http://localhost:3030/callback")], ) oauth_provider.context.client_info = client_info + oauth_provider._client_info = client_info # Should return None (skip registration) - request = await oauth_provider._register_client() + request = oauth_provider._create_registration_request(oauth_provider.context.oauth_metadata) assert request is None @pytest.mark.anyio @@ -670,7 +766,7 @@ async def test_auth_flow_with_no_tokens(self, oauth_provider: OAuthClientProvide ) # Mock the authorization process - oauth_provider._perform_authorization = mock.AsyncMock(return_value=("test_auth_code", "test_code_verifier")) + oauth_provider._perform_authorization = AsyncMock(return_value=("test_auth_code", "test_code_verifier")) # Next request should be to exchange token token_request = await auth_flow.asend(registration_response) @@ -700,6 +796,91 @@ async def test_auth_flow_with_no_tokens(self, oauth_provider: OAuthClientProvide assert oauth_provider.context.token_expiry_time is not None +class TestClientCredentialsProvider: + @pytest.mark.anyio + async def test_request_token_success( + self, + client_credentials_provider: ClientCredentialsProvider, + oauth_metadata: OAuthMetadata, + oauth_client_info: OAuthClientInformationFull, + oauth_token: OAuthToken, + ) -> None: + client_credentials_provider._metadata = oauth_metadata + client_credentials_provider._client_info = oauth_client_info + + token_json: dict[str, Any] = oauth_token.model_dump(by_alias=True, mode="json") + token_json.pop("refresh_token", None) + + with patch("httpx.AsyncClient") as mock_client_class: + mock_client = AsyncMock() + mock_client_class.return_value.__aenter__.return_value = mock_client + + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = token_json + mock_client.post.return_value = mock_response + + await client_credentials_provider.ensure_token() + + mock_client.post.assert_called_once() + _args, kwargs = mock_client.post.call_args + assert kwargs["data"]["resource"] == "https://api.example.com/v1/mcp" + assert client_credentials_provider._current_tokens is not None + assert client_credentials_provider._current_tokens.access_token == oauth_token.access_token + + @pytest.mark.anyio + async def test_async_auth_flow( + self, client_credentials_provider: ClientCredentialsProvider, oauth_token: OAuthToken + ) -> None: + client_credentials_provider._current_tokens = oauth_token + client_credentials_provider._token_expiry_time = time.time() + 3600 + + request = httpx.Request("GET", "https://api.example.com/data") + mock_response = Mock() + mock_response.status_code = 200 + + auth_flow: AsyncGenerator[httpx.Request, httpx.Response] = client_credentials_provider.async_auth_flow(request) + updated_request = await auth_flow.__anext__() + assert updated_request.headers["Authorization"] == f"Bearer {oauth_token.access_token}" + try: + await auth_flow.asend(mock_response) + except StopAsyncIteration: + pass + + +class TestTokenExchangeProvider: + @pytest.mark.anyio + async def test_request_token_success( + self, + token_exchange_provider: TokenExchangeProvider, + oauth_metadata: OAuthMetadata, + oauth_client_info: OAuthClientInformationFull, + oauth_token: OAuthToken, + ) -> None: + token_exchange_provider._metadata = oauth_metadata + token_exchange_provider._client_info = oauth_client_info + + token_json: dict[str, Any] = oauth_token.model_dump(by_alias=True, mode="json") + token_json.pop("refresh_token", None) + + with patch("httpx.AsyncClient") as mock_client_class: + mock_client = AsyncMock() + mock_client_class.return_value.__aenter__.return_value = mock_client + + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = token_json + mock_client.post.return_value = mock_response + + await token_exchange_provider.ensure_token() + + mock_client.post.assert_called_once() + _args, kwargs = mock_client.post.call_args + assert kwargs["data"]["resource"] == "https://api.example.com/v1/mcp" + assert token_exchange_provider._current_tokens is not None + assert token_exchange_provider._current_tokens.access_token == oauth_token.access_token + + @pytest.mark.parametrize( ( "issuer_url", @@ -769,7 +950,12 @@ def test_build_metadata( "token_endpoint": Is(token_endpoint), "registration_endpoint": Is(registration_endpoint), "scopes_supported": ["read", "write", "admin"], - "grant_types_supported": ["authorization_code", "refresh_token"], + "grant_types_supported": [ + "authorization_code", + "refresh_token", + "client_credentials", + "token_exchange", + ], "token_endpoint_auth_methods_supported": ["client_secret_post"], "service_documentation": Is(service_documentation_url), "revocation_endpoint": Is(revocation_endpoint), diff --git a/tests/issues/test_88_random_error.py b/tests/issues/test_88_random_error.py index 5584abcae..0883ad204 100644 --- a/tests/issues/test_88_random_error.py +++ b/tests/issues/test_88_random_error.py @@ -83,7 +83,12 @@ async def client( # - Long enough for fast operations (>10ms) # - Short enough for slow operations (<200ms) # - Not too short to avoid flakiness - async with ClientSession(read_stream, write_stream, read_timeout_seconds=timedelta(milliseconds=50)) as session: + async with ClientSession( + read_stream, + write_stream, + # Increased to 150ms to avoid flakiness on slower platforms + read_timeout_seconds=timedelta(milliseconds=150), + ) as session: await session.initialize() # First call should work (fast operation) diff --git a/tests/issues/test_malformed_input.py b/tests/issues/test_malformed_input.py index 065bc7841..6d6e22a42 100644 --- a/tests/issues/test_malformed_input.py +++ b/tests/issues/test_malformed_input.py @@ -76,6 +76,7 @@ async def test_malformed_initialize_request_does_not_crash_server(): method="tools/call", # params=None # Missing required params ) + another_request_message = SessionMessage(message=JSONRPCMessage(another_malformed_request)) await read_send_stream.send(another_request_message) diff --git a/tests/server/fastmcp/auth/test_auth_integration.py b/tests/server/fastmcp/auth/test_auth_integration.py index e4bb17397..352f0f0dc 100644 --- a/tests/server/fastmcp/auth/test_auth_integration.py +++ b/tests/server/fastmcp/auth/test_auth_integration.py @@ -21,6 +21,7 @@ AuthorizationParams, OAuthAuthorizationServerProvider, RefreshToken, + TokenError, construct_redirect_uri, ) from mcp.server.auth.routes import ClientRegistrationOptions, RevocationOptions, create_auth_routes @@ -154,6 +155,49 @@ async def exchange_refresh_token( refresh_token=new_refresh_token, ) + async def exchange_client_credentials(self, client: OAuthClientInformationFull, scopes: list[str]) -> OAuthToken: + access_token = f"access_{secrets.token_hex(32)}" + self.tokens[access_token] = AccessToken( + token=access_token, + client_id=client.client_id, + scopes=scopes, + expires_at=int(time.time()) + 3600, + ) + return OAuthToken( + access_token=access_token, + token_type="Bearer", + expires_in=3600, + scope=" ".join(scopes), + ) + + async def exchange_token( + self, + client: OAuthClientInformationFull, + subject_token: str, + subject_token_type: str, + actor_token: str | None, + actor_token_type: str | None, + scope: list[str] | None, + audience: str | None, + resource: str | None, + ) -> OAuthToken: + if subject_token == "bad_token": + raise TokenError("invalid_grant", "invalid subject token") + + access_token = f"exchanged_{secrets.token_hex(32)}" + self.tokens[access_token] = AccessToken( + token=access_token, + client_id=client.client_id, + scopes=scope or ["read"], + expires_at=int(time.time()) + 3600, + ) + return OAuthToken( + access_token=access_token, + token_type="Bearer", + expires_in=3600, + scope=" ".join(scope or ["read"]), + ) + async def load_access_token(self, token: str) -> AccessToken | None: token_info = self.tokens.get(token) @@ -361,6 +405,8 @@ async def test_metadata_endpoint(self, test_client: httpx.AsyncClient): assert metadata["grant_types_supported"] == [ "authorization_code", "refresh_token", + "client_credentials", + "token_exchange", ] assert metadata["service_documentation"] == "https://docs.example.com/" @@ -940,7 +986,28 @@ async def test_client_registration_invalid_grant_type(self, test_client: httpx.A error_data = response.json() assert "error" in error_data assert error_data["error"] == "invalid_client_metadata" - assert error_data["error_description"] == "grant_types must be authorization_code and refresh_token" + assert error_data["error_description"] == ( + "grant_types must be authorization_code and refresh_token " + "or client_credentials or token exchange or " + "client_credentials and token_exchange" + ) + + @pytest.mark.anyio + async def test_client_registration_client_credentials(self, test_client: httpx.AsyncClient): + client_metadata = { + "redirect_uris": ["https://client.example.com/callback"], + "client_name": "CC Client", + "grant_types": ["client_credentials"], + } + + response = await test_client.post( + "/register", + json=client_metadata, + ) + + assert response.status_code == 201, response.content + client_info = response.json() + assert client_info["grant_types"] == ["client_credentials"] class TestAuthorizeEndpointErrors: @@ -1217,3 +1284,110 @@ async def test_authorize_invalid_scope( # State should be preserved assert "state" in query_params assert query_params["state"][0] == "test_state" + + @pytest.mark.anyio + @pytest.mark.parametrize( + "registered_client", + [{"grant_types": ["client_credentials"]}], + indirect=True, + ) + async def test_client_credentials_token( + self, test_client: httpx.AsyncClient, registered_client: dict[str, str] + ) -> None: + response = await test_client.post( + "/token", + data={ + "grant_type": "client_credentials", + "client_id": registered_client["client_id"], + "client_secret": registered_client["client_secret"], + "scope": "read write", + }, + ) + assert response.status_code == 200 + data = response.json() + assert "access_token" in data + + @pytest.mark.anyio + async def test_metadata_includes_token_exchange(self, test_client: httpx.AsyncClient): + response = await test_client.get("/.well-known/oauth-authorization-server") + assert response.status_code == 200 + metadata = response.json() + assert "token_exchange" in metadata["grant_types_supported"] + + @pytest.mark.anyio + @pytest.mark.parametrize( + "registered_client", + [{"grant_types": ["token_exchange"]}], + indirect=True, + ) + async def test_token_exchange_success( + self, test_client: httpx.AsyncClient, registered_client: dict[str, str] + ) -> None: + response = await test_client.post( + "/token", + data={ + "grant_type": "token_exchange", + "client_id": registered_client["client_id"], + "client_secret": registered_client["client_secret"], + "subject_token": "good_token", + "subject_token_type": "access_token", + }, + ) + assert response.status_code == 200 + data = response.json() + assert "access_token" in data + + @pytest.mark.anyio + @pytest.mark.parametrize( + "registered_client", + [{"grant_types": ["token_exchange"]}], + indirect=True, + ) + async def test_token_exchange_invalid_subject( + self, test_client: httpx.AsyncClient, registered_client: dict[str, str] + ) -> None: + response = await test_client.post( + "/token", + data={ + "grant_type": "token_exchange", + "client_id": registered_client["client_id"], + "client_secret": registered_client["client_secret"], + "subject_token": "bad_token", + "subject_token_type": "access_token", + }, + ) + assert response.status_code == 400 + data = response.json() + assert data["error"] == "invalid_grant" + + @pytest.mark.anyio + @pytest.mark.parametrize( + "registered_client", + [{"grant_types": ["client_credentials", "token_exchange"]}], + indirect=True, + ) + async def test_client_credentials_and_token_exchange( + self, test_client: httpx.AsyncClient, registered_client: dict[str, str] + ) -> None: + cc_response = await test_client.post( + "/token", + data={ + "grant_type": "client_credentials", + "client_id": registered_client["client_id"], + "client_secret": registered_client["client_secret"], + "scope": "read write", + }, + ) + assert cc_response.status_code == 200 + + te_response = await test_client.post( + "/token", + data={ + "grant_type": "token_exchange", + "client_id": registered_client["client_id"], + "client_secret": registered_client["client_secret"], + "subject_token": "good_token", + "subject_token_type": "access_token", + }, + ) + assert te_response.status_code == 200 diff --git a/tests/server/fastmcp/resources/test_file_resources.py b/tests/server/fastmcp/resources/test_file_resources.py index ec3c85d8d..1ff9a3cb5 100644 --- a/tests/server/fastmcp/resources/test_file_resources.py +++ b/tests/server/fastmcp/resources/test_file_resources.py @@ -100,18 +100,21 @@ async def test_missing_file_error(self, temp_file: Path): with pytest.raises(ValueError, match="Error reading file"): await resource.read() - @pytest.mark.skipif(os.name == "nt", reason="File permissions behave differently on Windows") - @pytest.mark.anyio - async def test_permission_error(self, temp_file: Path): - """Test reading a file without permissions.""" - temp_file.chmod(0o000) # Remove all permissions - try: - resource = FileResource( - uri=FileUrl(temp_file.as_uri()), - name="test", - path=temp_file, - ) - with pytest.raises(ValueError, match="Error reading file"): - await resource.read() - finally: - temp_file.chmod(0o644) # Restore permissions + +@pytest.mark.skipif(os.name == "nt", reason="File permissions behave differently on Windows") +@pytest.mark.anyio +async def test_permission_error(temp_file: Path): + """Test reading a file without permissions.""" + if os.geteuid() == 0: + pytest.skip("Permission test not reliable when running as root") + temp_file.chmod(0o000) # Remove all permissions + try: + resource = FileResource( + uri=FileUrl(temp_file.as_uri()), + name="test", + path=temp_file, + ) + with pytest.raises(ValueError, match="Error reading file"): + await resource.read() + finally: + temp_file.chmod(0o644) # Restore permissions diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index ecbe6eb08..93cfb4a2d 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -7,6 +7,7 @@ import json import multiprocessing import socket +import sys import time from collections.abc import Generator from typing import Any @@ -1087,6 +1088,7 @@ async def mock_delete(self: httpx.AsyncClient, *args: Any, **kwargs: Any) -> htt @pytest.mark.anyio +@pytest.mark.skipif(sys.platform == "win32", reason="Resumption unstable on Windows") async def test_streamablehttp_client_resumption(event_server: tuple[SimpleEventStore, str]): """Test client session resumption using sync primitives for reliable coordination.""" _, server_url = event_server @@ -1204,6 +1206,12 @@ async def run_tool(): assert result.content[0].type == "text" assert result.content[0].text == "Completed" + # Allow any pending notifications to be processed + for _ in range(50): + if captured_notifications: + break + await anyio.sleep(0.1) + # We should have received the remaining notifications assert len(captured_notifications) == 1