diff --git a/src/databricks/sql/auth/auth.py b/src/databricks/sql/auth/auth.py index a5de0d622..a6c606b6e 100755 --- a/src/databricks/sql/auth/auth.py +++ b/src/databricks/sql/auth/auth.py @@ -65,8 +65,8 @@ def get_auth_provider(cfg: ClientContext, http_client): else: raise RuntimeError("No valid authentication settings!") - # Always wrap with token federation (falls back gracefully if not needed) - if base_provider: + # Wrap with token federation only if explicitly enabled via identity_federation_client_id + if base_provider and cfg.identity_federation_client_id: return TokenFederationProvider( hostname=cfg.hostname, external_provider=base_provider, diff --git a/src/databricks/sql/auth/token_federation.py b/src/databricks/sql/auth/token_federation.py index 7b62f6762..e38efc5a9 100644 --- a/src/databricks/sql/auth/token_federation.py +++ b/src/databricks/sql/auth/token_federation.py @@ -15,6 +15,24 @@ logger = logging.getLogger(__name__) +class TokenFederationError(Exception): + """Base exception for token federation errors.""" + + pass + + +class TokenExchangeNotAvailableError(TokenFederationError): + """Raised when token exchange endpoint is not available (404).""" + + pass + + +class TokenExchangeAuthenticationError(TokenFederationError): + """Raised when token exchange fails due to authentication issues (401/403).""" + + pass + + class Token: """ Represents an OAuth token with expiration management. @@ -72,8 +90,20 @@ class TokenFederationProvider(AuthProvider): """ Implementation of Token Federation for Databricks SQL Python driver. - This provider exchanges third-party access tokens for Databricks in-house tokens - when the token issuer is different from the Databricks host. + This provider exchanges third-party access tokens (e.g., Azure AD, AWS IAM) for + Databricks-native tokens when the token issuer differs from the Databricks host. + + Token federation is useful for: + - Cross-cloud authentication scenarios + - Unity Catalog access across Azure subscriptions + - Service principal authentication with external identity providers + + The provider automatically detects when token exchange is needed by comparing the + token issuer with the Databricks workspace hostname. If exchange fails, it gracefully + falls back to using the external token directly. + + Note: Token federation must be explicitly enabled by providing the + identity_federation_client_id parameter during connection setup. """ TOKEN_EXCHANGE_ENDPOINT = "/oidc/v1/token" @@ -92,9 +122,17 @@ def __init__( Args: hostname: The Databricks workspace hostname - external_provider: The external authentication provider + external_provider: The external authentication provider that provides the initial token http_client: HTTP client for making requests (required) - identity_federation_client_id: Optional client ID for token federation + identity_federation_client_id: Client ID for identity federation (required for token exchange). + This parameter enables token federation and should be provided when: + - Using Service Principal authentication across Azure subscriptions + - Accessing Unity Catalog resources in different Azure tenants + - Configured with your workspace administrator + + Without this parameter, the external token will be used directly without exchange. + Contact your Databricks workspace administrator to obtain the appropriate client ID + for your authentication scenario. """ if not http_client: raise ValueError("http_client is required for TokenFederationProvider") @@ -143,9 +181,33 @@ def _get_token(self) -> Token: try: token = self._exchange_token(access_token) self._cached_token = token + logger.info( + "Successfully exchanged external token for Databricks token" + ) return token + except TokenExchangeNotAvailableError: + logger.debug( + "Token exchange endpoint not available. Using external token directly. " + "This is expected when token federation is not configured for this workspace." + ) + except TokenExchangeAuthenticationError as e: + logger.warning( + "Token exchange failed due to authentication error. Using external token directly. " + "Error: %s. If this persists, verify your identity_federation_client_id configuration.", + e, + ) + except TokenFederationError as e: + logger.info( + "Token exchange not performed, using external token directly. " + "Error: %s", + e, + ) except Exception as e: - logger.warning("Token exchange failed, using external token: %s", e) + logger.debug( + "Token exchange failed with unexpected error, using external token directly. " + "Error: %s", + e, + ) # Use external token directly token = Token(access_token, token_type) @@ -163,7 +225,20 @@ def _should_exchange_token(self, access_token: str) -> bool: return not is_same_host(issuer, self.hostname) def _exchange_token(self, access_token: str) -> Token: - """Exchange the external token for a Databricks token.""" + """ + Exchange the external token for a Databricks token. + + Args: + access_token: The external access token to exchange + + Returns: + Token: The exchanged Databricks token + + Raises: + TokenExchangeNotAvailableError: If the endpoint is not available (404) + TokenExchangeAuthenticationError: If authentication fails (401/403) + TokenFederationError: For other token exchange errors + """ token_url = f"{self.hostname.rstrip('/')}{self.TOKEN_EXCHANGE_ENDPOINT}" data = { @@ -184,15 +259,45 @@ def _exchange_token(self, access_token: str) -> Token: body = urlencode(data) - response = self.http_client.request( - HttpMethod.POST, url=token_url, body=body, headers=headers - ) - - token_response = json.loads(response.data.decode()) - - return Token( - token_response["access_token"], token_response.get("token_type", "Bearer") - ) + try: + response = self.http_client.request( + HttpMethod.POST, url=token_url, body=body, headers=headers + ) + + # Check response status code + if response.status == 404: + raise TokenExchangeNotAvailableError( + "Token exchange endpoint not found. Token federation may not be enabled for this workspace." + ) + elif response.status in (401, 403): + error_detail = ( + response.data.decode() if response.data else "No error details" + ) + raise TokenExchangeAuthenticationError( + f"Authentication failed during token exchange (HTTP {response.status}): {error_detail}" + ) + elif response.status != 200: + error_detail = ( + response.data.decode() if response.data else "No error details" + ) + raise TokenFederationError( + f"Token exchange failed with HTTP {response.status}: {error_detail}" + ) + + token_response = json.loads(response.data.decode()) + + return Token( + token_response["access_token"], + token_response.get("token_type", "Bearer"), + ) + except TokenFederationError: + # Re-raise our custom exceptions + raise + except Exception as e: + # Handle unexpected errors (network errors, JSON parsing errors, etc.) + raise TokenFederationError( + f"Unexpected error during token exchange: {str(e)}" + ) from e def _extract_token_from_header(self, auth_header: str) -> Tuple[str, str]: """Extract token type and access token from Authorization header.""" diff --git a/tests/unit/test_auth.py b/tests/unit/test_auth.py index d1b941208..701c95799 100644 --- a/tests/unit/test_auth.py +++ b/tests/unit/test_auth.py @@ -145,7 +145,9 @@ def test_get_python_sql_connector_auth_provider_access_token(self): hostname = "moderakh-test.cloud.databricks.com" kwargs = {"access_token": "dpi123"} mock_http_client = MagicMock() - auth_provider = get_python_sql_connector_auth_provider(hostname, mock_http_client, **kwargs) + auth_provider = get_python_sql_connector_auth_provider( + hostname, mock_http_client, **kwargs + ) self.assertTrue(type(auth_provider).__name__, "AccessTokenAuthProvider") headers = {} @@ -163,10 +165,41 @@ def __call__(self, *args, **kwargs) -> HeaderFactory: hostname = "moderakh-test.cloud.databricks.com" kwargs = {"credentials_provider": MyProvider()} mock_http_client = MagicMock() - auth_provider = get_python_sql_connector_auth_provider(hostname, mock_http_client, **kwargs) + auth_provider = get_python_sql_connector_auth_provider( + hostname, mock_http_client, **kwargs + ) + + # Without identity_federation_client_id, should return ExternalAuthProvider directly + self.assertEqual(type(auth_provider).__name__, "ExternalAuthProvider") + + headers = {} + auth_provider.add_headers(headers) + self.assertEqual(headers["foo"], "bar") + + def test_get_python_sql_connector_auth_provider_with_token_federation(self): + class MyProvider(CredentialsProvider): + def auth_type(self) -> str: + return "mine" + + def __call__(self, *args, **kwargs) -> HeaderFactory: + return lambda: {"foo": "bar"} + + hostname = "moderakh-test.cloud.databricks.com" + kwargs = { + "credentials_provider": MyProvider(), + "identity_federation_client_id": "test-client-id", + } + mock_http_client = MagicMock() + auth_provider = get_python_sql_connector_auth_provider( + hostname, mock_http_client, **kwargs + ) + # With identity_federation_client_id, should wrap with TokenFederationProvider self.assertEqual(type(auth_provider).__name__, "TokenFederationProvider") - self.assertEqual(type(auth_provider.external_provider).__name__, "ExternalAuthProvider") + self.assertEqual( + type(auth_provider.external_provider).__name__, "ExternalAuthProvider" + ) + self.assertEqual(auth_provider.identity_federation_client_id, "test-client-id") headers = {} auth_provider.add_headers(headers) @@ -181,7 +214,9 @@ def test_get_python_sql_connector_auth_provider_noop(self): "_use_cert_as_auth": use_cert_as_auth, } mock_http_client = MagicMock() - auth_provider = get_python_sql_connector_auth_provider(hostname, mock_http_client, **kwargs) + auth_provider = get_python_sql_connector_auth_provider( + hostname, mock_http_client, **kwargs + ) self.assertTrue(type(auth_provider).__name__, "CredentialProvider") def test_get_python_sql_connector_basic_auth(self): @@ -191,7 +226,9 @@ def test_get_python_sql_connector_basic_auth(self): } mock_http_client = MagicMock() with self.assertRaises(ValueError) as e: - get_python_sql_connector_auth_provider("foo.cloud.databricks.com", mock_http_client, **kwargs) + get_python_sql_connector_auth_provider( + "foo.cloud.databricks.com", mock_http_client, **kwargs + ) self.assertIn( "Username/password authentication is no longer supported", str(e.exception) ) @@ -200,12 +237,13 @@ def test_get_python_sql_connector_basic_auth(self): def test_get_python_sql_connector_default_auth(self, mock__initial_get_token): hostname = "foo.cloud.databricks.com" mock_http_client = MagicMock() - auth_provider = get_python_sql_connector_auth_provider(hostname, mock_http_client) - - self.assertEqual(type(auth_provider).__name__, "TokenFederationProvider") - self.assertEqual(type(auth_provider.external_provider).__name__, "DatabricksOAuthProvider") + auth_provider = get_python_sql_connector_auth_provider( + hostname, mock_http_client + ) - self.assertEqual(auth_provider.external_provider._client_id, PYSQL_OAUTH_CLIENT_ID) + # Without identity_federation_client_id, should return DatabricksOAuthProvider directly + self.assertEqual(type(auth_provider).__name__, "DatabricksOAuthProvider") + self.assertEqual(auth_provider._client_id, PYSQL_OAUTH_CLIENT_ID) class TestClientCredentialsTokenSource: @@ -264,16 +302,16 @@ def test_no_token_refresh__when_token_is_not_expired( def test_get_token_success(self, token_source, http_response): mock_http_client = MagicMock() - + with patch.object(token_source, "_http_client", mock_http_client): # Create a mock response with the expected format mock_response = MagicMock() mock_response.status = 200 mock_response.data.decode.return_value = '{"access_token": "abc123", "token_type": "Bearer", "refresh_token": null}' - + # Mock the request method to return the response directly mock_http_client.request.return_value = mock_response - + token = token_source.get_token() # Assert @@ -284,16 +322,16 @@ def test_get_token_success(self, token_source, http_response): def test_get_token_failure(self, token_source, http_response): mock_http_client = MagicMock() - + with patch.object(token_source, "_http_client", mock_http_client): # Create a mock response with error mock_response = MagicMock() mock_response.status = 400 mock_response.data.decode.return_value = "Bad Request" - + # Mock the request method to return the response directly mock_http_client.request.return_value = mock_response - + with pytest.raises(Exception) as e: token_source.get_token() assert "Failed to get token: 400" in str(e.value) diff --git a/tests/unit/test_token_federation.py b/tests/unit/test_token_federation.py index 2e671c33e..7a40eb961 100644 --- a/tests/unit/test_token_federation.py +++ b/tests/unit/test_token_federation.py @@ -4,7 +4,13 @@ import jwt from datetime import datetime, timedelta -from databricks.sql.auth.token_federation import TokenFederationProvider, Token +from databricks.sql.auth.token_federation import ( + TokenFederationProvider, + Token, + TokenFederationError, + TokenExchangeNotAvailableError, + TokenExchangeAuthenticationError, +) from databricks.sql.auth.auth_utils import ( parse_hostname, decode_token, @@ -200,9 +206,51 @@ def test_exchange_token_failure(self, token_federation_provider, mock_http_clien mock_response.status = 400 mock_http_client.request.return_value = mock_response - with pytest.raises(KeyError): # Will raise KeyError due to missing access_token + with pytest.raises(TokenFederationError): + token_federation_provider._exchange_token("external-token-123") + + def test_exchange_token_404_error( + self, token_federation_provider, mock_http_client + ): + """Test token exchange with 404 error raises TokenExchangeNotAvailableError.""" + mock_response = Mock() + mock_response.status = 404 + mock_response.data = b"Not Found" + mock_http_client.request.return_value = mock_response + + with pytest.raises(TokenExchangeNotAvailableError) as exc_info: token_federation_provider._exchange_token("external-token-123") + assert "not found" in str(exc_info.value).lower() + + def test_exchange_token_401_error( + self, token_federation_provider, mock_http_client + ): + """Test token exchange with 401 error raises TokenExchangeAuthenticationError.""" + mock_response = Mock() + mock_response.status = 401 + mock_response.data = b"Unauthorized" + mock_http_client.request.return_value = mock_response + + with pytest.raises(TokenExchangeAuthenticationError) as exc_info: + token_federation_provider._exchange_token("external-token-123") + + assert "authentication failed" in str(exc_info.value).lower() + + def test_exchange_token_403_error( + self, token_federation_provider, mock_http_client + ): + """Test token exchange with 403 error raises TokenExchangeAuthenticationError.""" + mock_response = Mock() + mock_response.status = 403 + mock_response.data = b"Forbidden" + mock_http_client.request.return_value = mock_response + + with pytest.raises(TokenExchangeAuthenticationError) as exc_info: + token_federation_provider._exchange_token("external-token-123") + + assert "authentication failed" in str(exc_info.value).lower() + @pytest.mark.parametrize( "external_issuer,should_exchange", [ @@ -298,6 +346,54 @@ def add_headers_side_effect(headers): # Tokens should be different assert first_token != second_token + def test_token_exchange_fallback_on_404( + self, token_federation_provider, mock_external_provider, mock_http_client + ): + """Test that token exchange falls back gracefully on 404 error.""" + # Setup external provider to return an external token + external_token = create_jwt_token(issuer="https://login.microsoftonline.com") + mock_external_provider.add_headers = Mock( + side_effect=lambda headers: headers.update( + {"Authorization": f"Bearer {external_token}"} + ) + ) + + # Mock 404 error on token exchange + mock_response = Mock() + mock_response.status = 404 + mock_response.data = b"Not Found" + mock_http_client.request.return_value = mock_response + + headers = {} + token_federation_provider.add_headers(headers) + + # Should fall back to external token + assert headers["Authorization"] == f"Bearer {external_token}" + + def test_token_exchange_fallback_on_auth_error( + self, token_federation_provider, mock_external_provider, mock_http_client + ): + """Test that token exchange falls back gracefully on authentication error.""" + # Setup external provider to return an external token + external_token = create_jwt_token(issuer="https://login.microsoftonline.com") + mock_external_provider.add_headers = Mock( + side_effect=lambda headers: headers.update( + {"Authorization": f"Bearer {external_token}"} + ) + ) + + # Mock 401 error on token exchange + mock_response = Mock() + mock_response.status = 401 + mock_response.data = b"Unauthorized" + mock_http_client.request.return_value = mock_response + + headers = {} + token_federation_provider.add_headers(headers) + + # Should fall back to external token + assert headers["Authorization"] == f"Bearer {external_token}" + class TestUtilityFunctions: """Test utility functions used by TokenFederationProvider."""