Skip to content

Implements Token Federation for Python Driver #552

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 39 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
56a854f
initial commit
madhav-db May 7, 2025
aedb3bf
update vars
madhav-db May 7, 2025
d06672c
mod
madhav-db May 7, 2025
9aff811
debugging patch
madhav-db May 7, 2025
299b5ae
mod
madhav-db May 7, 2025
10a5016
debug
madhav-db May 7, 2025
3bb9b3d
debug
madhav-db May 7, 2025
708c13b
debug
madhav-db May 7, 2025
a1e9894
fix
madhav-db May 7, 2025
00e015c
fix
madhav-db May 7, 2025
d538b75
fix
madhav-db May 7, 2025
4b48ac9
fix
madhav-db May 7, 2025
e8d4a48
debug
madhav-db May 8, 2025
5b74b60
debug
madhav-db May 8, 2025
edc6027
debug
madhav-db May 8, 2025
3613cb0
debug
madhav-db May 8, 2025
e87b52d
readability
madhav-db May 8, 2025
929191b
separate py script
madhav-db May 8, 2025
82d0be2
addresses codecheck errors
madhav-db May 8, 2025
1e60750
adds unit test
madhav-db May 8, 2025
de48411
Fix: Apply Black formatting to auth and token_federation modules
madhav-db May 8, 2025
d54ba93
Enhance token federation refresh to get fresh external tokens
madhav-db May 8, 2025
aa2d1b9
refresh
madhav-db May 9, 2025
34413f3
fmt
madhav-db May 9, 2025
a93dd4b
clean up
madhav-db May 9, 2025
76df22e
update and add todo for future work
madhav-db May 9, 2025
c37cd01
refactoring
madhav-db May 9, 2025
f2d4516
update test
madhav-db May 9, 2025
aeeca66
fmt
madhav-db May 9, 2025
ae28649
remove idp detection
madhav-db May 9, 2025
541e82f
fmt
madhav-db May 11, 2025
49eab2a
fmt
madhav-db May 11, 2025
e6733cb
Apply black formatting to auth files
madhav-db May 11, 2025
29f95f2
Fix token refresh to use fresh token from provider
madhav-db May 11, 2025
2e12935
general improvements
madhav-db May 12, 2025
e9de21a
minor
madhav-db May 12, 2025
efb9149
test improvements
madhav-db May 12, 2025
7ab4068
Refactor token exchange parameters to be instance-specific in Databri…
madhav-db May 12, 2025
9fc4c0c
Refactor token expiry handling in DatabricksTokenFederationProvider a…
madhav-db May 12, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 78 additions & 0 deletions .github/workflows/token-federation-test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
name: Token Federation Test

# Tests token federation functionality with GitHub Actions OIDC tokens
on:
# Manual trigger with required inputs
workflow_dispatch:
inputs:
databricks_host:
description: 'Databricks host URL (e.g., example.cloud.databricks.com)'
required: true
databricks_http_path:
description: 'Databricks HTTP path (e.g., /sql/1.0/warehouses/abc123)'
required: true
identity_federation_client_id:
description: 'Identity federation client ID'
required: true

# Run on PRs that might affect token federation
pull_request:
branches: [main]
paths:
- 'src/databricks/sql/auth/**'
- 'examples/token_federation_*.py'
- 'tests/token_federation/**'
- '.github/workflows/token-federation-test.yml'

# Run on push to main that affects token federation
push:
branches: [main]
paths:
- 'src/databricks/sql/auth/**'
- 'examples/token_federation_*.py'
- 'tests/token_federation/**'
- '.github/workflows/token-federation-test.yml'

permissions:
id-token: write # Required for GitHub OIDC token
contents: read

jobs:
test-token-federation:
name: Test Token Federation
runs-on:
group: databricks-protected-runner-group
labels: linux-ubuntu-latest

steps:
- name: Checkout code
uses: actions/checkout@v4

- name: Set up Python 3.9
uses: actions/setup-python@v5
with:
python-version: '3.9'
cache: 'pip'

- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -e .
pip install pyarrow

- name: Get GitHub OIDC token
id: get-id-token
uses: actions/github-script@v7
with:
script: |
const token = await core.getIDToken('https://github.com/databricks')
core.setSecret(token)
core.setOutput('token', token)

