diff --git a/superset/exceptions.py b/superset/exceptions.py index 91a4656595cd5..86a0964556cab 100644 --- a/superset/exceptions.py +++ b/superset/exceptions.py @@ -356,3 +356,9 @@ def __init__(self, error: str): extra={"error": error}, ) ) + + +class CreateAuthLockFailedException(Exception): + """ + Exception to signalize failure to acquire lock when refreshing token. + """ diff --git a/superset/key_value/types.py b/superset/key_value/types.py index f5865de3231f0..043d75a06fe88 100644 --- a/superset/key_value/types.py +++ b/superset/key_value/types.py @@ -49,6 +49,7 @@ class KeyValueResource(StrEnum): DASHBOARD_PERMALINK = "dashboard_permalink" EXPLORE_PERMALINK = "explore_permalink" METASTORE_CACHE = "superset_metastore_cache" + OAUTH2 = "oauth2" class SharedKey(StrEnum): diff --git a/superset/utils/oauth2.py b/superset/utils/oauth2.py index 90a835741dcb0..9db29468b8961 100644 --- a/superset/utils/oauth2.py +++ b/superset/utils/oauth2.py @@ -17,12 +17,36 @@ from __future__ import annotations +import logging +import uuid +from collections.abc import Iterator +from contextlib import contextmanager from datetime import datetime, timedelta +from typing import TYPE_CHECKING + +import backoff from superset import db from superset.db_engine_specs.base import BaseEngineSpec +from superset.exceptions import CreateAuthLockFailedException +from superset.key_value.exceptions import KeyValueCreateFailedError +from superset.key_value.types import KeyValueResource, PickleKeyValueCodec + +if TYPE_CHECKING: + from superset.models.core import DatabaseUserOAuth2Tokens + + +LOCK_EXPIRATION = timedelta(seconds=30) +logger = logging.getLogger(__name__) +@backoff.on_exception( + backoff.expo, + CreateAuthLockFailedException, + factor=10, + base=2, + max_tries=5, +) def get_oauth2_access_token( database_id: int, user_id: int, @@ -49,21 +73,80 @@ def get_oauth2_access_token( return token.access_token if token.refresh_token: - # refresh access token + return refresh_oauth2_token(database_id, user_id, db_engine_spec, token) + + # since the access token is expired and there's no refresh token, delete the entry + db.session.delete(token) + + return None + + +def integers_to_uuid(a: int, b: int) -> uuid.UUID: # pylint: disable=invalid-name + """ + Generate UUID based on a namespace UUID and the string representation of integer pair. + """ + pair_str = f"{a}-{b}" + return uuid.uuid5(uuid.NAMESPACE_DNS, pair_str) + + +@contextmanager +def AuthLock( # pylint: disable=invalid-name + user_id: int, + database_id: int, +) -> Iterator[None]: + """ + KV global lock for refreshing tokens. + """ + # pylint: disable=import-outside-toplevel + from superset.commands.key_value.create import CreateKeyValueCommand + from superset.commands.key_value.delete import DeleteKeyValueCommand + from superset.commands.key_value.delete_expired import DeleteExpiredKeyValueCommand + + key = integers_to_uuid(user_id, database_id) + logger.debug( + "Acquiring lock to refresh OAuth2 token for user ID %d and database ID %d", + user_id, + database_id, + ) + try: + DeleteExpiredKeyValueCommand(resource=KeyValueResource.OAUTH2).run() + CreateKeyValueCommand( + resource=KeyValueResource.OAUTH2, + codec=PickleKeyValueCodec(), + key=key, + value=True, + expires_on=datetime.now() + LOCK_EXPIRATION, + ).run() + yield + except KeyValueCreateFailedError as ex: + raise CreateAuthLockFailedException("Error acquiring lock") from ex + finally: + DeleteKeyValueCommand(resource=KeyValueResource.OAUTH2, key=key).run() + logger.debug( + "Removed lock to refresh OAuth2 token for user ID %d and database ID %d", + user_id, + database_id, + ) + + +def refresh_oauth2_token( + database_id: int, + user_id: int, + db_engine_spec: type[BaseEngineSpec], + token: DatabaseUserOAuth2Tokens, +) -> str | None: + with AuthLock(user_id, database_id): token_response = db_engine_spec.get_oauth2_fresh_token(token.refresh_token) # store new access token; note that the refresh token might be revoked, in which # case there would be no access token in the response - if "access_token" in token_response: - token.access_token = token_response["access_token"] - token.access_token_expiration = datetime.now() + timedelta( - seconds=token_response["expires_in"] - ) - db.session.add(token) + if "access_token" not in token_response: + return None - return token.access_token + token.access_token = token_response["access_token"] + token.access_token_expiration = datetime.now() + timedelta( + seconds=token_response["expires_in"] + ) + db.session.add(token) - # since the access token is expired and there's no refresh token, delete the entry - db.session.delete(token) - - return None + return token.access_token diff --git a/tests/unit_tests/utils/oauth2_tests.py b/tests/unit_tests/utils/oauth2_tests.py index 6c859a538f04f..32aa971cface0 100644 --- a/tests/unit_tests/utils/oauth2_tests.py +++ b/tests/unit_tests/utils/oauth2_tests.py @@ -18,11 +18,16 @@ # pylint: disable=invalid-name, disallowed-name from datetime import datetime +from uuid import UUID +import pytest from freezegun import freeze_time from pytest_mock import MockerFixture -from superset.utils.oauth2 import get_oauth2_access_token +from superset.exceptions import CreateAuthLockFailedException +from superset.key_value.exceptions import KeyValueCreateFailedError +from superset.key_value.types import KeyValueResource +from superset.utils.oauth2 import AuthLock, get_oauth2_access_token, integers_to_uuid def test_get_oauth2_access_token_base_no_token(mocker: MockerFixture) -> None: @@ -93,3 +98,71 @@ def test_get_oauth2_access_token_base_no_refresh(mocker: MockerFixture) -> None: # check that token was deleted db.session.delete.assert_called_with(token) + + +def test_integers_to_uuid() -> None: + """ + Test `integers_to_uuid`. + """ + assert integers_to_uuid(1, 1) == UUID("4a426d86-eae0-53be-8f9a-8113ffc5a445") + assert integers_to_uuid(2, 1) == UUID("0a81e791-1685-5239-bc04-4cdd6aacc18d") + assert integers_to_uuid(1, 2) == UUID("83b19a49-b4f2-5ac5-9b52-5a63907dd160") + + +def test_AuthLock_happy_path(mocker: MockerFixture) -> None: + """ + Test successfully acquiring the global auth lock. + """ + CreateKeyValueCommand = mocker.patch( + "superset.commands.key_value.create.CreateKeyValueCommand" + ) + DeleteKeyValueCommand = mocker.patch( + "superset.commands.key_value.delete.DeleteKeyValueCommand" + ) + DeleteExpiredKeyValueCommand = mocker.patch( + "superset.commands.key_value.delete_expired.DeleteExpiredKeyValueCommand" + ) + PickleKeyValueCodec = mocker.patch("superset.utils.oauth2.PickleKeyValueCodec") + + with freeze_time("2024-01-01"): + with AuthLock(1, 1): + DeleteExpiredKeyValueCommand.assert_called_with( + resource=KeyValueResource.OAUTH2, + ) + CreateKeyValueCommand.assert_called_with( + resource=KeyValueResource.OAUTH2, + codec=PickleKeyValueCodec(), + key=integers_to_uuid(1, 1), + value=True, + expires_on=datetime(2024, 1, 1, 0, 0, 30), + ) + DeleteKeyValueCommand.assert_not_called() + + DeleteKeyValueCommand.assert_called_with( + resource=KeyValueResource.OAUTH2, + key=integers_to_uuid(1, 1), + ) + + +def test_AuthLock_no_lock(mocker: MockerFixture) -> None: + """ + Test unsuccessfully acquiring the global auth lock. + """ + mocker.patch( + "superset.commands.key_value.create.CreateKeyValueCommand", + side_effect=KeyValueCreateFailedError(), + ) + DeleteKeyValueCommand = mocker.patch( + "superset.commands.key_value.delete.DeleteKeyValueCommand" + ) + + with pytest.raises(CreateAuthLockFailedException) as excinfo: + with AuthLock(1, 1): + pass + assert str(excinfo.value) == "Error acquiring lock" + + # confirm that key was deleted + DeleteKeyValueCommand.assert_called_with( + resource=KeyValueResource.OAUTH2, + key=integers_to_uuid(1, 1), + )