Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/databricks/sql/auth/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ def get_auth_provider(cfg: ClientContext, http_client):
else:
raise RuntimeError("No valid authentication settings!")

# Always wrap with token federation (falls back gracefully if not needed)
if base_provider:
# Wrap with token federation only if explicitly enabled via identity_federation_client_id
if base_provider and cfg.identity_federation_client_id:
return TokenFederationProvider(
hostname=cfg.hostname,
external_provider=base_provider,
Expand Down
135 changes: 120 additions & 15 deletions src/databricks/sql/auth/token_federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,24 @@
logger = logging.getLogger(__name__)


class TokenFederationError(Exception):
"""Base exception for token federation errors."""

pass


class TokenExchangeNotAvailableError(TokenFederationError):
"""Raised when token exchange endpoint is not available (404)."""

pass


class TokenExchangeAuthenticationError(TokenFederationError):
"""Raised when token exchange fails due to authentication issues (401/403)."""

pass


class Token:
"""
Represents an OAuth token with expiration management.
Expand Down Expand Up @@ -72,8 +90,20 @@ class TokenFederationProvider(AuthProvider):
"""
Implementation of Token Federation for Databricks SQL Python driver.

This provider exchanges third-party access tokens for Databricks in-house tokens
when the token issuer is different from the Databricks host.
This provider exchanges third-party access tokens (e.g., Azure AD, AWS IAM) for
Databricks-native tokens when the token issuer differs from the Databricks host.

Token federation is useful for:
- Cross-cloud authentication scenarios
- Unity Catalog access across Azure subscriptions
- Service principal authentication with external identity providers

The provider automatically detects when token exchange is needed by comparing the
token issuer with the Databricks workspace hostname. If exchange fails, it gracefully
falls back to using the external token directly.

Note: Token federation must be explicitly enabled by providing the
identity_federation_client_id parameter during connection setup.
"""

TOKEN_EXCHANGE_ENDPOINT = "/oidc/v1/token"
Expand All @@ -92,9 +122,17 @@ def __init__(

Args:
hostname: The Databricks workspace hostname
external_provider: The external authentication provider
external_provider: The external authentication provider that provides the initial token
http_client: HTTP client for making requests (required)
identity_federation_client_id: Optional client ID for token federation
identity_federation_client_id: Client ID for identity federation (required for token exchange).
This parameter enables token federation and should be provided when:
- Using Service Principal authentication across Azure subscriptions
- Accessing Unity Catalog resources in different Azure tenants
- Configured with your workspace administrator

Without this parameter, the external token will be used directly without exchange.
Contact your Databricks workspace administrator to obtain the appropriate client ID
for your authentication scenario.
"""
if not http_client:
raise ValueError("http_client is required for TokenFederationProvider")
Expand Down Expand Up @@ -143,9 +181,33 @@ def _get_token(self) -> Token:
try:
token = self._exchange_token(access_token)
self._cached_token = token
logger.info(
"Successfully exchanged external token for Databricks token"
)
return token
except TokenExchangeNotAvailableError:
logger.debug(
"Token exchange endpoint not available. Using external token directly. "
"This is expected when token federation is not configured for this workspace."
)
except TokenExchangeAuthenticationError as e:
logger.warning(
"Token exchange failed due to authentication error. Using external token directly. "
"Error: %s. If this persists, verify your identity_federation_client_id configuration.",
e,
)
except TokenFederationError as e:
logger.info(
"Token exchange not performed, using external token directly. "
"Error: %s",
e,
)
except Exception as e:
logger.warning("Token exchange failed, using external token: %s", e)
logger.debug(
"Token exchange failed with unexpected error, using external token directly. "
"Error: %s",
e,
)

# Use external token directly
token = Token(access_token, token_type)
Expand All @@ -163,7 +225,20 @@ def _should_exchange_token(self, access_token: str) -> bool:
return not is_same_host(issuer, self.hostname)

def _exchange_token(self, access_token: str) -> Token:
"""Exchange the external token for a Databricks token."""
"""
Exchange the external token for a Databricks token.

Args:
access_token: The external access token to exchange

Returns:
Token: The exchanged Databricks token

Raises:
TokenExchangeNotAvailableError: If the endpoint is not available (404)
TokenExchangeAuthenticationError: If authentication fails (401/403)
TokenFederationError: For other token exchange errors
"""
token_url = f"{self.hostname.rstrip('/')}{self.TOKEN_EXCHANGE_ENDPOINT}"

data = {
Expand All @@ -184,15 +259,45 @@ def _exchange_token(self, access_token: str) -> Token:

body = urlencode(data)

response = self.http_client.request(
HttpMethod.POST, url=token_url, body=body, headers=headers
)

token_response = json.loads(response.data.decode())

return Token(
token_response["access_token"], token_response.get("token_type", "Bearer")
)
try:
response = self.http_client.request(
HttpMethod.POST, url=token_url, body=body, headers=headers
)

# Check response status code
if response.status == 404:
raise TokenExchangeNotAvailableError(
"Token exchange endpoint not found. Token federation may not be enabled for this workspace."
)
Comment on lines +268 to +271
Copy link

Copilot AI Nov 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] The error messages are hardcoded strings. Consider defining them as class constants (e.g., ERROR_MSG_404, ERROR_MSG_401_403) for better maintainability and consistency, especially since these messages are validated in tests.

Copilot uses AI. Check for mistakes.
elif response.status in (401, 403):
error_detail = (
response.data.decode() if response.data else "No error details"
)
raise TokenExchangeAuthenticationError(
f"Authentication failed during token exchange (HTTP {response.status}): {error_detail}"
)
elif response.status != 200:
error_detail = (
response.data.decode() if response.data else "No error details"
)
raise TokenFederationError(
f"Token exchange failed with HTTP {response.status}: {error_detail}"
)

token_response = json.loads(response.data.decode())

return Token(
token_response["access_token"],
token_response.get("token_type", "Bearer"),
)
except TokenFederationError:
# Re-raise our custom exceptions
raise
except Exception as e:
# Handle unexpected errors (network errors, JSON parsing errors, etc.)
raise TokenFederationError(
f"Unexpected error during token exchange: {str(e)}"
) from e

def _extract_token_from_header(self, auth_header: str) -> Tuple[str, str]:
"""Extract token type and access token from Authorization header."""
Expand Down
70 changes: 54 additions & 16 deletions tests/unit/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,9 @@ def test_get_python_sql_connector_auth_provider_access_token(self):
hostname = "moderakh-test.cloud.databricks.com"
kwargs = {"access_token": "dpi123"}
mock_http_client = MagicMock()
auth_provider = get_python_sql_connector_auth_provider(hostname, mock_http_client, **kwargs)
auth_provider = get_python_sql_connector_auth_provider(
hostname, mock_http_client, **kwargs
)
self.assertTrue(type(auth_provider).__name__, "AccessTokenAuthProvider")

headers = {}
Expand All @@ -163,10 +165,41 @@ def __call__(self, *args, **kwargs) -> HeaderFactory:
hostname = "moderakh-test.cloud.databricks.com"
kwargs = {"credentials_provider": MyProvider()}
mock_http_client = MagicMock()
auth_provider = get_python_sql_connector_auth_provider(hostname, mock_http_client, **kwargs)
auth_provider = get_python_sql_connector_auth_provider(
hostname, mock_http_client, **kwargs
)

# Without identity_federation_client_id, should return ExternalAuthProvider directly
self.assertEqual(type(auth_provider).__name__, "ExternalAuthProvider")

headers = {}
auth_provider.add_headers(headers)
self.assertEqual(headers["foo"], "bar")

def test_get_python_sql_connector_auth_provider_with_token_federation(self):
class MyProvider(CredentialsProvider):
def auth_type(self) -> str:
return "mine"

def __call__(self, *args, **kwargs) -> HeaderFactory:
return lambda: {"foo": "bar"}

hostname = "moderakh-test.cloud.databricks.com"
kwargs = {
"credentials_provider": MyProvider(),
"identity_federation_client_id": "test-client-id",
}
mock_http_client = MagicMock()
auth_provider = get_python_sql_connector_auth_provider(
hostname, mock_http_client, **kwargs
)

# With identity_federation_client_id, should wrap with TokenFederationProvider
self.assertEqual(type(auth_provider).__name__, "TokenFederationProvider")
self.assertEqual(type(auth_provider.external_provider).__name__, "ExternalAuthProvider")
self.assertEqual(
type(auth_provider.external_provider).__name__, "ExternalAuthProvider"
)
self.assertEqual(auth_provider.identity_federation_client_id, "test-client-id")

headers = {}
auth_provider.add_headers(headers)
Expand All @@ -181,7 +214,9 @@ def test_get_python_sql_connector_auth_provider_noop(self):
"_use_cert_as_auth": use_cert_as_auth,
}
mock_http_client = MagicMock()
auth_provider = get_python_sql_connector_auth_provider(hostname, mock_http_client, **kwargs)
auth_provider = get_python_sql_connector_auth_provider(
hostname, mock_http_client, **kwargs
)
self.assertTrue(type(auth_provider).__name__, "CredentialProvider")

