diff --git a/.github/workflows/token-federation-test.yml b/.github/workflows/token-federation-test.yml new file mode 100644 index 00000000..74b93608 --- /dev/null +++ b/.github/workflows/token-federation-test.yml @@ -0,0 +1,78 @@ +name: Token Federation Test + +# Tests token federation functionality with GitHub Actions OIDC tokens +on: + # Manual trigger with required inputs + workflow_dispatch: + inputs: + databricks_host: + description: 'Databricks host URL (e.g., example.cloud.databricks.com)' + required: true + databricks_http_path: + description: 'Databricks HTTP path (e.g., /sql/1.0/warehouses/abc123)' + required: true + identity_federation_client_id: + description: 'Identity federation client ID' + required: true + + # Run on PRs that might affect token federation + pull_request: + branches: [main] + paths: + - 'src/databricks/sql/auth/**' + - 'examples/token_federation_*.py' + - 'tests/token_federation/**' + - '.github/workflows/token-federation-test.yml' + + # Run on push to main that affects token federation + push: + branches: [main] + paths: + - 'src/databricks/sql/auth/**' + - 'examples/token_federation_*.py' + - 'tests/token_federation/**' + - '.github/workflows/token-federation-test.yml' + +permissions: + id-token: write # Required for GitHub OIDC token + contents: read + +jobs: + test-token-federation: + name: Test Token Federation + runs-on: + group: databricks-protected-runner-group + labels: linux-ubuntu-latest + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Python 3.9 + uses: actions/setup-python@v5 + with: + python-version: '3.9' + cache: 'pip' + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -e . + pip install pyarrow + + - name: Get GitHub OIDC token + id: get-id-token + uses: actions/github-script@v7 + with: + script: | + const token = await core.getIDToken('https://github.com/databricks') + core.setSecret(token) + core.setOutput('token', token) + + - name: Test token federation with GitHub OIDC token + env: + DATABRICKS_HOST_FOR_TF: ${{ github.event_name == 'workflow_dispatch' && inputs.databricks_host || secrets.DATABRICKS_HOST_FOR_TF }} + DATABRICKS_HTTP_PATH_FOR_TF: ${{ github.event_name == 'workflow_dispatch' && inputs.databricks_http_path || secrets.DATABRICKS_HTTP_PATH_FOR_TF }} + IDENTITY_FEDERATION_CLIENT_ID: ${{ github.event_name == 'workflow_dispatch' && inputs.identity_federation_client_id || secrets.IDENTITY_FEDERATION_CLIENT_ID }} + OIDC_TOKEN: ${{ steps.get-id-token.outputs.token }} + run: python tests/token_federation/github_oidc_test.py diff --git a/poetry.lock b/poetry.lock index 1bc396c9..67880458 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.1.3 and should not be changed by hand. [[package]] name = "astroid" @@ -6,6 +6,7 @@ version = "3.2.4" description = "An abstract syntax tree for Python with inference support." optional = false python-versions = ">=3.8.0" +groups = ["dev"] files = [ {file = "astroid-3.2.4-py3-none-any.whl", hash = "sha256:413658a61eeca6202a59231abb473f932038fbcbf1666587f66d482083413a25"}, {file = "astroid-3.2.4.tar.gz", hash = "sha256:0e14202810b30da1b735827f78f5157be2bbd4a7a59b7707ca0bfc2fb4c0063a"}, @@ -20,6 +21,7 @@ version = "22.12.0" description = "The uncompromising code formatter." optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "black-22.12.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9eedd20838bd5d75b80c9f5487dbcb06836a43833a37846cf1d8c1cc01cef59d"}, {file = "black-22.12.0-cp310-cp310-win_amd64.whl", hash = "sha256:159a46a4947f73387b4d83e87ea006dbb2337eab6c879620a3ba52699b1f4351"}, @@ -55,6 +57,7 @@ version = "2025.1.31" description = "Python package for providing Mozilla's CA Bundle." optional = false python-versions = ">=3.6" +groups = ["main"] files = [ {file = "certifi-2025.1.31-py3-none-any.whl", hash = "sha256:ca78db4565a652026a4db2bcdf68f2fb589ea80d0be70e03929ed730746b84fe"}, {file = "certifi-2025.1.31.tar.gz", hash = "sha256:3d5da6925056f6f18f119200434a4780a94263f10d1c21d032a6f6b2baa20651"}, @@ -66,6 +69,7 @@ version = "3.4.1" description = "The Real First Universal Charset Detector. Open, modern and actively maintained alternative to Chardet." optional = false python-versions = ">=3.7" +groups = ["main"] files = [ {file = "charset_normalizer-3.4.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:91b36a978b5ae0ee86c394f5a54d6ef44db1de0815eb43de826d41d21e4af3de"}, {file = "charset_normalizer-3.4.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7461baadb4dc00fd9e0acbe254e3d7d2112e7f92ced2adc96e54ef6501c5f176"}, @@ -167,6 +171,7 @@ version = "8.1.8" description = "Composable command line interface toolkit" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "click-8.1.8-py3-none-any.whl", hash = "sha256:63c132bbbed01578a06712a2d1f497bb62d9c1c0d329b7903a866228027263b2"}, {file = "click-8.1.8.tar.gz", hash = "sha256:ed53c9d8990d83c2a27deae68e4ee337473f6330c040a31d4225c9574d16096a"}, @@ -181,6 +186,8 @@ version = "0.4.6" description = "Cross-platform colored terminal text." optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" +groups = ["dev"] +markers = "sys_platform == \"win32\" or platform_system == \"Windows\"" files = [ {file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"}, {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, @@ -192,6 +199,7 @@ version = "0.3.9" description = "serialize all of Python" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "dill-0.3.9-py3-none-any.whl", hash = "sha256:468dff3b89520b474c0397703366b7b95eebe6303f108adf9b19da1f702be87a"}, {file = "dill-0.3.9.tar.gz", hash = "sha256:81aa267dddf68cbfe8029c42ca9ec6a4ab3b22371d1c450abc54422577b4512c"}, @@ -207,6 +215,7 @@ version = "2.0.0" description = "An implementation of lxml.xmlfile for the standard library" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "et_xmlfile-2.0.0-py3-none-any.whl", hash = "sha256:7a91720bc756843502c3b7504c77b8fe44217c85c537d85037f0f536151b2caa"}, {file = "et_xmlfile-2.0.0.tar.gz", hash = "sha256:dab3f4764309081ce75662649be815c4c9081e88f0837825f90fd28317d4da54"}, @@ -218,6 +227,8 @@ version = "1.2.2" description = "Backport of PEP 654 (exception groups)" optional = false python-versions = ">=3.7" +groups = ["dev"] +markers = "python_version <= \"3.10\"" files = [ {file = "exceptiongroup-1.2.2-py3-none-any.whl", hash = "sha256:3111b9d131c238bec2f8f516e123e14ba243563fb135d3fe885990585aa7795b"}, {file = "exceptiongroup-1.2.2.tar.gz", hash = "sha256:47c2edf7c6738fafb49fd34290706d1a1a2f4d1c6df275526b62cbb4aa5393cc"}, @@ -232,6 +243,7 @@ version = "3.10" description = "Internationalized Domain Names in Applications (IDNA)" optional = false python-versions = ">=3.6" +groups = ["main"] files = [ {file = "idna-3.10-py3-none-any.whl", hash = "sha256:946d195a0d259cbba61165e88e65941f16e9b36ea6ddb97f00452bae8b1287d3"}, {file = "idna-3.10.tar.gz", hash = "sha256:12f65c9b470abda6dc35cf8e63cc574b1c52b11df2c86030af0ac09b01b13ea9"}, @@ -246,6 +258,7 @@ version = "2.1.0" description = "brain-dead simple config-ini parsing" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "iniconfig-2.1.0-py3-none-any.whl", hash = "sha256:9deba5723312380e77435581c6bf4935c94cbfab9b1ed33ef8d238ea168eb760"}, {file = "iniconfig-2.1.0.tar.gz", hash = "sha256:3abbd2e30b36733fee78f9c7f7308f2d0050e88f0087fd25c2645f63c773e1c7"}, @@ -257,6 +270,7 @@ version = "5.13.2" description = "A Python utility / library to sort Python imports." optional = false python-versions = ">=3.8.0" +groups = ["dev"] files = [ {file = "isort-5.13.2-py3-none-any.whl", hash = "sha256:8ca5e72a8d85860d5a3fa69b8745237f2939afe12dbf656afbcb47fe72d947a6"}, {file = "isort-5.13.2.tar.gz", hash = "sha256:48fdfcb9face5d58a4f6dde2e72a1fb8dcaf8ab26f95ab49fab84c2ddefb0109"}, @@ -271,6 +285,7 @@ version = "4.3.3" description = "LZ4 Bindings for Python" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "lz4-4.3.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:b891880c187e96339474af2a3b2bfb11a8e4732ff5034be919aa9029484cd201"}, {file = "lz4-4.3.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:222a7e35137d7539c9c33bb53fcbb26510c5748779364014235afc62b0ec797f"}, @@ -321,6 +336,7 @@ version = "0.7.0" description = "McCabe checker, plugin for flake8" optional = false python-versions = ">=3.6" +groups = ["dev"] files = [ {file = "mccabe-0.7.0-py2.py3-none-any.whl", hash = "sha256:6c2d30ab6be0e4a46919781807b4f0d834ebdd6c6e3dca0bda5a15f863427b6e"}, {file = "mccabe-0.7.0.tar.gz", hash = "sha256:348e0240c33b60bbdf4e523192ef919f28cb2c3d7d5c7794f74009290f236325"}, @@ -332,6 +348,7 @@ version = "1.14.1" description = "Optional static typing for Python" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "mypy-1.14.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:52686e37cf13d559f668aa398dd7ddf1f92c5d613e4f8cb262be2fb4fedb0fcb"}, {file = "mypy-1.14.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:1fb545ca340537d4b45d3eecdb3def05e913299ca72c290326be19b3804b39c0"}, @@ -391,6 +408,7 @@ version = "1.0.0" description = "Type system extensions for programs checked with the mypy type checker." optional = false python-versions = ">=3.5" +groups = ["dev"] files = [ {file = "mypy_extensions-1.0.0-py3-none-any.whl", hash = "sha256:4392f6c0eb8a5668a69e23d168ffa70f0be9ccfd32b5cc2d26a34ae5b844552d"}, {file = "mypy_extensions-1.0.0.tar.gz", hash = "sha256:75dbf8955dc00442a438fc4d0666508a9a97b6bd41aa2f0ffe9d2f2725af0782"}, @@ -402,6 +420,8 @@ version = "1.24.4" description = "Fundamental package for array computing in Python" optional = false python-versions = ">=3.8" +groups = ["main", "dev"] +markers = "python_version < \"3.10\"" files = [ {file = "numpy-1.24.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c0bfb52d2169d58c1cdb8cc1f16989101639b34c7d3ce60ed70b19c63eba0b64"}, {file = "numpy-1.24.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:ed094d4f0c177b1b8e7aa9cba7d6ceed51c0e569a5318ac0ca9a090680a6a1b1"}, @@ -439,6 +459,8 @@ version = "2.2.4" description = "Fundamental package for array computing in Python" optional = false python-versions = ">=3.10" +groups = ["main", "dev"] +markers = "python_version >= \"3.10\"" files = [ {file = "numpy-2.2.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:8146f3550d627252269ac42ae660281d673eb6f8b32f113538e0cc2a9aed42b9"}, {file = "numpy-2.2.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:e642d86b8f956098b564a45e6f6ce68a22c2c97a04f5acd3f221f57b8cb850ae"}, @@ -503,6 +525,7 @@ version = "3.2.2" description = "A generic, spec-compliant, thorough implementation of the OAuth request-signing logic" optional = false python-versions = ">=3.6" +groups = ["main"] files = [ {file = "oauthlib-3.2.2-py3-none-any.whl", hash = "sha256:8139f29aac13e25d502680e9e19963e83f16838d48a0d71c287fe40e7067fbca"}, {file = "oauthlib-3.2.2.tar.gz", hash = "sha256:9859c40929662bec5d64f34d01c99e093149682a3f38915dc0655d5a633dd918"}, @@ -519,6 +542,7 @@ version = "3.1.5" description = "A Python library to read/write Excel 2010 xlsx/xlsm files" optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "openpyxl-3.1.5-py2.py3-none-any.whl", hash = "sha256:5282c12b107bffeef825f4617dc029afaf41d0ea60823bbb665ef3079dc79de2"}, {file = "openpyxl-3.1.5.tar.gz", hash = "sha256:cf0e3cf56142039133628b5acffe8ef0c12bc902d2aadd3e0fe5878dc08d1050"}, @@ -533,6 +557,7 @@ version = "24.2" description = "Core utilities for Python packages" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "packaging-24.2-py3-none-any.whl", hash = "sha256:09abb1bccd265c01f4a3aa3f7a7db064b36514d2cba19a2f694fe6150451a759"}, {file = "packaging-24.2.tar.gz", hash = "sha256:c228a6dc5e932d346bc5739379109d49e8853dd8223571c7c5b55260edc0b97f"}, @@ -544,6 +569,8 @@ version = "2.0.3" description = "Powerful data structures for data analysis, time series, and statistics" optional = false python-versions = ">=3.8" +groups = ["main"] +markers = "python_version < \"3.10\"" files = [ {file = "pandas-2.0.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e4c7c9f27a4185304c7caf96dc7d91bc60bc162221152de697c98eb0b2648dd8"}, {file = "pandas-2.0.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:f167beed68918d62bffb6ec64f2e1d8a7d297a038f86d4aed056b9493fca407f"}, @@ -573,11 +600,7 @@ files = [ ] [package.dependencies] -numpy = [ - {version = ">=1.20.3", markers = "python_version < \"3.10\""}, - {version = ">=1.21.0", markers = "python_version >= \"3.10\" and python_version < \"3.11\""}, - {version = ">=1.23.2", markers = "python_version >= \"3.11\""}, -] +numpy = {version = ">=1.20.3", markers = "python_version < \"3.10\""} python-dateutil = ">=2.8.2" pytz = ">=2020.1" tzdata = ">=2022.1" @@ -611,6 +634,8 @@ version = "2.2.3" description = "Powerful data structures for data analysis, time series, and statistics" optional = false python-versions = ">=3.9" +groups = ["main"] +markers = "python_version >= \"3.10\"" files = [ {file = "pandas-2.2.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:1948ddde24197a0f7add2bdc4ca83bf2b1ef84a1bc8ccffd95eda17fd836ecb5"}, {file = "pandas-2.2.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:381175499d3802cde0eabbaf6324cce0c4f5d52ca6f8c377c29ad442f50f6348"}, @@ -657,7 +682,11 @@ files = [ ] [package.dependencies] -numpy = {version = ">=1.26.0", markers = "python_version >= \"3.12\""} +numpy = [ + {version = ">=1.22.4", markers = "python_version < \"3.11\""}, + {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, + {version = ">=1.23.2", markers = "python_version == \"3.11\""}, +] python-dateutil = ">=2.8.2" pytz = ">=2020.1" tzdata = ">=2022.7" @@ -693,6 +722,7 @@ version = "0.12.1" description = "Utility library for gitignore style pattern matching of file paths." optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "pathspec-0.12.1-py3-none-any.whl", hash = "sha256:a0d503e138a4c123b27490a4f7beda6a01c6f288df0e4a8b79c7eb0dc7b4cc08"}, {file = "pathspec-0.12.1.tar.gz", hash = "sha256:a482d51503a1ab33b1c67a6c3813a26953dbdc71c31dacaef9a838c4e29f5712"}, @@ -704,6 +734,7 @@ version = "4.3.6" description = "A small Python package for determining appropriate platform-specific dirs, e.g. a `user data dir`." optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "platformdirs-4.3.6-py3-none-any.whl", hash = "sha256:73e575e1408ab8103900836b97580d5307456908a03e92031bab39e4554cc3fb"}, {file = "platformdirs-4.3.6.tar.gz", hash = "sha256:357fb2acbc885b0419afd3ce3ed34564c13c9b95c89360cd9563f73aa5e2b907"}, @@ -720,6 +751,7 @@ version = "1.5.0" description = "plugin and hook calling mechanisms for python" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "pluggy-1.5.0-py3-none-any.whl", hash = "sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669"}, {file = "pluggy-1.5.0.tar.gz", hash = "sha256:2cffa88e94fdc978c4c574f15f9e59b7f4201d439195c3715ca9e2486f1d0cf1"}, @@ -735,6 +767,8 @@ version = "17.0.0" description = "Python library for Apache Arrow" optional = true python-versions = ">=3.8" +groups = ["main"] +markers = "python_version < \"3.10\" and extra == \"pyarrow\"" files = [ {file = "pyarrow-17.0.0-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:a5c8b238d47e48812ee577ee20c9a2779e6a5904f1708ae240f53ecbee7c9f07"}, {file = "pyarrow-17.0.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:db023dc4c6cae1015de9e198d41250688383c3f9af8f565370ab2b4cb5f62655"}, @@ -786,6 +820,8 @@ version = "19.0.1" description = "Python library for Apache Arrow" optional = true python-versions = ">=3.9" +groups = ["main"] +markers = "python_version >= \"3.10\" and extra == \"pyarrow\"" files = [ {file = "pyarrow-19.0.1-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:fc28912a2dc924dddc2087679cc8b7263accc71b9ff025a1362b004711661a69"}, {file = "pyarrow-19.0.1-cp310-cp310-macosx_12_0_x86_64.whl", hash = "sha256:fca15aabbe9b8355800d923cc2e82c8ef514af321e18b437c3d782aa884eaeec"}, @@ -834,12 +870,51 @@ files = [ [package.extras] test = ["cffi", "hypothesis", "pandas", "pytest", "pytz"] +[[package]] +name = "pyjwt" +version = "2.9.0" +description = "JSON Web Token implementation in Python" +optional = false +python-versions = ">=3.8" +groups = ["main"] +markers = "python_version < \"3.10\"" +files = [ + {file = "PyJWT-2.9.0-py3-none-any.whl", hash = "sha256:3b02fb0f44517787776cf48f2ae25d8e14f300e6d7545a4315cee571a415e850"}, + {file = "pyjwt-2.9.0.tar.gz", hash = "sha256:7e1e5b56cc735432a7369cbfa0efe50fa113ebecdc04ae6922deba8b84582d0c"}, +] + +[package.extras] +crypto = ["cryptography (>=3.4.0)"] +dev = ["coverage[toml] (==5.0.4)", "cryptography (>=3.4.0)", "pre-commit", "pytest (>=6.0.0,<7.0.0)", "sphinx", "sphinx-rtd-theme", "zope.interface"] +docs = ["sphinx", "sphinx-rtd-theme", "zope.interface"] +tests = ["coverage[toml] (==5.0.4)", "pytest (>=6.0.0,<7.0.0)"] + +[[package]] +name = "pyjwt" +version = "2.10.1" +description = "JSON Web Token implementation in Python" +optional = false +python-versions = ">=3.9" +groups = ["main"] +markers = "python_version >= \"3.10\"" +files = [ + {file = "PyJWT-2.10.1-py3-none-any.whl", hash = "sha256:dcdd193e30abefd5debf142f9adfcdd2b58004e644f25406ffaebd50bd98dacb"}, + {file = "pyjwt-2.10.1.tar.gz", hash = "sha256:3cc5772eb20009233caf06e9d8a0577824723b44e6648ee0a2aedb6cf9381953"}, +] + +[package.extras] +crypto = ["cryptography (>=3.4.0)"] +dev = ["coverage[toml] (==5.0.4)", "cryptography (>=3.4.0)", "pre-commit", "pytest (>=6.0.0,<7.0.0)", "sphinx", "sphinx-rtd-theme", "zope.interface"] +docs = ["sphinx", "sphinx-rtd-theme", "zope.interface"] +tests = ["coverage[toml] (==5.0.4)", "pytest (>=6.0.0,<7.0.0)"] + [[package]] name = "pylint" version = "3.2.7" description = "python code static checker" optional = false python-versions = ">=3.8.0" +groups = ["dev"] files = [ {file = "pylint-3.2.7-py3-none-any.whl", hash = "sha256:02f4aedeac91be69fb3b4bea997ce580a4ac68ce58b89eaefeaf06749df73f4b"}, {file = "pylint-3.2.7.tar.gz", hash = "sha256:1b7a721b575eaeaa7d39db076b6e7743c993ea44f57979127c517c6c572c803e"}, @@ -851,7 +926,7 @@ colorama = {version = ">=0.4.5", markers = "sys_platform == \"win32\""} dill = [ {version = ">=0.2", markers = "python_version < \"3.11\""}, {version = ">=0.3.7", markers = "python_version >= \"3.12\""}, - {version = ">=0.3.6", markers = "python_version >= \"3.11\" and python_version < \"3.12\""}, + {version = ">=0.3.6", markers = "python_version == \"3.11\""}, ] isort = ">=4.2.5,<5.13.0 || >5.13.0,<6" mccabe = ">=0.6,<0.8" @@ -870,6 +945,7 @@ version = "7.4.4" description = "pytest: simple powerful testing with Python" optional = false python-versions = ">=3.7" +groups = ["dev"] files = [ {file = "pytest-7.4.4-py3-none-any.whl", hash = "sha256:b090cdf5ed60bf4c45261be03239c2c1c22df034fbffe691abe93cd80cea01d8"}, {file = "pytest-7.4.4.tar.gz", hash = "sha256:2cf0005922c6ace4a3e2ec8b4080eb0d9753fdc93107415332f50ce9e7994280"}, @@ -892,6 +968,7 @@ version = "0.5.2" description = "A py.test plugin that parses environment files before running tests" optional = false python-versions = "*" +groups = ["dev"] files = [ {file = "pytest-dotenv-0.5.2.tar.gz", hash = "sha256:2dc6c3ac6d8764c71c6d2804e902d0ff810fa19692e95fe138aefc9b1aa73732"}, {file = "pytest_dotenv-0.5.2-py3-none-any.whl", hash = "sha256:40a2cece120a213898afaa5407673f6bd924b1fa7eafce6bda0e8abffe2f710f"}, @@ -907,6 +984,7 @@ version = "2.9.0.post0" description = "Extensions to the standard Python datetime module" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" +groups = ["main"] files = [ {file = "python-dateutil-2.9.0.post0.tar.gz", hash = "sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3"}, {file = "python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427"}, @@ -921,6 +999,7 @@ version = "1.0.1" description = "Read key-value pairs from a .env file and set them as environment variables" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "python-dotenv-1.0.1.tar.gz", hash = "sha256:e324ee90a023d808f1959c46bcbc04446a10ced277783dc6ee09987c37ec10ca"}, {file = "python_dotenv-1.0.1-py3-none-any.whl", hash = "sha256:f7b63ef50f1b690dddf550d03497b66d609393b40b564ed0d674909a68ebf16a"}, @@ -935,6 +1014,7 @@ version = "2025.2" description = "World timezone definitions, modern and historical" optional = false python-versions = "*" +groups = ["main"] files = [ {file = "pytz-2025.2-py2.py3-none-any.whl", hash = "sha256:5ddf76296dd8c44c26eb8f4b6f35488f3ccbf6fbbd7adee0b7262d43f0ec2f00"}, {file = "pytz-2025.2.tar.gz", hash = "sha256:360b9e3dbb49a209c21ad61809c7fb453643e048b38924c765813546746e81c3"}, @@ -946,6 +1026,7 @@ version = "2.32.3" description = "Python HTTP for Humans." optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "requests-2.32.3-py3-none-any.whl", hash = "sha256:70761cfe03c773ceb22aa2f671b4757976145175cdfca038c02654d061d6dcc6"}, {file = "requests-2.32.3.tar.gz", hash = "sha256:55365417734eb18255590a9ff9eb97e9e1da868d4ccd6402399eaf68af20a760"}, @@ -967,6 +1048,7 @@ version = "1.17.0" description = "Python 2 and 3 compatibility utilities" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" +groups = ["main"] files = [ {file = "six-1.17.0-py2.py3-none-any.whl", hash = "sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274"}, {file = "six-1.17.0.tar.gz", hash = "sha256:ff70335d468e7eb6ec65b95b99d3a2836546063f63acc5171de367e834932a81"}, @@ -978,6 +1060,7 @@ version = "0.20.0" description = "Python bindings for the Apache Thrift RPC system" optional = false python-versions = "*" +groups = ["main"] files = [ {file = "thrift-0.20.0.tar.gz", hash = "sha256:4dd662eadf6b8aebe8a41729527bd69adf6ceaa2a8681cbef64d1273b3e8feba"}, ] @@ -996,6 +1079,8 @@ version = "2.2.1" description = "A lil' TOML parser" optional = false python-versions = ">=3.8" +groups = ["dev"] +markers = "python_version <= \"3.10\"" files = [ {file = "tomli-2.2.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:678e4fa69e4575eb77d103de3df8a895e1591b48e740211bd1067378c69e8249"}, {file = "tomli-2.2.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:023aa114dd824ade0100497eb2318602af309e5a55595f76b626d6d9f3b7b0a6"}, @@ -1037,6 +1122,7 @@ version = "0.13.2" description = "Style preserving TOML library" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "tomlkit-0.13.2-py3-none-any.whl", hash = "sha256:7a974427f6e119197f670fbbbeae7bef749a6c14e793db934baefc1b5f03efde"}, {file = "tomlkit-0.13.2.tar.gz", hash = "sha256:fff5fe59a87295b278abd31bec92c15d9bc4a06885ab12bcea52c71119392e79"}, @@ -1048,6 +1134,7 @@ version = "4.13.0" description = "Backported and Experimental Type Hints for Python 3.8+" optional = false python-versions = ">=3.8" +groups = ["dev"] files = [ {file = "typing_extensions-4.13.0-py3-none-any.whl", hash = "sha256:c8dd92cc0d6425a97c18fbb9d1954e5ff92c1ca881a309c45f06ebc0b79058e5"}, {file = "typing_extensions-4.13.0.tar.gz", hash = "sha256:0a4ac55a5820789d87e297727d229866c9650f6521b64206413c4fbada24d95b"}, @@ -1059,6 +1146,7 @@ version = "2025.2" description = "Provider of IANA time zone data" optional = false python-versions = ">=2" +groups = ["main"] files = [ {file = "tzdata-2025.2-py2.py3-none-any.whl", hash = "sha256:1a403fada01ff9221ca8044d701868fa132215d84beb92242d9acd2147f667a8"}, {file = "tzdata-2025.2.tar.gz", hash = "sha256:b60a638fcc0daffadf82fe0f57e53d06bdec2f36c4df66280ae79bce6bd6f2b9"}, @@ -1070,13 +1158,14 @@ version = "2.2.3" description = "HTTP library with thread-safe connection pooling, file post, and more." optional = false python-versions = ">=3.8" +groups = ["main"] files = [ {file = "urllib3-2.2.3-py3-none-any.whl", hash = "sha256:ca899ca043dcb1bafa3e262d73aa25c465bfb49e0bd9dd5d59f1d0acba2f8fac"}, {file = "urllib3-2.2.3.tar.gz", hash = "sha256:e7d814a81dad81e6caf2ec9fdedb284ecc9c73076b62654547cc64ccdcae26e9"}, ] [package.extras] -brotli = ["brotli (>=1.0.9)", "brotlicffi (>=0.8.0)"] +brotli = ["brotli (>=1.0.9) ; platform_python_implementation == \"CPython\"", "brotlicffi (>=0.8.0) ; platform_python_implementation != \"CPython\""] h2 = ["h2 (>=4,<5)"] socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"] zstd = ["zstandard (>=0.18.0)"] @@ -1085,6 +1174,6 @@ zstd = ["zstandard (>=0.18.0)"] pyarrow = ["pyarrow", "pyarrow"] [metadata] -lock-version = "2.0" +lock-version = "2.1" python-versions = "^3.8.0" -content-hash = "0bd6a6a019693a69a3da5ae312cea625ea73dfc5832b1e4051c7c7d1e76553d8" +content-hash = "aa36901ed7501adeeba5384352904ba06a34d298e400e926201e0fd57f6b6678" diff --git a/pyproject.toml b/pyproject.toml index 7b95a509..7d326b2b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,11 +25,12 @@ pyarrow = [ { version = ">=18.0.0", python = ">=3.13", optional=true } ] python-dateutil = "^2.8.0" +PyJWT = ">=2.0.0" [tool.poetry.extras] pyarrow = ["pyarrow"] -[tool.poetry.dev-dependencies] +[tool.poetry.group.dev.dependencies] pytest = "^7.1.2" mypy = "^1.10.1" pylint = ">=2.12.0" diff --git a/src/databricks/sql/auth/auth.py b/src/databricks/sql/auth/auth.py index 347934ee..3931356d 100755 --- a/src/databricks/sql/auth/auth.py +++ b/src/databricks/sql/auth/auth.py @@ -5,6 +5,7 @@ AuthProvider, AccessTokenAuthProvider, ExternalAuthProvider, + CredentialsProvider, DatabricksOAuthProvider, ) @@ -12,6 +13,10 @@ class AuthType(Enum): DATABRICKS_OAUTH = "databricks-oauth" AZURE_OAUTH = "azure-oauth" + # TODO: Token federation should be a feature that works with different auth types, + # not an auth type itself. This will be refactored in a future change. + # We will add a use_token_federation flag that can be used with any auth type. + TOKEN_FEDERATION = "token-federation" # other supported types (access_token) can be inferred # we can add more types as needed later @@ -29,6 +34,7 @@ def __init__( tls_client_cert_file: Optional[str] = None, oauth_persistence=None, credentials_provider=None, + identity_federation_client_id: Optional[str] = None, ): self.hostname = hostname self.access_token = access_token @@ -40,11 +46,64 @@ def __init__( self.tls_client_cert_file = tls_client_cert_file self.oauth_persistence = oauth_persistence self.credentials_provider = credentials_provider + self.identity_federation_client_id = identity_federation_client_id def get_auth_provider(cfg: ClientContext): + """ + Get an appropriate auth provider based on the provided configuration. + + Token Federation Support: + ----------------------- + Currently, token federation is implemented as a separate auth type, but the goal is to + refactor it as a feature that can work with any auth type. The current implementation + is maintained for backward compatibility while the refactoring is planned. + + Future refactoring will introduce a `use_token_federation` flag that can be combined + with any auth type to enable token federation. + + Args: + cfg: The client context containing configuration parameters + + Returns: + An appropriate AuthProvider instance + + Raises: + RuntimeError: If no valid authentication settings are provided + """ + # If credentials_provider is explicitly provided if cfg.credentials_provider: + # If token federation is enabled and credentials provider is provided, + # wrap the credentials provider with DatabricksTokenFederationProvider + if cfg.auth_type == AuthType.TOKEN_FEDERATION.value: + from databricks.sql.auth.token_federation import ( + DatabricksTokenFederationProvider, + ) + + federation_provider = DatabricksTokenFederationProvider( + cfg.credentials_provider, + cfg.hostname, + cfg.identity_federation_client_id, + ) + return ExternalAuthProvider(federation_provider) + + # If not token federation, just use the credentials provider directly return ExternalAuthProvider(cfg.credentials_provider) + + # If we don't have a credentials provider but have token federation auth type with access token + if cfg.auth_type == AuthType.TOKEN_FEDERATION.value and cfg.access_token: + # Create a simple credentials provider and wrap it with token federation provider + from databricks.sql.auth.token_federation import ( + DatabricksTokenFederationProvider, + SimpleCredentialsProvider, + ) + + simple_provider = SimpleCredentialsProvider(cfg.access_token) + federation_provider = DatabricksTokenFederationProvider( + simple_provider, cfg.hostname, cfg.identity_federation_client_id + ) + return ExternalAuthProvider(federation_provider) + if cfg.auth_type in [AuthType.DATABRICKS_OAUTH.value, AuthType.AZURE_OAUTH.value]: assert cfg.oauth_redirect_port_range is not None assert cfg.oauth_client_id is not None @@ -102,6 +161,27 @@ def get_client_id_and_redirect_port(use_azure_auth: bool): def get_python_sql_connector_auth_provider(hostname: str, **kwargs): + """ + Get an auth provider for the Python SQL connector. + + This function is the main entry point for authentication in the SQL connector. + It processes the parameters and creates an appropriate auth provider. + + TODO: Future refactoring needed: + 1. Add a use_token_federation flag that can be combined with any auth type + 2. Remove TOKEN_FEDERATION as an auth_type while maintaining backward compatibility + 3. Create a token federation wrapper that can wrap any existing auth provider + + Args: + hostname: The Databricks server hostname + **kwargs: Additional configuration parameters + + Returns: + An appropriate AuthProvider instance + + Raises: + ValueError: If username/password authentication is attempted (no longer supported) + """ auth_type = kwargs.get("auth_type") (client_id, redirect_port_range) = get_client_id_and_redirect_port( auth_type == AuthType.AZURE_OAUTH.value @@ -125,5 +205,6 @@ def get_python_sql_connector_auth_provider(hostname: str, **kwargs): else redirect_port_range, oauth_persistence=kwargs.get("experimental_oauth_persistence"), credentials_provider=kwargs.get("credentials_provider"), + identity_federation_client_id=kwargs.get("identity_federation_client_id"), ) return get_auth_provider(cfg) diff --git a/src/databricks/sql/auth/authenticators.py b/src/databricks/sql/auth/authenticators.py index 64eb91bb..c425f088 100644 --- a/src/databricks/sql/auth/authenticators.py +++ b/src/databricks/sql/auth/authenticators.py @@ -26,10 +26,16 @@ class CredentialsProvider(abc.ABC): @abc.abstractmethod def auth_type(self) -> str: + """ + Returns the authentication type for this provider + """ ... @abc.abstractmethod def __call__(self, *args, **kwargs) -> HeaderFactory: + """ + Configure and return a HeaderFactory that provides authentication headers + """ ... diff --git a/src/databricks/sql/auth/oidc_utils.py b/src/databricks/sql/auth/oidc_utils.py new file mode 100644 index 00000000..b0421cf7 --- /dev/null +++ b/src/databricks/sql/auth/oidc_utils.py @@ -0,0 +1,58 @@ +import logging +import requests +from typing import Optional + +from databricks.sql.auth.endpoint import ( + get_oauth_endpoints, + infer_cloud_from_host, +) + +logger = logging.getLogger(__name__) + + +class OIDCDiscoveryUtil: + """ + Utility class for OIDC discovery operations. + + This class handles discovery of OIDC endpoints through standard + discovery mechanisms, with fallback to default endpoints if needed. + """ + + # Standard token endpoint path for Databricks workspaces + DEFAULT_TOKEN_PATH = "oidc/v1/token" + + @staticmethod + def discover_token_endpoint(hostname: str) -> str: + """ + Get the token endpoint for the given Databricks hostname. + + For Databricks workspaces, the token endpoint is always at host/oidc/v1/token. + + Args: + hostname: The hostname to get token endpoint for + + Returns: + str: The token endpoint URL + """ + # Format the hostname and return the standard endpoint + hostname = OIDCDiscoveryUtil.format_hostname(hostname) + token_endpoint = f"{hostname}{OIDCDiscoveryUtil.DEFAULT_TOKEN_PATH}" + logger.info(f"Using token endpoint: {token_endpoint}") + return token_endpoint + + @staticmethod + def format_hostname(hostname: str) -> str: + """ + Format hostname to ensure it has proper https:// prefix and trailing slash. + + Args: + hostname: The hostname to format + + Returns: + str: The formatted hostname + """ + if not hostname.startswith("https://"): + hostname = f"https://{hostname}" + if not hostname.endswith("/"): + hostname = f"{hostname}/" + return hostname diff --git a/src/databricks/sql/auth/token.py b/src/databricks/sql/auth/token.py new file mode 100644 index 00000000..5abd1e02 --- /dev/null +++ b/src/databricks/sql/auth/token.py @@ -0,0 +1,65 @@ +""" +Token class for authentication tokens with expiry handling. +""" + +from datetime import datetime, timezone, timedelta +from typing import Optional + + +class Token: + """ + Represents an OAuth token with expiry information. + + This class handles token state including expiry calculation. + """ + + # Minimum time buffer before expiry to consider a token still valid (in seconds) + MIN_VALIDITY_BUFFER = 10 + + def __init__( + self, + access_token: str, + token_type: str, + refresh_token: str = "", + expiry: Optional[datetime] = None, + ): + """ + Initialize a Token object. + + Args: + access_token: The access token string + token_type: The token type (usually "Bearer") + refresh_token: Optional refresh token + expiry: Token expiry datetime, must be provided + + Raises: + ValueError: If no expiry is provided + """ + self.access_token = access_token + self.token_type = token_type + self.refresh_token = refresh_token + + # Ensure we have an expiry time + if expiry is None: + raise ValueError("Token expiry must be provided") + + # Ensure expiry is timezone-aware + if expiry.tzinfo is None: + # Convert naive datetime to aware datetime + self.expiry = expiry.replace(tzinfo=timezone.utc) + else: + self.expiry = expiry + + def is_valid(self) -> bool: + """ + Check if the token is valid (has at least MIN_VALIDITY_BUFFER seconds before expiry). + + Returns: + bool: True if the token is valid, False otherwise + """ + buffer = timedelta(seconds=self.MIN_VALIDITY_BUFFER) + return datetime.now(tz=timezone.utc) + buffer < self.expiry + + def __str__(self) -> str: + """Return the token as a string in the format used for Authorization headers.""" + return f"{self.token_type} {self.access_token}" diff --git a/src/databricks/sql/auth/token_federation.py b/src/databricks/sql/auth/token_federation.py new file mode 100644 index 00000000..7c2ed9b2 --- /dev/null +++ b/src/databricks/sql/auth/token_federation.py @@ -0,0 +1,371 @@ +import base64 +import json +import logging +from datetime import datetime, timezone, timedelta +from typing import Dict, Optional, Any, Tuple +from urllib.parse import urlparse + +import requests +import jwt +from requests.exceptions import RequestException + +from databricks.sql.auth.authenticators import CredentialsProvider, HeaderFactory +from databricks.sql.auth.oidc_utils import OIDCDiscoveryUtil +from databricks.sql.auth.token import Token + +logger = logging.getLogger(__name__) + + +class DatabricksTokenFederationProvider(CredentialsProvider): + """ + Implementation of the Credential Provider that exchanges a third party access token + for a Databricks token. + + This provider wraps an existing credentials provider and handles token exchange when + the token is from a different host than the Databricks host. It also manages token + refresh when tokens are expired. + """ + + # HTTP request configuration + EXCHANGE_HEADERS = { + "Accept": "*/*", + "Content-Type": "application/x-www-form-urlencoded", + } + + # Token exchange parameters + TOKEN_EXCHANGE_PARAMS = { + "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange", + "scope": "sql", + "subject_token_type": "urn:ietf:params:oauth:token-type:jwt", + "return_original_token_if_authenticated": "true", + } + + def __init__( + self, + credentials_provider: CredentialsProvider, + hostname: str, + identity_federation_client_id: Optional[str] = None, + ): + """ + Initialize the token federation provider. + + Args: + credentials_provider: The underlying credentials provider + hostname: The Databricks hostname + identity_federation_client_id: Optional client ID for identity federation + """ + self.credentials_provider = credentials_provider + self.hostname = hostname + self.identity_federation_client_id = identity_federation_client_id + self.token_endpoint: Optional[str] = None + + # Store the current token information + self.current_token: Optional[Token] = None + self.external_headers: Optional[Dict[str, str]] = None + + def auth_type(self) -> str: + """Return the auth type from the underlying credentials provider.""" + return self.credentials_provider.auth_type() + + @property + def host(self) -> str: + """ + Alias for hostname to maintain compatibility with code expecting a host attribute. + """ + return self.hostname + + def __call__(self, *args, **kwargs) -> HeaderFactory: + """ + Configure and return a HeaderFactory that provides authentication headers. + This is called by the ExternalAuthProvider to get headers for authentication. + """ + # First call the underlying credentials provider to get its headers + header_factory = self.credentials_provider(*args, **kwargs) + + # Get the standard token endpoint if not already set + if self.token_endpoint is None: + self.token_endpoint = OIDCDiscoveryUtil.discover_token_endpoint( + self.hostname + ) + + # Return a function that will get authentication headers + return self.get_auth_headers + + def _extract_token_info_from_header( + self, headers: Dict[str, str] + ) -> Tuple[str, str]: + """ + Extract token type and token value from authorization header. + + Args: + headers: Headers dictionary + + Returns: + Tuple[str, str]: Token type and token value + + Raises: + ValueError: If no authorization header is found or it has invalid format + """ + auth_header = headers.get("Authorization") + if not auth_header: + raise ValueError("No Authorization header found") + + parts = auth_header.split(" ", 1) + if len(parts) != 2: + raise ValueError(f"Invalid Authorization header format: {auth_header}") + + return parts[0], parts[1] + + def _parse_jwt_claims(self, token: str) -> Dict[str, Any]: + """ + Parse JWT token claims without validation. + + Args: + token: JWT token string + + Returns: + Dict[str, Any]: Parsed JWT claims + """ + try: + return jwt.decode(token, options={"verify_signature": False}) + except Exception as e: + logger.error(f"Failed to parse JWT: {str(e)}") + return {} + + def _get_expiry_from_jwt(self, token: str) -> Optional[datetime]: + """ + Extract expiry datetime from JWT token. + + Args: + token: JWT token string + + Returns: + Optional[datetime]: Expiry datetime if found in token, None otherwise + """ + claims = self._parse_jwt_claims(token) + + # Look for standard JWT expiry claim ("exp") + if "exp" in claims: + try: + # JWT expiry is in seconds since epoch + expiry_timestamp = int(claims["exp"]) + # Convert to datetime + return datetime.fromtimestamp(expiry_timestamp, tz=timezone.utc) + except (ValueError, TypeError) as e: + logger.warning(f"Invalid JWT expiry value: {e}") + + return None + + def _is_same_host(self, url1: str, url2: str) -> bool: + """ + Check if two URLs have the same host. + + Args: + url1: First URL + url2: Second URL + + Returns: + bool: True if hosts are the same, False otherwise + """ + try: + # Add protocol if missing to ensure proper parsing + if not url1.startswith(("http://", "https://")): + url1 = f"https://{url1}" + if not url2.startswith(("http://", "https://")): + url2 = f"https://{url2}" + + # Parse the URLs + parsed1 = urlparse(url1) + parsed2 = urlparse(url2) + + # Compare the hostnames + return parsed1.netloc.lower() == parsed2.netloc.lower() + except Exception as e: + logger.warning(f"Error comparing hosts: {str(e)}") + return False + + def refresh_token(self) -> Token: + """ + Refresh the token and return the new Token object. + + This method gets a fresh token from the credentials provider, + exchanges it if necessary, and returns the new Token object. + + Returns: + Token: The new refreshed token + + Raises: + ValueError: If token refresh fails + """ + # Get fresh headers from the credentials provider + header_factory = self.credentials_provider() + self.external_headers = header_factory() + + # Extract the new token info + token_type, access_token = self._extract_token_info_from_header( + self.external_headers + ) + + # Check if we need to exchange the token + token_claims = self._parse_jwt_claims(access_token) + + # Create new token based on whether it's from the same host or not + if self._is_same_host(token_claims.get("iss", ""), self.hostname): + # Token is from the same host, no need to exchange + logger.debug("Token from same host, creating token without exchange") + + expiry = self._get_expiry_from_jwt(access_token) + if expiry is None: + raise ValueError("Could not determine token expiry from JWT") + + new_token = Token(access_token, token_type, "", expiry) + else: + # Token is from a different host, need to exchange + logger.debug("Token from different host, exchanging token") + new_token = self._exchange_token(access_token) + + # Store the token + self.current_token = new_token + + return new_token + + def get_current_token(self) -> Token: + """ + Get the current token, refreshing if necessary. + + This method checks if the current token is valid and not expired. + If it is valid, it returns the current token. + If it is expired or doesn't exist, it refreshes the token. + + Returns: + Token: The current valid token + + Raises: + ValueError: If unable to get a valid token + """ + # Return current token if it exists and is valid + if self.current_token is not None and self.current_token.is_valid(): + return self.current_token + + # Token doesn't exist or is expired, get a fresh one + return self.refresh_token() + + def get_auth_headers(self) -> Dict[str, str]: + """ + Get authorization headers using the current token. + + This method gets the current token and returns it formatted + as authorization headers. + + Returns: + Dict[str, str]: Authorization headers + """ + try: + token = self.get_current_token() + return {"Authorization": f"{token.token_type} {token.access_token}"} + except Exception as e: + logger.error(f"Error getting auth headers: {str(e)}") + + # Fall back to external headers if available + if self.external_headers: + return self.external_headers + + # Return empty dict as a last resort + return {} + + def _send_token_exchange_request( + self, token_exchange_data: Dict[str, str] + ) -> Dict[str, Any]: + """ + Send the token exchange request to the token endpoint. + + Args: + token_exchange_data: Token exchange request data + + Returns: + Dict[str, Any]: Token exchange response + + Raises: + ValueError: If token exchange fails + """ + if not self.token_endpoint: + raise ValueError("Token endpoint not initialized") + + response = requests.post( + self.token_endpoint, data=token_exchange_data, headers=self.EXCHANGE_HEADERS + ) + + if response.status_code != 200: + raise ValueError( + f"Token exchange failed with status code {response.status_code}: " + f"{response.text}" + ) + + return response.json() + + def _exchange_token(self, access_token: str) -> Token: + """ + Exchange an external token for a Databricks token. + + Args: + access_token: External token to exchange + + Returns: + Token: Exchanged token + + Raises: + ValueError: If token exchange fails + """ + # Prepare the request data + token_exchange_data = dict(self.TOKEN_EXCHANGE_PARAMS) + token_exchange_data["subject_token"] = access_token + + # Add client_id if provided + if self.identity_federation_client_id: + token_exchange_data["client_id"] = self.identity_federation_client_id + + # Send the token exchange request + resp_data = self._send_token_exchange_request(token_exchange_data) + + # Extract token information + new_access_token = resp_data.get("access_token") + if not new_access_token: + raise ValueError("No access token in exchange response") + + token_type = resp_data.get("token_type", "Bearer") + refresh_token = resp_data.get("refresh_token", "") + + # Extract expiry from JWT claims + expiry = self._get_expiry_from_jwt(new_access_token) + if expiry is None: + raise ValueError("Unable to determine token expiry from JWT claims") + + return Token(new_access_token, token_type, refresh_token, expiry) + + +class SimpleCredentialsProvider(CredentialsProvider): + """A simple credentials provider that returns a fixed token.""" + + def __init__( + self, token: str, token_type: str = "Bearer", auth_type_value: str = "token" + ): + """ + Initialize a SimpleCredentialsProvider. + """ + self.token = token + self.token_type = token_type + self.auth_type_value = auth_type_value + + def auth_type(self) -> str: + """Return the auth type value.""" + return self.auth_type_value + + def __call__(self, *args, **kwargs) -> HeaderFactory: + """ + Return a HeaderFactory that provides a fixed token. + """ + + def get_headers() -> Dict[str, str]: + return {"Authorization": f"{self.token_type} {self.token}"} + + return get_headers diff --git a/tests/token_federation/github_oidc_test.py b/tests/token_federation/github_oidc_test.py new file mode 100755 index 00000000..10bd8686 --- /dev/null +++ b/tests/token_federation/github_oidc_test.py @@ -0,0 +1,169 @@ +""" +Test script for Databricks SQL token federation with GitHub Actions OIDC tokens. + +This script tests the Databricks SQL connector with token federation +using a GitHub Actions OIDC token. It connects to a Databricks SQL warehouse, +runs a simple query, and shows the connected user. +""" + +import os +import sys +import logging +import jwt +from databricks import sql + + +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" +) +logger = logging.getLogger(__name__) + + +def decode_jwt(token): + """ + Decode and return the claims from a JWT token. + + Args: + token: The JWT token string + + Returns: + dict: The decoded token claims or empty dict if decoding fails + """ + try: + # Using PyJWT library to decode token without verification + return jwt.decode(token, options={"verify_signature": False}) + except Exception as e: + logger.error(f"Failed to decode token: {str(e)}") + return {} + + +def get_environment_variables(): + """ + Get required environment variables for the test. + + Returns: + tuple: (github_token, host, http_path, identity_federation_client_id) + """ + github_token = os.environ.get("OIDC_TOKEN") + host = os.environ.get("DATABRICKS_HOST_FOR_TF") + http_path = os.environ.get("DATABRICKS_HTTP_PATH_FOR_TF") + identity_federation_client_id = os.environ.get("IDENTITY_FEDERATION_CLIENT_ID") + + # Validate required environment variables + if not github_token: + raise ValueError("OIDC_TOKEN environment variable is required") + if not host: + raise ValueError("DATABRICKS_HOST_FOR_TF environment variable is required") + if not http_path: + raise ValueError("DATABRICKS_HTTP_PATH_FOR_TF environment variable is required") + + return github_token, host, http_path, identity_federation_client_id + + +def display_token_info(claims): + """ + Display token claims for debugging. + + Args: + claims: Dictionary containing JWT token claims + """ + if not claims: + logger.warning("No token claims available to display") + return + + logger.info("=== GitHub OIDC Token Claims ===") + logger.info(f"Token issuer: {claims.get('iss')}") + logger.info(f"Token subject: {claims.get('sub')}") + logger.info(f"Token audience: {claims.get('aud')}") + logger.info(f"Token expiration: {claims.get('exp', 'unknown')}") + logger.info(f"Repository: {claims.get('repository', 'unknown')}") + logger.info(f"Workflow ref: {claims.get('workflow_ref', 'unknown')}") + logger.info(f"Event name: {claims.get('event_name', 'unknown')}") + logger.info("===============================") + + +def test_databricks_connection( + host, http_path, github_token, identity_federation_client_id +): + """ + Test connection to Databricks using token federation. + + Args: + host: Databricks host + http_path: Databricks HTTP path + github_token: GitHub OIDC token + identity_federation_client_id: Identity federation client ID + + Returns: + bool: True if the test is successful, False otherwise + """ + logger.info("=== Testing Connection via Connector ===") + logger.info(f"Connecting to Databricks at {host}{http_path}") + logger.info(f"Using client ID: {identity_federation_client_id}") + + connection_params = { + "server_hostname": host, + "http_path": http_path, + "access_token": github_token, + "auth_type": "token-federation", + } + + # Add identity federation client ID if provided + if identity_federation_client_id: + connection_params[ + "identity_federation_client_id" + ] = identity_federation_client_id + + try: + with sql.connect(**connection_params) as connection: + logger.info("Connection established successfully") + + # Execute a simple query + cursor = connection.cursor() + cursor.execute("SELECT 1 + 1 as result") + result = cursor.fetchall() + logger.info(f"Query result: {result[0][0]}") + + # Show current user + cursor.execute("SELECT current_user() as user") + result = cursor.fetchall() + logger.info(f"Connected as user: {result[0][0]}") + + logger.info("Token federation test successful!") + return True + except Exception as e: + logger.error(f"Error connecting to Databricks: {str(e)}") + return False + + +def main(): + """Main entry point for the test script.""" + try: + # Get environment variables + ( + github_token, + host, + http_path, + identity_federation_client_id, + ) = get_environment_variables() + + # Display token claims + claims = decode_jwt(github_token) + display_token_info(claims) + + # Test Databricks connection + success = test_databricks_connection( + host, http_path, github_token, identity_federation_client_id + ) + + if not success: + logger.error("Token federation test failed") + sys.exit(1) + + except Exception as e: + logger.error(f"Unexpected error: {str(e)}") + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/tests/unit/test_token_federation.py b/tests/unit/test_token_federation.py new file mode 100644 index 00000000..e4344fd5 --- /dev/null +++ b/tests/unit/test_token_federation.py @@ -0,0 +1,395 @@ +import pytest +from unittest.mock import MagicMock, patch +from datetime import datetime, timezone, timedelta +import jwt + +from databricks.sql.auth.token import Token +from databricks.sql.auth.token_federation import ( + DatabricksTokenFederationProvider, + SimpleCredentialsProvider, +) +from databricks.sql.auth.oidc_utils import OIDCDiscoveryUtil + + +@pytest.fixture +def future_time(): + """Fixture providing a future time for token expiry.""" + return datetime.now(tz=timezone.utc) + timedelta(hours=1) + + +@pytest.fixture +def valid_token(future_time): + """Fixture providing a valid token.""" + return Token("access_token_value", "Bearer", expiry=future_time) + + +class TestToken: + """Tests for the Token class.""" + + def test_valid_token_properties(self, future_time): + """Test that a valid token has the expected properties.""" + # Create token with future expiry + token = Token("access_token_value", "Bearer", expiry=future_time) + + # Verify properties + assert token.access_token == "access_token_value" + assert token.token_type == "Bearer" + assert token.refresh_token == "" + assert token.expiry == future_time + assert token.is_valid() + assert str(token) == "Bearer access_token_value" + + def test_expired_token_is_invalid(self): + """Test that an expired token is recognized as invalid.""" + past_time = datetime.now(tz=timezone.utc) - timedelta(hours=1) + token = Token("expired", "Bearer", expiry=past_time) + + assert not token.is_valid() + + def test_almost_expired_token_is_invalid(self): + """Test that a token about to expire is recognized as invalid.""" + almost_expired = datetime.now(tz=timezone.utc) + timedelta( + seconds=5 + ) # Less than MIN_VALIDITY_BUFFER + token = Token("almost", "Bearer", expiry=almost_expired) + + assert not token.is_valid() + + +class TestSimpleCredentialsProvider: + """Tests for the SimpleCredentialsProvider class.""" + + def test_provider_initialization_and_headers(self): + """Test SimpleCredentialsProvider initialization and header generation.""" + provider = SimpleCredentialsProvider("token1", "Bearer", "token") + + # Check auth type + assert provider.auth_type() == "token" + + # Check header generation + headers = provider()() + assert headers == {"Authorization": "Bearer token1"} + + +class TestOIDCDiscoveryUtil: + """Tests for the OIDCDiscoveryUtil class.""" + + @pytest.mark.parametrize( + "hostname,expected", + [ + # Without protocol and without trailing slash + ("databricks.com", "https://databricks.com/oidc/v1/token"), + # With protocol but without trailing slash + ("https://databricks.com", "https://databricks.com/oidc/v1/token"), + # With protocol and trailing slash + ("https://databricks.com/", "https://databricks.com/oidc/v1/token"), + ], + ) + def test_discover_token_endpoint(self, hostname, expected): + """Test token endpoint creation for various hostname formats.""" + token_endpoint = OIDCDiscoveryUtil.discover_token_endpoint(hostname) + assert token_endpoint == expected + + @pytest.mark.parametrize( + "hostname,expected", + [ + # Without protocol and without trailing slash + ("databricks.com", "https://databricks.com/"), + # With protocol but without trailing slash + ("https://databricks.com", "https://databricks.com/"), + # With protocol and trailing slash + ("https://databricks.com/", "https://databricks.com/"), + ], + ) + def test_format_hostname(self, hostname, expected): + """Test hostname formatting with various input formats.""" + formatted = OIDCDiscoveryUtil.format_hostname(hostname) + assert formatted == expected + + +class TestDatabricksTokenFederationProvider: + """Tests for the DatabricksTokenFederationProvider class.""" + + # ==== Fixtures ==== + @pytest.fixture + def mock_credentials_provider(self): + """Fixture providing a mock credentials provider.""" + provider = MagicMock() + provider.auth_type.return_value = "mock_auth_type" + header_factory = MagicMock() + header_factory.return_value = {"Authorization": "Bearer mock_token"} + provider.return_value = header_factory + return provider + + @pytest.fixture + def federation_provider(self, mock_credentials_provider): + """Fixture providing a token federation provider with mocked dependencies.""" + provider = DatabricksTokenFederationProvider( + mock_credentials_provider, "databricks.com", "client_id" + ) + # Initialize token endpoint to avoid discovery during tests + provider.token_endpoint = "https://databricks.com/oidc/v1/token" + return provider + + @pytest.fixture + def mock_dependencies(self): + """Mock all external dependencies of the federation provider.""" + with patch( + "databricks.sql.auth.oidc_utils.OIDCDiscoveryUtil.discover_token_endpoint", + return_value="https://databricks.com/oidc/v1/token", + ) as mock_discover: + with patch( + "databricks.sql.auth.token_federation.DatabricksTokenFederationProvider._parse_jwt_claims" + ) as mock_parse_jwt: + with patch( + "databricks.sql.auth.token_federation.DatabricksTokenFederationProvider._exchange_token" + ) as mock_exchange: + with patch( + "databricks.sql.auth.token_federation.DatabricksTokenFederationProvider._is_same_host" + ) as mock_is_same_host: + with patch( + "databricks.sql.auth.token_federation.requests.post" + ) as mock_post: + yield { + "discover": mock_discover, + "parse_jwt": mock_parse_jwt, + "exchange": mock_exchange, + "is_same_host": mock_is_same_host, + "post": mock_post, + } + + # ==== Basic functionality tests ==== + def test_provider_initialization(self, federation_provider): + """Test basic provider initialization and properties.""" + assert federation_provider.host == "databricks.com" + assert federation_provider.hostname == "databricks.com" + assert federation_provider.auth_type() == "mock_auth_type" + + # ==== Utility method tests ==== + @pytest.mark.parametrize( + "url1,url2,expected", + [ + # Same host with same protocol + ("https://databricks.com", "https://databricks.com", True), + # Different hosts + ("https://databricks.com", "https://different.com", False), + # Same host with different paths + ("https://databricks.com/path", "https://databricks.com/other", True), + # Same host with missing protocol + ("databricks.com", "https://databricks.com", True), + ], + ) + def test_is_same_host(self, federation_provider, url1, url2, expected): + """Test host comparison logic with various URL formats.""" + assert federation_provider._is_same_host(url1, url2) is expected + + @pytest.mark.parametrize( + "headers,expected_result,should_raise", + [ + # Valid Bearer token + ({"Authorization": "Bearer token"}, ("Bearer", "token"), False), + # Valid custom token type + ({"Authorization": "CustomType token"}, ("CustomType", "token"), False), + # Missing Authorization header + ({}, None, True), + # Empty Authorization header + ({"Authorization": ""}, None, True), + # Malformed Authorization header + ({"Authorization": "Bearer"}, None, True), + ], + ) + def test_extract_token_info( + self, federation_provider, headers, expected_result, should_raise + ): + """Test token extraction from headers with various formats.""" + if should_raise: + with pytest.raises(ValueError): + federation_provider._extract_token_info_from_header(headers) + else: + result = federation_provider._extract_token_info_from_header(headers) + assert result == expected_result + + def test_get_expiry_from_jwt(self, federation_provider): + """Test JWT token expiry extraction.""" + # Create a valid JWT token with expiry + expiry_timestamp = int( + (datetime.now(tz=timezone.utc) + timedelta(hours=1)).timestamp() + ) + valid_payload = { + "exp": expiry_timestamp, + "iat": int(datetime.now(tz=timezone.utc).timestamp()), + "sub": "test-subject", + } + valid_token = jwt.encode(valid_payload, "secret", algorithm="HS256") + + # Test with valid token + expiry = federation_provider._get_expiry_from_jwt(valid_token) + assert expiry is not None + assert isinstance(expiry, datetime) + assert expiry.tzinfo is not None # Should be timezone-aware + # Allow for small rounding differences + assert ( + abs( + ( + expiry - datetime.fromtimestamp(expiry_timestamp, tz=timezone.utc) + ).total_seconds() + ) + < 1 + ) + + # Test with invalid token format + assert federation_provider._get_expiry_from_jwt("invalid-token") is None + + # Test with token missing expiry claim + token_without_exp = jwt.encode( + {"sub": "test-subject"}, "secret", algorithm="HS256" + ) + assert federation_provider._get_expiry_from_jwt(token_without_exp) is None + + # ==== Core functionality tests ==== + def test_token_reuse_when_valid(self, federation_provider, future_time): + """Test that a valid token is reused without exchange.""" + # Prepare mock for exchange function + with patch.object(federation_provider, "_exchange_token") as mock_exchange: + # Set up a valid token + federation_provider.current_token = Token( + "existing_token", "Bearer", expiry=future_time + ) + federation_provider.external_headers = { + "Authorization": "Bearer external_token" + } + + # Get headers + headers = federation_provider.get_auth_headers() + + # Verify token was reused without exchange + assert headers["Authorization"] == "Bearer existing_token" + mock_exchange.assert_not_called() + + def test_token_exchange_from_different_host( + self, federation_provider, mock_dependencies + ): + """Test token exchange when token is from a different host.""" + # Configure mocks for token from different host + mock_dependencies["parse_jwt"].return_value = { + "iss": "https://login.microsoftonline.com/tenant" + } + mock_dependencies["is_same_host"].return_value = False + + # Configure credentials provider + headers = {"Authorization": "Bearer external_token"} + header_factory = MagicMock(return_value=headers) + mock_creds = MagicMock(return_value=header_factory) + federation_provider.credentials_provider = mock_creds + + # Configure mock token exchange + future_time = datetime.now(tz=timezone.utc) + timedelta(hours=1) + exchanged_token = Token("databricks_token", "Bearer", expiry=future_time) + mock_dependencies["exchange"].return_value = exchanged_token + + # Call refresh_token + token = federation_provider.refresh_token() + + # Verify token was exchanged + mock_dependencies["exchange"].assert_called_with("external_token") + assert token.access_token == "databricks_token" + assert federation_provider.current_token == token + + def test_token_from_same_host(self, federation_provider, mock_dependencies): + """Test handling of token from the same host (no exchange needed).""" + # Configure mocks for token from same host + mock_dependencies["parse_jwt"].return_value = {"iss": "https://databricks.com"} + mock_dependencies["is_same_host"].return_value = True + + # Configure credentials provider + headers = {"Authorization": "Bearer databricks_token"} + header_factory = MagicMock(return_value=headers) + mock_creds = MagicMock(return_value=header_factory) + federation_provider.credentials_provider = mock_creds + + # Mock JWT expiry extraction + expiry_time = datetime.now(tz=timezone.utc) + timedelta(hours=2) + with patch.object( + federation_provider, "_get_expiry_from_jwt", return_value=expiry_time + ): + # Call refresh_token + token = federation_provider.refresh_token() + + # Verify no exchange was performed + mock_dependencies["exchange"].assert_not_called() + assert token.access_token == "databricks_token" + assert token.expiry == expiry_time + + def test_call_returns_auth_headers_function( + self, federation_provider, mock_dependencies + ): + """Test that __call__ returns the get_auth_headers method directly.""" + with patch.object( + federation_provider, + "get_auth_headers", + return_value={"Authorization": "Bearer test_token"}, + ) as mock_get_auth: + # Get the header factory from __call__ + result = federation_provider() + + # Verify it's the get_auth_headers method + assert result is federation_provider.get_auth_headers + + # Call the result and verify it returns headers + headers = result() + assert headers == {"Authorization": "Bearer test_token"} + mock_get_auth.assert_called_once() + + def test_token_exchange_success(self, federation_provider): + """Test successful token exchange.""" + # Mock successful response + with patch("databricks.sql.auth.token_federation.requests.post") as mock_post: + # Create a token with a valid expiry + expiry_timestamp = int( + (datetime.now(tz=timezone.utc) + timedelta(hours=1)).timestamp() + ) + + # Configure mock response + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "access_token": "new_token", + "token_type": "Bearer", + "refresh_token": "refresh_value", + } + mock_post.return_value = mock_response + + # Mock JWT expiry extraction to return a valid expiry + with patch.object( + federation_provider, + "_get_expiry_from_jwt", + return_value=datetime.fromtimestamp(expiry_timestamp, tz=timezone.utc), + ): + # Call the exchange method + token = federation_provider._exchange_token("original_token") + + # Verify token properties + assert token.access_token == "new_token" + assert token.token_type == "Bearer" + assert token.refresh_token == "refresh_value" + + # Verify expiry time is correctly set + expiry_datetime = datetime.fromtimestamp( + expiry_timestamp, tz=timezone.utc + ) + assert token.expiry == expiry_datetime + + def test_token_exchange_failure(self, federation_provider): + """Test token exchange failure handling.""" + # Mock error response + with patch("databricks.sql.auth.token_federation.requests.post") as mock_post: + mock_response = MagicMock() + mock_response.status_code = 401 + mock_response.text = "Unauthorized" + mock_post.return_value = mock_response + + # Call the method and expect an exception + with pytest.raises( + ValueError, match="Token exchange failed with status code 401" + ): + federation_provider._exchange_token("original_token")