Skip to content

[PECOBLR-587] Azure Service Principal Credential Provider #621

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Jul 9, 2025
111 changes: 100 additions & 11 deletions poetry.lock

Large diffs are not rendered by default.

6 changes: 4 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,18 @@ requests = "^2.18.1"
oauthlib = "^3.1.0"
openpyxl = "^3.0.10"
urllib3 = ">=1.26"
python-dateutil = "^2.8.0"
pyarrow = [
{ version = ">=14.0.1", python = ">=3.8,<3.13", optional=true },
{ version = ">=18.0.0", python = ">=3.13", optional=true }
]
python-dateutil = "^2.8.0"
pyjwt = "^2.0.0"


[tool.poetry.extras]
pyarrow = ["pyarrow"]

[tool.poetry.dev-dependencies]
[tool.poetry.group.dev.dependencies]
pytest = "^7.1.2"
mypy = "^1.10.1"
pylint = ">=2.12.0"
Expand Down
55 changes: 20 additions & 35 deletions src/databricks/sql/auth/auth.py
Original file line number Diff line number Diff line change
@@ -1,51 +1,29 @@
from enum import Enum
from typing import Optional, List

from databricks.sql.auth.authenticators import (
AuthProvider,
AccessTokenAuthProvider,
ExternalAuthProvider,
DatabricksOAuthProvider,
AzureServicePrincipalCredentialProvider,
)


class AuthType(Enum):
DATABRICKS_OAUTH = "databricks-oauth"
AZURE_OAUTH = "azure-oauth"
# other supported types (access_token) can be inferred
# we can add more types as needed later


class ClientContext:
def __init__(
self,
hostname: str,
access_token: Optional[str] = None,
auth_type: Optional[str] = None,
oauth_scopes: Optional[List[str]] = None,
oauth_client_id: Optional[str] = None,
oauth_redirect_port_range: Optional[List[int]] = None,
use_cert_as_auth: Optional[str] = None,
tls_client_cert_file: Optional[str] = None,
oauth_persistence=None,
credentials_provider=None,
):
self.hostname = hostname
self.access_token = access_token
self.auth_type = auth_type
self.oauth_scopes = oauth_scopes
self.oauth_client_id = oauth_client_id
self.oauth_redirect_port_range = oauth_redirect_port_range
self.use_cert_as_auth = use_cert_as_auth
self.tls_client_cert_file = tls_client_cert_file
self.oauth_persistence = oauth_persistence
self.credentials_provider = credentials_provider
from databricks.sql.auth.common import AuthType, ClientContext


def get_auth_provider(cfg: ClientContext):
if cfg.credentials_provider:
return ExternalAuthProvider(cfg.credentials_provider)
if cfg.auth_type in [AuthType.DATABRICKS_OAUTH.value, AuthType.AZURE_OAUTH.value]:
elif cfg.auth_type == AuthType.AZURE_SP_M2M.value:
return ExternalAuthProvider(
AzureServicePrincipalCredentialProvider(
cfg.hostname,
cfg.azure_client_id,
cfg.azure_client_secret,
cfg.azure_tenant_id,
cfg.azure_workspace_resource_id,
)
)
elif cfg.auth_type in [AuthType.DATABRICKS_OAUTH.value, AuthType.AZURE_OAUTH.value]:
assert cfg.oauth_redirect_port_range is not None
assert cfg.oauth_client_id is not None
assert cfg.oauth_scopes is not None
Expand Down Expand Up @@ -102,10 +80,13 @@ def get_client_id_and_redirect_port(use_azure_auth: bool):


def get_python_sql_connector_auth_provider(hostname: str, **kwargs):
# TODO : unify all the auth mechanisms with the Python SDK

auth_type = kwargs.get("auth_type")
(client_id, redirect_port_range) = get_client_id_and_redirect_port(
auth_type == AuthType.AZURE_OAUTH.value
)