def test_get_python_sql_connector_basic_auth(self):
Expand All @@ -191,7 +226,9 @@ def test_get_python_sql_connector_basic_auth(self):
}
mock_http_client = MagicMock()
with self.assertRaises(ValueError) as e:
get_python_sql_connector_auth_provider("foo.cloud.databricks.com", mock_http_client, **kwargs)
get_python_sql_connector_auth_provider(
"foo.cloud.databricks.com", mock_http_client, **kwargs
)
self.assertIn(
"Username/password authentication is no longer supported", str(e.exception)
)
Expand All @@ -200,12 +237,13 @@ def test_get_python_sql_connector_basic_auth(self):
def test_get_python_sql_connector_default_auth(self, mock__initial_get_token):
hostname = "foo.cloud.databricks.com"
mock_http_client = MagicMock()
auth_provider = get_python_sql_connector_auth_provider(hostname, mock_http_client)

self.assertEqual(type(auth_provider).__name__, "TokenFederationProvider")
self.assertEqual(type(auth_provider.external_provider).__name__, "DatabricksOAuthProvider")
auth_provider = get_python_sql_connector_auth_provider(
hostname, mock_http_client
)

self.assertEqual(auth_provider.external_provider._client_id, PYSQL_OAUTH_CLIENT_ID)
# Without identity_federation_client_id, should return DatabricksOAuthProvider directly
self.assertEqual(type(auth_provider).__name__, "DatabricksOAuthProvider")
self.assertEqual(auth_provider._client_id, PYSQL_OAUTH_CLIENT_ID)


class TestClientCredentialsTokenSource:
Expand Down Expand Up @@ -264,16 +302,16 @@ def test_no_token_refresh__when_token_is_not_expired(

def test_get_token_success(self, token_source, http_response):
mock_http_client = MagicMock()

with patch.object(token_source, "_http_client", mock_http_client):
# Create a mock response with the expected format
mock_response = MagicMock()
mock_response.status = 200
mock_response.data.decode.return_value = '{"access_token": "abc123", "token_type": "Bearer", "refresh_token": null}'

# Mock the request method to return the response directly
mock_http_client.request.return_value = mock_response

token = token_source.get_token()

# Assert
Expand All @@ -284,16 +322,16 @@ def test_get_token_success(self, token_source, http_response):

def test_get_token_failure(self, token_source, http_response):
mock_http_client = MagicMock()

with patch.object(token_source, "_http_client", mock_http_client):
# Create a mock response with error
mock_response = MagicMock()
mock_response.status = 400
mock_response.data.decode.return_value = "Bad Request"

# Mock the request method to return the response directly
mock_http_client.request.return_value = mock_response

with pytest.raises(Exception) as e:
token_source.get_token()
assert "Failed to get token: 400" in str(e.value)
Expand Down
Loading
Loading