- name: Test token federation with GitHub OIDC token
env:
DATABRICKS_HOST_FOR_TF: ${{ github.event_name == 'workflow_dispatch' && inputs.databricks_host || secrets.DATABRICKS_HOST_FOR_TF }}
DATABRICKS_HTTP_PATH_FOR_TF: ${{ github.event_name == 'workflow_dispatch' && inputs.databricks_http_path || secrets.DATABRICKS_HTTP_PATH_FOR_TF }}
IDENTITY_FEDERATION_CLIENT_ID: ${{ github.event_name == 'workflow_dispatch' && inputs.identity_federation_client_id || secrets.IDENTITY_FEDERATION_CLIENT_ID }}
OIDC_TOKEN: ${{ steps.get-id-token.outputs.token }}
run: python tests/token_federation/github_oidc_test.py
111 changes: 100 additions & 11 deletions poetry.lock

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,12 @@ pyarrow = [
{ version = ">=18.0.0", python = ">=3.13", optional=true }
]
python-dateutil = "^2.8.0"
PyJWT = ">=2.0.0"

[tool.poetry.extras]
pyarrow = ["pyarrow"]

[tool.poetry.dev-dependencies]
[tool.poetry.group.dev.dependencies]
pytest = "^7.1.2"
mypy = "^1.10.1"
pylint = ">=2.12.0"
Expand Down
81 changes: 81 additions & 0 deletions src/databricks/sql/auth/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,18 @@
AuthProvider,
AccessTokenAuthProvider,
ExternalAuthProvider,
CredentialsProvider,
DatabricksOAuthProvider,
)


class AuthType(Enum):
DATABRICKS_OAUTH = "databricks-oauth"
AZURE_OAUTH = "azure-oauth"
# TODO: Token federation should be a feature that works with different auth types,
# not an auth type itself. This will be refactored in a future change.
# We will add a use_token_federation flag that can be used with any auth type.
TOKEN_FEDERATION = "token-federation"
# other supported types (access_token) can be inferred
# we can add more types as needed later

