From 2aae51954fe36183aab5114d3618ee1d63867cce Mon Sep 17 00:00:00 2001 From: David Black Date: Thu, 25 Sep 2025 17:14:18 +1000 Subject: [PATCH 1/5] Sem-ver: bugfix Remove the need and use of asynctest during testing Signed-off-by: David Black --- atlassian_jwt_auth/__init__.py | 15 +++++---------- .../tests/aiohttp/test_public_key_provider.py | 12 ++++-------- .../contrib/tests/aiohttp/test_verifier.py | 11 ++++------- test-requirements.txt | 1 - 4 files changed, 13 insertions(+), 26 deletions(-) diff --git a/atlassian_jwt_auth/__init__.py b/atlassian_jwt_auth/__init__.py index 560ff0a..a738e8c 100644 --- a/atlassian_jwt_auth/__init__.py +++ b/atlassian_jwt_auth/__init__.py @@ -1,15 +1,10 @@ from atlassian_jwt_auth.algorithms import get_permitted_algorithm_names # noqa - +from atlassian_jwt_auth.key import ( + HTTPSPublicKeyRetriever, # noqa + KeyIdentifier, # noqa +) from atlassian_jwt_auth.signer import ( # noqa create_signer, create_signer_from_file_private_key_repository, ) - -from atlassian_jwt_auth.key import ( # noqa - KeyIdentifier, - HTTPSPublicKeyRetriever, -) - -from atlassian_jwt_auth.verifier import ( # noqa - JWTAuthVerifier, -) +from atlassian_jwt_auth.verifier import JWTAuthVerifier # noqa diff --git a/atlassian_jwt_auth/contrib/tests/aiohttp/test_public_key_provider.py b/atlassian_jwt_auth/contrib/tests/aiohttp/test_public_key_provider.py index a980093..649513c 100644 --- a/atlassian_jwt_auth/contrib/tests/aiohttp/test_public_key_provider.py +++ b/atlassian_jwt_auth/contrib/tests/aiohttp/test_public_key_provider.py @@ -1,16 +1,12 @@ import os +from unittest import IsolatedAsyncioTestCase as TestCase from unittest import mock +from unittest.mock import AsyncMock as CoroutineMock +from unittest.mock import Mock import aiohttp from multidict import CIMultiDict -try: - from unittest import IsolatedAsyncioTestCase as TestCase - from unittest.mock import AsyncMock as CoroutineMock - from unittest.mock import Mock -except ImportError: - from asynctest import CoroutineMock, Mock, TestCase - from atlassian_jwt_auth.contrib.aiohttp import HTTPSPublicKeyRetriever from atlassian_jwt_auth.key import PEM_FILE_TYPE from atlassian_jwt_auth.tests import utils @@ -26,7 +22,7 @@ def set_headers(self, headers): def set_text(self, text): self._session.get.return_value.text.return_value = text - def _get_session(self): + def _get_session(self) -> Mock: session = Mock(spec=aiohttp.ClientSession) session.attach_mock(CoroutineMock(), "get") diff --git a/atlassian_jwt_auth/contrib/tests/aiohttp/test_verifier.py b/atlassian_jwt_auth/contrib/tests/aiohttp/test_verifier.py index 2ccfbf1..67587a8 100644 --- a/atlassian_jwt_auth/contrib/tests/aiohttp/test_verifier.py +++ b/atlassian_jwt_auth/contrib/tests/aiohttp/test_verifier.py @@ -1,12 +1,9 @@ import asyncio +from unittest import IsolatedAsyncioTestCase as TestCase +from unittest.mock import AsyncMock as CoroutineMock -try: - from unittest import IsolatedAsyncioTestCase as TestCase - from unittest.mock import AsyncMock as CoroutineMock -except ImportError: - from asynctest import CoroutineMock, TestCase - -from atlassian_jwt_auth.contrib.aiohttp import HTTPSPublicKeyRetriever, JWTAuthVerifier +from atlassian_jwt_auth.contrib.aiohttp import (HTTPSPublicKeyRetriever, + JWTAuthVerifier) from atlassian_jwt_auth.tests import test_verifier, utils diff --git a/test-requirements.txt b/test-requirements.txt index 6cad2af..ab2ed3a 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -3,4 +3,3 @@ flask>=2.0.3,<4.0.0 Django>=3.2.9,<5.0.0 atlassian-httptest==1.0.0 aiohttp==3.12.14 -asynctest==0.13.0 From 706d59cbf786993646beb62940710925290381d8 Mon Sep 17 00:00:00 2001 From: David Black Date: Thu, 25 Sep 2025 17:17:45 +1000 Subject: [PATCH 2/5] Sem-ver: api-break Add type hints to the code base Signed-off-by: David Black --- atlassian_jwt_auth/algorithms.py | 5 +- atlassian_jwt_auth/auth.py | 25 +++++-- .../contrib/aiohttp/__init__.py | 1 + atlassian_jwt_auth/contrib/aiohttp/auth.py | 13 +++- atlassian_jwt_auth/contrib/aiohttp/key.py | 31 +++++--- .../contrib/aiohttp/verifier.py | 11 ++- .../contrib/django/decorators.py | 14 +++- .../contrib/django/middleware.py | 15 ++-- .../contrib/flask_app/decorators.py | 8 ++- atlassian_jwt_auth/contrib/requests.py | 18 ++++- .../contrib/tests/aiohttp/test_auth.py | 8 ++- .../contrib/tests/aiohttp/test_verifier.py | 3 +- .../contrib/tests/test_requests.py | 3 +- atlassian_jwt_auth/contrib/tests/utils.py | 18 +++-- atlassian_jwt_auth/exceptions.py | 7 +- atlassian_jwt_auth/frameworks/common/asap.py | 24 +++++-- .../frameworks/common/backend.py | 31 +++++--- .../frameworks/common/decorators.py | 61 +++++++++------- .../frameworks/common/tests/test_utils.py | 2 +- atlassian_jwt_auth/frameworks/common/utils.py | 13 ++-- .../frameworks/django/__init__.py | 4 +- .../frameworks/django/backend.py | 22 ++++-- .../frameworks/django/decorators.py | 23 ++++-- .../frameworks/django/middleware.py | 13 ++-- .../frameworks/flask/backend.py | 32 ++++++--- .../frameworks/flask/decorators.py | 12 +++- atlassian_jwt_auth/frameworks/wsgi/backend.py | 22 ++++-- .../frameworks/wsgi/middleware.py | 5 +- atlassian_jwt_auth/key.py | 71 +++++++++++-------- atlassian_jwt_auth/signer.py | 66 +++++++++++------ atlassian_jwt_auth/tests/test_key.py | 4 +- atlassian_jwt_auth/tests/utils.py | 47 +++++++++--- atlassian_jwt_auth/verifier.py | 46 +++++++++--- 33 files changed, 479 insertions(+), 199 deletions(-) diff --git a/atlassian_jwt_auth/algorithms.py b/atlassian_jwt_auth/algorithms.py index b208f69..df30ffc 100644 --- a/atlassian_jwt_auth/algorithms.py +++ b/atlassian_jwt_auth/algorithms.py @@ -1,4 +1,7 @@ -def get_permitted_algorithm_names(): +from typing import List + + +def get_permitted_algorithm_names() -> List[str]: """returns permitted algorithm names.""" return [ "RS256", diff --git a/atlassian_jwt_auth/auth.py b/atlassian_jwt_auth/auth.py index edab9b5..a03e6e0 100644 --- a/atlassian_jwt_auth/auth.py +++ b/atlassian_jwt_auth/auth.py @@ -1,25 +1,42 @@ from __future__ import absolute_import +from typing import Any, Iterable, Union + import atlassian_jwt_auth +from atlassian_jwt_auth import KeyIdentifier +from atlassian_jwt_auth.signer import JWTAuthSigner class BaseJWTAuth(object): """Adds a JWT bearer token to the request per the ASAP specification""" - def __init__(self, signer, audience, *args, **kwargs): + def __init__( + self, + signer: JWTAuthSigner, + audience: Union[str, Iterable[str]], + *args: Any, + **kwargs: Any, + ) -> None: self._audience = audience self._signer = signer self._additional_claims = kwargs.get("additional_claims", {}) @classmethod - def create(cls, issuer, key_identifier, private_key_pem, audience, **kwargs): + def create( + cls, + issuer: str, + key_identifier: Union[KeyIdentifier, str], + private_key_pem: str, + audience: Union[str, Iterable[str]], + **kwargs: Any, + ) -> "BaseJWTAuth": """Instantiate a JWTAuth while creating the signer inline""" signer = atlassian_jwt_auth.create_signer( issuer, key_identifier, private_key_pem, **kwargs ) return cls(signer, audience) - def _get_header_value(self): + def _get_header_value(self) -> bytes: return b"Bearer " + self._signer.generate_jwt( self._audience, additional_claims=self._additional_claims - ) + ).encode("utf-8") diff --git a/atlassian_jwt_auth/contrib/aiohttp/__init__.py b/atlassian_jwt_auth/contrib/aiohttp/__init__.py index f02cf8f..1de9136 100644 --- a/atlassian_jwt_auth/contrib/aiohttp/__init__.py +++ b/atlassian_jwt_auth/contrib/aiohttp/__init__.py @@ -5,6 +5,7 @@ if sys.version_info >= (3, 5): try: import aiohttp # noqa + from .auth import JWTAuth # noqa from .key import HTTPSPublicKeyRetriever # noqa from .verifier import JWTAuthVerifier # noqa diff --git a/atlassian_jwt_auth/contrib/aiohttp/auth.py b/atlassian_jwt_auth/contrib/aiohttp/auth.py index 3e76abd..a1aa7aa 100644 --- a/atlassian_jwt_auth/contrib/aiohttp/auth.py +++ b/atlassian_jwt_auth/contrib/aiohttp/auth.py @@ -1,5 +1,8 @@ +from typing import Any, Iterable, Union + from aiohttp import BasicAuth +from atlassian_jwt_auth import KeyIdentifier from atlassian_jwt_auth.auth import BaseJWTAuth @@ -12,10 +15,16 @@ class JWTAuth(BaseJWTAuth, BasicAuth): def __new__(cls, *args, **kwargs): return super().__new__(cls, "") - def encode(self): + def encode(self) -> str: return self._get_header_value().decode(self.encoding) -def create_jwt_auth(issuer, key_identifier, private_key_pem, audience, **kwargs): +def create_jwt_auth( + issuer: str, + key_identifier: Union[KeyIdentifier, str], + private_key_pem: str, + audience: Union[str, Iterable[str]], + **kwargs: Any, +) -> BaseJWTAuth: """Instantiate a JWTAuth while creating the signer inline""" return JWTAuth.create(issuer, key_identifier, private_key_pem, audience, **kwargs) diff --git a/atlassian_jwt_auth/contrib/aiohttp/key.py b/atlassian_jwt_auth/contrib/aiohttp/key.py index cd1357c..fc96ef8 100644 --- a/atlassian_jwt_auth/contrib/aiohttp/key.py +++ b/atlassian_jwt_auth/contrib/aiohttp/key.py @@ -1,15 +1,13 @@ import asyncio import urllib.parse +from asyncio import AbstractEventLoop +from typing import Any, Awaitable, Dict, Optional import aiohttp from atlassian_jwt_auth.exceptions import PublicKeyRetrieverException -from atlassian_jwt_auth.key import ( - PEM_FILE_TYPE, -) -from atlassian_jwt_auth.key import ( - HTTPSPublicKeyRetriever as _HTTPSPublicKeyRetriever, -) +from atlassian_jwt_auth.key import PEM_FILE_TYPE +from atlassian_jwt_auth.key import HTTPSPublicKeyRetriever as _HTTPSPublicKeyRetriever class HTTPSPublicKeyRetriever(_HTTPSPublicKeyRetriever): @@ -17,20 +15,24 @@ class HTTPSPublicKeyRetriever(_HTTPSPublicKeyRetriever): _class_session = None - def __init__(self, base_url, *, loop=None): + def __init__( + self, base_url: str, *, loop: Optional[AbstractEventLoop] = None + ) -> None: if loop is None: loop = asyncio.get_event_loop() self.loop = loop super().__init__(base_url) - def _get_session(self): + def _get_session(self) -> aiohttp.ClientSession: # type: ignore[override] if HTTPSPublicKeyRetriever._class_session is None: HTTPSPublicKeyRetriever._class_session = aiohttp.ClientSession( loop=self.loop ) return HTTPSPublicKeyRetriever._class_session - def _convert_proxies_to_proxy_arg(self, url, requests_kwargs): + def _convert_proxies_to_proxy_arg( + self, url: str, requests_kwargs: Dict[Any, Any] + ) -> Dict[str, Any]: """returns a modified requests_kwargs dict that contains proxy information in a form that aiohttp accepts (it wants proxy information instead of a dict of proxies). @@ -43,11 +45,18 @@ def _convert_proxies_to_proxy_arg(self, url, requests_kwargs): requests_kwargs["proxy"] = proxy return requests_kwargs - async def _retrieve(self, url, requests_kwargs): + async def _retrieve( + self, url: str, requests_kwargs: Dict[Any, Any] + ) -> Awaitable[str]: requests_kwargs = self._convert_proxies_to_proxy_arg(url, requests_kwargs) try: resp = await self._session.get( - url, headers={"accept": PEM_FILE_TYPE}, **requests_kwargs + url, + headers={ + "accept": # type: ignore[misc] + PEM_FILE_TYPE + }, + **requests_kwargs, ) resp.raise_for_status() self._check_content_type(url, resp.headers["content-type"]) diff --git a/atlassian_jwt_auth/contrib/aiohttp/verifier.py b/atlassian_jwt_auth/contrib/aiohttp/verifier.py index 1877d3d..e7aad9e 100644 --- a/atlassian_jwt_auth/contrib/aiohttp/verifier.py +++ b/atlassian_jwt_auth/contrib/aiohttp/verifier.py @@ -1,4 +1,5 @@ import asyncio +from typing import Any, Dict, Iterable, Union import jwt @@ -6,8 +7,14 @@ from atlassian_jwt_auth.verifier import JWTAuthVerifier as _JWTAuthVerifier -class JWTAuthVerifier(_JWTAuthVerifier): - async def verify_jwt(self, a_jwt, audience, leeway=0, **requests_kwargs): +class JWTAuthVerifier(_JWTAuthVerifier): # type: ignore[override] + async def verify_jwt( # type: ignore[override] + self, + a_jwt: str, + audience: Union[str, Iterable[str]], + leeway: int = 0, + **requests_kwargs: Any, + ) -> Dict[Any, Any]: """Verify if the token is correct Returns: diff --git a/atlassian_jwt_auth/contrib/django/decorators.py b/atlassian_jwt_auth/contrib/django/decorators.py index ac29f45..2438fb6 100644 --- a/atlassian_jwt_auth/contrib/django/decorators.py +++ b/atlassian_jwt_auth/contrib/django/decorators.py @@ -1,11 +1,17 @@ +from collections.abc import Callable from functools import wraps +from typing import Iterable, Optional from django.http.response import HttpResponse from atlassian_jwt_auth.frameworks.django.decorators import with_asap -def validate_asap(issuers=None, subjects=None, required=True): +def validate_asap( + issuers: Optional[Iterable[str]] = None, + subjects: Optional[Iterable[str]] = None, + required: bool = True, +) -> Callable: """Decorator to allow endpoint-specific ASAP authorization, assuming ASAP authentication has already occurred. @@ -45,7 +51,11 @@ def validate_asap_wrapper(request, *args, **kwargs): return validate_asap_decorator -def requires_asap(issuers=None, subject_should_match_issuer=None, func=None): +def requires_asap( + issuers: Optional[Iterable[str]] = None, + subject_should_match_issuer: Optional[bool] = None, + func: Optional[Callable] = None, +) -> Callable: """Decorator for Django endpoints to require ASAP :param list issuers: *required The 'iss' claims that this endpoint is from. diff --git a/atlassian_jwt_auth/contrib/django/middleware.py b/atlassian_jwt_auth/contrib/django/middleware.py index 22cea7a..cf22e9e 100644 --- a/atlassian_jwt_auth/contrib/django/middleware.py +++ b/atlassian_jwt_auth/contrib/django/middleware.py @@ -1,3 +1,5 @@ +from typing import Any, Callable, Optional + from django.conf import settings from django.utils.deprecation import MiddlewareMixin @@ -10,7 +12,7 @@ class ProxiedAsapMiddleware(OldStyleASAPMiddleware, MiddlewareMixin): This must come before any authentication middleware.""" - def __init__(self, get_response=None): + def __init__(self, get_response: Optional[Any] = None) -> None: super(ProxiedAsapMiddleware, self).__init__() self.get_response = get_response @@ -25,7 +27,7 @@ def __init__(self, get_response=None): settings, "ASAP_PROXIED_AUTHORIZATION_HEADER", "HTTP_X_ASAP_AUTHORIZATION" ) - def process_request(self, request): + def process_request(self, request) -> Optional[str]: error_response = super(ProxiedAsapMiddleware, self).process_request(request) if error_response: @@ -33,7 +35,7 @@ def process_request(self, request): forwarded_for = request.META.pop(self.xfwd, None) if forwarded_for is None: - return + return None request.asap_forwarded = True request.META["HTTP_X_FORWARDED_FOR"] = forwarded_for @@ -46,10 +48,13 @@ def process_request(self, request): request.META["HTTP_AUTHORIZATION"] = orig_auth if asap_auth is not None: request.META[self.xauth] = asap_auth + return None - def process_view(self, request, view_func, view_args, view_kwargs): + def process_view( + self, request: Any, view_func: Callable, view_args: Any, view_kwargs: Any + ) -> None: if not hasattr(request, "asap_forwarded"): - return + return None # swap headers back into place asap_auth = request.META.pop(self.xauth, None) diff --git a/atlassian_jwt_auth/contrib/flask_app/decorators.py b/atlassian_jwt_auth/contrib/flask_app/decorators.py index 5ff6985..5485085 100644 --- a/atlassian_jwt_auth/contrib/flask_app/decorators.py +++ b/atlassian_jwt_auth/contrib/flask_app/decorators.py @@ -1,7 +1,13 @@ +from typing import Callable, Iterable, Optional + from atlassian_jwt_auth.frameworks.flask.decorators import with_asap -def requires_asap(f, issuers=None, subject_should_match_issuer=None): +def requires_asap( + f: Callable, + issuers: Optional[Iterable[str]] = None, + subject_should_match_issuer: Optional[bool] = None, +) -> Callable: """ Wrapper for Flask endpoints to make them require asap authentication to access. diff --git a/atlassian_jwt_auth/contrib/requests.py b/atlassian_jwt_auth/contrib/requests.py index 84cae64..9a2bba1 100644 --- a/atlassian_jwt_auth/contrib/requests.py +++ b/atlassian_jwt_auth/contrib/requests.py @@ -1,18 +1,30 @@ from __future__ import absolute_import +from typing import Any, Iterable, Union + +import requests from requests.auth import AuthBase +from atlassian_jwt_auth import KeyIdentifier from atlassian_jwt_auth.auth import BaseJWTAuth class JWTAuth(AuthBase, BaseJWTAuth): """Adds a JWT bearer token to the request per the ASAP specification""" - def __call__(self, r): - r.headers["Authorization"] = self._get_header_value() + def __call__( + self, r: requests.models.PreparedRequest + ) -> requests.models.PreparedRequest: + r.headers["Authorization"] = self._get_header_value() # type: ignore[assignment] return r -def create_jwt_auth(issuer, key_identifier, private_key_pem, audience, **kwargs): +def create_jwt_auth( + issuer: str, + key_identifier: Union[KeyIdentifier, str], + private_key_pem: str, + audience: Union[str, Iterable[str]], + **kwargs: Any, +) -> BaseJWTAuth: """Instantiate a JWTAuth while creating the signer inline""" return JWTAuth.create(issuer, key_identifier, private_key_pem, audience, **kwargs) diff --git a/atlassian_jwt_auth/contrib/tests/aiohttp/test_auth.py b/atlassian_jwt_auth/contrib/tests/aiohttp/test_auth.py index ca1732c..421c9a9 100644 --- a/atlassian_jwt_auth/contrib/tests/aiohttp/test_auth.py +++ b/atlassian_jwt_auth/contrib/tests/aiohttp/test_auth.py @@ -1,5 +1,7 @@ import unittest +from typing import Any, Type +from atlassian_jwt_auth.auth import BaseJWTAuth from atlassian_jwt_auth.contrib.aiohttp.auth import JWTAuth, create_jwt_auth from atlassian_jwt_auth.contrib.tests import test_requests from atlassian_jwt_auth.tests import utils @@ -8,12 +10,12 @@ class BaseAuthTest(test_requests.BaseRequestsTest): """tests for the contrib.aiohttp.JWTAuth class""" - auth_cls = JWTAuth + auth_cls: Type[JWTAuth] = JWTAuth - def _get_auth_header(self, auth): + def _get_auth_header(self, auth) -> bytes: return auth.encode().encode("latin1") - def create_jwt_auth(self, *args, **kwargs): + def create_jwt_auth(self, *args: Any, **kwargs: Any) -> BaseJWTAuth: return create_jwt_auth(*args, **kwargs) diff --git a/atlassian_jwt_auth/contrib/tests/aiohttp/test_verifier.py b/atlassian_jwt_auth/contrib/tests/aiohttp/test_verifier.py index 67587a8..798b0b0 100644 --- a/atlassian_jwt_auth/contrib/tests/aiohttp/test_verifier.py +++ b/atlassian_jwt_auth/contrib/tests/aiohttp/test_verifier.py @@ -2,8 +2,7 @@ from unittest import IsolatedAsyncioTestCase as TestCase from unittest.mock import AsyncMock as CoroutineMock -from atlassian_jwt_auth.contrib.aiohttp import (HTTPSPublicKeyRetriever, - JWTAuthVerifier) +from atlassian_jwt_auth.contrib.aiohttp import HTTPSPublicKeyRetriever, JWTAuthVerifier from atlassian_jwt_auth.tests import test_verifier, utils diff --git a/atlassian_jwt_auth/contrib/tests/test_requests.py b/atlassian_jwt_auth/contrib/tests/test_requests.py index b74c614..7d23fd9 100644 --- a/atlassian_jwt_auth/contrib/tests/test_requests.py +++ b/atlassian_jwt_auth/contrib/tests/test_requests.py @@ -1,5 +1,6 @@ import unittest from datetime import timedelta +from typing import Any import jwt from requests import Request @@ -12,7 +13,7 @@ class BaseRequestsTest(object): """tests for the contrib.requests.JWTAuth class""" - auth_cls = JWTAuth + auth_cls: Any = JWTAuth def setUp(self): self._private_key_pem = self.get_new_private_key_in_pem_format() diff --git a/atlassian_jwt_auth/contrib/tests/utils.py b/atlassian_jwt_auth/contrib/tests/utils.py index cb57ef9..f7d5113 100644 --- a/atlassian_jwt_auth/contrib/tests/utils.py +++ b/atlassian_jwt_auth/contrib/tests/utils.py @@ -1,19 +1,23 @@ +from typing import Any, Dict, Type + import atlassian_jwt_auth +from atlassian_jwt_auth import JWTAuthVerifier +from atlassian_jwt_auth.key import BasePublicKeyRetriever -def get_static_retriever_class(keys): - class StaticPublicKeyRetriever(object): - """Retrieves a key from a static list of public keys +def get_static_retriever_class(keys: Dict[str, Any]) -> Type[BasePublicKeyRetriever]: + class StaticPublicKeyRetriever(BasePublicKeyRetriever): + """Retrieves a key from a static dict of public keys (for use in tests only)""" - def __init__(self, *args, **kwargs): - self.keys = keys + def __init__(self, *args: Any, **kwargs: Any) -> None: + self.keys: Dict[str, Any] = keys - def retrieve(self, key_identifier, **requests_kwargs): + def retrieve(self, key_identifier, **requests_kwargs) -> Any: return self.keys[key_identifier.key_id] return StaticPublicKeyRetriever -def static_verifier(keys): +def static_verifier(keys: Dict[str, Any]) -> JWTAuthVerifier: return atlassian_jwt_auth.JWTAuthVerifier(get_static_retriever_class(keys)()) diff --git a/atlassian_jwt_auth/exceptions.py b/atlassian_jwt_auth/exceptions.py index 5fdcde8..d831e0f 100644 --- a/atlassian_jwt_auth/exceptions.py +++ b/atlassian_jwt_auth/exceptions.py @@ -1,3 +1,6 @@ +from typing import Any + + class _WrappedException(object): """Allow wrapping exceptions in a new class while preserving the original as an attribute. @@ -7,7 +10,7 @@ class _WrappedException(object): should be sufficient for most use cases. """ - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: wrapped_args = [arg for arg in args] if args: @@ -25,7 +28,7 @@ class _WithStatus(object): details about the HTTP client library. """ - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: status_code = kwargs.pop("status_code", None) super(_WithStatus, self).__init__(*args, **kwargs) self.status_code = status_code diff --git a/atlassian_jwt_auth/frameworks/common/asap.py b/atlassian_jwt_auth/frameworks/common/asap.py index 4af1bae..622ffe5 100644 --- a/atlassian_jwt_auth/frameworks/common/asap.py +++ b/atlassian_jwt_auth/frameworks/common/asap.py @@ -1,16 +1,25 @@ import logging +from typing import Any, Dict, Iterable, Optional from jwt.exceptions import InvalidIssuerError, InvalidTokenError +from atlassian_jwt_auth import JWTAuthVerifier from atlassian_jwt_auth.exceptions import ( JtiUniquenessException, NoTokenProvidedError, PublicKeyRetrieverException, SubjectDoesNotMatchIssuerException, ) +from atlassian_jwt_auth.frameworks.common.backend import Backend +from atlassian_jwt_auth.frameworks.common.utils import SettingsDict -def _process_asap_token(request, backend, settings, verifier=None): +def _process_asap_token( + request: Any, + backend: Backend, + settings: SettingsDict, + verifier: Optional[JWTAuthVerifier] = None, +) -> Optional[str]: """Verifies an ASAP token, validates the claims, and returns an error response""" logger = logging.getLogger("asap") @@ -21,14 +30,14 @@ def _process_asap_token(request, backend, settings, verifier=None): and not settings.ASAP_REQUIRED and (settings.ASAP_REQUIRED is not None) ): - return + return None try: if token is None: raise NoTokenProvidedError if verifier is None: verifier = backend.get_verifier(settings=settings) asap_claims = verifier.verify_jwt( - token, + token.decode("utf-8") if isinstance(token, bytes) else token, settings.ASAP_VALID_AUDIENCE, leeway=settings.ASAP_VALID_LEEWAY, ) @@ -47,7 +56,7 @@ def _process_asap_token(request, backend, settings, verifier=None): # will return 403 for a missing file to avoid leaking # information. raise - logger.warning("Could not retrieve the matching public key") + logger.warning("Could not retrieve the matching public key") error_response = backend.get_401_response( "Unauthorized: Key not found", request=request ) @@ -77,10 +86,13 @@ def _process_asap_token(request, backend, settings, verifier=None): if error_response is not None and settings.ASAP_REQUIRED: return error_response + return None -def _verify_issuers(asap_claims, issuers=None): +def _verify_issuers( + asap_claims: Dict[Any, Any], issuers: Optional[Iterable[str]] = None +) -> None: """Verify that the issuer in the claims is valid and is expected.""" claim_iss = asap_claims.get("iss") - if issuers and claim_iss not in issuers: + if issuers is not None and claim_iss is not None and claim_iss not in issuers: raise InvalidIssuerError diff --git a/atlassian_jwt_auth/frameworks/common/backend.py b/atlassian_jwt_auth/frameworks/common/backend.py index 3178bd4..1f8359d 100644 --- a/atlassian_jwt_auth/frameworks/common/backend.py +++ b/atlassian_jwt_auth/frameworks/common/backend.py @@ -1,5 +1,6 @@ from abc import ABCMeta, abstractmethod, abstractproperty from functools import lru_cache +from typing import Any, Dict, Optional, Union from atlassian_jwt_auth import HTTPSPublicKeyRetriever, JWTAuthVerifier @@ -7,7 +8,7 @@ @lru_cache(maxsize=20) -def _get_verifier(settings): +def _get_verifier(settings) -> JWTAuthVerifier: """This has been extracted out of Backend to avoid possible memory leaks via retained instance references. """ @@ -58,26 +59,36 @@ class Backend: } @abstractmethod - def get_authorization_header(self, request=None): + def get_authorization_header(self, request: Optional[Any] = None) -> bytes: pass @abstractmethod - def get_401_response(self, data=None, headers=None, request=None): + def get_401_response( + self, + data: Optional[Any] = None, + headers: Optional[Any] = None, + request: Optional[Any] = None, + ) -> Any: pass @abstractmethod - def get_403_response(self, data=None, headers=None, request=None): + def get_403_response( + self, + data: Optional[Any] = None, + headers: Optional[Any] = None, + request: Optional[Any] = None, + ) -> Any: pass @abstractmethod - def set_asap_claims_for_request(self, request, claims): + def set_asap_claims_for_request(self, request: Any, claims: Any) -> None: pass @abstractproperty - def settings(self): + def settings(self) -> SettingsDict: return SettingsDict(self.default_settings) - def get_asap_token(self, request): + def get_asap_token(self, request: Any) -> Optional[bytes]: auth_header = self.get_authorization_header(request) if auth_header is None: @@ -97,16 +108,16 @@ def get_asap_token(self, request): return auth_values[1] - def get_verifier(self, settings=None): + def get_verifier(self, settings: Optional[SettingsDict] = None) -> JWTAuthVerifier: """Returns a verifier for ASAP JWT tokens""" if settings is None: settings = self.settings return self._get_verifier(settings) - def _get_verifier(self, settings): + def _get_verifier(self, settings: SettingsDict) -> JWTAuthVerifier: return _get_verifier(settings) - def _process_settings(self, settings): + def _process_settings(self, settings: Union[SettingsDict, Dict]) -> SettingsDict: valid_issuers = settings.get("ASAP_VALID_ISSUERS") if valid_issuers: settings["ASAP_VALID_ISSUERS"] = set(valid_issuers) diff --git a/atlassian_jwt_auth/frameworks/common/decorators.py b/atlassian_jwt_auth/frameworks/common/decorators.py index ef1d362..0ea7967 100644 --- a/atlassian_jwt_auth/frameworks/common/decorators.py +++ b/atlassian_jwt_auth/frameworks/common/decorators.py @@ -1,24 +1,26 @@ from functools import wraps +from typing import Any, Callable, Dict, Iterable, Optional from jwt.exceptions import InvalidIssuerError, InvalidTokenError from .asap import _process_asap_token, _verify_issuers +from .backend import Backend from .utils import SettingsDict def _with_asap( - func=None, - backend=None, - issuers=None, - required=True, - subject_should_match_issuer=None, -): + func: Optional[Callable] = None, + backend: Optional[Backend] = None, + issuers: Optional[Iterable[str]] = None, + required: bool = True, + subject_should_match_issuer: Optional[bool] = None, +) -> Callable: if backend is None: raise ValueError("Invalid value for backend. Use a subclass instead.") - def with_asap_decorator(func): + def with_asap_decorator(func: Callable): @wraps(func) - def with_asap_wrapper(*args, **kwargs): + def with_asap_wrapper(*args: Any, **kwargs: Any) -> Optional[str]: settings = _update_settings_from_kwargs( backend.settings, issuers=issuers, @@ -46,19 +48,21 @@ def with_asap_wrapper(*args, **kwargs): def _restrict_asap( - func=None, - backend=None, - issuers=None, - required=True, - subject_should_match_issuer=None, + func: Optional[Callable] = None, + backend: Optional[Backend] = None, + issuers: Optional[Iterable[str]] = None, + required: bool = True, + subject_should_match_issuer: Optional[bool] = None, ): """Decorator to allow endpoint-specific ASAP authorization, assuming ASAP authentication has already occurred. """ - def restrict_asap_decorator(func): + def restrict_asap_decorator(func: Callable) -> Optional[Any]: @wraps(func) - def restrict_asap_wrapper(request, *args, **kwargs): + def restrict_asap_wrapper(request, *args, **kwargs) -> Any: + if backend is None: + raise ValueError("Backend cannot be None") settings = _update_settings_from_kwargs( backend.settings, issuers=issuers, @@ -69,18 +73,22 @@ def restrict_asap_wrapper(request, *args, **kwargs): error_response = None if required and not asap_claims: - return backend.get_401_response("Unauthorized", request=request) + if backend is not None: + return backend.get_401_response("Unauthorized", request=request) try: - _verify_issuers(asap_claims, settings.ASAP_VALID_ISSUERS) + if asap_claims is not None: + _verify_issuers(asap_claims, settings.ASAP_VALID_ISSUERS) except InvalidIssuerError: - error_response = backend.get_403_response( - "Forbidden: Invalid token issuer", request=request - ) + if backend is not None: + error_response = backend.get_403_response( + "Forbidden: Invalid token issuer", request=request + ) except InvalidTokenError: - error_response = backend.get_401_response( - "Unauthorized: Invalid token", request=request - ) + if backend is not None: + error_response = backend.get_401_response( + "Unauthorized: Invalid token", request=request + ) if error_response and required: return error_response @@ -96,8 +104,11 @@ def restrict_asap_wrapper(request, *args, **kwargs): def _update_settings_from_kwargs( - settings, issuers=None, required=True, subject_should_match_issuer=None -): + settings: Dict[Any, Any], + issuers: Optional[Iterable] = None, + required: bool = True, + subject_should_match_issuer: Optional[bool] = None, +) -> SettingsDict: settings = settings.copy() if issuers is not None: diff --git a/atlassian_jwt_auth/frameworks/common/tests/test_utils.py b/atlassian_jwt_auth/frameworks/common/tests/test_utils.py index 72ef7fd..e0b3415 100644 --- a/atlassian_jwt_auth/frameworks/common/tests/test_utils.py +++ b/atlassian_jwt_auth/frameworks/common/tests/test_utils.py @@ -6,7 +6,7 @@ class SettingsDictTest(unittest.TestCase): """Tests for the SettingsDict class.""" - def test_hash(self): + def test_hash(self) -> None: """Test that SettingsDict instances can be hashed.""" dictionary_one = {"a": "b", "3": set([1]), "f": None} dictionary_two = {"a": "b", "3": set([1]), "f": None} diff --git a/atlassian_jwt_auth/frameworks/common/utils.py b/atlassian_jwt_auth/frameworks/common/utils.py index d2ef74f..81ff4ba 100644 --- a/atlassian_jwt_auth/frameworks/common/utils.py +++ b/atlassian_jwt_auth/frameworks/common/utils.py @@ -1,14 +1,17 @@ +from typing import Any + + class SettingsDict(dict): - def __getattr__(self, name): + def __getattr__(self, name: str) -> Any: if name not in self: raise AttributeError return self[name] - def __setitem__(self, key, value): + def __setitem__(self, key: Any, value: Any) -> None: raise AttributeError("SettingsDict properties are immutable") - def _hash_key(self): + def _hash_key(self) -> frozenset[Any]: keys_and_values = [] for key, value in self.items(): if isinstance(value, set): @@ -16,8 +19,8 @@ def _hash_key(self): keys_and_values.append("%s %s" % (key, hash(value))) return frozenset(keys_and_values) - def __hash__(self): + def __hash__(self) -> int: # type: ignore[override] return hash(self._hash_key()) - def __eq__(self, other): + def __eq__(self, other) -> bool: return hash(self) == hash(other) diff --git a/atlassian_jwt_auth/frameworks/django/__init__.py b/atlassian_jwt_auth/frameworks/django/__init__.py index 1c5af87..405ea5d 100644 --- a/atlassian_jwt_auth/frameworks/django/__init__.py +++ b/atlassian_jwt_auth/frameworks/django/__init__.py @@ -1,2 +1,2 @@ -from .decorators import with_asap, restrict_asap # noqa -from .middleware import asap_middleware, OldStyleASAPMiddleware # noqa +from .decorators import restrict_asap, with_asap # noqa +from .middleware import OldStyleASAPMiddleware, asap_middleware # noqa diff --git a/atlassian_jwt_auth/frameworks/django/backend.py b/atlassian_jwt_auth/frameworks/django/backend.py index de0b4c6..344deeb 100644 --- a/atlassian_jwt_auth/frameworks/django/backend.py +++ b/atlassian_jwt_auth/frameworks/django/backend.py @@ -1,17 +1,24 @@ +from typing import Any, Optional + from django.conf import settings as django_settings -from django.http import HttpResponse, HttpResponseForbidden +from django.http import HttpRequest, HttpResponse, HttpResponseForbidden from ..common.backend import Backend class DjangoBackend(Backend): - def get_authorization_header(self, request=None): + def get_authorization_header(self, request: Optional[HttpRequest] = None) -> bytes: if request is None: raise ValueError("No request available") return request.META.get("HTTP_AUTHORIZATION", b"") - def get_401_response(self, data=None, headers=None, request=None): + def get_401_response( + self, + data: Optional[Any] = None, + headers: Optional[Any] = None, + request: Optional[HttpRequest] = None, + ) -> HttpResponse: if headers is None: headers = {} @@ -23,7 +30,12 @@ def get_401_response(self, data=None, headers=None, request=None): return response - def get_403_response(self, data=None, headers=None, request=None): + def get_403_response( + self, + data: Optional[Any] = None, + headers: Optional[Any] = None, + request: Optional[HttpRequest] = None, + ) -> HttpResponse: if headers is None: headers = {} @@ -33,7 +45,7 @@ def get_403_response(self, data=None, headers=None, request=None): return response - def set_asap_claims_for_request(self, request, claims): + def set_asap_claims_for_request(self, request: HttpRequest, claims: Any) -> None: request.asap_claims = claims @property diff --git a/atlassian_jwt_auth/frameworks/django/decorators.py b/atlassian_jwt_auth/frameworks/django/decorators.py index 852c284..61aa954 100644 --- a/atlassian_jwt_auth/frameworks/django/decorators.py +++ b/atlassian_jwt_auth/frameworks/django/decorators.py @@ -1,8 +1,16 @@ +from typing import Callable, Iterable, Optional + +from ..common.backend import Backend from ..common.decorators import _restrict_asap, _with_asap from .backend import DjangoBackend -def with_asap(func=None, issuers=None, required=None, subject_should_match_issuer=None): +def with_asap( + func: Optional[Callable] = None, + issuers: Optional[Iterable[str]] = None, + required: bool = True, + subject_should_match_issuer: Optional[bool] = None, +) -> Callable: """Decorator to allow endpoint-specific ASAP authentication. If authentication fails, a 401 or 403 response will be returned. Otherwise, @@ -25,12 +33,12 @@ def with_asap(func=None, issuers=None, required=None, subject_should_match_issue def restrict_asap( - func=None, - backend=None, - issuers=None, - required=True, - subject_should_match_issuer=None, -): + func: Optional[Callable] = None, + backend: Optional[Backend] = None, + issuers: Optional[Iterable[str]] = None, + required: bool = True, + subject_should_match_issuer: Optional[bool] = None, +) -> Callable: """Decorator to allow endpoint-specific ASAP authorization policies. This decorator assumes that request.asap_claims has previously been set by @@ -48,6 +56,7 @@ def restrict_asap( must match the issuer for a token to be considered valid. """ + issuers = issuers if issuers is not None else [] return _restrict_asap( func, DjangoBackend(), issuers, required, subject_should_match_issuer=None ) diff --git a/atlassian_jwt_auth/frameworks/django/middleware.py b/atlassian_jwt_auth/frameworks/django/middleware.py index 319deaf..c87fa08 100644 --- a/atlassian_jwt_auth/frameworks/django/middleware.py +++ b/atlassian_jwt_auth/frameworks/django/middleware.py @@ -1,14 +1,18 @@ +from typing import Any, Callable, Optional + +from django.http import HttpRequest + from ..common.asap import _process_asap_token from .backend import DjangoBackend -def asap_middleware(get_response): +def asap_middleware(get_response: Any) -> Callable: """Middleware to enable ASAP for all requests""" backend = DjangoBackend() settings = backend.settings _verifier = backend.get_verifier(settings=settings) - def middleware(request): + def middleware(request: HttpRequest) -> Any: error_response = _process_asap_token( request, backend, settings, verifier=_verifier ) @@ -24,14 +28,15 @@ class OldStyleASAPMiddleware(object): """Middleware to enable ASAP for all requests (for legacy applications using MIDDLEWARE_CLASSES)""" - def __init__(self): + def __init__(self) -> None: self.backend = DjangoBackend() self.settings = self.backend.settings self._verifier = self.backend.get_verifier(settings=self.settings) - def process_request(self, request): + def process_request(self, request: HttpRequest) -> Optional[str]: error_response = _process_asap_token( request, self.backend, self.settings, verifier=self._verifier ) if error_response is not None: return error_response + return None diff --git a/atlassian_jwt_auth/frameworks/flask/backend.py b/atlassian_jwt_auth/frameworks/flask/backend.py index 9cfadf0..0519aa7 100644 --- a/atlassian_jwt_auth/frameworks/flask/backend.py +++ b/atlassian_jwt_auth/frameworks/flask/backend.py @@ -1,17 +1,28 @@ -from flask import Response, current_app, g +from typing import Any, Optional + +from flask import Request, Response, current_app, g from flask import request as current_req from ..common.backend import Backend +from ..common.utils import SettingsDict class FlaskBackend(Backend): - def get_authorization_header(self, request=None): + def get_authorization_header(self, request: Optional[Request] = None) -> bytes: if request is None: request = current_req - return request.headers.get("AUTHORIZATION", "") - - def get_401_response(self, data=None, headers=None, request=None): + auth_header = request.headers.get("AUTHORIZATION", "") + return ( + auth_header.encode("utf-8") if isinstance(auth_header, str) else auth_header + ) + + def get_401_response( + self, + data: Optional[Any] = None, + headers: Optional[Any] = None, + request: Optional[Request] = None, + ) -> Response: if headers is None: headers = {} @@ -19,14 +30,19 @@ def get_401_response(self, data=None, headers=None, request=None): return Response(data, status=401, headers=headers) - def get_403_response(self, data=None, headers=None, request=None): + def get_403_response( + self, + data: Optional[Any] = None, + headers: Optional[Any] = None, + request: Optional[Request] = None, + ) -> Response: return Response(data, status=403, headers=headers) - def set_asap_claims_for_request(self, request, claims): + def set_asap_claims_for_request(self, request: Request, claims: Any) -> None: g.asap_claims = claims @property - def settings(self): + def settings(self) -> SettingsDict: settings = {} settings.update(self.default_settings) diff --git a/atlassian_jwt_auth/frameworks/flask/decorators.py b/atlassian_jwt_auth/frameworks/flask/decorators.py index 4fbbedf..3f82092 100644 --- a/atlassian_jwt_auth/frameworks/flask/decorators.py +++ b/atlassian_jwt_auth/frameworks/flask/decorators.py @@ -1,8 +1,16 @@ +from collections.abc import Callable +from typing import Iterable, Optional + from ..common.decorators import _with_asap from .backend import FlaskBackend -def with_asap(func=None, issuers=None, required=None, subject_should_match_issuer=None): +def with_asap( + func: Optional[Callable] = None, + issuers: Optional[Iterable[str]] = None, + required: Optional[bool] = None, + subject_should_match_issuer: Optional[bool] = None, +): """Decorator to allow endpoint-specific ASAP authentication. If authentication fails, a 401 or 403 response will be returned. Otherwise, @@ -20,5 +28,5 @@ def with_asap(func=None, issuers=None, required=None, subject_should_match_issue token to be considered valid. """ return _with_asap( - func, FlaskBackend(), issuers, required, subject_should_match_issuer + func, FlaskBackend(), issuers, required or False, subject_should_match_issuer ) diff --git a/atlassian_jwt_auth/frameworks/wsgi/backend.py b/atlassian_jwt_auth/frameworks/wsgi/backend.py index dc8997a..78a0a7b 100644 --- a/atlassian_jwt_auth/frameworks/wsgi/backend.py +++ b/atlassian_jwt_auth/frameworks/wsgi/backend.py @@ -1,9 +1,11 @@ +from typing import Any, Optional + from ..common.backend import Backend from ..common.utils import SettingsDict class WSGIBackend(Backend): - def __init__(self, settings): + def __init__(self, settings) -> None: self._settings = SettingsDict(settings) def get_authorization_header(self, request=None): @@ -12,7 +14,12 @@ def get_authorization_header(self, request=None): return request.environ.get("HTTP_AUTHORIZATION", b"") - def get_401_response(self, data=None, headers=None, request=None): + def get_401_response( + self, + data: Any = None, + headers: Optional[Any] = None, + request: Optional[Any] = None, + ) -> str: if request is None: raise TypeError("request must have a value") @@ -24,7 +31,12 @@ def get_401_response(self, data=None, headers=None, request=None): request.start_response("401 Unauthorized", list(headers.items()), None) return "" - def get_403_response(self, data=None, headers=None, request=None): + def get_403_response( + self, + data: Any = None, + headers: Optional[Any] = None, + request: Optional[Any] = None, + ) -> str: if request is None: raise TypeError("request must have a value") @@ -34,11 +46,11 @@ def get_403_response(self, data=None, headers=None, request=None): request.start_response("403 Forbidden", list(headers.items()), None) return "" - def set_asap_claims_for_request(self, request, claims): + def set_asap_claims_for_request(self, request: Any, claims: Any) -> None: request.environ["ATL_ASAP_CLAIMS"] = claims @property - def settings(self): + def settings(self) -> SettingsDict: settings = {} settings.update(self.default_settings) diff --git a/atlassian_jwt_auth/frameworks/wsgi/middleware.py b/atlassian_jwt_auth/frameworks/wsgi/middleware.py index 4e3fc5b..1d11a25 100644 --- a/atlassian_jwt_auth/frameworks/wsgi/middleware.py +++ b/atlassian_jwt_auth/frameworks/wsgi/middleware.py @@ -1,4 +1,5 @@ from collections import namedtuple +from typing import Any from ..common.asap import _process_asap_token from .backend import WSGIBackend @@ -7,12 +8,12 @@ class ASAPMiddleware(object): - def __init__(self, handler, settings): + def __init__(self, handler: Any, settings: Any) -> None: self._next = handler self._backend = WSGIBackend(settings) self._verifier = self._backend.get_verifier() - def __call__(self, environ, start_response): + def __call__(self, environ: Any, start_response: Any): settings = self._backend.settings request = Request(environ, start_response) error_response = _process_asap_token( diff --git a/atlassian_jwt_auth/key.py b/atlassian_jwt_auth/key.py index 2bbac6b..9cb80b8 100644 --- a/atlassian_jwt_auth/key.py +++ b/atlassian_jwt_auth/key.py @@ -3,6 +3,7 @@ import os import re from email.message import EmailMessage +from typing import Any, Generator, Iterable, Tuple, Union from urllib.parse import unquote_plus import cachecontrol @@ -25,15 +26,15 @@ class KeyIdentifier(object): """This class represents a key identifier""" - def __init__(self, identifier): + def __init__(self, identifier: str) -> None: self.__key_id = validate_key_identifier(identifier) @property - def key_id(self): + def key_id(self) -> str: return self.__key_id -def validate_key_identifier(identifier): +def validate_key_identifier(identifier: str) -> str: """returns a validated key identifier.""" regex = re.compile(r"^[\w.\-\+/]*$") _error_msg = "Invalid key identifier %s" % identifier @@ -51,7 +52,7 @@ def validate_key_identifier(identifier): return identifier -def _get_key_id_from_jwt_header(a_jwt): +def _get_key_id_from_jwt_header(a_jwt: Union[str, bytes]) -> KeyIdentifier: """returns the key identifier from a jwt header.""" header = jwt.get_unverified_header(a_jwt) return KeyIdentifier(header["kid"]) @@ -60,7 +61,7 @@ def _get_key_id_from_jwt_header(a_jwt): class BasePublicKeyRetriever(object): """Base class for retrieving a public key.""" - def retrieve(self, key_identifier, **kwargs): + def retrieve(self, key_identifier: Union[KeyIdentifier, str], **kwargs) -> Any: raise NotImplementedError() @@ -73,7 +74,7 @@ class HTTPSPublicKeyRetriever(BasePublicKeyRetriever): # HTTPSPublicKeyRetriever: _class_session = None - def __init__(self, base_url): + def __init__(self, base_url: str) -> None: if base_url is None or not base_url.startswith("https://"): raise PublicKeyRetrieverException("The base url must start with https://") if not base_url.endswith("/"): @@ -82,14 +83,16 @@ def __init__(self, base_url): self._session = self._get_session() self._proxies = requests.utils.get_environ_proxies(self.base_url) - def _get_session(self): + def _get_session(self) -> requests.Session: if HTTPSPublicKeyRetriever._class_session is None: session = cachecontrol.CacheControl(requests.Session()) session.trust_env = False HTTPSPublicKeyRetriever._class_session = session return HTTPSPublicKeyRetriever._class_session - def retrieve(self, key_identifier, **requests_kwargs): + def retrieve( + self, key_identifier: Union[KeyIdentifier, str], **requests_kwargs: Any + ) -> Any: """returns the public key for given key_identifier.""" if not isinstance(key_identifier, KeyIdentifier): key_identifier = KeyIdentifier(key_identifier) @@ -100,12 +103,12 @@ def retrieve(self, key_identifier, **requests_kwargs): return self._retrieve(url, requests_kwargs) except requests.RequestException as e: try: - status_code = e.response.status_code + status_code = e.response.status_code if e.response else None except AttributeError: status_code = None raise PublicKeyRetrieverException(e, status_code=status_code) - def _retrieve(self, url, requests_kwargs): + def _retrieve(self, url: str, requests_kwargs: Any) -> Any: resp = self._session.get( url, headers={"accept": PEM_FILE_TYPE}, **requests_kwargs ) @@ -113,7 +116,7 @@ def _retrieve(self, url, requests_kwargs): self._check_content_type(url, resp.headers["content-type"]) return resp.text - def _check_content_type(self, url, content_type): + def _check_content_type(self, url: str, content_type: str): msg = EmailMessage() msg["content-type"] = content_type media_type = msg.get_content_type() @@ -129,15 +132,19 @@ class HTTPSMultiRepositoryPublicKeyRetriever(BasePublicKeyRetriever): repository locations based upon key ids. """ - def __init__(self, key_repository_urls): + def __init__(self, key_repository_urls: Iterable[str]) -> None: if not isinstance(key_repository_urls, list): raise TypeError("keystore_urls must be a list of urls.") self._retrievers = self._create_retrievers(key_repository_urls) - def _create_retrievers(self, key_repository_urls): + def _create_retrievers( + self, key_repository_urls: Iterable[str] + ) -> Iterable[BasePublicKeyRetriever]: return [HTTPSPublicKeyRetriever(url) for url in key_repository_urls] - def handle_retrieval_exception(self, retriever, exception): + def handle_retrieval_exception( + self, retriever: BasePublicKeyRetriever, exception: Exception + ): """Handles working with exceptions encountered during key retrieval. """ @@ -148,7 +155,9 @@ def handle_retrieval_exception(self, retriever, exception): if exception.status_code is None or exception.status_code < 500: raise - def retrieve(self, key_identifier, **requests_kwargs): + def retrieve( + self, key_identifier: Union[KeyIdentifier, str], **requests_kwargs: Any + ) -> Any: for retriever in self._retrievers: try: return retriever.retrieve(key_identifier, **requests_kwargs) @@ -159,7 +168,7 @@ def retrieve(self, key_identifier, **requests_kwargs): "Unable to retrieve public key from store", extra={ "underlying_error": str(e), - "key repository": retriever.base_url, + "key repository": getattr(retriever, "base_url", "unknown"), }, ) raise PublicKeyRetrieverException("Cannot load key from key repositories") @@ -168,7 +177,7 @@ def retrieve(self, key_identifier, **requests_kwargs): class BasePrivateKeyRetriever(object): """This is the base private key retriever class.""" - def load(self, issuer): + def load(self, issuer: str) -> Tuple[Union[KeyIdentifier], Union[str, bytes]]: """returns the key identifier and private key pem found for the given issuer. """ @@ -180,10 +189,10 @@ class DataUriPrivateKeyRetriever(BasePrivateKeyRetriever): private key from the supplied data uri. """ - def __init__(self, data_uri): + def __init__(self, data_uri: str) -> None: self._data_uri = data_uri - def load(self, issuer): + def load(self, issuer: str) -> Tuple[Union[KeyIdentifier], Union[str, bytes]]: if not self._data_uri.startswith("data:application/pkcs8;kid="): raise PrivateKeyRetrieverException("Unrecognised data uri format.") splitted = self._data_uri.split(";") @@ -207,14 +216,18 @@ class StaticPrivateKeyRetriever(BasePrivateKeyRetriever): initially provided to it in calls to load. """ - def __init__(self, key_identifier, private_key_pem): + def __init__( + self, + key_identifier: Union[KeyIdentifier, str], + private_key_pem: Union[str, bytes], + ) -> None: if not isinstance(key_identifier, KeyIdentifier): key_identifier = KeyIdentifier(key_identifier) - self.key_identifier = key_identifier - self.private_key_pem = private_key_pem + self.key_identifier: KeyIdentifier = key_identifier + self.private_key_pem: Union[str, bytes] = private_key_pem - def load(self, issuer): + def load(self, issuer: str) -> Tuple[Union[KeyIdentifier], Union[str, bytes]]: return self.key_identifier, self.private_key_pem @@ -224,17 +237,17 @@ class FilePrivateKeyRetriever(BasePrivateKeyRetriever): repository path. """ - def __init__(self, private_key_repository_path): + def __init__(self, private_key_repository_path: str) -> None: self.private_key_repository = FilePrivateKeyRepository( private_key_repository_path ) - def load(self, issuer): + def load(self, issuer: str) -> Tuple[KeyIdentifier, str]: key_identifier = self._find_last_key_id(issuer) private_key_pem = self.private_key_repository.load_key(key_identifier) return key_identifier, private_key_pem - def _find_last_key_id(self, issuer): + def _find_last_key_id(self, issuer) -> KeyIdentifier: key_identifiers = list(self.private_key_repository.find_valid_key_ids(issuer)) if key_identifiers: @@ -246,16 +259,16 @@ def _find_last_key_id(self, issuer): class FilePrivateKeyRepository(object): """This class represents a file backed private key repository.""" - def __init__(self, path): + def __init__(self, path) -> None: self.path = path - def find_valid_key_ids(self, issuer): + def find_valid_key_ids(self, issuer: str) -> Generator[KeyIdentifier, Any, None]: issuer_directory = os.path.join(self.path, issuer) for filename in sorted(os.listdir(issuer_directory)): if filename.endswith(".pem"): yield KeyIdentifier("%s/%s" % (issuer, filename)) - def load_key(self, key_identifier): + def load_key(self, key_identifier: KeyIdentifier) -> str: key_filename = os.path.join(self.path, key_identifier.key_id) with open(key_filename, "rb") as f: return f.read().decode("utf-8") diff --git a/atlassian_jwt_auth/signer.py b/atlassian_jwt_auth/signer.py index ec65599..2bf12cd 100644 --- a/atlassian_jwt_auth/signer.py +++ b/atlassian_jwt_auth/signer.py @@ -1,22 +1,26 @@ import calendar import datetime import random +from typing import Any, Dict, Iterable, Optional, Union import jwt from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import serialization from atlassian_jwt_auth import algorithms, key +from atlassian_jwt_auth.key import BasePrivateKeyRetriever, KeyIdentifier class JWTAuthSigner(object): - def __init__(self, issuer, private_key_retriever, **kwargs): + def __init__( + self, issuer: str, private_key_retriever: BasePrivateKeyRetriever, **kwargs: Any + ) -> None: self.issuer = issuer self.private_key_retriever = private_key_retriever self.lifetime = kwargs.get("lifetime", datetime.timedelta(minutes=1)) self.algorithm = kwargs.get("algorithm", "RS256") self.subject = kwargs.get("subject", None) - self._private_keys_cache = dict() + self._private_keys_cache: Dict[str, Any] = dict() if self.algorithm not in set(algorithms.get_permitted_algorithm_names()): raise ValueError("Algorithm, '%s', is not permitted." % self.algorithm) @@ -25,24 +29,31 @@ def __init__(self, issuer, private_key_retriever, **kwargs): "lifetime, '%s',exceeds the allowed 1 hour max" % (self.lifetime) ) - def _obtain_private_key(self, key_identifier, private_key_pem): + def _obtain_private_key( + self, key_identifier: KeyIdentifier, private_key_pem: Union[str, bytes] + ): """returns a loaded instance of the given private key either from cache or from the given private_key_pem. """ priv_key = self._private_keys_cache.get(key_identifier.key_id, None) if priv_key is not None: return priv_key + private_key_bytes: bytes if not isinstance(private_key_pem, bytes): - private_key_pem = private_key_pem.encode() + private_key_bytes = private_key_pem.encode() + else: + private_key_bytes = private_key_pem priv_key = serialization.load_pem_private_key( - private_key_pem, password=None, backend=default_backend() + private_key_bytes, password=None, backend=default_backend() ) if len(self._private_keys_cache) > 10: self._private_keys_cache = dict() self._private_keys_cache[key_identifier.key_id] = priv_key return priv_key - def _generate_claims(self, audience, **kwargs): + def _generate_claims( + self, audience: Union[str, Iterable[str]], **kwargs: Any + ) -> Dict[Any, Any]: """returns a new dictionary of claims.""" now = self._now() claims = { @@ -58,10 +69,10 @@ def _generate_claims(self, audience, **kwargs): claims.update(kwargs.get("additional_claims", {})) return claims - def _now(self): + def _now(self) -> datetime.datetime: return datetime.datetime.now(datetime.timezone.utc) - def generate_jwt(self, audience, **kwargs): + def generate_jwt(self, audience: Union[str, Iterable[str]], **kwargs: Any) -> str: """returns a new signed jwt for use.""" key_identifier, private_key_pem = self.private_key_retriever.load(self.issuer) private_key = self._obtain_private_key(key_identifier, private_key_pem) @@ -69,31 +80,39 @@ def generate_jwt(self, audience, **kwargs): self._generate_claims(audience, **kwargs), key=private_key, algorithm=self.algorithm, - headers={"kid": key_identifier.key_id}, + headers={ + "kid": key_identifier.key_id + if isinstance(key_identifier, KeyIdentifier) + else key_identifier + }, ) - if isinstance(token, str): - token = token.encode("utf-8") + if isinstance(token, bytes): + return token.decode("utf-8") return token class TokenReusingJWTAuthSigner(JWTAuthSigner): - def __init__(self, issuer, private_key_retriever, **kwargs): + def __init__( + self, issuer: str, private_key_retriever: BasePrivateKeyRetriever, **kwargs: Any + ) -> None: super(TokenReusingJWTAuthSigner, self).__init__( issuer, private_key_retriever, **kwargs ) self.reuse_threshold = kwargs.get("reuse_jwt_threshold", 0.95) - def get_cached_token(self, audience, **kwargs): + def get_cached_token( + self, audience: Union[str, Iterable[str]], **kwargs: Any + ) -> Optional[str]: """returns the cached token. If there is no matching cached token then None is returned. """ return getattr(self, "_previous_token", None) - def set_cached_token(self, value): + def set_cached_token(self, value: Any) -> None: """sets the cached token.""" self._previous_token = value - def can_reuse_token(self, existing_token, claims): + def can_reuse_token(self, existing_token, claims) -> bool: """returns True if the provided existing token can be reused for the claims provided. """ @@ -120,7 +139,7 @@ def can_reuse_token(self, existing_token, claims): return False return True - def generate_jwt(self, audience, **kwargs): + def generate_jwt(self, audience: Union[str, Iterable[str]], **kwargs: Any) -> str: existing_token = self.get_cached_token(audience, **kwargs) claims = self._generate_claims(audience, **kwargs) if existing_token and self.can_reuse_token(existing_token, claims): @@ -130,14 +149,21 @@ def generate_jwt(self, audience, **kwargs): return token -def _create_signer(issuer, private_key_retriever, **kwargs): +def _create_signer( + issuer: str, private_key_retriever: BasePrivateKeyRetriever, **kwargs: Any +) -> JWTAuthSigner: signer_cls = JWTAuthSigner if kwargs.get("reuse_jwts", None): signer_cls = TokenReusingJWTAuthSigner return signer_cls(issuer, private_key_retriever, **kwargs) -def create_signer(issuer, key_identifier, private_key_pem, **kwargs): +def create_signer( + issuer: str, + key_identifier: Union[KeyIdentifier, str], + private_key_pem: Union[str, bytes], + **kwargs: Any, +) -> JWTAuthSigner: private_key_retriever = key.StaticPrivateKeyRetriever( key_identifier, private_key_pem ) @@ -145,7 +171,7 @@ def create_signer(issuer, key_identifier, private_key_pem, **kwargs): def create_signer_from_file_private_key_repository( - issuer, private_key_repository, **kwargs -): + issuer: str, private_key_repository: str, **kwargs: Any +) -> JWTAuthSigner: private_key_retriever = key.FilePrivateKeyRetriever(private_key_repository) return _create_signer(issuer, private_key_retriever, **kwargs) diff --git a/atlassian_jwt_auth/tests/test_key.py b/atlassian_jwt_auth/tests/test_key.py index 6e3a37d..34d3127 100644 --- a/atlassian_jwt_auth/tests/test_key.py +++ b/atlassian_jwt_auth/tests/test_key.py @@ -6,7 +6,7 @@ class TestKeyModule(unittest.TestCase): """tests for the key module.""" - def test_key_identifier_with_invalid_keys(self): + def test_key_identifier_with_invalid_keys(self) -> None: """test that invalid key identifiers are not permitted.""" keys = [ "../aha", @@ -27,7 +27,7 @@ def test_key_identifier_with_invalid_keys(self): with self.assertRaises(ValueError): atlassian_jwt_auth.KeyIdentifier(identifier=key) - def test_key_identifier_with_valid_keys(self): + def test_key_identifier_with_valid_keys(self) -> None: """test that valid keys work as expected.""" for key in ["oa.oo/a", "oo.sasdf.asdf/yes", "oo/o"]: key_id = atlassian_jwt_auth.KeyIdentifier(identifier=key) diff --git a/atlassian_jwt_auth/tests/utils.py b/atlassian_jwt_auth/tests/utils.py index 75bf629..e617a0b 100644 --- a/atlassian_jwt_auth/tests/utils.py +++ b/atlassian_jwt_auth/tests/utils.py @@ -1,11 +1,15 @@ +from typing import Any, Iterable, Optional, Protocol, Union + from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives.asymmetric import ec, rsa import atlassian_jwt_auth +from atlassian_jwt_auth import KeyIdentifier +from atlassian_jwt_auth.signer import JWTAuthSigner -def get_new_rsa_private_key_in_pem_format(): +def get_new_rsa_private_key_in_pem_format() -> bytes: """returns a new rsa key in pem format.""" private_key = rsa.generate_private_key( key_size=2048, backend=default_backend(), public_exponent=65537 @@ -17,7 +21,7 @@ def get_new_rsa_private_key_in_pem_format(): ) -def get_public_key_pem_for_private_key_pem(private_key_pem): +def get_public_key_pem_for_private_key_pem(private_key_pem: bytes) -> bytes: private_key = serialization.load_pem_private_key( private_key_pem, password=None, backend=default_backend() ) @@ -28,7 +32,7 @@ def get_public_key_pem_for_private_key_pem(private_key_pem): ) -def get_example_jwt_auth_signer(**kwargs): +def get_example_jwt_auth_signer(**kwargs: Any) -> JWTAuthSigner: """returns an example jwt_auth_signer instance.""" issuer = kwargs.get("issuer", "egissuer") key_id = kwargs.get("key_id", "%s/a" % issuer) @@ -37,7 +41,13 @@ def get_example_jwt_auth_signer(**kwargs): return atlassian_jwt_auth.create_signer(issuer, key_id, key, algorithm=algorithm) -def create_token(issuer, audience, key_id, private_key, subject=None): +def create_token( + issuer: str, + audience: Union[str, Iterable[str]], + key_id: Union[KeyIdentifier, str], + private_key: str, + subject: Optional[str] = None, +): """ " returns a token based upon the supplied parameters.""" signer = atlassian_jwt_auth.create_signer( issuer, key_id, private_key, subject=subject @@ -50,27 +60,46 @@ class BaseJWTAlgorithmTestMixin(object): jwt algorithms easier. """ - def get_new_private_key_in_pem_format(self): + def get_new_private_key_in_pem_format(self) -> bytes: """returns a new private key in pem format.""" raise NotImplementedError("not implemented.") -class RS256KeyTestMixin(object): +class UnitTestProtocol(Protocol): + def assertEqual(self, a, b): ... + + def assertIsNotNone(self, a): ... + + def assertTrue(self, a): ... + + def assertIn(self, a, b): ... + + def assertNotEqual(self, a, b): ... + + +class KeyMixInProtocol(Protocol): + @property + def algorithm(self) -> str: ... + + def get_new_private_key_in_pem_format(self) -> bytes: ... + + +class RS256KeyTestMixin(KeyMixInProtocol): """Private rs256 test mixin.""" @property - def algorithm(self): + def algorithm(self) -> str: return "RS256" def get_new_private_key_in_pem_format(self): return get_new_rsa_private_key_in_pem_format() -class ES256KeyTestMixin(object): +class ES256KeyTestMixin(KeyMixInProtocol): """Private es256 test mixin.""" @property - def algorithm(self): + def algorithm(self) -> str: return "ES256" def get_new_private_key_in_pem_format(self): diff --git a/atlassian_jwt_auth/verifier.py b/atlassian_jwt_auth/verifier.py index ccae9ea..03e3261 100644 --- a/atlassian_jwt_auth/verifier.py +++ b/atlassian_jwt_auth/verifier.py @@ -1,17 +1,28 @@ from collections import OrderedDict from functools import lru_cache +from typing import Any, Dict, Iterable, Optional, Sequence, Union import jwt import jwt.api_jwt from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurvePublicKey +from cryptography.hazmat.primitives.asymmetric.ed448 import Ed448PublicKey +from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PublicKey from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicKey +from jwt import PyJWK from jwt.exceptions import InvalidAlgorithmError -from atlassian_jwt_auth import algorithms, exceptions, key +from atlassian_jwt_auth import KeyIdentifier, algorithms, exceptions, key +from atlassian_jwt_auth.key import BasePublicKeyRetriever + +AllowedPublicKeys = Union[ + RSAPublicKey, EllipticCurvePublicKey, Ed25519PublicKey, Ed448PublicKey +] @lru_cache(maxsize=10) -def _load_public_key(algorithms, public_key, algorithm): +def _load_public_key( + algorithms: Sequence[str], public_key: str, algorithm: Optional[str] +) -> Any: """Returns a public key object instance given the public key and algorithm. @@ -20,7 +31,7 @@ def _load_public_key(algorithms, public_key, algorithm): """ if isinstance(public_key, (RSAPublicKey, EllipticCurvePublicKey)): return public_key - if algorithm not in algorithms: + if algorithm is None or algorithm not in algorithms: raise InvalidAlgorithmError("The specified alg value is not allowed") py_jws = jwt.api_jws.PyJWS(algorithms=algorithms) alg_obj = py_jws._algorithms[algorithm] @@ -30,16 +41,20 @@ def _load_public_key(algorithms, public_key, algorithm): class JWTAuthVerifier(object): """This class can be used to verify a JWT.""" - def __init__(self, public_key_retriever, **kwargs): + def __init__( + self, public_key_retriever: BasePublicKeyRetriever, **kwargs: Any + ) -> None: self.public_key_retriever = public_key_retriever self.algorithms = algorithms.get_permitted_algorithm_names() - self._seen_jti = OrderedDict() + self._seen_jti: OrderedDict[str, None] = OrderedDict() self._subject_should_match_issuer = kwargs.get( "subject_should_match_issuer", True ) self._check_jti_uniqueness = kwargs.get("check_jti_uniqueness", False) - def verify_jwt(self, a_jwt, audience, leeway=0, **requests_kwargs): + def verify_jwt( + self, a_jwt: str, audience: str, leeway: int = 0, **requests_kwargs: Any + ) -> Dict[Any, Any]: """Verify if the token is correct Returns: @@ -52,21 +67,30 @@ def verify_jwt(self, a_jwt, audience, leeway=0, **requests_kwargs): public_key = self._retrieve_pub_key(key_identifier, requests_kwargs) alg = jwt.get_unverified_header(a_jwt).get("alg", None) - public_key_obj = self._load_public_key(public_key, alg) + public_key_obj = self._load_public_key(public_key, alg or "RS256") return self._decode_jwt( a_jwt, key_identifier, public_key_obj, audience=audience, leeway=leeway ) - def _retrieve_pub_key(self, key_identifier, requests_kwargs): + def _retrieve_pub_key( + self, key_identifier: Union[KeyIdentifier, str], requests_kwargs: Any + ) -> str: return self.public_key_retriever.retrieve(key_identifier, **requests_kwargs) - def _load_public_key(self, public_key, algorithm): + def _load_public_key(self, public_key: str, algorithm: Optional[str]) -> Any: """Returns a public key object instance given the public key and algorithm. """ return _load_public_key(tuple(self.algorithms), public_key, algorithm) - def _decode_jwt(self, a_jwt, key_identifier, jwt_key, audience=None, leeway=0): + def _decode_jwt( + self, + a_jwt: str, + key_identifier: KeyIdentifier, + jwt_key: Union[AllowedPublicKeys, PyJWK, str, bytes], + audience: Optional[Union[str, Iterable[str]]] = None, + leeway: int = 0, + ) -> Dict[Any, Any]: """Decode JWT and check if it's valid""" options = { "verify_signature": True, @@ -113,7 +137,7 @@ def _decode_jwt(self, a_jwt, key_identifier, jwt_key, audience=None, leeway=0): self._check_jti(_jti) return claims - def _check_jti(self, jti): + def _check_jti(self, jti: str) -> None: """Checks that the given jti has not been already been used.""" if jti in self._seen_jti: raise exceptions.JtiUniquenessException( From 69d8f30abe29c51305f4d1014b746f43d1719841 Mon Sep 17 00:00:00 2001 From: David Black Date: Thu, 25 Sep 2025 17:18:28 +1000 Subject: [PATCH 3/5] Sem-ver: bugfix Configure mypy and run it in ci Signed-off-by: David Black --- .github/workflows/build.yml | 3 ++- pyproject.toml | 10 ++++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 70b43b2..be2d603 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -19,11 +19,12 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip wheel setuptools - pip install -q ruff + pip install -q ruff mypy - name: Lint run: | ruff check . ruff format --check . + mypy . - name: Test run: | pip install wheel diff --git a/pyproject.toml b/pyproject.toml index 5199082..e12922e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,3 +15,13 @@ line-ending = "auto" [tool.ruff.lint] extend-select = ["I"] +[tool.mypy] +warn_unused_configs = true + +[[tool.mypy.overrides]] +module = "django.*" +ignore_missing_imports = true + +[[tool.mypy.overrides]] +module = "httptest.*" +ignore_missing_imports = true From d99e87b3e4ebefe2224df1ca94cde51d4a04c889 Mon Sep 17 00:00:00 2001 From: David Black Date: Thu, 25 Sep 2025 17:19:01 +1000 Subject: [PATCH 4/5] Sem-ver: feature Provide type information to consumers of the library Signed-off-by: David Black Sem-ver: bugfix Have mypy install stubs as part of ci Signed-off-by: David Black Sem-ver: bugfix Typing fix ups - this commit will be merged into an earlier one Signed-off-by: David Black Remove the unused UnitTestProtocol code Signed-off-by: David Black --- .github/workflows/build.yml | 2 ++ atlassian_jwt_auth/auth.py | 4 +-- atlassian_jwt_auth/frameworks/common/asap.py | 8 ++++-- .../frameworks/common/decorators.py | 25 +++++++++---------- .../frameworks/django/decorators.py | 7 ++++-- .../frameworks/django/tests/test_django.py | 4 +-- .../frameworks/flask/decorators.py | 4 ++- atlassian_jwt_auth/py.typed | 0 atlassian_jwt_auth/signer.py | 12 ++++----- atlassian_jwt_auth/tests/utils.py | 14 +---------- atlassian_jwt_auth/verifier.py | 8 ++++-- test-requirements.txt | 2 ++ 12 files changed, 46 insertions(+), 44 deletions(-) create mode 100644 atlassian_jwt_auth/py.typed diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index be2d603..7509a0f 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -18,6 +18,8 @@ jobs: python-version: ${{ matrix.python-version }} - name: Install dependencies run: | + pip install -r requirements.txt + pip install -r test-requirements.txt python -m pip install --upgrade pip wheel setuptools pip install -q ruff mypy - name: Lint diff --git a/atlassian_jwt_auth/auth.py b/atlassian_jwt_auth/auth.py index a03e6e0..0cec861 100644 --- a/atlassian_jwt_auth/auth.py +++ b/atlassian_jwt_auth/auth.py @@ -26,7 +26,7 @@ def create( cls, issuer: str, key_identifier: Union[KeyIdentifier, str], - private_key_pem: str, + private_key_pem: Union[str, bytes], audience: Union[str, Iterable[str]], **kwargs: Any, ) -> "BaseJWTAuth": @@ -39,4 +39,4 @@ def create( def _get_header_value(self) -> bytes: return b"Bearer " + self._signer.generate_jwt( self._audience, additional_claims=self._additional_claims - ).encode("utf-8") + ) diff --git a/atlassian_jwt_auth/frameworks/common/asap.py b/atlassian_jwt_auth/frameworks/common/asap.py index 622ffe5..9a54fc1 100644 --- a/atlassian_jwt_auth/frameworks/common/asap.py +++ b/atlassian_jwt_auth/frameworks/common/asap.py @@ -37,7 +37,7 @@ def _process_asap_token( if verifier is None: verifier = backend.get_verifier(settings=settings) asap_claims = verifier.verify_jwt( - token.decode("utf-8") if isinstance(token, bytes) else token, + token, settings.ASAP_VALID_AUDIENCE, leeway=settings.ASAP_VALID_LEEWAY, ) @@ -94,5 +94,9 @@ def _verify_issuers( ) -> None: """Verify that the issuer in the claims is valid and is expected.""" claim_iss = asap_claims.get("iss") - if issuers is not None and claim_iss is not None and claim_iss not in issuers: + if issuers is None: + return None + + if claim_iss is None or claim_iss not in issuers: raise InvalidIssuerError + return None diff --git a/atlassian_jwt_auth/frameworks/common/decorators.py b/atlassian_jwt_auth/frameworks/common/decorators.py index 0ea7967..2b2d6ad 100644 --- a/atlassian_jwt_auth/frameworks/common/decorators.py +++ b/atlassian_jwt_auth/frameworks/common/decorators.py @@ -72,23 +72,22 @@ def restrict_asap_wrapper(request, *args, **kwargs) -> Any: asap_claims = getattr(request, "asap_claims", None) error_response = None - if required and not asap_claims: - if backend is not None: + if not asap_claims: + if required: return backend.get_401_response("Unauthorized", request=request) - + else: + # Claims are not required and asap claims are not present. + return func(request, *args, **kwargs) try: - if asap_claims is not None: - _verify_issuers(asap_claims, settings.ASAP_VALID_ISSUERS) + _verify_issuers(asap_claims, settings.ASAP_VALID_ISSUERS) except InvalidIssuerError: - if backend is not None: - error_response = backend.get_403_response( - "Forbidden: Invalid token issuer", request=request - ) + error_response = backend.get_403_response( + "Forbidden: Invalid token issuer", request=request + ) except InvalidTokenError: - if backend is not None: - error_response = backend.get_401_response( - "Unauthorized: Invalid token", request=request - ) + error_response = backend.get_401_response( + "Unauthorized: Invalid token", request=request + ) if error_response and required: return error_response diff --git a/atlassian_jwt_auth/frameworks/django/decorators.py b/atlassian_jwt_auth/frameworks/django/decorators.py index 61aa954..aab43a0 100644 --- a/atlassian_jwt_auth/frameworks/django/decorators.py +++ b/atlassian_jwt_auth/frameworks/django/decorators.py @@ -56,7 +56,10 @@ def restrict_asap( must match the issuer for a token to be considered valid. """ - issuers = issuers if issuers is not None else [] return _restrict_asap( - func, DjangoBackend(), issuers, required, subject_should_match_issuer=None + func, + DjangoBackend(), + issuers, + required, + subject_should_match_issuer=None, ) diff --git a/atlassian_jwt_auth/frameworks/django/tests/test_django.py b/atlassian_jwt_auth/frameworks/django/tests/test_django.py index 95c39ab..026bbc1 100644 --- a/atlassian_jwt_auth/frameworks/django/tests/test_django.py +++ b/atlassian_jwt_auth/frameworks/django/tests/test_django.py @@ -68,7 +68,7 @@ def check_response( token=None, authorization=None, retriever_key=None, - ): + ) -> None: if authorization is None: if token is None: if private_key is None: @@ -211,7 +211,7 @@ def test_request_non_decorated_subject_is_rejected(self): def test_request_using_settings_only_is_allowed(self): self.check_response("unneeded", "two") - def test_request_subject_does_not_need_to_match_issuer_from_settings(self): + def test_request_subject_does_not_need_to_match_issuer_from_settings(self) -> None: self.test_settings["ASAP_SUBJECT_SHOULD_MATCH_ISSUER"] = False self.check_response("needed", "one", 200, subject="different_than_is") diff --git a/atlassian_jwt_auth/frameworks/flask/decorators.py b/atlassian_jwt_auth/frameworks/flask/decorators.py index 3f82092..e81fb4e 100644 --- a/atlassian_jwt_auth/frameworks/flask/decorators.py +++ b/atlassian_jwt_auth/frameworks/flask/decorators.py @@ -27,6 +27,8 @@ def with_asap( must match the issuer for a token to be considered valid. """ + if required is None: + required = True return _with_asap( - func, FlaskBackend(), issuers, required or False, subject_should_match_issuer + func, FlaskBackend(), issuers, required, subject_should_match_issuer ) diff --git a/atlassian_jwt_auth/py.typed b/atlassian_jwt_auth/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/atlassian_jwt_auth/signer.py b/atlassian_jwt_auth/signer.py index 2bf12cd..5487aac 100644 --- a/atlassian_jwt_auth/signer.py +++ b/atlassian_jwt_auth/signer.py @@ -72,11 +72,11 @@ def _generate_claims( def _now(self) -> datetime.datetime: return datetime.datetime.now(datetime.timezone.utc) - def generate_jwt(self, audience: Union[str, Iterable[str]], **kwargs: Any) -> str: + def generate_jwt(self, audience: Union[str, Iterable[str]], **kwargs: Any) -> bytes: """returns a new signed jwt for use.""" key_identifier, private_key_pem = self.private_key_retriever.load(self.issuer) private_key = self._obtain_private_key(key_identifier, private_key_pem) - token = jwt.encode( + token: str = jwt.encode( self._generate_claims(audience, **kwargs), key=private_key, algorithm=self.algorithm, @@ -86,9 +86,7 @@ def generate_jwt(self, audience: Union[str, Iterable[str]], **kwargs: Any) -> st else key_identifier }, ) - if isinstance(token, bytes): - return token.decode("utf-8") - return token + return token.encode("utf-8") class TokenReusingJWTAuthSigner(JWTAuthSigner): @@ -102,7 +100,7 @@ def __init__( def get_cached_token( self, audience: Union[str, Iterable[str]], **kwargs: Any - ) -> Optional[str]: + ) -> Optional[bytes]: """returns the cached token. If there is no matching cached token then None is returned. """ @@ -139,7 +137,7 @@ def can_reuse_token(self, existing_token, claims) -> bool: return False return True - def generate_jwt(self, audience: Union[str, Iterable[str]], **kwargs: Any) -> str: + def generate_jwt(self, audience: Union[str, Iterable[str]], **kwargs: Any) -> bytes: existing_token = self.get_cached_token(audience, **kwargs) claims = self._generate_claims(audience, **kwargs) if existing_token and self.can_reuse_token(existing_token, claims): diff --git a/atlassian_jwt_auth/tests/utils.py b/atlassian_jwt_auth/tests/utils.py index e617a0b..bf60215 100644 --- a/atlassian_jwt_auth/tests/utils.py +++ b/atlassian_jwt_auth/tests/utils.py @@ -45,7 +45,7 @@ def create_token( issuer: str, audience: Union[str, Iterable[str]], key_id: Union[KeyIdentifier, str], - private_key: str, + private_key: Union[str, bytes], subject: Optional[str] = None, ): """ " returns a token based upon the supplied parameters.""" @@ -65,18 +65,6 @@ def get_new_private_key_in_pem_format(self) -> bytes: raise NotImplementedError("not implemented.") -class UnitTestProtocol(Protocol): - def assertEqual(self, a, b): ... - - def assertIsNotNone(self, a): ... - - def assertTrue(self, a): ... - - def assertIn(self, a, b): ... - - def assertNotEqual(self, a, b): ... - - class KeyMixInProtocol(Protocol): @property def algorithm(self) -> str: ... diff --git a/atlassian_jwt_auth/verifier.py b/atlassian_jwt_auth/verifier.py index 03e3261..4575f18 100644 --- a/atlassian_jwt_auth/verifier.py +++ b/atlassian_jwt_auth/verifier.py @@ -53,7 +53,11 @@ def __init__( self._check_jti_uniqueness = kwargs.get("check_jti_uniqueness", False) def verify_jwt( - self, a_jwt: str, audience: str, leeway: int = 0, **requests_kwargs: Any + self, + a_jwt: Union[str, bytes], + audience: str, + leeway: int = 0, + **requests_kwargs: Any, ) -> Dict[Any, Any]: """Verify if the token is correct @@ -85,7 +89,7 @@ def _load_public_key(self, public_key: str, algorithm: Optional[str]) -> Any: def _decode_jwt( self, - a_jwt: str, + a_jwt: Union[str, bytes], key_identifier: KeyIdentifier, jwt_key: Union[AllowedPublicKeys, PyJWK, str, bytes], audience: Optional[Union[str, Iterable[str]]] = None, diff --git a/test-requirements.txt b/test-requirements.txt index ab2ed3a..9dd0d92 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -3,3 +3,5 @@ flask>=2.0.3,<4.0.0 Django>=3.2.9,<5.0.0 atlassian-httptest==1.0.0 aiohttp==3.12.14 +types-requests +types-setuptools From 52eb4efebda5d9c75554d2ca086eade5034ddd24 Mon Sep 17 00:00:00 2001 From: David Black Date: Mon, 29 Sep 2025 12:48:45 +1000 Subject: [PATCH 5/5] Sem-ver: bugfix Use __all__ instead of noqa Signed-off-by: David Black --- atlassian_jwt_auth/__init__.py | 19 ++++++++++---- .../contrib/aiohttp/__init__.py | 25 +++++++------------ .../contrib/flask_app/__init__.py | 6 ++++- .../frameworks/django/__init__.py | 6 +++-- .../frameworks/flask/__init__.py | 4 ++- .../frameworks/wsgi/__init__.py | 4 ++- 6 files changed, 38 insertions(+), 26 deletions(-) diff --git a/atlassian_jwt_auth/__init__.py b/atlassian_jwt_auth/__init__.py index a738e8c..1041b98 100644 --- a/atlassian_jwt_auth/__init__.py +++ b/atlassian_jwt_auth/__init__.py @@ -1,10 +1,19 @@ -from atlassian_jwt_auth.algorithms import get_permitted_algorithm_names # noqa +from atlassian_jwt_auth.algorithms import get_permitted_algorithm_names from atlassian_jwt_auth.key import ( - HTTPSPublicKeyRetriever, # noqa - KeyIdentifier, # noqa + HTTPSPublicKeyRetriever, + KeyIdentifier, ) -from atlassian_jwt_auth.signer import ( # noqa +from atlassian_jwt_auth.signer import ( create_signer, create_signer_from_file_private_key_repository, ) -from atlassian_jwt_auth.verifier import JWTAuthVerifier # noqa +from atlassian_jwt_auth.verifier import JWTAuthVerifier + +__all__ = [ + "get_permitted_algorithm_names", + "HTTPSPublicKeyRetriever", + "KeyIdentifier", + "create_signer", + "create_signer_from_file_private_key_repository", + "JWTAuthVerifier", +] diff --git a/atlassian_jwt_auth/contrib/aiohttp/__init__.py b/atlassian_jwt_auth/contrib/aiohttp/__init__.py index 1de9136..916c28f 100644 --- a/atlassian_jwt_auth/contrib/aiohttp/__init__.py +++ b/atlassian_jwt_auth/contrib/aiohttp/__init__.py @@ -1,18 +1,11 @@ """Provide asyncio support""" -import sys - -if sys.version_info >= (3, 5): - try: - import aiohttp # noqa - - from .auth import JWTAuth # noqa - from .key import HTTPSPublicKeyRetriever # noqa - from .verifier import JWTAuthVerifier # noqa - except ImportError as e: - import warnings - - warnings.warn(str(e)) - - -del sys +from .auth import JWTAuth +from .key import HTTPSPublicKeyRetriever +from .verifier import JWTAuthVerifier + +__all__ = [ + "JWTAuth", + "HTTPSPublicKeyRetriever", + "JWTAuthVerifier", +] diff --git a/atlassian_jwt_auth/contrib/flask_app/__init__.py b/atlassian_jwt_auth/contrib/flask_app/__init__.py index 062704e..5acc92b 100644 --- a/atlassian_jwt_auth/contrib/flask_app/__init__.py +++ b/atlassian_jwt_auth/contrib/flask_app/__init__.py @@ -1,6 +1,10 @@ import warnings -from .decorators import requires_asap # noqa +from .decorators import requires_asap + +__all__ = [ + "requires_asap", +] warnings.warn( "The atlassian_jwt_auth.contrib.flask_app package is deprecated in 4.0.0 " diff --git a/atlassian_jwt_auth/frameworks/django/__init__.py b/atlassian_jwt_auth/frameworks/django/__init__.py index 405ea5d..2b29c0d 100644 --- a/atlassian_jwt_auth/frameworks/django/__init__.py +++ b/atlassian_jwt_auth/frameworks/django/__init__.py @@ -1,2 +1,4 @@ -from .decorators import restrict_asap, with_asap # noqa -from .middleware import OldStyleASAPMiddleware, asap_middleware # noqa +from .decorators import restrict_asap, with_asap +from .middleware import OldStyleASAPMiddleware, asap_middleware + +__all__ = ["restrict_asap", "with_asap", "OldStyleASAPMiddleware", "asap_middleware"] diff --git a/atlassian_jwt_auth/frameworks/flask/__init__.py b/atlassian_jwt_auth/frameworks/flask/__init__.py index 7acd6ad..3394e9c 100644 --- a/atlassian_jwt_auth/frameworks/flask/__init__.py +++ b/atlassian_jwt_auth/frameworks/flask/__init__.py @@ -1 +1,3 @@ -from .decorators import with_asap # noqa +from .decorators import with_asap + +__all__ = ["with_asap"] diff --git a/atlassian_jwt_auth/frameworks/wsgi/__init__.py b/atlassian_jwt_auth/frameworks/wsgi/__init__.py index 5c67962..7a72ba0 100644 --- a/atlassian_jwt_auth/frameworks/wsgi/__init__.py +++ b/atlassian_jwt_auth/frameworks/wsgi/__init__.py @@ -1 +1,3 @@ -from .middleware import ASAPMiddleware # noqa +from .middleware import ASAPMiddleware + +__all__ = ["ASAPMiddleware"]