if kwargs.get("username") or kwargs.get("password"):
raise ValueError(
"Username/password authentication is no longer supported. "
Expand All @@ -120,6 +101,10 @@ def get_python_sql_connector_auth_provider(hostname: str, **kwargs):
tls_client_cert_file=kwargs.get("_tls_client_cert_file"),
oauth_scopes=PYSQL_OAUTH_SCOPES,
oauth_client_id=kwargs.get("oauth_client_id") or client_id,
azure_client_id=kwargs.get("azure_client_id"),
azure_client_secret=kwargs.get("azure_client_secret"),
azure_tenant_id=kwargs.get("azure_tenant_id"),
azure_workspace_resource_id=kwargs.get("azure_workspace_resource_id"),
oauth_redirect_port_range=[kwargs["oauth_redirect_port"]]
if kwargs.get("oauth_client_id") and kwargs.get("oauth_redirect_port")
else redirect_port_range,
Expand Down
95 changes: 91 additions & 4 deletions src/databricks/sql/auth/authenticators.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
import abc
import base64
import logging
from typing import Callable, Dict, List

from databricks.sql.auth.oauth import OAuthManager
from databricks.sql.auth.endpoint import get_oauth_endpoints, infer_cloud_from_host
from databricks.sql.common.http import HttpHeader
from databricks.sql.auth.oauth import (
OAuthManager,
RefreshableTokenSource,
ClientCredentialsTokenSource,
)
from databricks.sql.auth.endpoint import get_oauth_endpoints
from databricks.sql.auth.common import (
AuthType,
get_effective_azure_login_app_id,
get_azure_tenant_id_from_host,
)

# Private API: this is an evolving interface and it will change in the future.
# Please must not depend on it in your applications.
Expand Down Expand Up @@ -146,3 +154,82 @@ def add_headers(self, request_headers: Dict[str, str]):
headers = self._header_factory()
for k, v in headers.items():
request_headers[k] = v


class AzureServicePrincipalCredentialProvider(CredentialsProvider):
"""
A credential provider for Azure Service Principal authentication with Databricks.

This class implements the CredentialsProvider protocol to authenticate requests
to Databricks REST APIs using Azure Active Directory (AAD) service principal
credentials. It handles OAuth 2.0 client credentials flow to obtain access tokens
from Azure AD and automatically refreshes them when they expire.

Attributes:
hostname (str): The Databricks workspace hostname.
azure_client_id (str): The Azure service principal's client ID.
azure_client_secret (str): The Azure service principal's client secret.
azure_tenant_id (str): The Azure AD tenant ID.
azure_workspace_resource_id (str, optional): The Azure workspace resource ID.
"""

AZURE_AAD_ENDPOINT = "https://login.microsoftonline.com"
AZURE_TOKEN_ENDPOINT = "oauth2/token"

AZURE_MANAGED_RESOURCE = "https://management.core.windows.net/"

DATABRICKS_AZURE_SP_TOKEN_HEADER = "X-Databricks-Azure-SP-Management-Token"
DATABRICKS_AZURE_WORKSPACE_RESOURCE_ID_HEADER = (
"X-Databricks-Azure-Workspace-Resource-Id"
)

def __init__(
self,
hostname,
azure_client_id,
azure_client_secret,
azure_tenant_id=None,
azure_workspace_resource_id=None,
):
self.hostname = hostname
self.azure_client_id = azure_client_id
self.azure_client_secret = azure_client_secret
self.azure_workspace_resource_id = azure_workspace_resource_id
self.azure_tenant_id = azure_tenant_id or get_azure_tenant_id_from_host(
hostname
)

def auth_type(self) -> str:
return AuthType.AZURE_SP_M2M.value

def get_token_source(self, resource: str) -> RefreshableTokenSource:
return ClientCredentialsTokenSource(
token_url=f"{self.AZURE_AAD_ENDPOINT}/{self.azure_tenant_id}/{self.AZURE_TOKEN_ENDPOINT}",
client_id=self.azure_client_id,
client_secret=self.azure_client_secret,
extra_params={"resource": resource},
)

def __call__(self, *args, **kwargs) -> HeaderFactory:
inner = self.get_token_source(
resource=get_effective_azure_login_app_id(self.hostname)
)
cloud = self.get_token_source(resource=self.AZURE_MANAGED_RESOURCE)

def header_factory() -> Dict[str, str]:
inner_token = inner.get_token()
cloud_token = cloud.get_token()

headers = {
HttpHeader.AUTHORIZATION.value: f"{inner_token.token_type} {inner_token.access_token}",
self.DATABRICKS_AZURE_SP_TOKEN_HEADER: cloud_token.access_token,
}

if self.azure_workspace_resource_id:
headers[
self.DATABRICKS_AZURE_WORKSPACE_RESOURCE_ID_HEADER
] = self.azure_workspace_resource_id

return headers

return header_factory
100 changes: 100 additions & 0 deletions src/databricks/sql/auth/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
from enum import Enum
import logging
from typing import Optional, List
from urllib.parse import urlparse
from databricks.sql.common.http import DatabricksHttpClient, HttpMethod

logger = logging.getLogger(__name__)


class AuthType(Enum):
DATABRICKS_OAUTH = "databricks-oauth"
AZURE_OAUTH = "azure-oauth"
AZURE_SP_M2M = "azure-sp-m2m"


class AzureAppId(Enum):
DEV = (".dev.azuredatabricks.net", "62a912ac-b58e-4c1d-89ea-b2dbfc7358fc")
STAGING = (".staging.azuredatabricks.net", "4a67d088-db5c-48f1-9ff2-0aace800ae68")
PROD = (".azuredatabricks.net", "2ff814a6-3304-4ab8-85cb-cd0e6f879c1d")


class ClientContext:
def __init__(
self,
hostname: str,
access_token: Optional[str] = None,
auth_type: Optional[str] = None,
oauth_scopes: Optional[List[str]] = None,
oauth_client_id: Optional[str] = None,
azure_client_id: Optional[str] = None,
azure_client_secret: Optional[str] = None,
azure_tenant_id: Optional[str] = None,
azure_workspace_resource_id: Optional[str] = None,
oauth_redirect_port_range: Optional[List[int]] = None,
use_cert_as_auth: Optional[str] = None,
tls_client_cert_file: Optional[str] = None,
oauth_persistence=None,
credentials_provider=None,
):
self.hostname = hostname
self.access_token = access_token
self.auth_type = auth_type
self.oauth_scopes = oauth_scopes
self.oauth_client_id = oauth_client_id
self.azure_client_id = azure_client_id
self.azure_client_secret = azure_client_secret
self.azure_tenant_id = azure_tenant_id
self.azure_workspace_resource_id = azure_workspace_resource_id
self.oauth_redirect_port_range = oauth_redirect_port_range
self.use_cert_as_auth = use_cert_as_auth
self.tls_client_cert_file = tls_client_cert_file
self.oauth_persistence = oauth_persistence
self.credentials_provider = credentials_provider


def get_effective_azure_login_app_id(hostname) -> str:
"""
Get the effective Azure login app ID for a given hostname.
This function determines the appropriate Azure login app ID based on the hostname.
If the hostname does not match any of these domains, it returns the default Databricks resource ID.

"""
for azure_app_id in AzureAppId:
domain, app_id = azure_app_id.value
if domain in hostname:
return app_id

# default databricks resource id
return AzureAppId.PROD.value[1]


def get_azure_tenant_id_from_host(host: str, http_client=None) -> str:
"""
Load the Azure tenant ID from the Azure Databricks login page.

This function retrieves the Azure tenant ID by making a request to the Databricks
Azure Active Directory (AAD) authentication endpoint. The endpoint redirects to
the Azure login page, and the tenant ID is extracted from the redirect URL.
"""

if http_client is None:
http_client = DatabricksHttpClient.get_instance()

login_url = f"{host}/aad/auth"
logger.debug("Loading tenant ID from %s", login_url)
with http_client.execute(HttpMethod.GET, login_url, allow_redirects=False) as resp:
if resp.status_code // 100 != 3:
raise ValueError(
f"Failed to get tenant ID from {login_url}: expected status code 3xx, got {resp.status_code}"
)
entra_id_endpoint = resp.headers.get("Location")
if entra_id_endpoint is None:
raise ValueError(f"No Location header in response from {login_url}")
# The Location header has the following form: https://login.microsoftonline.com/<tenant-id>/oauth2/authorize?...
# The domain may change depending on the Azure cloud (e.g. login.microsoftonline.us for US Government cloud).
url = urlparse(entra_id_endpoint)
path_segments = url.path.split("/")
if len(path_segments) < 2:
raise ValueError(f"Invalid path in Location header: {url.path}")
return path_segments[1]
Loading
Loading