Expand All @@ -29,6 +34,7 @@ def __init__(
tls_client_cert_file: Optional[str] = None,
oauth_persistence=None,
credentials_provider=None,
identity_federation_client_id: Optional[str] = None,
):
self.hostname = hostname
self.access_token = access_token
Expand All @@ -40,11 +46,64 @@ def __init__(
self.tls_client_cert_file = tls_client_cert_file
self.oauth_persistence = oauth_persistence
self.credentials_provider = credentials_provider
self.identity_federation_client_id = identity_federation_client_id


def get_auth_provider(cfg: ClientContext):
"""
Get an appropriate auth provider based on the provided configuration.

Token Federation Support:
-----------------------
Currently, token federation is implemented as a separate auth type, but the goal is to
refactor it as a feature that can work with any auth type. The current implementation
is maintained for backward compatibility while the refactoring is planned.

Future refactoring will introduce a `use_token_federation` flag that can be combined
with any auth type to enable token federation.

Args:
cfg: The client context containing configuration parameters

Returns:
An appropriate AuthProvider instance

Raises:
RuntimeError: If no valid authentication settings are provided
"""
# If credentials_provider is explicitly provided
if cfg.credentials_provider:
# If token federation is enabled and credentials provider is provided,
# wrap the credentials provider with DatabricksTokenFederationProvider
if cfg.auth_type == AuthType.TOKEN_FEDERATION.value:
from databricks.sql.auth.token_federation import (
DatabricksTokenFederationProvider,
)

federation_provider = DatabricksTokenFederationProvider(
cfg.credentials_provider,
cfg.hostname,
cfg.identity_federation_client_id,
)
return ExternalAuthProvider(federation_provider)

# If not token federation, just use the credentials provider directly
return ExternalAuthProvider(cfg.credentials_provider)

# If we don't have a credentials provider but have token federation auth type with access token
if cfg.auth_type == AuthType.TOKEN_FEDERATION.value and cfg.access_token:
# Create a simple credentials provider and wrap it with token federation provider
from databricks.sql.auth.token_federation import (
DatabricksTokenFederationProvider,
SimpleCredentialsProvider,
)

simple_provider = SimpleCredentialsProvider(cfg.access_token)
federation_provider = DatabricksTokenFederationProvider(
simple_provider, cfg.hostname, cfg.identity_federation_client_id
)
return ExternalAuthProvider(federation_provider)

if cfg.auth_type in [AuthType.DATABRICKS_OAUTH.value, AuthType.AZURE_OAUTH.value]:
assert cfg.oauth_redirect_port_range is not None
assert cfg.oauth_client_id is not None
Expand Down Expand Up @@ -102,6 +161,27 @@ def get_client_id_and_redirect_port(use_azure_auth: bool):


def get_python_sql_connector_auth_provider(hostname: str, **kwargs):
"""
Get an auth provider for the Python SQL connector.

This function is the main entry point for authentication in the SQL connector.
It processes the parameters and creates an appropriate auth provider.

TODO: Future refactoring needed:
1. Add a use_token_federation flag that can be combined with any auth type
2. Remove TOKEN_FEDERATION as an auth_type while maintaining backward compatibility
3. Create a token federation wrapper that can wrap any existing auth provider

Args:
hostname: The Databricks server hostname
**kwargs: Additional configuration parameters

Returns:
An appropriate AuthProvider instance

Raises:
ValueError: If username/password authentication is attempted (no longer supported)
"""
auth_type = kwargs.get("auth_type")
(client_id, redirect_port_range) = get_client_id_and_redirect_port(
auth_type == AuthType.AZURE_OAUTH.value
Expand All @@ -125,5 +205,6 @@ def get_python_sql_connector_auth_provider(hostname: str, **kwargs):
else redirect_port_range,
oauth_persistence=kwargs.get("experimental_oauth_persistence"),
credentials_provider=kwargs.get("credentials_provider"),
identity_federation_client_id=kwargs.get("identity_federation_client_id"),
)
return get_auth_provider(cfg)
6 changes: 6 additions & 0 deletions src/databricks/sql/auth/authenticators.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,16 @@ class CredentialsProvider(abc.ABC):

@abc.abstractmethod
def auth_type(self) -> str:
"""
Returns the authentication type for this provider
"""
...

@abc.abstractmethod
def __call__(self, *args, **kwargs) -> HeaderFactory:
"""
Configure and return a HeaderFactory that provides authentication headers
"""
...


Expand Down
58 changes: 58 additions & 0 deletions src/databricks/sql/auth/oidc_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import logging
import requests
from typing import Optional

from databricks.sql.auth.endpoint import (
get_oauth_endpoints,
infer_cloud_from_host,
)

logger = logging.getLogger(__name__)


class OIDCDiscoveryUtil:
"""
Utility class for OIDC discovery operations.

This class handles discovery of OIDC endpoints through standard
discovery mechanisms, with fallback to default endpoints if needed.
"""

# Standard token endpoint path for Databricks workspaces
DEFAULT_TOKEN_PATH = "oidc/v1/token"

@staticmethod
def discover_token_endpoint(hostname: str) -> str:
"""
Get the token endpoint for the given Databricks hostname.

For Databricks workspaces, the token endpoint is always at host/oidc/v1/token.

Args:
hostname: The hostname to get token endpoint for

Returns:
str: The token endpoint URL
"""
# Format the hostname and return the standard endpoint
hostname = OIDCDiscoveryUtil.format_hostname(hostname)
token_endpoint = f"{hostname}{OIDCDiscoveryUtil.DEFAULT_TOKEN_PATH}"
logger.info(f"Using token endpoint: {token_endpoint}")
return token_endpoint

@staticmethod
def format_hostname(hostname: str) -> str:
"""
Format hostname to ensure it has proper https:// prefix and trailing slash.

Args:
hostname: The hostname to format

Returns:
str: The formatted hostname
"""
if not hostname.startswith("https://"):
hostname = f"https://{hostname}"
if not hostname.endswith("/"):
hostname = f"{hostname}/"
return hostname
65 changes: 65 additions & 0 deletions src/databricks/sql/auth/token.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
"""
Token class for authentication tokens with expiry handling.
"""

from datetime import datetime, timezone, timedelta
from typing import Optional


class Token:
"""
Represents an OAuth token with expiry information.

This class handles token state including expiry calculation.
"""

# Minimum time buffer before expiry to consider a token still valid (in seconds)
MIN_VALIDITY_BUFFER = 10

def __init__(
self,
access_token: str,
token_type: str,
refresh_token: str = "",
expiry: Optional[datetime] = None,
):
"""
Initialize a Token object.

Args:
access_token: The access token string
token_type: The token type (usually "Bearer")
refresh_token: Optional refresh token
expiry: Token expiry datetime, must be provided

Raises:
ValueError: If no expiry is provided
"""
self.access_token = access_token
self.token_type = token_type
self.refresh_token = refresh_token

# Ensure we have an expiry time
if expiry is None:
raise ValueError("Token expiry must be provided")

# Ensure expiry is timezone-aware
if expiry.tzinfo is None:
# Convert naive datetime to aware datetime
self.expiry = expiry.replace(tzinfo=timezone.utc)
else:
self.expiry = expiry

def is_valid(self) -> bool:
"""
Check if the token is valid (has at least MIN_VALIDITY_BUFFER seconds before expiry).

Returns:
bool: True if the token is valid, False otherwise
"""
buffer = timedelta(seconds=self.MIN_VALIDITY_BUFFER)
return datetime.now(tz=timezone.utc) + buffer < self.expiry

def __str__(self) -> str:
"""Return the token as a string in the format used for Authorization headers."""
return f"{self.token_type} {self.access_token}"
Loading
Loading