diff --git a/src/databricks/sql/auth/auth.py b/src/databricks/sql/auth/auth.py index a8accac0..a5de0d62 100755 --- a/src/databricks/sql/auth/auth.py +++ b/src/databricks/sql/auth/auth.py @@ -8,13 +8,17 @@ AzureServicePrincipalCredentialProvider, ) from databricks.sql.auth.common import AuthType, ClientContext +from databricks.sql.auth.token_federation import TokenFederationProvider def get_auth_provider(cfg: ClientContext, http_client): + # Determine the base auth provider + base_provider: Optional[AuthProvider] = None + if cfg.credentials_provider: - return ExternalAuthProvider(cfg.credentials_provider) + base_provider = ExternalAuthProvider(cfg.credentials_provider) elif cfg.auth_type == AuthType.AZURE_SP_M2M.value: - return ExternalAuthProvider( + base_provider = ExternalAuthProvider( AzureServicePrincipalCredentialProvider( cfg.hostname, cfg.azure_client_id, @@ -29,7 +33,7 @@ def get_auth_provider(cfg: ClientContext, http_client): assert cfg.oauth_client_id is not None assert cfg.oauth_scopes is not None - return DatabricksOAuthProvider( + base_provider = DatabricksOAuthProvider( cfg.hostname, cfg.oauth_persistence, cfg.oauth_redirect_port_range, @@ -39,17 +43,17 @@ def get_auth_provider(cfg: ClientContext, http_client): cfg.auth_type, ) elif cfg.access_token is not None: - return AccessTokenAuthProvider(cfg.access_token) + base_provider = AccessTokenAuthProvider(cfg.access_token) elif cfg.use_cert_as_auth and cfg.tls_client_cert_file: # no op authenticator. authentication is performed using ssl certificate outside of headers - return AuthProvider() + base_provider = AuthProvider() else: if ( cfg.oauth_redirect_port_range is not None and cfg.oauth_client_id is not None and cfg.oauth_scopes is not None ): - return DatabricksOAuthProvider( + base_provider = DatabricksOAuthProvider( cfg.hostname, cfg.oauth_persistence, cfg.oauth_redirect_port_range, @@ -61,6 +65,17 @@ def get_auth_provider(cfg: ClientContext, http_client): else: raise RuntimeError("No valid authentication settings!") + # Always wrap with token federation (falls back gracefully if not needed) + if base_provider: + return TokenFederationProvider( + hostname=cfg.hostname, + external_provider=base_provider, + http_client=http_client, + identity_federation_client_id=cfg.identity_federation_client_id, + ) + + return base_provider + PYSQL_OAUTH_SCOPES = ["sql", "offline_access"] PYSQL_OAUTH_CLIENT_ID = "databricks-sql-python" @@ -114,5 +129,6 @@ def get_python_sql_connector_auth_provider(hostname: str, http_client, **kwargs) else redirect_port_range, oauth_persistence=kwargs.get("experimental_oauth_persistence"), credentials_provider=kwargs.get("credentials_provider"), + identity_federation_client_id=kwargs.get("identity_federation_client_id"), ) return get_auth_provider(cfg, http_client) diff --git a/src/databricks/sql/auth/auth_utils.py b/src/databricks/sql/auth/auth_utils.py new file mode 100644 index 00000000..439aabc5 --- /dev/null +++ b/src/databricks/sql/auth/auth_utils.py @@ -0,0 +1,64 @@ +import logging +import jwt +from datetime import datetime, timedelta +from typing import Optional, Dict, Tuple +from urllib.parse import urlparse + +logger = logging.getLogger(__name__) + + +def parse_hostname(hostname: str) -> str: + """ + Normalize the hostname to include scheme and trailing slash. + + Args: + hostname: The hostname to normalize + + Returns: + Normalized hostname with scheme and trailing slash + """ + if not hostname.startswith("http://") and not hostname.startswith("https://"): + hostname = f"https://{hostname}" + if not hostname.endswith("/"): + hostname = f"{hostname}/" + return hostname + + +def decode_token(access_token: str) -> Optional[Dict]: + """ + Decode a JWT token without verification to extract claims. + + Args: + access_token: The JWT access token to decode + + Returns: + Decoded token claims or None if decoding fails + """ + try: + return jwt.decode(access_token, options={"verify_signature": False}) + except Exception as e: + logger.debug("Failed to decode JWT token: %s", e) + return None + + +def is_same_host(url1: str, url2: str) -> bool: + """ + Check if two URLs have the same host. + + Args: + url1: First URL + url2: Second URL + + Returns: + True if hosts are the same, False otherwise + """ + try: + host1 = urlparse(url1).netloc + host2 = urlparse(url2).netloc + # Handle port differences (e.g., example.com vs example.com:443) + host1_without_port = host1.split(":")[0] + host2_without_port = host2.split(":")[0] + return host1_without_port == host2_without_port + except Exception as e: + logger.debug("Failed to parse URLs: %s", e) + return False diff --git a/src/databricks/sql/auth/common.py b/src/databricks/sql/auth/common.py index 679e353f..3e0be0d2 100644 --- a/src/databricks/sql/auth/common.py +++ b/src/databricks/sql/auth/common.py @@ -37,6 +37,7 @@ def __init__( tls_client_cert_file: Optional[str] = None, oauth_persistence=None, credentials_provider=None, + identity_federation_client_id: Optional[str] = None, # HTTP client configuration parameters ssl_options=None, # SSLOptions type socket_timeout: Optional[float] = None, @@ -65,6 +66,7 @@ def __init__( self.tls_client_cert_file = tls_client_cert_file self.oauth_persistence = oauth_persistence self.credentials_provider = credentials_provider + self.identity_federation_client_id = identity_federation_client_id # HTTP client configuration self.ssl_options = ssl_options diff --git a/src/databricks/sql/auth/token_federation.py b/src/databricks/sql/auth/token_federation.py new file mode 100644 index 00000000..7b62f676 --- /dev/null +++ b/src/databricks/sql/auth/token_federation.py @@ -0,0 +1,206 @@ +import logging +import json +from datetime import datetime, timedelta +from typing import Optional, Dict, Tuple +from urllib.parse import urlencode + +from databricks.sql.auth.authenticators import AuthProvider +from databricks.sql.auth.auth_utils import ( + parse_hostname, + decode_token, + is_same_host, +) +from databricks.sql.common.http import HttpMethod + +logger = logging.getLogger(__name__) + + +class Token: + """ + Represents an OAuth token with expiration management. + """ + + def __init__(self, access_token: str, token_type: str = "Bearer"): + """ + Initialize a token. + + Args: + access_token: The access token string + token_type: The token type (default: Bearer) + """ + self.access_token = access_token + self.token_type = token_type + self.expiry_time = self._calculate_expiry() + + def _calculate_expiry(self) -> datetime: + """ + Calculate the token expiry time from JWT claims. + + Returns: + The token expiry datetime + """ + decoded = decode_token(self.access_token) + if decoded and "exp" in decoded: + # Use JWT exp claim with 1 minute buffer + return datetime.fromtimestamp(decoded["exp"]) - timedelta(minutes=1) + # Default to 1 hour if no expiry info + return datetime.now() + timedelta(hours=1) + + def is_expired(self) -> bool: + """ + Check if the token is expired. + + Returns: + True if token is expired, False otherwise + """ + return datetime.now() >= self.expiry_time + + def to_dict(self) -> Dict[str, str]: + """ + Convert token to dictionary format. + + Returns: + Dictionary with access_token and token_type + """ + return { + "access_token": self.access_token, + "token_type": self.token_type, + } + + +class TokenFederationProvider(AuthProvider): + """ + Implementation of Token Federation for Databricks SQL Python driver. + + This provider exchanges third-party access tokens for Databricks in-house tokens + when the token issuer is different from the Databricks host. + """ + + TOKEN_EXCHANGE_ENDPOINT = "/oidc/v1/token" + TOKEN_EXCHANGE_GRANT_TYPE = "urn:ietf:params:oauth:grant-type:token-exchange" + TOKEN_EXCHANGE_SUBJECT_TYPE = "urn:ietf:params:oauth:token-type:jwt" + + def __init__( + self, + hostname: str, + external_provider: AuthProvider, + http_client, + identity_federation_client_id: Optional[str] = None, + ): + """ + Initialize the Token Federation Provider. + + Args: + hostname: The Databricks workspace hostname + external_provider: The external authentication provider + http_client: HTTP client for making requests (required) + identity_federation_client_id: Optional client ID for token federation + """ + if not http_client: + raise ValueError("http_client is required for TokenFederationProvider") + + self.hostname = parse_hostname(hostname) + self.external_provider = external_provider + self.http_client = http_client + self.identity_federation_client_id = identity_federation_client_id + + self._cached_token: Optional[Token] = None + self._external_headers: Dict[str, str] = {} + + def add_headers(self, request_headers: Dict[str, str]): + """Add authentication headers to the request.""" + + if self._cached_token and not self._cached_token.is_expired(): + request_headers[ + "Authorization" + ] = f"{self._cached_token.token_type} {self._cached_token.access_token}" + return + + # Get the external headers first to check if we need token federation + self._external_headers = {} + self.external_provider.add_headers(self._external_headers) + + # If no Authorization header from external provider, pass through all headers + if "Authorization" not in self._external_headers: + request_headers.update(self._external_headers) + return + + token = self._get_token() + request_headers["Authorization"] = f"{token.token_type} {token.access_token}" + + def _get_token(self) -> Token: + """Get or refresh the authentication token.""" + # Check if cached token is still valid + if self._cached_token and not self._cached_token.is_expired(): + return self._cached_token + + # Extract token from already-fetched headers + auth_header = self._external_headers.get("Authorization", "") + token_type, access_token = self._extract_token_from_header(auth_header) + + # Check if token exchange is needed + if self._should_exchange_token(access_token): + try: + token = self._exchange_token(access_token) + self._cached_token = token + return token + except Exception as e: + logger.warning("Token exchange failed, using external token: %s", e) + + # Use external token directly + token = Token(access_token, token_type) + self._cached_token = token + return token + + def _should_exchange_token(self, access_token: str) -> bool: + """Check if the token should be exchanged based on issuer.""" + decoded = decode_token(access_token) + if not decoded: + return False + + issuer = decoded.get("iss", "") + # Check if issuer host is different from Databricks host + return not is_same_host(issuer, self.hostname) + + def _exchange_token(self, access_token: str) -> Token: + """Exchange the external token for a Databricks token.""" + token_url = f"{self.hostname.rstrip('/')}{self.TOKEN_EXCHANGE_ENDPOINT}" + + data = { + "grant_type": self.TOKEN_EXCHANGE_GRANT_TYPE, + "subject_token": access_token, + "subject_token_type": self.TOKEN_EXCHANGE_SUBJECT_TYPE, + "scope": "sql", + "return_original_token_if_authenticated": "true", + } + + if self.identity_federation_client_id: + data["client_id"] = self.identity_federation_client_id + + headers = { + "Content-Type": "application/x-www-form-urlencoded", + "Accept": "*/*", + } + + body = urlencode(data) + + response = self.http_client.request( + HttpMethod.POST, url=token_url, body=body, headers=headers + ) + + token_response = json.loads(response.data.decode()) + + return Token( + token_response["access_token"], token_response.get("token_type", "Bearer") + ) + + def _extract_token_from_header(self, auth_header: str) -> Tuple[str, str]: + """Extract token type and access token from Authorization header.""" + if not auth_header: + raise ValueError("Authorization header is missing") + + parts = auth_header.split(" ", 1) + if len(parts) != 2: + raise ValueError("Invalid Authorization header format") + + return parts[0], parts[1] diff --git a/tests/unit/test_auth.py b/tests/unit/test_auth.py index a5ad7562..d1b94120 100644 --- a/tests/unit/test_auth.py +++ b/tests/unit/test_auth.py @@ -164,7 +164,9 @@ def __call__(self, *args, **kwargs) -> HeaderFactory: kwargs = {"credentials_provider": MyProvider()} mock_http_client = MagicMock() auth_provider = get_python_sql_connector_auth_provider(hostname, mock_http_client, **kwargs) - self.assertTrue(type(auth_provider).__name__, "ExternalAuthProvider") + + self.assertEqual(type(auth_provider).__name__, "TokenFederationProvider") + self.assertEqual(type(auth_provider.external_provider).__name__, "ExternalAuthProvider") headers = {} auth_provider.add_headers(headers) @@ -199,8 +201,11 @@ def test_get_python_sql_connector_default_auth(self, mock__initial_get_token): hostname = "foo.cloud.databricks.com" mock_http_client = MagicMock() auth_provider = get_python_sql_connector_auth_provider(hostname, mock_http_client) - self.assertTrue(type(auth_provider).__name__, "DatabricksOAuthProvider") - self.assertTrue(auth_provider._client_id, PYSQL_OAUTH_CLIENT_ID) + + self.assertEqual(type(auth_provider).__name__, "TokenFederationProvider") + self.assertEqual(type(auth_provider.external_provider).__name__, "DatabricksOAuthProvider") + + self.assertEqual(auth_provider.external_provider._client_id, PYSQL_OAUTH_CLIENT_ID) class TestClientCredentialsTokenSource: diff --git a/tests/unit/test_token_federation.py b/tests/unit/test_token_federation.py new file mode 100644 index 00000000..2e671c33 --- /dev/null +++ b/tests/unit/test_token_federation.py @@ -0,0 +1,342 @@ +import pytest +from unittest.mock import Mock, patch +import json +import jwt +from datetime import datetime, timedelta + +from databricks.sql.auth.token_federation import TokenFederationProvider, Token +from databricks.sql.auth.auth_utils import ( + parse_hostname, + decode_token, + is_same_host, +) +from databricks.sql.common.http import HttpMethod + + +@pytest.fixture +def mock_http_client(): + """Fixture for mock HTTP client.""" + return Mock() + + +@pytest.fixture +def mock_external_provider(): + """Fixture for mock external provider.""" + return Mock() + + +@pytest.fixture +def token_federation_provider(mock_http_client, mock_external_provider): + """Fixture for TokenFederationProvider.""" + return TokenFederationProvider( + hostname="https://test.databricks.com/", + external_provider=mock_external_provider, + http_client=mock_http_client, + identity_federation_client_id="test-client-id", + ) + + +def create_mock_token_response( + access_token="databricks-token-456", token_type="Bearer", expires_in=3600 +): + """Helper function to create mock token exchange response.""" + mock_response = Mock() + mock_response.data = json.dumps( + { + "access_token": access_token, + "token_type": token_type, + "expires_in": expires_in, + } + ).encode("utf-8") + mock_response.status = 200 + return mock_response + + +def create_jwt_token(issuer="https://test.databricks.com", exp_hours=1, **kwargs): + """Helper function to create JWT tokens for testing.""" + payload = { + "iss": issuer, + "aud": "databricks", + "exp": int((datetime.now() + timedelta(hours=exp_hours)).timestamp()), + **kwargs, + } + return jwt.encode(payload, "secret", algorithm="HS256") + + +class TestTokenFederationProvider: + """Test TokenFederationProvider functionality.""" + + def test_init_requires_http_client(self, mock_external_provider): + """Test that http_client is required.""" + with pytest.raises(ValueError, match="http_client is required"): + TokenFederationProvider( + hostname="test.databricks.com", + external_provider=mock_external_provider, + http_client=None, + ) + + @pytest.mark.parametrize( + "input_hostname,expected", + [ + ("test.databricks.com", "https://test.databricks.com/"), + ("https://test.databricks.com", "https://test.databricks.com/"), + ("https://test.databricks.com/", "https://test.databricks.com/"), + ("test.databricks.com/", "https://test.databricks.com/"), + ], + ) + def test_hostname_normalization( + self, input_hostname, expected, mock_http_client, mock_external_provider + ): + """Test hostname normalization during initialization.""" + provider = TokenFederationProvider( + hostname=input_hostname, + external_provider=mock_external_provider, + http_client=mock_http_client, + ) + assert provider.hostname == expected + + @pytest.mark.parametrize( + "auth_header,expected_type,expected_token", + [ + ("Bearer test-token-123", "Bearer", "test-token-123"), + ("Basic dGVzdDp0ZXN0", "Basic", "dGVzdDp0ZXN0"), + ], + ) + def test_extract_token_from_valid_header( + self, token_federation_provider, auth_header, expected_type, expected_token + ): + """Test extraction of token from valid Authorization header.""" + token_type, access_token = token_federation_provider._extract_token_from_header( + auth_header + ) + assert token_type == expected_type + assert access_token == expected_token + + @pytest.mark.parametrize( + "invalid_header", + [ + "Bearer", # Missing token + "", # Empty header + "InvalidFormat", # No space separator + ], + ) + def test_extract_token_from_invalid_header( + self, token_federation_provider, invalid_header + ): + """Test extraction fails for invalid Authorization headers.""" + with pytest.raises(ValueError): + token_federation_provider._extract_token_from_header(invalid_header) + + @pytest.mark.parametrize( + "issuer,hostname,should_exchange", + [ + ( + "https://login.microsoftonline.com/tenant-id/", + "https://test.databricks.com/", + True, + ), + ("https://test.databricks.com", "https://test.databricks.com/", False), + ("https://test.databricks.com:443", "https://test.databricks.com/", False), + ("https://accounts.google.com", "https://test.databricks.com/", True), + ], + ) + def test_should_exchange_token( + self, token_federation_provider, issuer, hostname, should_exchange + ): + """Test token exchange decision based on issuer.""" + token_federation_provider.hostname = hostname + jwt_token = create_jwt_token(issuer=issuer) + + result = token_federation_provider._should_exchange_token(jwt_token) + assert result == should_exchange + + def test_should_exchange_token_invalid_jwt(self, token_federation_provider): + """Test that invalid JWT returns False for exchange.""" + result = token_federation_provider._should_exchange_token("invalid-jwt-token") + assert result is False + + def test_exchange_token_success(self, token_federation_provider, mock_http_client): + """Test successful token exchange.""" + access_token = "external-token-123" + mock_http_client.request.return_value = create_mock_token_response() + + result = token_federation_provider._exchange_token(access_token) + + # Verify result is a Token object + assert isinstance(result, Token) + assert result.access_token == "databricks-token-456" + assert result.token_type == "Bearer" + + # Verify the request + mock_http_client.request.assert_called_once() + call_args = mock_http_client.request.call_args + + # Check method and URL + assert call_args[0][0] == HttpMethod.POST + assert call_args[1]["url"] == "https://test.databricks.com/oidc/v1/token" + + # Check body contains expected parameters + from urllib.parse import parse_qs + + body = call_args[1]["body"] + parsed_body = parse_qs(body) + + assert ( + parsed_body["grant_type"][0] + == "urn:ietf:params:oauth:grant-type:token-exchange" + ) + assert parsed_body["subject_token"][0] == access_token + assert ( + parsed_body["subject_token_type"][0] + == "urn:ietf:params:oauth:token-type:jwt" + ) + assert parsed_body["scope"][0] == "sql" + assert parsed_body["client_id"][0] == "test-client-id" + + def test_exchange_token_failure(self, token_federation_provider, mock_http_client): + """Test token exchange failure handling.""" + mock_response = Mock() + mock_response.data = b'{"error": "invalid_request"}' + mock_response.status = 400 + mock_http_client.request.return_value = mock_response + + with pytest.raises(KeyError): # Will raise KeyError due to missing access_token + token_federation_provider._exchange_token("external-token-123") + + @pytest.mark.parametrize( + "external_issuer,should_exchange", + [ + ("https://login.microsoftonline.com/tenant-id/", True), + ("https://test.databricks.com", False), + ], + ) + def test_add_headers_token_exchange( + self, + token_federation_provider, + mock_external_provider, + mock_http_client, + external_issuer, + should_exchange, + ): + """Test adding headers with and without token exchange.""" + # Setup external provider to return a token + external_token = create_jwt_token(issuer=external_issuer) + mock_external_provider.add_headers = Mock( + side_effect=lambda headers: headers.update( + {"Authorization": f"Bearer {external_token}"} + ) + ) + + if should_exchange: + # Mock successful token exchange + mock_http_client.request.return_value = create_mock_token_response() + expected_token = "databricks-token-456" + else: + expected_token = external_token + + headers = {} + token_federation_provider.add_headers(headers) + + assert headers["Authorization"] == f"Bearer {expected_token}" + + def test_token_caching(self, token_federation_provider, mock_external_provider): + """Test that tokens are cached and reused.""" + external_token = create_jwt_token() + mock_external_provider.add_headers = Mock( + side_effect=lambda headers: headers.update( + {"Authorization": f"Bearer {external_token}"} + ) + ) + + # First call + headers1 = {} + token_federation_provider.add_headers(headers1) + + # Second call - should use cached token + headers2 = {} + token_federation_provider.add_headers(headers2) + + # External provider should only be called once + assert mock_external_provider.add_headers.call_count == 1 + + # Both headers should be the same + assert headers1["Authorization"] == headers2["Authorization"] + + def test_token_cache_expiry( + self, token_federation_provider, mock_external_provider + ): + """Test that expired cached tokens are refreshed.""" + call_count = [0] + + def add_headers_side_effect(headers): + call_count[0] += 1 + token = create_jwt_token( + exp_hours=0.001 if call_count[0] == 1 else 1 + ) # First token expires quickly + headers.update({"Authorization": f"Bearer {token}"}) + + mock_external_provider.add_headers = Mock(side_effect=add_headers_side_effect) + + # First call + headers1 = {} + token_federation_provider.add_headers(headers1) + first_token = headers1["Authorization"].split(" ")[1] + + # Force cache expiry + token_federation_provider._cached_token = Token(first_token) + token_federation_provider._cached_token.expiry_time = ( + datetime.now() - timedelta(seconds=1) + ) + + # Second call - should get new token + headers2 = {} + token_federation_provider.add_headers(headers2) + second_token = headers2["Authorization"].split(" ")[1] + + # External provider should be called twice + assert mock_external_provider.add_headers.call_count == 2 + # Tokens should be different + assert first_token != second_token + + +class TestUtilityFunctions: + """Test utility functions used by TokenFederationProvider.""" + + @pytest.mark.parametrize( + "input_hostname,expected", + [ + ("test.databricks.com", "https://test.databricks.com/"), + ("https://test.databricks.com", "https://test.databricks.com/"), + ("https://test.databricks.com/", "https://test.databricks.com/"), + ("test.databricks.com/", "https://test.databricks.com/"), + ], + ) + def test_parse_hostname(self, input_hostname, expected): + """Test hostname parsing.""" + assert parse_hostname(input_hostname) == expected + + @pytest.mark.parametrize( + "url1,url2,expected", + [ + ("https://test.databricks.com", "https://test.databricks.com", True), + ("https://test.databricks.com", "https://test.databricks.com:443", True), + ("https://test1.databricks.com", "https://test2.databricks.com", False), + ("https://login.microsoftonline.com", "https://test.databricks.com", False), + ], + ) + def test_is_same_host(self, url1, url2, expected): + """Test host comparison.""" + assert is_same_host(url1, url2) == expected + + def test_decode_token_valid(self): + """Test decoding a valid JWT token.""" + token = create_jwt_token() + result = decode_token(token) + assert result is not None + assert "iss" in result + assert "exp" in result + + def test_decode_token_invalid(self): + """Test decoding an invalid token.""" + result = decode_token("invalid-token") + assert result is None