diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 2e9da00..70b43b2 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -5,7 +5,7 @@ on: [push] jobs: build: - runs-on: ubuntu-20.04 + runs-on: ubuntu-24.04 strategy: matrix: python-version: [3.9, "3.10", "3.11", "3.12", "3.13"] @@ -19,11 +19,11 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip wheel setuptools - pip install -q pycodestyle==2.9.1 flake8==5.0.4 + pip install -q ruff - name: Lint run: | - pycodestyle . - flake8 . + ruff check . + ruff format --check . - name: Test run: | pip install wheel diff --git a/atlassian_jwt_auth/algorithms.py b/atlassian_jwt_auth/algorithms.py index c1c8962..b208f69 100644 --- a/atlassian_jwt_auth/algorithms.py +++ b/atlassian_jwt_auth/algorithms.py @@ -1,13 +1,13 @@ def get_permitted_algorithm_names(): - """ returns permitted algorithm names. """ + """returns permitted algorithm names.""" return [ - 'RS256', - 'RS384', - 'RS512', - 'ES256', - 'ES384', - 'ES512', - 'PS256', - 'PS384', - 'PS512' + "RS256", + "RS384", + "RS512", + "ES256", + "ES384", + "ES512", + "PS256", + "PS384", + "PS512", ] diff --git a/atlassian_jwt_auth/auth.py b/atlassian_jwt_auth/auth.py index f0a1032..edab9b5 100644 --- a/atlassian_jwt_auth/auth.py +++ b/atlassian_jwt_auth/auth.py @@ -9,16 +9,17 @@ class BaseJWTAuth(object): def __init__(self, signer, audience, *args, **kwargs): self._audience = audience self._signer = signer - self._additional_claims = kwargs.get('additional_claims', {}) + self._additional_claims = kwargs.get("additional_claims", {}) @classmethod - def create(cls, issuer, key_identifier, private_key_pem, audience, - **kwargs): + def create(cls, issuer, key_identifier, private_key_pem, audience, **kwargs): """Instantiate a JWTAuth while creating the signer inline""" - signer = atlassian_jwt_auth.create_signer(issuer, key_identifier, - private_key_pem, **kwargs) + signer = atlassian_jwt_auth.create_signer( + issuer, key_identifier, private_key_pem, **kwargs + ) return cls(signer, audience) def _get_header_value(self): - return b'Bearer ' + self._signer.generate_jwt( - self._audience, additional_claims=self._additional_claims) + return b"Bearer " + self._signer.generate_jwt( + self._audience, additional_claims=self._additional_claims + ) diff --git a/atlassian_jwt_auth/contrib/aiohttp/__init__.py b/atlassian_jwt_auth/contrib/aiohttp/__init__.py index 38ed0d3..f02cf8f 100644 --- a/atlassian_jwt_auth/contrib/aiohttp/__init__.py +++ b/atlassian_jwt_auth/contrib/aiohttp/__init__.py @@ -1,4 +1,5 @@ """Provide asyncio support""" + import sys if sys.version_info >= (3, 5): @@ -9,6 +10,7 @@ from .verifier import JWTAuthVerifier # noqa except ImportError as e: import warnings + warnings.warn(str(e)) diff --git a/atlassian_jwt_auth/contrib/aiohttp/auth.py b/atlassian_jwt_auth/contrib/aiohttp/auth.py index e601617..3e76abd 100644 --- a/atlassian_jwt_auth/contrib/aiohttp/auth.py +++ b/atlassian_jwt_auth/contrib/aiohttp/auth.py @@ -8,15 +8,14 @@ class JWTAuth(BaseJWTAuth, BasicAuth): It should be aiohttp.BasicAuth subclass, so redefine its `__new__` method. """ + def __new__(cls, *args, **kwargs): - return super().__new__(cls, '') + return super().__new__(cls, "") def encode(self): return self._get_header_value().decode(self.encoding) -def create_jwt_auth( - issuer, key_identifier, private_key_pem, audience, **kwargs): +def create_jwt_auth(issuer, key_identifier, private_key_pem, audience, **kwargs): """Instantiate a JWTAuth while creating the signer inline""" - return JWTAuth.create( - issuer, key_identifier, private_key_pem, audience, **kwargs) + return JWTAuth.create(issuer, key_identifier, private_key_pem, audience, **kwargs) diff --git a/atlassian_jwt_auth/contrib/aiohttp/key.py b/atlassian_jwt_auth/contrib/aiohttp/key.py index cbeaa08..cd1357c 100644 --- a/atlassian_jwt_auth/contrib/aiohttp/key.py +++ b/atlassian_jwt_auth/contrib/aiohttp/key.py @@ -6,12 +6,15 @@ from atlassian_jwt_auth.exceptions import PublicKeyRetrieverException from atlassian_jwt_auth.key import ( PEM_FILE_TYPE, - HTTPSPublicKeyRetriever as _HTTPSPublicKeyRetriever +) +from atlassian_jwt_auth.key import ( + HTTPSPublicKeyRetriever as _HTTPSPublicKeyRetriever, ) class HTTPSPublicKeyRetriever(_HTTPSPublicKeyRetriever): """A class for retrieving JWT public keys with aiohttp""" + _class_session = None def __init__(self, base_url, *, loop=None): @@ -23,32 +26,32 @@ def __init__(self, base_url, *, loop=None): def _get_session(self): if HTTPSPublicKeyRetriever._class_session is None: HTTPSPublicKeyRetriever._class_session = aiohttp.ClientSession( - loop=self.loop) + loop=self.loop + ) return HTTPSPublicKeyRetriever._class_session def _convert_proxies_to_proxy_arg(self, url, requests_kwargs): - """ returns a modified requests_kwargs dict that contains proxy - information in a form that aiohttp accepts - (it wants proxy information instead of a dict of proxies). + """returns a modified requests_kwargs dict that contains proxy + information in a form that aiohttp accepts + (it wants proxy information instead of a dict of proxies). """ proxy = None - if 'proxies' in requests_kwargs: + if "proxies" in requests_kwargs: scheme = urllib.parse.urlparse(url).scheme - proxy = requests_kwargs['proxies'].get(scheme, None) - del requests_kwargs['proxies'] - requests_kwargs['proxy'] = proxy + proxy = requests_kwargs["proxies"].get(scheme, None) + del requests_kwargs["proxies"] + requests_kwargs["proxy"] = proxy return requests_kwargs async def _retrieve(self, url, requests_kwargs): - requests_kwargs = self._convert_proxies_to_proxy_arg( - url, requests_kwargs) + requests_kwargs = self._convert_proxies_to_proxy_arg(url, requests_kwargs) try: - resp = await self._session.get(url, headers={'accept': - PEM_FILE_TYPE}, - **requests_kwargs) + resp = await self._session.get( + url, headers={"accept": PEM_FILE_TYPE}, **requests_kwargs + ) resp.raise_for_status() - self._check_content_type(url, resp.headers['content-type']) + self._check_content_type(url, resp.headers["content-type"]) return await resp.text() except aiohttp.ClientError as e: - status_code = getattr(e, 'code', None) + status_code = getattr(e, "code", None) raise PublicKeyRetrieverException(e, status_code=status_code) diff --git a/atlassian_jwt_auth/contrib/aiohttp/verifier.py b/atlassian_jwt_auth/contrib/aiohttp/verifier.py index 388da76..1877d3d 100644 --- a/atlassian_jwt_auth/contrib/aiohttp/verifier.py +++ b/atlassian_jwt_auth/contrib/aiohttp/verifier.py @@ -22,8 +22,8 @@ async def verify_jwt(self, a_jwt, audience, leeway=0, **requests_kwargs): if asyncio.iscoroutine(public_key): public_key = await public_key - alg = jwt.get_unverified_header(a_jwt).get('alg', None) + alg = jwt.get_unverified_header(a_jwt).get("alg", None) public_key_obj = self._load_public_key(public_key, alg) return self._decode_jwt( - a_jwt, key_identifier, public_key_obj, - audience=audience, leeway=leeway) + a_jwt, key_identifier, public_key_obj, audience=audience, leeway=leeway + ) diff --git a/atlassian_jwt_auth/contrib/django/__init__.py b/atlassian_jwt_auth/contrib/django/__init__.py index 2b1bd6f..9f7e5b8 100644 --- a/atlassian_jwt_auth/contrib/django/__init__.py +++ b/atlassian_jwt_auth/contrib/django/__init__.py @@ -1,8 +1,8 @@ import warnings - warnings.warn( "The atlassian_jwt_auth.contrib.django package is deprecated in 4.0.0 " "in favour of atlassian_jwt_auth.frameworks.django.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/atlassian_jwt_auth/contrib/django/decorators.py b/atlassian_jwt_auth/contrib/django/decorators.py index 8474956..ac29f45 100644 --- a/atlassian_jwt_auth/contrib/django/decorators.py +++ b/atlassian_jwt_auth/contrib/django/decorators.py @@ -16,30 +16,32 @@ def validate_asap(issuers=None, subjects=None, required=True): :param boolean required: Whether or not to require ASAP on this endpoint. Note that requirements will be still be verified if claims are present. """ + def validate_asap_decorator(func): @wraps(func) def validate_asap_wrapper(request, *args, **kwargs): - asap_claims = getattr(request, 'asap_claims', None) + asap_claims = getattr(request, "asap_claims", None) if required and not asap_claims: - message = 'Unauthorized: Invalid or missing token' + message = "Unauthorized: Invalid or missing token" response = HttpResponse(message, status=401) - response['WWW-Authenticate'] = 'Bearer' + response["WWW-Authenticate"] = "Bearer" return response if asap_claims: - iss = asap_claims['iss'] + iss = asap_claims["iss"] if issuers and iss not in issuers: - message = 'Forbidden: Invalid token issuer' + message = "Forbidden: Invalid token issuer" return HttpResponse(message, status=403) - sub = asap_claims.get('sub') + sub = asap_claims.get("sub") if subjects and sub not in subjects: - message = 'Forbidden: Invalid token subject' + message = "Forbidden: Invalid token subject" return HttpResponse(message, status=403) return func(request, *args, **kwargs) return validate_asap_wrapper + return validate_asap_decorator @@ -48,7 +50,9 @@ def requires_asap(issuers=None, subject_should_match_issuer=None, func=None): :param list issuers: *required The 'iss' claims that this endpoint is from. """ - return with_asap(func=func, - required=True, - issuers=issuers, - subject_should_match_issuer=subject_should_match_issuer) + return with_asap( + func=func, + required=True, + issuers=issuers, + subject_should_match_issuer=subject_should_match_issuer, + ) diff --git a/atlassian_jwt_auth/contrib/django/middleware.py b/atlassian_jwt_auth/contrib/django/middleware.py index 9612b51..22cea7a 100644 --- a/atlassian_jwt_auth/contrib/django/middleware.py +++ b/atlassian_jwt_auth/contrib/django/middleware.py @@ -1,9 +1,7 @@ from django.conf import settings from django.utils.deprecation import MiddlewareMixin -from atlassian_jwt_auth.frameworks.django.middleware import ( - OldStyleASAPMiddleware -) +from atlassian_jwt_auth.frameworks.django.middleware import OldStyleASAPMiddleware class ProxiedAsapMiddleware(OldStyleASAPMiddleware, MiddlewareMixin): @@ -18,17 +16,17 @@ def __init__(self, get_response=None): # Rely on this header to tell us if a request has been forwarded # from an ASAP-enabled service; will overwrite X-Forwarded-For - self.xfwd = getattr(settings, 'ASAP_PROXIED_FORWARDED_FOR_HEADER', - 'HTTP_X_ASAP_FORWARDED_FOR') + self.xfwd = getattr( + settings, "ASAP_PROXIED_FORWARDED_FOR_HEADER", "HTTP_X_ASAP_FORWARDED_FOR" + ) # This header won't always be set, i.e. some users will be anonymous - self.xauth = getattr(settings, 'ASAP_PROXIED_AUTHORIZATION_HEADER', - 'HTTP_X_ASAP_AUTHORIZATION') + self.xauth = getattr( + settings, "ASAP_PROXIED_AUTHORIZATION_HEADER", "HTTP_X_ASAP_AUTHORIZATION" + ) def process_request(self, request): - error_response = super(ProxiedAsapMiddleware, self).process_request( - request - ) + error_response = super(ProxiedAsapMiddleware, self).process_request(request) if error_response: return error_response @@ -38,26 +36,26 @@ def process_request(self, request): return request.asap_forwarded = True - request.META['HTTP_X_FORWARDED_FOR'] = forwarded_for + request.META["HTTP_X_FORWARDED_FOR"] = forwarded_for - asap_auth = request.META.pop('HTTP_AUTHORIZATION', None) + asap_auth = request.META.pop("HTTP_AUTHORIZATION", None) orig_auth = request.META.pop(self.xauth, None) # Swap original client header in to allow regular auth middleware if orig_auth is not None: - request.META['HTTP_AUTHORIZATION'] = orig_auth + request.META["HTTP_AUTHORIZATION"] = orig_auth if asap_auth is not None: request.META[self.xauth] = asap_auth def process_view(self, request, view_func, view_args, view_kwargs): - if not hasattr(request, 'asap_forwarded'): + if not hasattr(request, "asap_forwarded"): return # swap headers back into place asap_auth = request.META.pop(self.xauth, None) - orig_auth = request.META.pop('HTTP_AUTHORIZATION', None) + orig_auth = request.META.pop("HTTP_AUTHORIZATION", None) if asap_auth is not None: - request.META['HTTP_AUTHORIZATION'] = asap_auth + request.META["HTTP_AUTHORIZATION"] = asap_auth if orig_auth is not None: request.META[self.xauth] = orig_auth diff --git a/atlassian_jwt_auth/contrib/flask_app/__init__.py b/atlassian_jwt_auth/contrib/flask_app/__init__.py index 557e614..062704e 100644 --- a/atlassian_jwt_auth/contrib/flask_app/__init__.py +++ b/atlassian_jwt_auth/contrib/flask_app/__init__.py @@ -2,9 +2,9 @@ from .decorators import requires_asap # noqa - warnings.warn( "The atlassian_jwt_auth.contrib.flask_app package is deprecated in 4.0.0 " "in favour of atlassian_jwt_auth.frameworks.flask.", - DeprecationWarning, stacklevel=2 + DeprecationWarning, + stacklevel=2, ) diff --git a/atlassian_jwt_auth/contrib/flask_app/decorators.py b/atlassian_jwt_auth/contrib/flask_app/decorators.py index b45a64b..5ff6985 100644 --- a/atlassian_jwt_auth/contrib/flask_app/decorators.py +++ b/atlassian_jwt_auth/contrib/flask_app/decorators.py @@ -7,7 +7,9 @@ def requires_asap(f, issuers=None, subject_should_match_issuer=None): access. """ - return with_asap(func=f, - required=True, - issuers=issuers, - subject_should_match_issuer=subject_should_match_issuer) + return with_asap( + func=f, + required=True, + issuers=issuers, + subject_should_match_issuer=subject_should_match_issuer, + ) diff --git a/atlassian_jwt_auth/contrib/requests.py b/atlassian_jwt_auth/contrib/requests.py index 5eafdc5..84cae64 100644 --- a/atlassian_jwt_auth/contrib/requests.py +++ b/atlassian_jwt_auth/contrib/requests.py @@ -1,20 +1,18 @@ from __future__ import absolute_import -from atlassian_jwt_auth.auth import BaseJWTAuth - from requests.auth import AuthBase +from atlassian_jwt_auth.auth import BaseJWTAuth + class JWTAuth(AuthBase, BaseJWTAuth): """Adds a JWT bearer token to the request per the ASAP specification""" def __call__(self, r): - r.headers['Authorization'] = self._get_header_value() + r.headers["Authorization"] = self._get_header_value() return r -def create_jwt_auth( - issuer, key_identifier, private_key_pem, audience, **kwargs): +def create_jwt_auth(issuer, key_identifier, private_key_pem, audience, **kwargs): """Instantiate a JWTAuth while creating the signer inline""" - return JWTAuth.create( - issuer, key_identifier, private_key_pem, audience, **kwargs) + return JWTAuth.create(issuer, key_identifier, private_key_pem, audience, **kwargs) diff --git a/atlassian_jwt_auth/contrib/tests/aiohttp/test_auth.py b/atlassian_jwt_auth/contrib/tests/aiohttp/test_auth.py index 95b64c9..ca1732c 100644 --- a/atlassian_jwt_auth/contrib/tests/aiohttp/test_auth.py +++ b/atlassian_jwt_auth/contrib/tests/aiohttp/test_auth.py @@ -1,28 +1,25 @@ import unittest -from atlassian_jwt_auth.contrib.aiohttp.auth import create_jwt_auth, JWTAuth -from atlassian_jwt_auth.tests import utils +from atlassian_jwt_auth.contrib.aiohttp.auth import JWTAuth, create_jwt_auth from atlassian_jwt_auth.contrib.tests import test_requests +from atlassian_jwt_auth.tests import utils class BaseAuthTest(test_requests.BaseRequestsTest): - """ tests for the contrib.aiohttp.JWTAuth class """ + """tests for the contrib.aiohttp.JWTAuth class""" + auth_cls = JWTAuth def _get_auth_header(self, auth): - return auth.encode().encode('latin1') + return auth.encode().encode("latin1") def create_jwt_auth(self, *args, **kwargs): return create_jwt_auth(*args, **kwargs) -class RequestsRS256Test(BaseAuthTest, - utils.RS256KeyTestMixin, - unittest.TestCase): +class RequestsRS256Test(BaseAuthTest, utils.RS256KeyTestMixin, unittest.TestCase): pass -class RequestsES256Test(BaseAuthTest, - utils.ES256KeyTestMixin, - unittest.TestCase): +class RequestsES256Test(BaseAuthTest, utils.ES256KeyTestMixin, unittest.TestCase): pass diff --git a/atlassian_jwt_auth/contrib/tests/aiohttp/test_public_key_provider.py b/atlassian_jwt_auth/contrib/tests/aiohttp/test_public_key_provider.py index 30633f7..a980093 100644 --- a/atlassian_jwt_auth/contrib/tests/aiohttp/test_public_key_provider.py +++ b/atlassian_jwt_auth/contrib/tests/aiohttp/test_public_key_provider.py @@ -9,7 +9,7 @@ from unittest.mock import AsyncMock as CoroutineMock from unittest.mock import Mock except ImportError: - from asynctest import CoroutineMock, TestCase, Mock + from asynctest import CoroutineMock, Mock, TestCase from atlassian_jwt_auth.contrib.aiohttp import HTTPSPublicKeyRetriever from atlassian_jwt_auth.key import PEM_FILE_TYPE @@ -20,7 +20,6 @@ class DummyHTTPSPublicKeyRetriever(HTTPSPublicKeyRetriever): - def set_headers(self, headers): self._session.get.return_value.headers.update(headers) @@ -29,12 +28,12 @@ def set_text(self, text): def _get_session(self): session = Mock(spec=aiohttp.ClientSession) - session.attach_mock(CoroutineMock(), 'get') + session.attach_mock(CoroutineMock(), "get") resp = session.get.return_value resp.headers = CIMultiDict({"content-type": PEM_FILE_TYPE}) - resp.text = CoroutineMock(return_value='i-am-a-public-key') - resp.raise_for_status = Mock(name='raise_for_status') + resp.text = CoroutineMock(return_value="i-am-a-public-key") + resp.raise_for_status = Mock(name="raise_for_status") return session @@ -44,65 +43,62 @@ class BaseHTTPSPublicKeyRetrieverTestMixin(object): def setUp(self): self._private_key_pem = self.get_new_private_key_in_pem_format() self._public_key_pem = utils.get_public_key_pem_for_private_key_pem( - self._private_key_pem) - self.base_url = 'https://example.com' + self._private_key_pem + ) + self.base_url = "https://example.com" async def test_retrieve(self): """Check if retrieve method returns public key""" retriever = DummyHTTPSPublicKeyRetriever(self.base_url) retriever.set_text(self._public_key_pem) - self.assertEqual( - await retriever.retrieve('example/eg'), - self._public_key_pem) + self.assertEqual(await retriever.retrieve("example/eg"), self._public_key_pem) async def test_retrieve_with_charset_in_content_type_h(self): """Check if retrieve method correctly checks content-type""" - headers = {'content-type': 'application/x-pem-file;charset=UTF-8'} + headers = {"content-type": "application/x-pem-file;charset=UTF-8"} retriever = DummyHTTPSPublicKeyRetriever(self.base_url) retriever.set_text(self._public_key_pem) retriever.set_headers(headers) - self.assertEqual( - await retriever.retrieve('example/eg'), - self._public_key_pem) + self.assertEqual(await retriever.retrieve("example/eg"), self._public_key_pem) async def test_retrieve_fails_with_different_content_type(self): """ Check if retrieve method raises an error for incorrect content-type """ - headers = {'content-type': 'different/not-supported'} + headers = {"content-type": "different/not-supported"} retriever = DummyHTTPSPublicKeyRetriever(self.base_url) retriever.set_text(self._public_key_pem) retriever.set_headers(headers) with self.assertRaises(ValueError): - await retriever.retrieve('example/eg') + await retriever.retrieve("example/eg") async def test_retrieve_session_uses_env_proxy(self): - """ tests that the underlying session makes use of environmental - proxy configured. + """tests that the underlying session makes use of environmental + proxy configured. """ - proxy_location = 'https://example.proxy' - key_id = 'example/eg' - expected_proxies, proxy_dict = get_expected_and_os_proxies_dict( - proxy_location) + proxy_location = "https://example.proxy" + key_id = "example/eg" + expected_proxies, proxy_dict = get_expected_and_os_proxies_dict(proxy_location) with mock.patch.dict(os.environ, proxy_dict, clear=True): retriever = DummyHTTPSPublicKeyRetriever(self.base_url) self.assertEqual(retriever._proxies, expected_proxies) await retriever.retrieve(key_id) retriever._session.get.assert_called_once_with( - f'{self.base_url}/{key_id}', headers={'accept': PEM_FILE_TYPE}, - proxy=expected_proxies[self.base_url.split(':')[0]] + f"{self.base_url}/{key_id}", + headers={"accept": PEM_FILE_TYPE}, + proxy=expected_proxies[self.base_url.split(":")[0]], ) -class RS256HTTPSPublicKeyRetrieverTest(utils.RS256KeyTestMixin, - BaseHTTPSPublicKeyRetrieverTestMixin, - TestCase): +class RS256HTTPSPublicKeyRetrieverTest( + utils.RS256KeyTestMixin, BaseHTTPSPublicKeyRetrieverTestMixin, TestCase +): """Tests for aiohttp.HTTPSPublicKeyRetriever class for RS256 algorithm""" -class ES256HTTPSPublicKeyRetrieverTest(utils.RS256KeyTestMixin, - BaseHTTPSPublicKeyRetrieverTestMixin, - TestCase): +class ES256HTTPSPublicKeyRetrieverTest( + utils.RS256KeyTestMixin, BaseHTTPSPublicKeyRetrieverTestMixin, TestCase +): """Tests for aiohttp.HTTPSPublicKeyRetriever class for ES256 algorithm""" diff --git a/atlassian_jwt_auth/contrib/tests/aiohttp/test_verifier.py b/atlassian_jwt_auth/contrib/tests/aiohttp/test_verifier.py index 9dc4187..2ccfbf1 100644 --- a/atlassian_jwt_auth/contrib/tests/aiohttp/test_verifier.py +++ b/atlassian_jwt_auth/contrib/tests/aiohttp/test_verifier.py @@ -4,15 +4,13 @@ from unittest import IsolatedAsyncioTestCase as TestCase from unittest.mock import AsyncMock as CoroutineMock except ImportError: - from asynctest import TestCase, CoroutineMock + from asynctest import CoroutineMock, TestCase -from atlassian_jwt_auth.contrib.aiohttp import (HTTPSPublicKeyRetriever, - JWTAuthVerifier) +from atlassian_jwt_auth.contrib.aiohttp import HTTPSPublicKeyRetriever, JWTAuthVerifier from atlassian_jwt_auth.tests import test_verifier, utils class SyncJWTAuthVerifier(JWTAuthVerifier): - def __init__(self, *args, loop=None, **kwargs): if loop is None: loop = asyncio.get_event_loop() @@ -20,9 +18,7 @@ def __init__(self, *args, loop=None, **kwargs): super().__init__(*args, **kwargs) def verify_jwt(self, *args, **kwargs): - return self.loop.run_until_complete( - super().verify_jwt(*args, **kwargs) - ) + return self.loop.run_until_complete(super().verify_jwt(*args, **kwargs)) class JWTAuthVerifierTestMixin(test_verifier.BaseJWTAuthVerifierTest): @@ -39,10 +35,12 @@ def _setup_jwt_auth_verifier(self, pub_key_pem, **kwargs): class JWTAuthVerifierRS256Test( - utils.RS256KeyTestMixin, JWTAuthVerifierTestMixin, TestCase): + utils.RS256KeyTestMixin, JWTAuthVerifierTestMixin, TestCase +): """Tests for aiohttp.JWTAuthVerifier class for RS256 algorithm""" class JWTAuthVerifierES256Test( - utils.ES256KeyTestMixin, JWTAuthVerifierTestMixin, TestCase): + utils.ES256KeyTestMixin, JWTAuthVerifierTestMixin, TestCase +): """Tests for aiohttp.JWTAuthVerifier class for ES256 algorithm""" diff --git a/atlassian_jwt_auth/contrib/tests/test_requests.py b/atlassian_jwt_auth/contrib/tests/test_requests.py index f264945..b74c614 100644 --- a/atlassian_jwt_auth/contrib/tests/test_requests.py +++ b/atlassian_jwt_auth/contrib/tests/test_requests.py @@ -5,35 +5,40 @@ from requests import Request import atlassian_jwt_auth -from atlassian_jwt_auth.tests import utils from atlassian_jwt_auth.contrib.requests import JWTAuth, create_jwt_auth +from atlassian_jwt_auth.tests import utils class BaseRequestsTest(object): + """tests for the contrib.requests.JWTAuth class""" - """ tests for the contrib.requests.JWTAuth class """ auth_cls = JWTAuth def setUp(self): self._private_key_pem = self.get_new_private_key_in_pem_format() self._public_key_pem = utils.get_public_key_pem_for_private_key_pem( - self._private_key_pem) + self._private_key_pem + ) def assert_authorization_header_is_valid(self, auth): - """ asserts that the given request contains a valid Authorization - header. + """asserts that the given request contains a valid Authorization + header. """ auth_header = self._get_auth_header(auth) - bearer = auth_header.split(b' ')[1] + bearer = auth_header.split(b" ")[1] # Decode the JWT (verifying the signature and aud match) # an exception is thrown if this fails algorithms = atlassian_jwt_auth.get_permitted_algorithm_names() - return jwt.decode(bearer, self._public_key_pem.decode(), - audience='audience', algorithms=algorithms) + return jwt.decode( + bearer, + self._public_key_pem.decode(), + audience="audience", + algorithms=algorithms, + ) def _get_auth_header(self, auth): request = auth(Request()) - auth_header = request.headers['Authorization'] + auth_header = request.headers["Authorization"] return auth_header def create_jwt_auth(self, *args, **kwargs): @@ -42,108 +47,150 @@ def create_jwt_auth(self, *args, **kwargs): def test_JWTAuth_make_authenticated_request(self): """Verify a valid Authorization header is added by JWTAuth""" jwt_auth_signer = atlassian_jwt_auth.create_signer( - 'issuer', - 'issuer/key', + "issuer", + "issuer/key", self._private_key_pem.decode(), - algorithm=self.algorithm) - auth = self.auth_cls(jwt_auth_signer, 'audience') + algorithm=self.algorithm, + ) + auth = self.auth_cls(jwt_auth_signer, "audience") self.assert_authorization_header_is_valid(auth) def test_create_jwt_auth(self): """Verify a valid Authorization header is added by JWTAuth""" - auth = self.create_jwt_auth('issuer', 'issuer/key', - self._private_key_pem.decode(), 'audience', - algorithm=self.algorithm) + auth = self.create_jwt_auth( + "issuer", + "issuer/key", + self._private_key_pem.decode(), + "audience", + algorithm=self.algorithm, + ) self.assert_authorization_header_is_valid(auth) def test_create_jwt_auth_with_additional_claims(self): - """ Verify a Valid Authorization header is added by JWTAuth and - contains the additional claims when provided. + """Verify a Valid Authorization header is added by JWTAuth and + contains the additional claims when provided. """ jwt_auth_signer = atlassian_jwt_auth.create_signer( - 'issuer', - 'issuer/key', + "issuer", + "issuer/key", self._private_key_pem.decode(), - algorithm=self.algorithm) - auth = self.auth_cls(jwt_auth_signer, 'audience', - additional_claims={'example': 'claim'}) + algorithm=self.algorithm, + ) + auth = self.auth_cls( + jwt_auth_signer, "audience", additional_claims={"example": "claim"} + ) token = self.assert_authorization_header_is_valid(auth) - self.assertEqual(token.get('example'), 'claim') + self.assertEqual(token.get("example"), "claim") def test_do_not_reuse_jwts(self): - auth = self.create_jwt_auth('issuer', 'issuer/key', - self._private_key_pem.decode(), 'audience', - algorithm=self.algorithm) + auth = self.create_jwt_auth( + "issuer", + "issuer/key", + self._private_key_pem.decode(), + "audience", + algorithm=self.algorithm, + ) auth_header = self._get_auth_header(auth) self.assertNotEqual(auth_header, self._get_auth_header(auth)) def test_reuse_jwts(self): - auth = self.create_jwt_auth('issuer', 'issuer/key', - self._private_key_pem.decode(), 'audience', - algorithm=self.algorithm, reuse_jwts=True) + auth = self.create_jwt_auth( + "issuer", + "issuer/key", + self._private_key_pem.decode(), + "audience", + algorithm=self.algorithm, + reuse_jwts=True, + ) auth_header = self._get_auth_header(auth) self.assertEqual(auth_header, self._get_auth_header(auth)) def test_do_not_reuse_jwt_if_audience_changes(self): - auth = self.create_jwt_auth('issuer', 'issuer/key', - self._private_key_pem.decode(), 'audience', - algorithm=self.algorithm, reuse_jwts=True) + auth = self.create_jwt_auth( + "issuer", + "issuer/key", + self._private_key_pem.decode(), + "audience", + algorithm=self.algorithm, + reuse_jwts=True, + ) auth_header = self._get_auth_header(auth) - auth._audience = 'not-' + auth._audience + auth._audience = "not-" + auth._audience self.assertNotEqual(auth_header, self._get_auth_header(auth)) def test_do_not_reuse_jwt_if_issuer_changes(self): - auth = self.create_jwt_auth('issuer', 'issuer/key', - self._private_key_pem.decode(), 'audience', - algorithm=self.algorithm, reuse_jwts=True) + auth = self.create_jwt_auth( + "issuer", + "issuer/key", + self._private_key_pem.decode(), + "audience", + algorithm=self.algorithm, + reuse_jwts=True, + ) auth_header = self._get_auth_header(auth) - auth._signer.issuer = 'not-' + auth._signer.issuer + auth._signer.issuer = "not-" + auth._signer.issuer self.assertNotEqual(auth_header, self._get_auth_header(auth)) def test_do_not_reuse_jwt_if_lifetime_changes(self): - auth = self.create_jwt_auth('issuer', 'issuer/key', - self._private_key_pem.decode(), 'audience', - algorithm=self.algorithm, reuse_jwts=True) + auth = self.create_jwt_auth( + "issuer", + "issuer/key", + self._private_key_pem.decode(), + "audience", + algorithm=self.algorithm, + reuse_jwts=True, + ) auth_header = self._get_auth_header(auth) auth._signer.lifetime = auth._signer.lifetime - timedelta(seconds=1) self.assertNotEqual(auth_header, self._get_auth_header(auth)) def test_do_not_reuse_jwt_if_subject_changes(self): - auth = self.create_jwt_auth('issuer', 'issuer/key', - self._private_key_pem.decode(), 'audience', - algorithm=self.algorithm, reuse_jwts=True, - subject='subject') + auth = self.create_jwt_auth( + "issuer", + "issuer/key", + self._private_key_pem.decode(), + "audience", + algorithm=self.algorithm, + reuse_jwts=True, + subject="subject", + ) auth_header = self._get_auth_header(auth) - auth._signer.subject = 'not-' + auth._signer.subject + auth._signer.subject = "not-" + auth._signer.subject self.assertNotEqual(auth_header, self._get_auth_header(auth)) def test_do_not_reuse_jwt_if_additional_claims_change(self): - auth = self.create_jwt_auth('issuer', 'issuer/key', - self._private_key_pem.decode(), 'audience', - algorithm=self.algorithm, reuse_jwts=True) + auth = self.create_jwt_auth( + "issuer", + "issuer/key", + self._private_key_pem.decode(), + "audience", + algorithm=self.algorithm, + reuse_jwts=True, + ) auth_header = self._get_auth_header(auth) - auth._additional_claims['foo'] = 'bar' + auth._additional_claims["foo"] = "bar" self.assertNotEqual(auth_header, self._get_auth_header(auth)) def test_reuse_jwt_with_additional_claims(self): # calculating the cache key with additional claims is non-trivial - auth = self.create_jwt_auth('issuer', 'issuer/key', - self._private_key_pem.decode(), 'audience', - algorithm=self.algorithm, reuse_jwts=True) - auth._additional_claims['foo'] = 'bar' - auth._additional_claims['fool'] = 'blah' - auth._additional_claims['foot'] = 'quux' + auth = self.create_jwt_auth( + "issuer", + "issuer/key", + self._private_key_pem.decode(), + "audience", + algorithm=self.algorithm, + reuse_jwts=True, + ) + auth._additional_claims["foo"] = "bar" + auth._additional_claims["fool"] = "blah" + auth._additional_claims["foot"] = "quux" auth_header = self._get_auth_header(auth) self.assertEqual(auth_header, self._get_auth_header(auth)) -class RequestsRS256Test(BaseRequestsTest, - utils.RS256KeyTestMixin, - unittest.TestCase): +class RequestsRS256Test(BaseRequestsTest, utils.RS256KeyTestMixin, unittest.TestCase): pass -class RequestsES256Test(BaseRequestsTest, - utils.ES256KeyTestMixin, - unittest.TestCase): +class RequestsES256Test(BaseRequestsTest, utils.ES256KeyTestMixin, unittest.TestCase): pass diff --git a/atlassian_jwt_auth/contrib/tests/utils.py b/atlassian_jwt_auth/contrib/tests/utils.py index 6980147..cb57ef9 100644 --- a/atlassian_jwt_auth/contrib/tests/utils.py +++ b/atlassian_jwt_auth/contrib/tests/utils.py @@ -2,10 +2,9 @@ def get_static_retriever_class(keys): - class StaticPublicKeyRetriever(object): - """ Retrieves a key from a static list of public keys - (for use in tests only) """ + """Retrieves a key from a static list of public keys + (for use in tests only)""" def __init__(self, *args, **kwargs): self.keys = keys @@ -17,6 +16,4 @@ def retrieve(self, key_identifier, **requests_kwargs): def static_verifier(keys): - return atlassian_jwt_auth.JWTAuthVerifier( - get_static_retriever_class(keys)() - ) + return atlassian_jwt_auth.JWTAuthVerifier(get_static_retriever_class(keys)()) diff --git a/atlassian_jwt_auth/exceptions.py b/atlassian_jwt_auth/exceptions.py index c90299f..5fdcde8 100644 --- a/atlassian_jwt_auth/exceptions.py +++ b/atlassian_jwt_auth/exceptions.py @@ -13,10 +13,8 @@ def __init__(self, *args, **kwargs): if args: orig = args[0] if isinstance(orig, Exception): - wrapped_args[0] = str(orig) - self.original_exception = getattr(orig, 'original_exception', - orig) + self.original_exception = getattr(orig, "original_exception", orig) super(_WrappedException, self).__init__(*wrapped_args, **kwargs) @@ -28,7 +26,7 @@ class _WithStatus(object): """ def __init__(self, *args, **kwargs): - status_code = kwargs.pop('status_code', None) + status_code = kwargs.pop("status_code", None) super(_WithStatus, self).__init__(*args, **kwargs) self.status_code = status_code @@ -54,13 +52,14 @@ class KeyIdentifierException(ASAPAuthenticationException): class JtiUniquenessException(ASAPAuthenticationException): - """Raise when a JTI is seen more than once. """ + """Raise when a JTI is seen more than once.""" class SubjectDoesNotMatchIssuerException(ASAPAuthenticationException): - """Raise when the subject and issuer differ. """ + """Raise when the subject and issuer differ.""" class NoTokenProvidedError(ASAPAuthenticationException): """Raise when no token is provided""" + pass diff --git a/atlassian_jwt_auth/frameworks/common/asap.py b/atlassian_jwt_auth/frameworks/common/asap.py index 7afa7e6..4af1bae 100644 --- a/atlassian_jwt_auth/frameworks/common/asap.py +++ b/atlassian_jwt_auth/frameworks/common/asap.py @@ -3,21 +3,24 @@ from jwt.exceptions import InvalidIssuerError, InvalidTokenError from atlassian_jwt_auth.exceptions import ( - PublicKeyRetrieverException, - NoTokenProvidedError, JtiUniquenessException, + NoTokenProvidedError, + PublicKeyRetrieverException, SubjectDoesNotMatchIssuerException, ) def _process_asap_token(request, backend, settings, verifier=None): - """ Verifies an ASAP token, validates the claims, and returns an error + """Verifies an ASAP token, validates the claims, and returns an error response""" - logger = logging.getLogger('asap') + logger = logging.getLogger("asap") token = backend.get_asap_token(request) error_response = None - if token is None and not settings.ASAP_REQUIRED and ( - settings.ASAP_REQUIRED is not None): + if ( + token is None + and not settings.ASAP_REQUIRED + and (settings.ASAP_REQUIRED is not None) + ): return try: if token is None: @@ -33,10 +36,8 @@ def _process_asap_token(request, backend, settings, verifier=None): _verify_issuers(asap_claims, settings.ASAP_VALID_ISSUERS) backend.set_asap_claims_for_request(request, asap_claims) except NoTokenProvidedError: - logger.info('No token provided') - error_response = backend.get_401_response( - 'Unauthorized', request=request - ) + logger.info("No token provided") + error_response = backend.get_401_response("Unauthorized", request=request) except PublicKeyRetrieverException as e: if e.status_code not in (403, 404): # Any error other than "not found" is a problem and should @@ -46,32 +47,32 @@ def _process_asap_token(request, backend, settings, verifier=None): # will return 403 for a missing file to avoid leaking # information. raise - logger.warning('Could not retrieve the matching public key') + logger.warning("Could not retrieve the matching public key") error_response = backend.get_401_response( - 'Unauthorized: Key not found', request=request + "Unauthorized: Key not found", request=request ) except InvalidIssuerError: - logger.warning('Invalid token - issuer') + logger.warning("Invalid token - issuer") error_response = backend.get_403_response( - 'Forbidden: Invalid token issuer', request=request + "Forbidden: Invalid token issuer", request=request ) except InvalidTokenError: - logger.warning('Invalid token') + logger.warning("Invalid token") error_response = backend.get_401_response( - 'Unauthorized: Invalid token', request=request + "Unauthorized: Invalid token", request=request ) except JtiUniquenessException: - logger.warning('Invalid token - duplicate jti') + logger.warning("Invalid token - duplicate jti") error_response = backend.get_401_response( - 'Unauthorized: Invalid token - duplicate jti', request=request + "Unauthorized: Invalid token - duplicate jti", request=request ) except SubjectDoesNotMatchIssuerException: - logger.warning('Invalid token - subject and issuer do not match') + logger.warning("Invalid token - subject and issuer do not match") error_response = backend.get_401_response( - 'Unauthorized: Subject and Issuer do not match', request=request + "Unauthorized: Subject and Issuer do not match", request=request ) except Exception: - logger.exception('An error occured while checking an asap token') + logger.exception("An error occured while checking an asap token") raise if error_response is not None and settings.ASAP_REQUIRED: @@ -80,6 +81,6 @@ def _process_asap_token(request, backend, settings, verifier=None): def _verify_issuers(asap_claims, issuers=None): """Verify that the issuer in the claims is valid and is expected.""" - claim_iss = asap_claims.get('iss') + claim_iss = asap_claims.get("iss") if issuers and claim_iss not in issuers: raise InvalidIssuerError diff --git a/atlassian_jwt_auth/frameworks/common/backend.py b/atlassian_jwt_auth/frameworks/common/backend.py index 6ee6ce2..3178bd4 100644 --- a/atlassian_jwt_auth/frameworks/common/backend.py +++ b/atlassian_jwt_auth/frameworks/common/backend.py @@ -8,61 +8,53 @@ @lru_cache(maxsize=20) def _get_verifier(settings): - """ This has been extracted out of Backend to avoid possible memory - leaks via retained instance references. + """This has been extracted out of Backend to avoid possible memory + leaks via retained instance references. """ retriever = settings.ASAP_KEY_RETRIEVER_CLASS( base_url=settings.ASAP_PUBLICKEY_REPOSITORY ) kwargs = {} if settings.ASAP_SUBJECT_SHOULD_MATCH_ISSUER is not None: - kwargs = {'subject_should_match_issuer': - settings.ASAP_SUBJECT_SHOULD_MATCH_ISSUER} + kwargs = { + "subject_should_match_issuer": settings.ASAP_SUBJECT_SHOULD_MATCH_ISSUER + } if settings.ASAP_CHECK_JTI_UNIQUENESS is not None: - kwargs['check_jti_uniqueness'] = settings.ASAP_CHECK_JTI_UNIQUENESS - return JWTAuthVerifier( - retriever, - **kwargs - ) + kwargs["check_jti_uniqueness"] = settings.ASAP_CHECK_JTI_UNIQUENESS + return JWTAuthVerifier(retriever, **kwargs) -class Backend(): +class Backend: """Abstract class representing a web framework backend Backends allow specific implementation details of web frameworks to be abstracted away from the underlying logic of ASAP. """ + __metaclass__ = ABCMeta - default_headers_401 = {'WWW-Authenticate': 'Bearer'} + default_headers_401 = {"WWW-Authenticate": "Bearer"} default_settings = { # The class to be instantiated to retrieve public keys - 'ASAP_KEY_RETRIEVER_CLASS': HTTPSPublicKeyRetriever, - + "ASAP_KEY_RETRIEVER_CLASS": HTTPSPublicKeyRetriever, # The repository URL where the key retriever can fetch public keys - 'ASAP_PUBLICKEY_REPOSITORY': None, - + "ASAP_PUBLICKEY_REPOSITORY": None, # Whether or not ASAP authentication is required # This is primarily useful when phasing in ASAP authentication - 'ASAP_REQUIRED': True, - + "ASAP_REQUIRED": True, # The valid audience value expected when authenticating tokens - 'ASAP_VALID_AUDIENCE': None, - + "ASAP_VALID_AUDIENCE": None, # The amount of leeway to apply when evaluating token expiration # timestamps - 'ASAP_VALID_LEEWAY': 0, - + "ASAP_VALID_LEEWAY": 0, # An iterable of valid token issuers allowed to authenticate # (this can be overridden at the decorator level) - 'ASAP_VALID_ISSUERS': None, - + "ASAP_VALID_ISSUERS": None, # Enforce that the ASAP subject must match the issuer - 'ASAP_SUBJECT_SHOULD_MATCH_ISSUER': None, - + "ASAP_SUBJECT_SHOULD_MATCH_ISSUER": None, # Enforce that tokens have a unique JTI # Set this to True to enforce JTI uniqueness checking. - 'ASAP_CHECK_JTI_UNIQUENESS': None, + "ASAP_CHECK_JTI_UNIQUENESS": None, } @abstractmethod @@ -97,10 +89,10 @@ def get_asap_token(self, request): # headers, but some libraries allow sending bytes (Django tests) # and some (requests) always send str so we need to convert if # that is the case to properly support Python 3. - auth_header = auth_header.encode(encoding='iso-8859-1') + auth_header = auth_header.encode(encoding="iso-8859-1") - auth_values = auth_header.split(b' ') - if len(auth_values) != 2 or auth_values[0].lower() != b'bearer': + auth_values = auth_header.split(b" ") + if len(auth_values) != 2 or auth_values[0].lower() != b"bearer": return None return auth_values[1] @@ -115,12 +107,12 @@ def _get_verifier(self, settings): return _get_verifier(settings) def _process_settings(self, settings): - valid_issuers = settings.get('ASAP_VALID_ISSUERS') + valid_issuers = settings.get("ASAP_VALID_ISSUERS") if valid_issuers: - settings['ASAP_VALID_ISSUERS'] = set(valid_issuers) + settings["ASAP_VALID_ISSUERS"] = set(valid_issuers) - valid_aud = settings.get('ASAP_VALID_AUDIENCE') + valid_aud = settings.get("ASAP_VALID_AUDIENCE") if valid_aud and isinstance(valid_aud, list): - settings['ASAP_VALID_AUDIENCE'] = set(valid_aud) + settings["ASAP_VALID_AUDIENCE"] = set(valid_aud) return SettingsDict(settings) diff --git a/atlassian_jwt_auth/frameworks/common/decorators.py b/atlassian_jwt_auth/frameworks/common/decorators.py index e95405e..ef1d362 100644 --- a/atlassian_jwt_auth/frameworks/common/decorators.py +++ b/atlassian_jwt_auth/frameworks/common/decorators.py @@ -1,33 +1,36 @@ from functools import wraps + from jwt.exceptions import InvalidIssuerError, InvalidTokenError from .asap import _process_asap_token, _verify_issuers from .utils import SettingsDict -def _with_asap(func=None, backend=None, issuers=None, required=True, - subject_should_match_issuer=None): +def _with_asap( + func=None, + backend=None, + issuers=None, + required=True, + subject_should_match_issuer=None, +): if backend is None: - raise ValueError( - 'Invalid value for backend. Use a subclass instead.' - ) + raise ValueError("Invalid value for backend. Use a subclass instead.") def with_asap_decorator(func): @wraps(func) def with_asap_wrapper(*args, **kwargs): settings = _update_settings_from_kwargs( backend.settings, - issuers=issuers, required=required, - subject_should_match_issuer=subject_should_match_issuer + issuers=issuers, + required=required, + subject_should_match_issuer=subject_should_match_issuer, ) request = None if len(args) > 0: request = args[0] - error_response = _process_asap_token( - request, backend, settings - ) + error_response = _process_asap_token(request, backend, settings) if error_response is not None: return error_response @@ -42,8 +45,13 @@ def with_asap_wrapper(*args, **kwargs): return with_asap_decorator -def _restrict_asap(func=None, backend=None, issuers=None, - required=True, subject_should_match_issuer=None): +def _restrict_asap( + func=None, + backend=None, + issuers=None, + required=True, + subject_should_match_issuer=None, +): """Decorator to allow endpoint-specific ASAP authorization, assuming ASAP authentication has already occurred. """ @@ -53,26 +61,25 @@ def restrict_asap_decorator(func): def restrict_asap_wrapper(request, *args, **kwargs): settings = _update_settings_from_kwargs( backend.settings, - issuers=issuers, required=required, - subject_should_match_issuer=subject_should_match_issuer + issuers=issuers, + required=required, + subject_should_match_issuer=subject_should_match_issuer, ) - asap_claims = getattr(request, 'asap_claims', None) + asap_claims = getattr(request, "asap_claims", None) error_response = None if required and not asap_claims: - return backend.get_401_response( - 'Unauthorized', request=request - ) + return backend.get_401_response("Unauthorized", request=request) try: _verify_issuers(asap_claims, settings.ASAP_VALID_ISSUERS) except InvalidIssuerError: error_response = backend.get_403_response( - 'Forbidden: Invalid token issuer', request=request + "Forbidden: Invalid token issuer", request=request ) except InvalidTokenError: error_response = backend.get_401_response( - 'Unauthorized: Invalid token', request=request + "Unauthorized: Invalid token", request=request ) if error_response and required: @@ -88,19 +95,18 @@ def restrict_asap_wrapper(request, *args, **kwargs): return restrict_asap_decorator -def _update_settings_from_kwargs(settings, issuers=None, required=True, - subject_should_match_issuer=None): +def _update_settings_from_kwargs( + settings, issuers=None, required=True, subject_should_match_issuer=None +): settings = settings.copy() if issuers is not None: - settings['ASAP_VALID_ISSUERS'] = set(issuers) + settings["ASAP_VALID_ISSUERS"] = set(issuers) if required is not None: - settings['ASAP_REQUIRED'] = required + settings["ASAP_REQUIRED"] = required if subject_should_match_issuer is not None: - settings['ASAP_SUBJECT_SHOULD_MATCH_ISSUER'] = ( - subject_should_match_issuer - ) + settings["ASAP_SUBJECT_SHOULD_MATCH_ISSUER"] = subject_should_match_issuer return SettingsDict(settings) diff --git a/atlassian_jwt_auth/frameworks/common/tests/test_utils.py b/atlassian_jwt_auth/frameworks/common/tests/test_utils.py index 0c6d3a0..72ef7fd 100644 --- a/atlassian_jwt_auth/frameworks/common/tests/test_utils.py +++ b/atlassian_jwt_auth/frameworks/common/tests/test_utils.py @@ -4,13 +4,13 @@ class SettingsDictTest(unittest.TestCase): - """ Tests for the SettingsDict class. """ + """Tests for the SettingsDict class.""" def test_hash(self): - """ Test that SettingsDict instances can be hashed. """ - dictionary_one = {'a': 'b', '3': set([1]), 'f': None} - dictionary_two = {'a': 'b', '3': set([1]), 'f': None} - dictionary_three = {'a': 'b', '3': set([1]), 'diff': '333'} + """Test that SettingsDict instances can be hashed.""" + dictionary_one = {"a": "b", "3": set([1]), "f": None} + dictionary_two = {"a": "b", "3": set([1]), "f": None} + dictionary_three = {"a": "b", "3": set([1]), "diff": "333"} settings_one = utils.SettingsDict(dictionary_one) settings_two = utils.SettingsDict(dictionary_two) settings_three = utils.SettingsDict(dictionary_three) diff --git a/atlassian_jwt_auth/frameworks/common/utils.py b/atlassian_jwt_auth/frameworks/common/utils.py index fe46684..d2ef74f 100644 --- a/atlassian_jwt_auth/frameworks/common/utils.py +++ b/atlassian_jwt_auth/frameworks/common/utils.py @@ -6,7 +6,7 @@ def __getattr__(self, name): return self[name] def __setitem__(self, key, value): - raise AttributeError('SettingsDict properties are immutable') + raise AttributeError("SettingsDict properties are immutable") def _hash_key(self): keys_and_values = [] diff --git a/atlassian_jwt_auth/frameworks/django/backend.py b/atlassian_jwt_auth/frameworks/django/backend.py index 3f3e5fd..de0b4c6 100644 --- a/atlassian_jwt_auth/frameworks/django/backend.py +++ b/atlassian_jwt_auth/frameworks/django/backend.py @@ -7,9 +7,9 @@ class DjangoBackend(Backend): def get_authorization_header(self, request=None): if request is None: - raise ValueError('No request available') + raise ValueError("No request available") - return request.META.get('HTTP_AUTHORIZATION', b'') + return request.META.get("HTTP_AUTHORIZATION", b"") def get_401_response(self, data=None, headers=None, request=None): if headers is None: diff --git a/atlassian_jwt_auth/frameworks/django/decorators.py b/atlassian_jwt_auth/frameworks/django/decorators.py index 8ce7bca..852c284 100644 --- a/atlassian_jwt_auth/frameworks/django/decorators.py +++ b/atlassian_jwt_auth/frameworks/django/decorators.py @@ -1,9 +1,8 @@ -from ..common.decorators import _with_asap, _restrict_asap +from ..common.decorators import _restrict_asap, _with_asap from .backend import DjangoBackend -def with_asap(func=None, issuers=None, required=None, - subject_should_match_issuer=None): +def with_asap(func=None, issuers=None, required=None, subject_should_match_issuer=None): """Decorator to allow endpoint-specific ASAP authentication. If authentication fails, a 401 or 403 response will be returned. Otherwise, @@ -21,13 +20,17 @@ def with_asap(func=None, issuers=None, required=None, token to be considered valid. """ return _with_asap( - func, DjangoBackend(), issuers, required, - subject_should_match_issuer + func, DjangoBackend(), issuers, required, subject_should_match_issuer ) -def restrict_asap(func=None, backend=None, issuers=None, - required=True, subject_should_match_issuer=None): +def restrict_asap( + func=None, + backend=None, + issuers=None, + required=True, + subject_should_match_issuer=None, +): """Decorator to allow endpoint-specific ASAP authorization policies. This decorator assumes that request.asap_claims has previously been set by @@ -46,6 +49,5 @@ def restrict_asap(func=None, backend=None, issuers=None, token to be considered valid. """ return _restrict_asap( - func, DjangoBackend(), issuers, required, - subject_should_match_issuer=None + func, DjangoBackend(), issuers, required, subject_should_match_issuer=None ) diff --git a/atlassian_jwt_auth/frameworks/django/middleware.py b/atlassian_jwt_auth/frameworks/django/middleware.py index 91cf996..319deaf 100644 --- a/atlassian_jwt_auth/frameworks/django/middleware.py +++ b/atlassian_jwt_auth/frameworks/django/middleware.py @@ -9,8 +9,9 @@ def asap_middleware(get_response): _verifier = backend.get_verifier(settings=settings) def middleware(request): - error_response = _process_asap_token(request, backend, settings, - verifier=_verifier) + error_response = _process_asap_token( + request, backend, settings, verifier=_verifier + ) if error_response is not None: return error_response diff --git a/atlassian_jwt_auth/frameworks/django/tests/settings.py b/atlassian_jwt_auth/frameworks/django/tests/settings.py index c1e381b..256e9bb 100644 --- a/atlassian_jwt_auth/frameworks/django/tests/settings.py +++ b/atlassian_jwt_auth/frameworks/django/tests/settings.py @@ -8,49 +8,48 @@ # See https://docs.djangoproject.com/en/3.2/howto/deployment/checklist/ # SECURITY WARNING: keep the secret key used in production secret! -SECRET_KEY = 'django-insecure-5i@^w(cnsaqrx*3co@!&wd' \ - 'vgp4wflkgw$qt#9j@e#tyxg!wdzd' +SECRET_KEY = "django-insecure-5i@^w(cnsaqrx*3co@!&wdvgp4wflkgw$qt#9j@e#tyxg!wdzd" # SECURITY WARNING: don't run with debug turned on in production! DEBUG = True -ALLOWED_HOSTS = ['*'] +ALLOWED_HOSTS = ["*"] # Application definition INSTALLED_APPS = [ - 'django.contrib.admin', - 'django.contrib.auth', - 'django.contrib.contenttypes', - 'django.contrib.sessions', - 'django.contrib.messages', - 'django.contrib.staticfiles', + "django.contrib.admin", + "django.contrib.auth", + "django.contrib.contenttypes", + "django.contrib.sessions", + "django.contrib.messages", + "django.contrib.staticfiles", ] MIDDLEWARE = [ - 'django.middleware.security.SecurityMiddleware', - 'django.contrib.sessions.middleware.SessionMiddleware', - 'django.middleware.common.CommonMiddleware', - 'django.middleware.csrf.CsrfViewMiddleware', - 'django.contrib.auth.middleware.AuthenticationMiddleware', - 'django.contrib.messages.middleware.MessageMiddleware', - 'django.middleware.clickjacking.XFrameOptionsMiddleware', + "django.middleware.security.SecurityMiddleware", + "django.contrib.sessions.middleware.SessionMiddleware", + "django.middleware.common.CommonMiddleware", + "django.middleware.csrf.CsrfViewMiddleware", + "django.contrib.auth.middleware.AuthenticationMiddleware", + "django.contrib.messages.middleware.MessageMiddleware", + "django.middleware.clickjacking.XFrameOptionsMiddleware", ] -ROOT_URLCONF = 'atlassian_jwt_auth.frameworks.django.tests.urls' +ROOT_URLCONF = "atlassian_jwt_auth.frameworks.django.tests.urls" TEMPLATES = [ { - 'BACKEND': 'django.template.backends.django.DjangoTemplates', - 'DIRS': [], - 'APP_DIRS': True, - 'OPTIONS': { - 'context_processors': [ - 'django.template.context_processors.debug', - 'django.template.context_processors.request', - 'django.contrib.auth.context_processors.auth', - 'django.contrib.messages.context_processors.messages', + "BACKEND": "django.template.backends.django.DjangoTemplates", + "DIRS": [], + "APP_DIRS": True, + "OPTIONS": { + "context_processors": [ + "django.template.context_processors.debug", + "django.template.context_processors.request", + "django.contrib.auth.context_processors.auth", + "django.contrib.messages.context_processors.messages", ], }, }, @@ -60,9 +59,9 @@ # https://docs.djangoproject.com/en/3.2/ref/settings/#databases DATABASES = { - 'default': { - 'ENGINE': 'django.db.backends.sqlite3', - 'NAME': None, + "default": { + "ENGINE": "django.db.backends.sqlite3", + "NAME": None, } } @@ -72,20 +71,17 @@ AUTH_PASSWORD_VALIDATORS = [ { - 'NAME': 'django.contrib.auth.password_validation.' - 'UserAttributeSimilarityValidator', + "NAME": "django.contrib.auth.password_validation." + "UserAttributeSimilarityValidator", }, { - 'NAME': 'django.contrib.auth.password_validation.' - 'MinimumLengthValidator', + "NAME": "django.contrib.auth.password_validation.MinimumLengthValidator", }, { - 'NAME': 'django.contrib.auth.password_validation.' - 'CommonPasswordValidator', + "NAME": "django.contrib.auth.password_validation.CommonPasswordValidator", }, { - 'NAME': 'django.contrib.auth.password_validation.' - 'NumericPasswordValidator', + "NAME": "django.contrib.auth.password_validation.NumericPasswordValidator", }, ] @@ -93,9 +89,9 @@ # Internationalization # https://docs.djangoproject.com/en/3.2/topics/i18n/ -LANGUAGE_CODE = 'en-us' +LANGUAGE_CODE = "en-us" -TIME_ZONE = 'UTC' +TIME_ZONE = "UTC" USE_I18N = True @@ -107,13 +103,13 @@ # Static files (CSS, JavaScript, Images) # https://docs.djangoproject.com/en/3.2/howto/static-files/ -STATIC_URL = '/static/' +STATIC_URL = "/static/" -ASAP_VALID_AUDIENCE = 'server-app' -ASAP_VALID_ISSUERS = ('client-app', 'whitelist') +ASAP_VALID_AUDIENCE = "server-app" +ASAP_VALID_ISSUERS = ("client-app", "whitelist") ASAP_PUBLICKEY_REPOSITORY = None # Default primary key field type # https://docs.djangoproject.com/en/3.2/ref/settings/#default-auto-field -DEFAULT_AUTO_FIELD = 'django.db.models.BigAutoField' +DEFAULT_AUTO_FIELD = "django.db.models.BigAutoField" diff --git a/atlassian_jwt_auth/frameworks/django/tests/test_django.py b/atlassian_jwt_auth/frameworks/django/tests/test_django.py index 1d8d9b5..95c39ab 100644 --- a/atlassian_jwt_auth/frameworks/django/tests/test_django.py +++ b/atlassian_jwt_auth/frameworks/django/tests/test_django.py @@ -2,7 +2,7 @@ import django from django.test.testcases import SimpleTestCase -from django.test.utils import override_settings, modify_settings +from django.test.utils import modify_settings, override_settings try: from django.urls import reverse @@ -14,18 +14,18 @@ ) from atlassian_jwt_auth.tests import utils from atlassian_jwt_auth.tests.utils import ( - create_token, RS256KeyTestMixin, + create_token, ) class DjangoAsapMixin(object): - @classmethod def setUpClass(cls): os.environ.setdefault( - 'DJANGO_SETTINGS_MODULE', - 'atlassian_jwt_auth.frameworks.django.tests.settings') + "DJANGO_SETTINGS_MODULE", + "atlassian_jwt_auth.frameworks.django.tests.settings", + ) django.setup() super(DjangoAsapMixin, cls).setUpClass() @@ -33,7 +33,7 @@ def setUpClass(cls): @classmethod def tearDownClass(cls): super(DjangoAsapMixin, cls).tearDownClass() - del os.environ['DJANGO_SETTINGS_MODULE'] + del os.environ["DJANGO_SETTINGS_MODULE"] def setUp(self): super(DjangoAsapMixin, self).setUp() @@ -42,313 +42,368 @@ def setUp(self): self._private_key_pem ) - self.retriever = get_static_retriever_class({ - 'client-app/key01': self._public_key_pem - }) + self.retriever = get_static_retriever_class( + {"client-app/key01": self._public_key_pem} + ) - self.test_settings = { - 'ASAP_KEY_RETRIEVER_CLASS': self.retriever - } + self.test_settings = {"ASAP_KEY_RETRIEVER_CLASS": self.retriever} -@modify_settings(MIDDLEWARE={ - 'prepend': 'atlassian_jwt_auth.frameworks.django.asap_middleware', -}) +@modify_settings( + MIDDLEWARE={ + "prepend": "atlassian_jwt_auth.frameworks.django.asap_middleware", + } +) class TestAsapMiddleware(DjangoAsapMixin, RS256KeyTestMixin, SimpleTestCase): - - def check_response(self, - view_name, - response_content='', - status_code=200, - issuer='client-app', - audience='server-app', - key_id='client-app/key01', - subject=None, - private_key=None, - token=None, - authorization=None, - retriever_key=None): + def check_response( + self, + view_name, + response_content="", + status_code=200, + issuer="client-app", + audience="server-app", + key_id="client-app/key01", + subject=None, + private_key=None, + token=None, + authorization=None, + retriever_key=None, + ): if authorization is None: if token is None: if private_key is None: private_key = self._private_key_pem - token = create_token(issuer=issuer, audience=audience, - key_id=key_id, private_key=private_key, - subject=subject) - authorization = b'Bearer ' + token + token = create_token( + issuer=issuer, + audience=audience, + key_id=key_id, + private_key=private_key, + subject=subject, + ) + authorization = b"Bearer " + token test_settings = self.test_settings.copy() if retriever_key is not None: - retriever = get_static_retriever_class({ - retriever_key: self._public_key_pem - }) - test_settings['ASAP_KEY_RETRIEVER_CLASS'] = retriever + retriever = get_static_retriever_class( + {retriever_key: self._public_key_pem} + ) + test_settings["ASAP_KEY_RETRIEVER_CLASS"] = retriever with override_settings(**test_settings): - response = self.client.get(reverse(view_name), - HTTP_AUTHORIZATION=authorization) + response = self.client.get( + reverse(view_name), HTTP_AUTHORIZATION=authorization + ) - self.assertContains(response, response_content, - status_code=status_code) + self.assertContains(response, response_content, status_code=status_code) def test_request_with_valid_token_is_allowed(self): - self.check_response('needed', 'one', 200) + self.check_response("needed", "one", 200) def test_request_with_valid_token_multiple_allowed_auds(self): - audiences = ['server-app', 'another_one'] - self.test_settings['ASAP_VALID_AUDIENCE'] = audiences + audiences = ["server-app", "another_one"] + self.test_settings["ASAP_VALID_AUDIENCE"] = audiences for aud in audiences: - self.check_response('needed', 'one', 200, audience=aud) + self.check_response("needed", "one", 200, audience=aud) def test_request_with_valid_token_multiple_allowed_auds_invalid_aud(self): - audiences = ['server-app', 'another_one'] - self.test_settings['ASAP_VALID_AUDIENCE'] = audiences - self.check_response('needed', 'Unauthorized', 401, audience="invalid") + audiences = ["server-app", "another_one"] + self.test_settings["ASAP_VALID_AUDIENCE"] = audiences + self.check_response("needed", "Unauthorized", 401, audience="invalid") def test_request_with_duplicate_jti_is_rejected_as_per_setting(self): - self.test_settings['ASAP_CHECK_JTI_UNIQUENESS'] = True + self.test_settings["ASAP_CHECK_JTI_UNIQUENESS"] = True token = create_token( - issuer='client-app', audience='server-app', - key_id='client-app/key01', private_key=self._private_key_pem + issuer="client-app", + audience="server-app", + key_id="client-app/key01", + private_key=self._private_key_pem, ) - str_auth = 'Bearer ' + token.decode(encoding='iso-8859-1') - self.check_response('needed', 'one', 200, authorization=str_auth) - self.check_response('needed', 'duplicate jti', 401, - authorization=str_auth) + str_auth = "Bearer " + token.decode(encoding="iso-8859-1") + self.check_response("needed", "one", 200, authorization=str_auth) + self.check_response("needed", "duplicate jti", 401, authorization=str_auth) def _assert_request_with_duplicate_jti_is_accepted(self): token = create_token( - issuer='client-app', audience='server-app', - key_id='client-app/key01', private_key=self._private_key_pem + issuer="client-app", + audience="server-app", + key_id="client-app/key01", + private_key=self._private_key_pem, ) - str_auth = 'Bearer ' + token.decode(encoding='iso-8859-1') - self.check_response('needed', 'one', 200, authorization=str_auth) - self.check_response('needed', 'one', 200, authorization=str_auth) + str_auth = "Bearer " + token.decode(encoding="iso-8859-1") + self.check_response("needed", "one", 200, authorization=str_auth) + self.check_response("needed", "one", 200, authorization=str_auth) def test_request_with_duplicate_jti_is_accepted(self): self._assert_request_with_duplicate_jti_is_accepted() def test_request_with_duplicate_jti_is_accepted_as_per_setting(self): - self.test_settings['ASAP_CHECK_JTI_UNIQUENESS'] = False + self.test_settings["ASAP_CHECK_JTI_UNIQUENESS"] = False self._assert_request_with_duplicate_jti_is_accepted() def test_request_with_string_headers_is_allowed(self): token = create_token( - issuer='client-app', audience='server-app', - key_id='client-app/key01', private_key=self._private_key_pem + issuer="client-app", + audience="server-app", + key_id="client-app/key01", + private_key=self._private_key_pem, ) - str_auth = 'Bearer ' + token.decode(encoding='iso-8859-1') - self.check_response('needed', 'one', 200, authorization=str_auth) + str_auth = "Bearer " + token.decode(encoding="iso-8859-1") + self.check_response("needed", "one", 200, authorization=str_auth) def test_request_with_invalid_audience_is_rejected(self): - self.check_response('needed', 'Unauthorized', 401, - audience='invalid') + self.check_response("needed", "Unauthorized", 401, audience="invalid") def test_request_with_invalid_token_is_rejected(self): - self.check_response('needed', 'Unauthorized', 401, - authorization='Bearer invalid') + self.check_response( + "needed", "Unauthorized", 401, authorization="Bearer invalid" + ) def test_request_without_token_is_rejected(self): with override_settings(**self.test_settings): - response = self.client.get(reverse('needed')) + response = self.client.get(reverse("needed")) - self.assertContains(response, 'Unauthorized', - status_code=401) + self.assertContains(response, "Unauthorized", status_code=401) def test_request_with_invalid_issuer_is_rejected(self): - self.check_response('needed', 'Forbidden', 403, - issuer='something-invalid', - key_id='something-invalid/key01', - retriever_key='something-invalid/key01') + self.check_response( + "needed", + "Forbidden", + 403, + issuer="something-invalid", + key_id="something-invalid/key01", + retriever_key="something-invalid/key01", + ) def test_request_non_whitelisted_decorated_issuer_is_rejected(self): - self.check_response('needed', 'Forbidden', 403, - issuer='unexpected', - key_id='unexpected/key01', - retriever_key='unexpected/key01') + self.check_response( + "needed", + "Forbidden", + 403, + issuer="unexpected", + key_id="unexpected/key01", + retriever_key="unexpected/key01", + ) def test_request_non_decorated_issuer_is_rejected(self): - self.check_response('restricted_issuer', 'Forbidden', 403) + self.check_response("restricted_issuer", "Forbidden", 403) def test_request_decorated_issuer_is_allowed(self): - self.check_response('restricted_issuer', 'three', - issuer='whitelist', - key_id='whitelist/key01', - retriever_key='whitelist/key01') + self.check_response( + "restricted_issuer", + "three", + issuer="whitelist", + key_id="whitelist/key01", + retriever_key="whitelist/key01", + ) # TODO: modify JWTAuthSigner to allow non-issuer subjects and update the # decorated subject test cases def test_request_non_decorated_subject_is_rejected(self): - self.check_response('restricted_subject', 'Forbidden', 403, - issuer='whitelist', - key_id='whitelist/key01', - retriever_key='whitelist/key01') + self.check_response( + "restricted_subject", + "Forbidden", + 403, + issuer="whitelist", + key_id="whitelist/key01", + retriever_key="whitelist/key01", + ) def test_request_using_settings_only_is_allowed(self): - self.check_response('unneeded', 'two') + self.check_response("unneeded", "two") def test_request_subject_does_not_need_to_match_issuer_from_settings(self): - self.test_settings['ASAP_SUBJECT_SHOULD_MATCH_ISSUER'] = False - self.check_response('needed', 'one', 200, subject='different_than_is') + self.test_settings["ASAP_SUBJECT_SHOULD_MATCH_ISSUER"] = False + self.check_response("needed", "one", 200, subject="different_than_is") def test_request_subject_and_issue_not_matching(self): self.check_response( - 'needed', - 'Subject and Issuer do not match', + "needed", + "Subject and Issuer do not match", 401, - subject='different_than_is', + subject="different_than_is", ) class TestAsapDecorator(DjangoAsapMixin, RS256KeyTestMixin, SimpleTestCase): def test_request_with_valid_token_is_allowed(self): token = create_token( - issuer='client-app', audience='server-app', - key_id='client-app/key01', private_key=self._private_key_pem + issuer="client-app", + audience="server-app", + key_id="client-app/key01", + private_key=self._private_key_pem, ) with override_settings(**self.test_settings): - response = self.client.get(reverse('expected'), - HTTP_AUTHORIZATION=b'Bearer ' + token) + response = self.client.get( + reverse("expected"), HTTP_AUTHORIZATION=b"Bearer " + token + ) - self.assertContains(response, 'Greatest Success!', status_code=200) + self.assertContains(response, "Greatest Success!", status_code=200) def test_request_with_string_headers_is_allowed(self): token = create_token( - issuer='client-app', audience='server-app', - key_id='client-app/key01', private_key=self._private_key_pem + issuer="client-app", + audience="server-app", + key_id="client-app/key01", + private_key=self._private_key_pem, ) - str_token = token.decode(encoding='iso-8859-1') + str_token = token.decode(encoding="iso-8859-1") with override_settings(**self.test_settings): - response = self.client.get(reverse('expected'), - HTTP_AUTHORIZATION='Bearer ' + - str_token) + response = self.client.get( + reverse("expected"), HTTP_AUTHORIZATION="Bearer " + str_token + ) - self.assertContains(response, 'Greatest Success!', status_code=200) + self.assertContains(response, "Greatest Success!", status_code=200) def test_request_with_invalid_audience_is_rejected(self): token = create_token( - issuer='client-app', audience='something-invalid', - key_id='client-app/key01', private_key=self._private_key_pem + issuer="client-app", + audience="something-invalid", + key_id="client-app/key01", + private_key=self._private_key_pem, ) with override_settings(**self.test_settings): - response = self.client.get(reverse('expected'), - HTTP_AUTHORIZATION=b'Bearer ' + token) + response = self.client.get( + reverse("expected"), HTTP_AUTHORIZATION=b"Bearer " + token + ) - self.assertContains(response, 'Unauthorized: Invalid token', - status_code=401) + self.assertContains(response, "Unauthorized: Invalid token", status_code=401) def test_request_with_invalid_token_is_rejected(self): with override_settings(**self.test_settings): response = self.client.get( - reverse('expected'), - HTTP_AUTHORIZATION=b'Bearer notavalidtoken') + reverse("expected"), HTTP_AUTHORIZATION=b"Bearer notavalidtoken" + ) - self.assertContains(response, 'Unauthorized: Invalid token', - status_code=401) + self.assertContains(response, "Unauthorized: Invalid token", status_code=401) def test_request_without_token_is_rejected(self): with override_settings(**self.test_settings): - response = self.client.get(reverse('expected')) + response = self.client.get(reverse("expected")) - self.assertContains(response, 'Unauthorized', - status_code=401) + self.assertContains(response, "Unauthorized", status_code=401) def test_request_with_invalid_issuer_is_rejected(self): - retriever = get_static_retriever_class({ - 'something-invalid/key01': self._public_key_pem - }) + retriever = get_static_retriever_class( + {"something-invalid/key01": self._public_key_pem} + ) token = create_token( - issuer='something-invalid', audience='server-app', - key_id='something-invalid/key01', private_key=self._private_key_pem + issuer="something-invalid", + audience="server-app", + key_id="something-invalid/key01", + private_key=self._private_key_pem, ) with override_settings(ASAP_KEY_RETRIEVER_CLASS=retriever): - response = self.client.get(reverse('expected'), - HTTP_AUTHORIZATION=b'Bearer ' + token) + response = self.client.get( + reverse("expected"), HTTP_AUTHORIZATION=b"Bearer " + token + ) - self.assertContains(response, 'Forbidden: Invalid token issuer', - status_code=403) + self.assertContains( + response, "Forbidden: Invalid token issuer", status_code=403 + ) def test_request_non_decorated_issuer_is_rejected(self): token = create_token( - issuer='client-app', audience='server-app', - key_id='client-app/key01', private_key=self._private_key_pem + issuer="client-app", + audience="server-app", + key_id="client-app/key01", + private_key=self._private_key_pem, ) with override_settings(**self.test_settings): - response = self.client.get(reverse('decorated'), - HTTP_AUTHORIZATION=b'Bearer ' + token) + response = self.client.get( + reverse("decorated"), HTTP_AUTHORIZATION=b"Bearer " + token + ) - self.assertContains(response, 'Forbidden: Invalid token issuer', - status_code=403) + self.assertContains( + response, "Forbidden: Invalid token issuer", status_code=403 + ) def test_request_decorated_issuer_is_allowed(self): - retriever = get_static_retriever_class({ - 'whitelist/key01': self._public_key_pem - }) + retriever = get_static_retriever_class( + {"whitelist/key01": self._public_key_pem} + ) token = create_token( - issuer='whitelist', audience='server-app', - key_id='whitelist/key01', private_key=self._private_key_pem + issuer="whitelist", + audience="server-app", + key_id="whitelist/key01", + private_key=self._private_key_pem, ) with override_settings(ASAP_KEY_RETRIEVER_CLASS=retriever): - response = self.client.get(reverse('decorated'), - HTTP_AUTHORIZATION=b'Bearer ' + token) + response = self.client.get( + reverse("decorated"), HTTP_AUTHORIZATION=b"Bearer " + token + ) - self.assertContains(response, 'Only the right issuer is allowed.') + self.assertContains(response, "Only the right issuer is allowed.") def test_request_using_settings_only_is_allowed(self): token = create_token( - issuer='client-app', audience='server-app', - key_id='client-app/key01', private_key=self._private_key_pem + issuer="client-app", + audience="server-app", + key_id="client-app/key01", + private_key=self._private_key_pem, ) with override_settings(**self.test_settings): - response = self.client.get(reverse('settings'), - HTTP_AUTHORIZATION=b'Bearer ' + token) + response = self.client.get( + reverse("settings"), HTTP_AUTHORIZATION=b"Bearer " + token + ) - self.assertContains(response, 'Any settings issuer is allowed.') + self.assertContains(response, "Any settings issuer is allowed.") def test_request_subject_does_not_need_to_match_issuer(self): token = create_token( - issuer='client-app', audience='server-app', - key_id='client-app/key01', private_key=self._private_key_pem, - subject='not-client-app', + issuer="client-app", + audience="server-app", + key_id="client-app/key01", + private_key=self._private_key_pem, + subject="not-client-app", ) with override_settings(**self.test_settings): response = self.client.get( - reverse('subject_does_not_need_to_match_issuer'), - HTTP_AUTHORIZATION=b'Bearer ' + token) + reverse("subject_does_not_need_to_match_issuer"), + HTTP_AUTHORIZATION=b"Bearer " + token, + ) - self.assertContains(response, 'Subject does not need to match issuer.') + self.assertContains(response, "Subject does not need to match issuer.") def test_request_subject_does_need_to_match_issuer_override_settings(self): - """ tests that the with_asap decorator can override the - ASAP_SUBJECT_SHOULD_MATCH_ISSUER setting. + """tests that the with_asap decorator can override the + ASAP_SUBJECT_SHOULD_MATCH_ISSUER setting. """ token = create_token( - issuer='client-app', audience='server-app', - key_id='client-app/key01', private_key=self._private_key_pem, - subject='not-client-app', + issuer="client-app", + audience="server-app", + key_id="client-app/key01", + private_key=self._private_key_pem, + subject="not-client-app", ) - with override_settings(**dict( - self.test_settings, ASAP_SUBJECT_SHOULD_MATCH_ISSUER=False)): + with override_settings( + **dict(self.test_settings, ASAP_SUBJECT_SHOULD_MATCH_ISSUER=False) + ): response = self.client.get( - reverse('subject_does_need_to_match_issuer'), - HTTP_AUTHORIZATION=b'Bearer ' + token) + reverse("subject_does_need_to_match_issuer"), + HTTP_AUTHORIZATION=b"Bearer " + token, + ) self.assertContains( response, - 'Unauthorized: Subject and Issuer do not match', - status_code=401 + "Unauthorized: Subject and Issuer do not match", + status_code=401, ) def test_request_subject_does_not_need_to_match_issuer_from_settings(self): token = create_token( - issuer='client-app', audience='server-app', - key_id='client-app/key01', private_key=self._private_key_pem, - subject='not-client-app', + issuer="client-app", + audience="server-app", + key_id="client-app/key01", + private_key=self._private_key_pem, + subject="not-client-app", ) - with override_settings(**dict( - self.test_settings, ASAP_SUBJECT_SHOULD_MATCH_ISSUER=False)): + with override_settings( + **dict(self.test_settings, ASAP_SUBJECT_SHOULD_MATCH_ISSUER=False) + ): response = self.client.get( - reverse('subject_does_not_need_to_match_issuer_from_settings'), - HTTP_AUTHORIZATION=b'Bearer ' + token) + reverse("subject_does_not_need_to_match_issuer_from_settings"), + HTTP_AUTHORIZATION=b"Bearer " + token, + ) self.assertContains( - response, 'Subject does not need to match issuer (settings).') + response, "Subject does not need to match issuer (settings)." + ) diff --git a/atlassian_jwt_auth/frameworks/django/tests/urls.py b/atlassian_jwt_auth/frameworks/django/tests/urls.py index 616a58b..2c7dbd5 100644 --- a/atlassian_jwt_auth/frameworks/django/tests/urls.py +++ b/atlassian_jwt_auth/frameworks/django/tests/urls.py @@ -2,28 +2,36 @@ from atlassian_jwt_auth.frameworks.django.tests import views - urlpatterns = [ - path('asap/expected', views.expected_view, name='expected'), - path(r'^asap/unexpected', views.unexpected_view, name='unexpected'), - path('^asap/decorated', views.decorated_view, name='decorated'), - path('asap/settings', views.settings_view, name='settings'), - - path('asap/subject_does_not_need_to_match_issuer', - views.subject_does_not_need_to_match_issuer_view, - name='subject_does_not_need_to_match_issuer'), - path('asap/subject_does_need_to_match_issuer_view', - views.subject_does_need_to_match_issuer_view, - name='subject_does_need_to_match_issuer'), - - path('asap/subject_does_not_need_to_match_issuer_from_settings', - views.subject_does_not_need_to_match_issuer_from_settings_view, - name='subject_does_not_need_to_match_issuer_from_settings'), - - path('asap/needed', views.needed_view, name='needed'), - path(r'asap/unneeded', views.unneeded_view, name='unneeded'), - path(r'asap/restricted_issuer', views.restricted_issuer_view, - name='restricted_issuer'), - path('asap/restricted_subject', views.restricted_subject_view, - name='restricted_subject'), + path("asap/expected", views.expected_view, name="expected"), + path(r"^asap/unexpected", views.unexpected_view, name="unexpected"), + path("^asap/decorated", views.decorated_view, name="decorated"), + path("asap/settings", views.settings_view, name="settings"), + path( + "asap/subject_does_not_need_to_match_issuer", + views.subject_does_not_need_to_match_issuer_view, + name="subject_does_not_need_to_match_issuer", + ), + path( + "asap/subject_does_need_to_match_issuer_view", + views.subject_does_need_to_match_issuer_view, + name="subject_does_need_to_match_issuer", + ), + path( + "asap/subject_does_not_need_to_match_issuer_from_settings", + views.subject_does_not_need_to_match_issuer_from_settings_view, + name="subject_does_not_need_to_match_issuer_from_settings", + ), + path("asap/needed", views.needed_view, name="needed"), + path(r"asap/unneeded", views.unneeded_view, name="unneeded"), + path( + r"asap/restricted_issuer", + views.restricted_issuer_view, + name="restricted_issuer", + ), + path( + "asap/restricted_subject", + views.restricted_subject_view, + name="restricted_subject", + ), ] diff --git a/atlassian_jwt_auth/frameworks/django/tests/views.py b/atlassian_jwt_auth/frameworks/django/tests/views.py index 732edad..45aef00 100644 --- a/atlassian_jwt_auth/frameworks/django/tests/views.py +++ b/atlassian_jwt_auth/frameworks/django/tests/views.py @@ -1,60 +1,59 @@ from django.http import HttpResponse -from atlassian_jwt_auth.frameworks.django import with_asap, restrict_asap -from atlassian_jwt_auth.contrib.django.decorators import (requires_asap, - validate_asap) +from atlassian_jwt_auth.contrib.django.decorators import requires_asap, validate_asap +from atlassian_jwt_auth.frameworks.django import restrict_asap, with_asap -@with_asap(issuers=['client-app']) +@with_asap(issuers=["client-app"]) def expected_view(request): - return HttpResponse('Greatest Success!') + return HttpResponse("Greatest Success!") -@with_asap(issuers=['unexpected']) +@with_asap(issuers=["unexpected"]) def unexpected_view(request): - return HttpResponse('This should fail.') + return HttpResponse("This should fail.") -@with_asap(issuers=['whitelist']) +@with_asap(issuers=["whitelist"]) def decorated_view(request): - return HttpResponse('Only the right issuer is allowed.') + return HttpResponse("Only the right issuer is allowed.") @requires_asap() def settings_view(request): - return HttpResponse('Any settings issuer is allowed.') + return HttpResponse("Any settings issuer is allowed.") @with_asap(subject_should_match_issuer=False) def subject_does_not_need_to_match_issuer_view(request): - return HttpResponse('Subject does not need to match issuer.') + return HttpResponse("Subject does not need to match issuer.") @with_asap(subject_should_match_issuer=True) def subject_does_need_to_match_issuer_view(request): - return HttpResponse('Subject does need to match issuer.') + return HttpResponse("Subject does need to match issuer.") @with_asap() def subject_does_not_need_to_match_issuer_from_settings_view(request): - return HttpResponse('Subject does not need to match issuer (settings).') + return HttpResponse("Subject does not need to match issuer (settings).") @restrict_asap def needed_view(request): - return HttpResponse('one') + return HttpResponse("one") @restrict_asap(required=False) def unneeded_view(request): - return HttpResponse('two') + return HttpResponse("two") -@restrict_asap(issuers=['whitelist']) +@restrict_asap(issuers=["whitelist"]) def restricted_issuer_view(request): - return HttpResponse('three') + return HttpResponse("three") -@validate_asap(subjects=['client-app']) +@validate_asap(subjects=["client-app"]) def restricted_subject_view(request): - return HttpResponse('four') + return HttpResponse("four") diff --git a/atlassian_jwt_auth/frameworks/flask/backend.py b/atlassian_jwt_auth/frameworks/flask/backend.py index 73a8acd..9cfadf0 100644 --- a/atlassian_jwt_auth/frameworks/flask/backend.py +++ b/atlassian_jwt_auth/frameworks/flask/backend.py @@ -1,4 +1,5 @@ -from flask import Response, current_app, g, request as current_req +from flask import Response, current_app, g +from flask import request as current_req from ..common.backend import Backend @@ -8,7 +9,7 @@ def get_authorization_header(self, request=None): if request is None: request = current_req - return request.headers.get('AUTHORIZATION', '') + return request.headers.get("AUTHORIZATION", "") def get_401_response(self, data=None, headers=None, request=None): if headers is None: diff --git a/atlassian_jwt_auth/frameworks/flask/decorators.py b/atlassian_jwt_auth/frameworks/flask/decorators.py index 0db29d4..4fbbedf 100644 --- a/atlassian_jwt_auth/frameworks/flask/decorators.py +++ b/atlassian_jwt_auth/frameworks/flask/decorators.py @@ -2,8 +2,7 @@ from .backend import FlaskBackend -def with_asap(func=None, issuers=None, required=None, - subject_should_match_issuer=None): +def with_asap(func=None, issuers=None, required=None, subject_should_match_issuer=None): """Decorator to allow endpoint-specific ASAP authentication. If authentication fails, a 401 or 403 response will be returned. Otherwise, @@ -21,6 +20,5 @@ def with_asap(func=None, issuers=None, required=None, token to be considered valid. """ return _with_asap( - func, FlaskBackend(), issuers, required, - subject_should_match_issuer + func, FlaskBackend(), issuers, required, subject_should_match_issuer ) diff --git a/atlassian_jwt_auth/frameworks/flask/tests/test_flask.py b/atlassian_jwt_auth/frameworks/flask/tests/test_flask.py index 23c366f..dc560ec 100644 --- a/atlassian_jwt_auth/frameworks/flask/tests/test_flask.py +++ b/atlassian_jwt_auth/frameworks/flask/tests/test_flask.py @@ -13,11 +13,13 @@ def get_app(): app = Flask(__name__) - app.config.update({ - 'ASAP_VALID_AUDIENCE': 'server-app', - 'ASAP_VALID_ISSUERS': ('client-app',), - 'ASAP_PUBLICKEY_REPOSITORY': None - }) + app.config.update( + { + "ASAP_VALID_AUDIENCE": "server-app", + "ASAP_VALID_ISSUERS": ("client-app",), + "ASAP_PUBLICKEY_REPOSITORY": None, + } + ) @app.route("/") @requires_asap @@ -25,7 +27,7 @@ def view(): return "OK" @app.route("/restricted-to-another-client/") - @with_asap(issuers=['another-client']) + @with_asap(issuers=["another-client"]) def view_for_another_client_app(): return "OK" @@ -33,7 +35,7 @@ def view_for_another_client_app(): class FlaskTests(utils.RS256KeyTestMixin, unittest.TestCase): - """ tests for the atlassian_jwt_auth.contrib.tests.flask """ + """tests for the atlassian_jwt_auth.contrib.tests.flask""" def setUp(self): self._private_key_pem = self.get_new_private_key_in_pem_format() @@ -44,61 +46,56 @@ def setUp(self): self.app = get_app() self.client = self.app.test_client() - retriever = get_static_retriever_class({ - 'client-app/key01': self._public_key_pem - }) - self.app.config['ASAP_KEY_RETRIEVER_CLASS'] = retriever + retriever = get_static_retriever_class( + {"client-app/key01": self._public_key_pem} + ) + self.app.config["ASAP_KEY_RETRIEVER_CLASS"] = retriever - def send_request(self, token, url='/'): - """ returns the response of sending a request containing the given - token sent in the Authorization header. + def send_request(self, token, url="/"): + """returns the response of sending a request containing the given + token sent in the Authorization header. """ # Note: We send the auth header as a string and not bytes here # due to how Werkzeug's Header code works. - return self.client.get(url, headers={ - 'Authorization': (b'Bearer ' + token).decode('iso-8859-1') - }) + return self.client.get( + url, headers={"Authorization": (b"Bearer " + token).decode("iso-8859-1")} + ) def test_request_with_valid_token_is_allowed(self): token = create_token( - 'client-app', 'server-app', - 'client-app/key01', self._private_key_pem + "client-app", "server-app", "client-app/key01", self._private_key_pem ) self.assertEqual(self.send_request(token).status_code, 200) def test_request_with_valid_token_multiple_allowed_auds(self): - audiences = ['server-app', 'another_one'] - self.app.config['ASAP_VALID_AUDIENCE'] = audiences + audiences = ["server-app", "another_one"] + self.app.config["ASAP_VALID_AUDIENCE"] = audiences for aud in audiences: token = create_token( - 'client-app', aud, - 'client-app/key01', self._private_key_pem + "client-app", aud, "client-app/key01", self._private_key_pem ) self.assertEqual(self.send_request(token).status_code, 200) def test_request_with_valid_token_multiple_allowed_auds_invalid_aud(self): - audiences = ['server-app', 'another_one'] - self.app.config['ASAP_VALID_AUDIENCE'] = audiences + audiences = ["server-app", "another_one"] + self.app.config["ASAP_VALID_AUDIENCE"] = audiences token = create_token( - 'client-app', "invalid", - 'client-app/key01', self._private_key_pem + "client-app", "invalid", "client-app/key01", self._private_key_pem ) self.assertEqual(self.send_request(token).status_code, 401) def test_request_with_duplicate_jti_is_rejected_as_per_setting(self): - self.app.config['ASAP_CHECK_JTI_UNIQUENESS'] = True + self.app.config["ASAP_CHECK_JTI_UNIQUENESS"] = True token = create_token( - 'client-app', 'server-app', - 'client-app/key01', self._private_key_pem + "client-app", "server-app", "client-app/key01", self._private_key_pem ) self.assertEqual(self.send_request(token).status_code, 200) self.assertEqual(self.send_request(token).status_code, 401) def _assert_request_with_duplicate_jti_is_accepted(self): token = create_token( - 'client-app', 'server-app', - 'client-app/key01', self._private_key_pem + "client-app", "server-app", "client-app/key01", self._private_key_pem ) self.assertEqual(self.send_request(token).status_code, 200) self.assertEqual(self.send_request(token).status_code, 200) @@ -107,55 +104,57 @@ def test_request_with_duplicate_jti_is_accepted(self): self._assert_request_with_duplicate_jti_is_accepted() def test_request_with_duplicate_jti_is_accepted_as_per_setting(self): - self.app.config['ASAP_CHECK_JTI_UNIQUENESS'] = False + self.app.config["ASAP_CHECK_JTI_UNIQUENESS"] = False self._assert_request_with_duplicate_jti_is_accepted() def test_request_with_invalid_audience_is_rejected(self): token = create_token( - 'client-app', 'invalid-audience', - 'client-app/key01', self._private_key_pem + "client-app", "invalid-audience", "client-app/key01", self._private_key_pem ) self.assertEqual(self.send_request(token).status_code, 401) def test_request_with_invalid_token_is_rejected(self): - response = self.send_request(b'notavalidtoken') + response = self.send_request(b"notavalidtoken") self.assertEqual(response.status_code, 401) def test_request_with_invalid_issuer_is_rejected(self): # Try with a different audience with a valid signature - self.app.config['ASAP_KEY_RETRIEVER_CLASS'] = ( - get_static_retriever_class({ - 'another-client/key01': self._public_key_pem - }) + self.app.config["ASAP_KEY_RETRIEVER_CLASS"] = get_static_retriever_class( + {"another-client/key01": self._public_key_pem} ) token = create_token( - 'another-client', 'server-app', - 'another-client/key01', self._private_key_pem + "another-client", + "server-app", + "another-client/key01", + self._private_key_pem, ) self.assertEqual(self.send_request(token).status_code, 403) def test_decorated_request_with_invalid_issuer_is_rejected(self): # Try with a different audience with a valid signature token = create_token( - 'client-app', 'server-app', - 'client-app/key01', self._private_key_pem + "client-app", "server-app", "client-app/key01", self._private_key_pem ) - url = '/restricted-to-another-client/' + url = "/restricted-to-another-client/" self.assertEqual(self.send_request(token, url=url).status_code, 403) def test_request_subject_and_issue_not_matching(self): token = create_token( - 'client-app', 'server-app', - 'client-app/key01', self._private_key_pem, - subject='different' + "client-app", + "server-app", + "client-app/key01", + self._private_key_pem, + subject="different", ) self.assertEqual(self.send_request(token).status_code, 401) def test_request_subject_does_not_need_to_match_issuer_from_settings(self): - self.app.config['ASAP_SUBJECT_SHOULD_MATCH_ISSUER'] = False + self.app.config["ASAP_SUBJECT_SHOULD_MATCH_ISSUER"] = False token = create_token( - 'client-app', 'server-app', - 'client-app/key01', self._private_key_pem, - subject='different' + "client-app", + "server-app", + "client-app/key01", + self._private_key_pem, + subject="different", ) self.assertEqual(self.send_request(token).status_code, 200) diff --git a/atlassian_jwt_auth/frameworks/wsgi/backend.py b/atlassian_jwt_auth/frameworks/wsgi/backend.py index 513f302..dc8997a 100644 --- a/atlassian_jwt_auth/frameworks/wsgi/backend.py +++ b/atlassian_jwt_auth/frameworks/wsgi/backend.py @@ -8,9 +8,9 @@ def __init__(self, settings): def get_authorization_header(self, request=None): if request is None: - raise ValueError('No request available') + raise ValueError("No request available") - return request.environ.get('HTTP_AUTHORIZATION', b'') + return request.environ.get("HTTP_AUTHORIZATION", b"") def get_401_response(self, data=None, headers=None, request=None): if request is None: @@ -21,7 +21,7 @@ def get_401_response(self, data=None, headers=None, request=None): headers.update(self.default_headers_401) - request.start_response('401 Unauthorized', list(headers.items()), None) + request.start_response("401 Unauthorized", list(headers.items()), None) return "" def get_403_response(self, data=None, headers=None, request=None): @@ -31,11 +31,11 @@ def get_403_response(self, data=None, headers=None, request=None): if headers is None: headers = {} - request.start_response('403 Forbidden', list(headers.items()), None) + request.start_response("403 Forbidden", list(headers.items()), None) return "" def set_asap_claims_for_request(self, request, claims): - request.environ['ATL_ASAP_CLAIMS'] = claims + request.environ["ATL_ASAP_CLAIMS"] = claims @property def settings(self): diff --git a/atlassian_jwt_auth/frameworks/wsgi/middleware.py b/atlassian_jwt_auth/frameworks/wsgi/middleware.py index b8ee4b6..4e3fc5b 100644 --- a/atlassian_jwt_auth/frameworks/wsgi/middleware.py +++ b/atlassian_jwt_auth/frameworks/wsgi/middleware.py @@ -1,8 +1,9 @@ from collections import namedtuple + from ..common.asap import _process_asap_token from .backend import WSGIBackend -Request = namedtuple('Request', ['environ', 'start_response']) +Request = namedtuple("Request", ["environ", "start_response"]) class ASAPMiddleware(object): diff --git a/atlassian_jwt_auth/frameworks/wsgi/tests/test_wsgi.py b/atlassian_jwt_auth/frameworks/wsgi/tests/test_wsgi.py index 71fb0a2..a4f82ce 100644 --- a/atlassian_jwt_auth/frameworks/wsgi/tests/test_wsgi.py +++ b/atlassian_jwt_auth/frameworks/wsgi/tests/test_wsgi.py @@ -9,12 +9,12 @@ def app(environ, start_response): - start_response('200 OK', [], None) + start_response("200 OK", [], None) return "OK" class WsgiTests(utils.RS256KeyTestMixin, unittest.TestCase): - """ tests for the atlassian_jwt_auth.contrib.tests.flask """ + """tests for the atlassian_jwt_auth.contrib.tests.flask""" def setUp(self): self._private_key_pem = self.get_new_private_key_in_pem_format() @@ -22,110 +22,114 @@ def setUp(self): self._private_key_pem ) - retriever = get_static_retriever_class({ - 'client-app/key01': self._public_key_pem - }) + retriever = get_static_retriever_class( + {"client-app/key01": self._public_key_pem} + ) self.config = { - 'ASAP_VALID_AUDIENCE': 'server-app', - 'ASAP_VALID_ISSUERS': ('client-app',), - 'ASAP_KEY_RETRIEVER_CLASS': retriever + "ASAP_VALID_AUDIENCE": "server-app", + "ASAP_VALID_ISSUERS": ("client-app",), + "ASAP_KEY_RETRIEVER_CLASS": retriever, } def get_app_with_middleware(self, config): return ASAPMiddleware(app, config) - def send_request(self, url='/', config=None, token=None, application=None): - """ returns the response of sending a request containing the given - token sent in the Authorization header. + def send_request(self, url="/", config=None, token=None, application=None): + """returns the response of sending a request containing the given + token sent in the Authorization header. """ resp_info = {} def start_response(status, response_headers, exc_info=None): - resp_info['status'] = status - resp_info['headers'] = response_headers + resp_info["status"] = status + resp_info["headers"] = response_headers environ = {} if token: - environ['HTTP_AUTHORIZATION'] = b'Bearer ' + token + environ["HTTP_AUTHORIZATION"] = b"Bearer " + token if application is None: application = self.get_app_with_middleware(config or self.config) return application(environ, start_response), resp_info, environ def test_request_with_valid_token_is_allowed(self): token = create_token( - 'client-app', 'server-app', - 'client-app/key01', self._private_key_pem + "client-app", "server-app", "client-app/key01", self._private_key_pem ) body, resp_info, environ = self.send_request(token=token) - self.assertEqual(resp_info['status'], '200 OK') - self.assertIn('ATL_ASAP_CLAIMS', environ) + self.assertEqual(resp_info["status"], "200 OK") + self.assertIn("ATL_ASAP_CLAIMS", environ) def test_request_with_duplicate_jti_is_rejected_as_per_setting(self): - self.config['ASAP_CHECK_JTI_UNIQUENESS'] = True + self.config["ASAP_CHECK_JTI_UNIQUENESS"] = True token = create_token( - 'client-app', 'server-app', - 'client-app/key01', self._private_key_pem + "client-app", "server-app", "client-app/key01", self._private_key_pem ) application = self.get_app_with_middleware(self.config) body, resp_info, environ = self.send_request( - token=token, application=application) - self.assertEqual(resp_info['status'], '200 OK') + token=token, application=application + ) + self.assertEqual(resp_info["status"], "200 OK") body, resp_info, environ = self.send_request( - token=token, application=application) - self.assertEqual(resp_info['status'], '401 Unauthorized') + token=token, application=application + ) + self.assertEqual(resp_info["status"], "401 Unauthorized") def _assert_request_with_duplicate_jti_is_accepted(self): token = create_token( - 'client-app', 'server-app', - 'client-app/key01', self._private_key_pem + "client-app", "server-app", "client-app/key01", self._private_key_pem ) application = self.get_app_with_middleware(self.config) body, resp_info, environ = self.send_request( - token=token, application=application) - self.assertEqual(resp_info['status'], '200 OK') + token=token, application=application + ) + self.assertEqual(resp_info["status"], "200 OK") body, resp_info, environ = self.send_request( - token=token, application=application) - self.assertEqual(resp_info['status'], '200 OK') + token=token, application=application + ) + self.assertEqual(resp_info["status"], "200 OK") def test_request_with_duplicate_jti_is_accepted(self): self._assert_request_with_duplicate_jti_is_accepted() def test_request_with_duplicate_jti_is_accepted_as_per_setting(self): - self.config['ASAP_CHECK_JTI_UNIQUENESS'] = False + self.config["ASAP_CHECK_JTI_UNIQUENESS"] = False self._assert_request_with_duplicate_jti_is_accepted() def test_request_with_invalid_audience_is_rejected(self): token = create_token( - 'client-app', 'invalid-audience', - 'client-app/key01', self._private_key_pem + "client-app", "invalid-audience", "client-app/key01", self._private_key_pem ) body, resp_info, environ = self.send_request(token=token) - self.assertEqual(resp_info['status'], '401 Unauthorized') - self.assertNotIn('ATL_ASAP_CLAIMS', environ) + self.assertEqual(resp_info["status"], "401 Unauthorized") + self.assertNotIn("ATL_ASAP_CLAIMS", environ) def test_request_with_invalid_token_is_rejected(self): - body, resp_info, environ = self.send_request(token=b'notavalidtoken') - self.assertEqual(resp_info['status'], '401 Unauthorized') - self.assertNotIn('ATL_ASAP_CLAIMS', environ) + body, resp_info, environ = self.send_request(token=b"notavalidtoken") + self.assertEqual(resp_info["status"], "401 Unauthorized") + self.assertNotIn("ATL_ASAP_CLAIMS", environ) def test_request_subject_and_issue_not_matching(self): token = create_token( - 'client-app', 'server-app', - 'client-app/key01', self._private_key_pem, - subject='different' + "client-app", + "server-app", + "client-app/key01", + self._private_key_pem, + subject="different", ) body, resp_info, environ = self.send_request(token=token) - self.assertEqual(resp_info['status'], '401 Unauthorized') - self.assertNotIn('ATL_ASAP_CLAIMS', environ) + self.assertEqual(resp_info["status"], "401 Unauthorized") + self.assertNotIn("ATL_ASAP_CLAIMS", environ) def test_request_subject_does_not_need_to_match_issuer_from_settings(self): - self.config['ASAP_SUBJECT_SHOULD_MATCH_ISSUER'] = False + self.config["ASAP_SUBJECT_SHOULD_MATCH_ISSUER"] = False token = create_token( - 'client-app', 'server-app', - 'client-app/key01', self._private_key_pem, - subject='different' + "client-app", + "server-app", + "client-app/key01", + self._private_key_pem, + subject="different", ) body, resp_info, environ = self.send_request(token=token) - self.assertEqual(resp_info['status'], '200 OK') - self.assertIn('ATL_ASAP_CLAIMS', environ) + self.assertEqual(resp_info["status"], "200 OK") + self.assertIn("ATL_ASAP_CLAIMS", environ) diff --git a/atlassian_jwt_auth/key.py b/atlassian_jwt_auth/key.py index 6984e01..2bbac6b 100644 --- a/atlassian_jwt_auth/key.py +++ b/atlassian_jwt_auth/key.py @@ -2,8 +2,8 @@ import logging import os import re -from urllib.parse import unquote_plus from email.message import EmailMessage +from urllib.parse import unquote_plus import cachecontrol import cryptography.hazmat.backends @@ -11,19 +11,19 @@ import requests import requests.utils from cryptography.hazmat.primitives import serialization -from requests.exceptions import RequestException, ConnectionError - -from atlassian_jwt_auth.exceptions import (KeyIdentifierException, - PublicKeyRetrieverException, - PrivateKeyRetrieverException) +from requests.exceptions import ConnectionError, RequestException +from atlassian_jwt_auth.exceptions import ( + KeyIdentifierException, + PrivateKeyRetrieverException, + PublicKeyRetrieverException, +) -PEM_FILE_TYPE = 'application/x-pem-file' +PEM_FILE_TYPE = "application/x-pem-file" class KeyIdentifier(object): - - """ This class represents a key identifier """ + """This class represents a key identifier""" def __init__(self, identifier): self.__key_id = validate_key_identifier(identifier) @@ -34,9 +34,9 @@ def key_id(self): def validate_key_identifier(identifier): - """ returns a validated key identifier. """ - regex = re.compile(r'^[\w.\-\+/]*$') - _error_msg = 'Invalid key identifier %s' % identifier + """returns a validated key identifier.""" + regex = re.compile(r"^[\w.\-\+/]*$") + _error_msg = "Invalid key identifier %s" % identifier if not identifier: raise KeyIdentifierException(_error_msg) if not regex.match(identifier): @@ -44,41 +44,40 @@ def validate_key_identifier(identifier): normalised = os.path.normpath(identifier) if normalised != identifier: raise KeyIdentifierException(_error_msg) - if normalised.startswith('/'): + if normalised.startswith("/"): raise KeyIdentifierException(_error_msg) - if '..' in normalised: + if ".." in normalised: raise KeyIdentifierException(_error_msg) return identifier def _get_key_id_from_jwt_header(a_jwt): - """ returns the key identifier from a jwt header. """ + """returns the key identifier from a jwt header.""" header = jwt.get_unverified_header(a_jwt) - return KeyIdentifier(header['kid']) + return KeyIdentifier(header["kid"]) class BasePublicKeyRetriever(object): - """ Base class for retrieving a public key. """ + """Base class for retrieving a public key.""" def retrieve(self, key_identifier, **kwargs): raise NotImplementedError() class HTTPSPublicKeyRetriever(BasePublicKeyRetriever): - - """ This class retrieves public key from a https location based upon the - given key id. + """This class retrieves public key from a https location based upon the + given key id. """ + # Use a static requests session, reused/shared by all instances of # HTTPSPublicKeyRetriever: _class_session = None def __init__(self, base_url): - if base_url is None or not base_url.startswith('https://'): - raise PublicKeyRetrieverException( - 'The base url must start with https://') - if not base_url.endswith('/'): - base_url += '/' + if base_url is None or not base_url.startswith("https://"): + raise PublicKeyRetrieverException("The base url must start with https://") + if not base_url.endswith("/"): + base_url += "/" self.base_url = base_url self._session = self._get_session() self._proxies = requests.utils.get_environ_proxies(self.base_url) @@ -91,11 +90,11 @@ def _get_session(self): return HTTPSPublicKeyRetriever._class_session def retrieve(self, key_identifier, **requests_kwargs): - """ returns the public key for given key_identifier. """ + """returns the public key for given key_identifier.""" if not isinstance(key_identifier, KeyIdentifier): key_identifier = KeyIdentifier(key_identifier) - if self._proxies and 'proxies' not in requests_kwargs: - requests_kwargs['proxies'] = self._proxies + if self._proxies and "proxies" not in requests_kwargs: + requests_kwargs["proxies"] = self._proxies url = self.base_url + key_identifier.key_id try: return self._retrieve(url, requests_kwargs) @@ -107,44 +106,43 @@ def retrieve(self, key_identifier, **requests_kwargs): raise PublicKeyRetrieverException(e, status_code=status_code) def _retrieve(self, url, requests_kwargs): - resp = self._session.get(url, headers={'accept': PEM_FILE_TYPE}, - **requests_kwargs) + resp = self._session.get( + url, headers={"accept": PEM_FILE_TYPE}, **requests_kwargs + ) resp.raise_for_status() - self._check_content_type(url, resp.headers['content-type']) + self._check_content_type(url, resp.headers["content-type"]) return resp.text def _check_content_type(self, url, content_type): msg = EmailMessage() - msg['content-type'] = content_type + msg["content-type"] = content_type media_type = msg.get_content_type() if media_type.lower() != PEM_FILE_TYPE.lower(): raise PublicKeyRetrieverException( - "Invalid content-type, '%s', for url '%s' ." % - (content_type, url)) + "Invalid content-type, '%s', for url '%s' ." % (content_type, url) + ) class HTTPSMultiRepositoryPublicKeyRetriever(BasePublicKeyRetriever): - """ This class retrieves public key from the supplied https key - repository locations based upon key ids. + """This class retrieves public key from the supplied https key + repository locations based upon key ids. """ def __init__(self, key_repository_urls): if not isinstance(key_repository_urls, list): - raise TypeError('keystore_urls must be a list of urls.') + raise TypeError("keystore_urls must be a list of urls.") self._retrievers = self._create_retrievers(key_repository_urls) def _create_retrievers(self, key_repository_urls): - return [HTTPSPublicKeyRetriever(url) for url - in key_repository_urls] + return [HTTPSPublicKeyRetriever(url) for url in key_repository_urls] def handle_retrieval_exception(self, retriever, exception): - """ Handles working with exceptions encountered during key - retrieval. + """Handles working with exceptions encountered during key + retrieval. """ if isinstance(exception, PublicKeyRetrieverException): - original_exception = getattr( - exception, 'original_exception', None) + original_exception = getattr(exception, "original_exception", None) if isinstance(original_exception, ConnectionError): return if exception.status_code is None or exception.status_code < 500: @@ -158,53 +156,55 @@ def retrieve(self, key_identifier, **requests_kwargs): self.handle_retrieval_exception(retriever, e) logger = logging.getLogger(__name__) logger.warning( - 'Unable to retrieve public key from store', - extra={'underlying_error': str(e), - 'key repository': retriever.base_url}) - raise PublicKeyRetrieverException( - 'Cannot load key from key repositories') + "Unable to retrieve public key from store", + extra={ + "underlying_error": str(e), + "key repository": retriever.base_url, + }, + ) + raise PublicKeyRetrieverException("Cannot load key from key repositories") class BasePrivateKeyRetriever(object): - """ This is the base private key retriever class. """ + """This is the base private key retriever class.""" def load(self, issuer): - """ returns the key identifier and private key pem found - for the given issuer. + """returns the key identifier and private key pem found + for the given issuer. """ - raise NotImplementedError('Not implemented.') + raise NotImplementedError("Not implemented.") class DataUriPrivateKeyRetriever(BasePrivateKeyRetriever): - """ This class can be used to retrieve the key identifier and - private key from the supplied data uri. + """This class can be used to retrieve the key identifier and + private key from the supplied data uri. """ def __init__(self, data_uri): self._data_uri = data_uri def load(self, issuer): - if not self._data_uri.startswith('data:application/pkcs8;kid='): - raise PrivateKeyRetrieverException('Unrecognised data uri format.') - splitted = self._data_uri.split(';') - key_identifier = KeyIdentifier(unquote_plus( - splitted[1][len('kid='):])) - key_data = base64.b64decode(splitted[-1].split(',')[-1]) + if not self._data_uri.startswith("data:application/pkcs8;kid="): + raise PrivateKeyRetrieverException("Unrecognised data uri format.") + splitted = self._data_uri.split(";") + key_identifier = KeyIdentifier(unquote_plus(splitted[1][len("kid=") :])) + key_data = base64.b64decode(splitted[-1].split(",")[-1]) key = serialization.load_der_private_key( key_data, password=None, - backend=cryptography.hazmat.backends.default_backend()) + backend=cryptography.hazmat.backends.default_backend(), + ) private_key_pem = key.private_bytes( encoding=serialization.Encoding.PEM, format=serialization.PrivateFormat.TraditionalOpenSSL, - encryption_algorithm=serialization.NoEncryption() + encryption_algorithm=serialization.NoEncryption(), ) - return key_identifier, private_key_pem.decode('utf-8') + return key_identifier, private_key_pem.decode("utf-8") class StaticPrivateKeyRetriever(BasePrivateKeyRetriever): - """ This class simply returns the key_identifier and private_key_pem - initially provided to it in calls to load. + """This class simply returns the key_identifier and private_key_pem + initially provided to it in calls to load. """ def __init__(self, key_identifier, private_key_pem): @@ -219,14 +219,15 @@ def load(self, issuer): class FilePrivateKeyRetriever(BasePrivateKeyRetriever): - """ This class can be used to retrieve the latest key identifier and - private key for a given issuer found under its private key - repository path. + """This class can be used to retrieve the latest key identifier and + private key for a given issuer found under its private key + repository path. """ def __init__(self, private_key_repository_path): self.private_key_repository = FilePrivateKeyRepository( - private_key_repository_path) + private_key_repository_path + ) def load(self, issuer): key_identifier = self._find_last_key_id(issuer) @@ -234,17 +235,16 @@ def load(self, issuer): return key_identifier, private_key_pem def _find_last_key_id(self, issuer): - key_identifiers = list( - self.private_key_repository.find_valid_key_ids(issuer)) + key_identifiers = list(self.private_key_repository.find_valid_key_ids(issuer)) if key_identifiers: return key_identifiers[-1] else: - raise IOError('Issuer has no valid keys: %s' % issuer) + raise IOError("Issuer has no valid keys: %s" % issuer) class FilePrivateKeyRepository(object): - """ This class represents a file backed private key repository. """ + """This class represents a file backed private key repository.""" def __init__(self, path): self.path = path @@ -252,10 +252,10 @@ def __init__(self, path): def find_valid_key_ids(self, issuer): issuer_directory = os.path.join(self.path, issuer) for filename in sorted(os.listdir(issuer_directory)): - if filename.endswith('.pem'): - yield KeyIdentifier('%s/%s' % (issuer, filename)) + if filename.endswith(".pem"): + yield KeyIdentifier("%s/%s" % (issuer, filename)) def load_key(self, key_identifier): key_filename = os.path.join(self.path, key_identifier.key_id) - with open(key_filename, 'rb') as f: - return f.read().decode('utf-8') + with open(key_filename, "rb") as f: + return f.read().decode("utf-8") diff --git a/atlassian_jwt_auth/signer.py b/atlassian_jwt_auth/signer.py index 968c758..ec65599 100644 --- a/atlassian_jwt_auth/signer.py +++ b/atlassian_jwt_auth/signer.py @@ -6,31 +6,28 @@ from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import serialization -from atlassian_jwt_auth import algorithms -from atlassian_jwt_auth import key +from atlassian_jwt_auth import algorithms, key class JWTAuthSigner(object): - def __init__(self, issuer, private_key_retriever, **kwargs): self.issuer = issuer self.private_key_retriever = private_key_retriever - self.lifetime = kwargs.get('lifetime', datetime.timedelta(minutes=1)) - self.algorithm = kwargs.get('algorithm', 'RS256') - self.subject = kwargs.get('subject', None) + self.lifetime = kwargs.get("lifetime", datetime.timedelta(minutes=1)) + self.algorithm = kwargs.get("algorithm", "RS256") + self.subject = kwargs.get("subject", None) self._private_keys_cache = dict() - if self.algorithm not in set( - algorithms.get_permitted_algorithm_names()): - raise ValueError("Algorithm, '%s', is not permitted." % - self.algorithm) + if self.algorithm not in set(algorithms.get_permitted_algorithm_names()): + raise ValueError("Algorithm, '%s', is not permitted." % self.algorithm) if self.lifetime > datetime.timedelta(hours=1): - raise ValueError("lifetime, '%s',exceeds the allowed 1 hour max" % - (self.lifetime)) + raise ValueError( + "lifetime, '%s',exceeds the allowed 1 hour max" % (self.lifetime) + ) def _obtain_private_key(self, key_identifier, private_key_pem): - """ returns a loaded instance of the given private key either from - cache or from the given private_key_pem. + """returns a loaded instance of the given private key either from + cache or from the given private_key_pem. """ priv_key = self._private_keys_cache.get(key_identifier.key_id, None) if priv_key is not None: @@ -38,9 +35,7 @@ def _obtain_private_key(self, key_identifier, private_key_pem): if not isinstance(private_key_pem, bytes): private_key_pem = private_key_pem.encode() priv_key = serialization.load_pem_private_key( - private_key_pem, - password=None, - backend=default_backend() + private_key_pem, password=None, backend=default_backend() ) if len(self._private_keys_cache) > 10: self._private_keys_cache = dict() @@ -48,78 +43,78 @@ def _obtain_private_key(self, key_identifier, private_key_pem): return priv_key def _generate_claims(self, audience, **kwargs): - """ returns a new dictionary of claims. """ + """returns a new dictionary of claims.""" now = self._now() claims = { - 'iss': self.issuer, - 'exp': now + self.lifetime, - 'iat': now, - 'aud': audience, - 'jti': '%s:%s' % ( - now.strftime('%s'), random.SystemRandom().getrandbits(32)), - 'nbf': now, - 'sub': self.subject or self.issuer, + "iss": self.issuer, + "exp": now + self.lifetime, + "iat": now, + "aud": audience, + "jti": "%s:%s" + % (now.strftime("%s"), random.SystemRandom().getrandbits(32)), + "nbf": now, + "sub": self.subject or self.issuer, } - claims.update(kwargs.get('additional_claims', {})) + claims.update(kwargs.get("additional_claims", {})) return claims def _now(self): return datetime.datetime.now(datetime.timezone.utc) def generate_jwt(self, audience, **kwargs): - """ returns a new signed jwt for use. """ - key_identifier, private_key_pem = self.private_key_retriever.load( - self.issuer) - private_key = self._obtain_private_key( - key_identifier, private_key_pem) + """returns a new signed jwt for use.""" + key_identifier, private_key_pem = self.private_key_retriever.load(self.issuer) + private_key = self._obtain_private_key(key_identifier, private_key_pem) token = jwt.encode( self._generate_claims(audience, **kwargs), key=private_key, algorithm=self.algorithm, - headers={'kid': key_identifier.key_id}) + headers={"kid": key_identifier.key_id}, + ) if isinstance(token, str): - token = token.encode('utf-8') + token = token.encode("utf-8") return token class TokenReusingJWTAuthSigner(JWTAuthSigner): - def __init__(self, issuer, private_key_retriever, **kwargs): super(TokenReusingJWTAuthSigner, self).__init__( - issuer, private_key_retriever, **kwargs) - self.reuse_threshold = kwargs.get('reuse_jwt_threshold', 0.95) + issuer, private_key_retriever, **kwargs + ) + self.reuse_threshold = kwargs.get("reuse_jwt_threshold", 0.95) def get_cached_token(self, audience, **kwargs): - """ returns the cached token. If there is no matching cached token - then None is returned. + """returns the cached token. If there is no matching cached token + then None is returned. """ - return getattr(self, '_previous_token', None) + return getattr(self, "_previous_token", None) def set_cached_token(self, value): - """ sets the cached token.""" + """sets the cached token.""" self._previous_token = value def can_reuse_token(self, existing_token, claims): - """ returns True if the provided existing token can be reused - for the claims provided. + """returns True if the provided existing token can be reused + for the claims provided. """ if existing_token is None: return False existing_claims = jwt.decode( - existing_token, options={'verify_signature': False}) - existing_lifetime = (int(existing_claims['exp']) - - int(existing_claims['iat'])) - this_lifetime = (claims['exp'] - claims['iat']).total_seconds() + existing_token, options={"verify_signature": False} + ) + existing_lifetime = int(existing_claims["exp"]) - int(existing_claims["iat"]) + this_lifetime = (claims["exp"] - claims["iat"]).total_seconds() if existing_lifetime != this_lifetime: return False - about_to_expire = int(existing_claims['iat']) + ( - self.reuse_threshold * existing_lifetime) + about_to_expire = int(existing_claims["iat"]) + ( + self.reuse_threshold * existing_lifetime + ) if calendar.timegm(self._now().utctimetuple()) > about_to_expire: return False if set(claims.keys()) != set(existing_claims.keys()): return False for dict_key, val in claims.items(): - if dict_key in ['exp', 'iat', 'jti', 'nbf']: + if dict_key in ["exp", "iat", "jti", "nbf"]: continue if existing_claims[dict_key] != val: return False @@ -130,26 +125,27 @@ def generate_jwt(self, audience, **kwargs): claims = self._generate_claims(audience, **kwargs) if existing_token and self.can_reuse_token(existing_token, claims): return existing_token - token = super(TokenReusingJWTAuthSigner, self).generate_jwt( - audience, **kwargs) + token = super(TokenReusingJWTAuthSigner, self).generate_jwt(audience, **kwargs) self.set_cached_token(token) return token def _create_signer(issuer, private_key_retriever, **kwargs): signer_cls = JWTAuthSigner - if kwargs.get('reuse_jwts', None): + if kwargs.get("reuse_jwts", None): signer_cls = TokenReusingJWTAuthSigner return signer_cls(issuer, private_key_retriever, **kwargs) def create_signer(issuer, key_identifier, private_key_pem, **kwargs): private_key_retriever = key.StaticPrivateKeyRetriever( - key_identifier, private_key_pem) + key_identifier, private_key_pem + ) return _create_signer(issuer, private_key_retriever, **kwargs) def create_signer_from_file_private_key_repository( - issuer, private_key_repository, **kwargs): + issuer, private_key_repository, **kwargs +): private_key_retriever = key.FilePrivateKeyRetriever(private_key_repository) return _create_signer(issuer, private_key_retriever, **kwargs) diff --git a/atlassian_jwt_auth/tests/test_key.py b/atlassian_jwt_auth/tests/test_key.py index 487a75b..6e3a37d 100644 --- a/atlassian_jwt_auth/tests/test_key.py +++ b/atlassian_jwt_auth/tests/test_key.py @@ -4,21 +4,31 @@ class TestKeyModule(unittest.TestCase): - - """ tests for the key module. """ + """tests for the key module.""" def test_key_identifier_with_invalid_keys(self): - """ test that invalid key identifiers are not permitted. """ - keys = ['../aha', '/a', r'\c:a', 'lk2j34/#$', 'a../../a', 'a/;a', - ' ', ' / ', ' /', - u'dir/some\0thing', 'a/#a', 'a/a?x', 'a/a;', - ] + """test that invalid key identifiers are not permitted.""" + keys = [ + "../aha", + "/a", + r"\c:a", + "lk2j34/#$", + "a../../a", + "a/;a", + " ", + " / ", + " /", + "dir/some\0thing", + "a/#a", + "a/a?x", + "a/a;", + ] for key in keys: with self.assertRaises(ValueError): atlassian_jwt_auth.KeyIdentifier(identifier=key) def test_key_identifier_with_valid_keys(self): - """ test that valid keys work as expected. """ - for key in ['oa.oo/a', 'oo.sasdf.asdf/yes', 'oo/o']: + """test that valid keys work as expected.""" + for key in ["oa.oo/a", "oo.sasdf.asdf/yes", "oo/o"]: key_id = atlassian_jwt_auth.KeyIdentifier(identifier=key) self.assertEqual(key_id.key_id, key) diff --git a/atlassian_jwt_auth/tests/test_private_key_provider.py b/atlassian_jwt_auth/tests/test_private_key_provider.py index 25c298f..1e33c33 100644 --- a/atlassian_jwt_auth/tests/test_private_key_provider.py +++ b/atlassian_jwt_auth/tests/test_private_key_provider.py @@ -2,67 +2,68 @@ import unittest from cryptography.hazmat.backends import default_backend -from cryptography.hazmat.primitives.serialization import load_pem_private_key from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.serialization import load_pem_private_key +from atlassian_jwt_auth.key import DataUriPrivateKeyRetriever from atlassian_jwt_auth.signer import JWTAuthSigner from atlassian_jwt_auth.tests import utils -from atlassian_jwt_auth.key import DataUriPrivateKeyRetriever def convert_key_pem_format_to_der_format(private_key_pem): - private_key = load_pem_private_key(private_key_pem, - password=None, - backend=default_backend()) + private_key = load_pem_private_key( + private_key_pem, password=None, backend=default_backend() + ) return private_key.private_bytes( encoding=serialization.Encoding.DER, format=serialization.PrivateFormat.PKCS8, - encryption_algorithm=serialization.NoEncryption() + encryption_algorithm=serialization.NoEncryption(), ) class BaseDataUriPrivateKeyRetrieverTest(object): - """ tests for the DataUriPrivateKeyRetriever class. """ + """tests for the DataUriPrivateKeyRetriever class.""" def setUp(self): self._private_key_pem = self.get_new_private_key_in_pem_format() self._public_key_pem = utils.get_public_key_pem_for_private_key_pem( - self._private_key_pem) + self._private_key_pem + ) self._private_key_der = convert_key_pem_format_to_der_format( - self._private_key_pem) + self._private_key_pem + ) def get_example_data_uri(self, private_key_der): - return ('data:application/pkcs8;kid=example%2Feg;base64,' + - base64.b64encode(private_key_der).decode('utf-8')) + return "data:application/pkcs8;kid=example%2Feg;base64," + base64.b64encode( + private_key_der + ).decode("utf-8") def test_load_data_uri(self): - """ tests that a valid data uri is correctly loaded. """ - expected_kid = 'example/eg' + """tests that a valid data uri is correctly loaded.""" + expected_kid = "example/eg" data_uri = self.get_example_data_uri(self._private_key_der) provider = DataUriPrivateKeyRetriever(data_uri) - kid, private_key_pem = provider.load('example') + kid, private_key_pem = provider.load("example") self.assertEqual(kid.key_id, expected_kid) - self.assertEqual(private_key_pem, - self._private_key_pem.decode('utf-8')) + self.assertEqual(private_key_pem, self._private_key_pem.decode("utf-8")) def test_load_data_uri_can_be_used_with_a_signer(self): - """ tests that the data uri private key retriever can be used with a - signer to generate a jwt. + """tests that the data uri private key retriever can be used with a + signer to generate a jwt. """ data_uri = self.get_example_data_uri(self._private_key_der) provider = DataUriPrivateKeyRetriever(data_uri) - jwt_auth_signer = JWTAuthSigner( - 'issuer', provider, algorithm=self.algorithm) - jwt_auth_signer.generate_jwt('aud') + jwt_auth_signer = JWTAuthSigner("issuer", provider, algorithm=self.algorithm) + jwt_auth_signer.generate_jwt("aud") -class DataUriPrivateKeyRetrieverRS256Test(BaseDataUriPrivateKeyRetrieverTest, - utils.RS256KeyTestMixin, - unittest.TestCase): +class DataUriPrivateKeyRetrieverRS256Test( + BaseDataUriPrivateKeyRetrieverTest, utils.RS256KeyTestMixin, unittest.TestCase +): pass -class DataUriPrivateKeyRetrieverES256Test(BaseDataUriPrivateKeyRetrieverTest, - utils.ES256KeyTestMixin, - unittest.TestCase): +class DataUriPrivateKeyRetrieverES256Test( + BaseDataUriPrivateKeyRetrieverTest, utils.ES256KeyTestMixin, unittest.TestCase +): pass diff --git a/atlassian_jwt_auth/tests/test_public_key_provider.py b/atlassian_jwt_auth/tests/test_public_key_provider.py index cbdd8a4..5f71861 100644 --- a/atlassian_jwt_auth/tests/test_public_key_provider.py +++ b/atlassian_jwt_auth/tests/test_public_key_provider.py @@ -7,62 +7,59 @@ import requests from atlassian_jwt_auth.key import ( - HTTPSPublicKeyRetriever, - HTTPSMultiRepositoryPublicKeyRetriever, PEM_FILE_TYPE, + HTTPSMultiRepositoryPublicKeyRetriever, + HTTPSPublicKeyRetriever, ) from atlassian_jwt_auth.tests import utils def get_expected_and_os_proxies_dict(proxy_location): - """ returns expected proxy & environmental - proxy dictionary based upon the provided proxy location. + """returns expected proxy & environmental + proxy dictionary based upon the provided proxy location. """ expected_proxies = { - 'http': proxy_location, - 'https': proxy_location, - } - os_proxy_dict = { - 'HTTP_PROXY': proxy_location, - 'HTTPS_PROXY': proxy_location + "http": proxy_location, + "https": proxy_location, } + os_proxy_dict = {"HTTP_PROXY": proxy_location, "HTTPS_PROXY": proxy_location} return expected_proxies, os_proxy_dict class BaseHTTPSPublicKeyRetrieverTest(object): - """ tests for the HTTPSPublicKeyRetriever class. """ + """tests for the HTTPSPublicKeyRetriever class.""" def create_retriever(self, url): - """ returns a public key retriever created using the given url. """ + """returns a public key retriever created using the given url.""" return HTTPSPublicKeyRetriever(url) def setUp(self): self._private_key_pem = self.get_new_private_key_in_pem_format() self._public_key_pem = utils.get_public_key_pem_for_private_key_pem( - self._private_key_pem) - self.base_url = 'https://example.com' + self._private_key_pem + ) + self.base_url = "https://example.com" def test_https_public_key_retriever_does_not_support_http_url(self): - """ tests that HTTPSPublicKeyRetriever does not support http:// - base urls. + """tests that HTTPSPublicKeyRetriever does not support http:// + base urls. """ with self.assertRaises(ValueError): - self.create_retriever('http://example.com') + self.create_retriever("http://example.com") def test_https_public_key_retriever_does_not_support_none_url(self): - """ tests that HTTPSPublicKeyRetriever does not support None - base urls. + """tests that HTTPSPublicKeyRetriever does not support None + base urls. """ with self.assertRaises(ValueError): self.create_retriever(None) def test_https_public_key_retriever_session_uses_env_proxy(self): - """ tests that the underlying session makes use of environmental - proxy configured. + """tests that the underlying session makes use of environmental + proxy configured. """ - proxy_location = 'https://example.proxy' - expected_proxies, proxy_dict = get_expected_and_os_proxies_dict( - proxy_location) + proxy_location = "https://example.proxy" + expected_proxies, proxy_dict = get_expected_and_os_proxies_dict(proxy_location) with mock.patch.dict(os.environ, proxy_dict, clear=True): retriever = self.create_retriever(self.base_url) key_retrievers = [retriever] @@ -72,106 +69,99 @@ def test_https_public_key_retriever_session_uses_env_proxy(self): self.assertEqual(key_retriever._proxies, expected_proxies) def test_https_public_key_retriever_supports_https_url(self): - """ tests that HTTPSPublicKeyRetriever supports https:// - base urls. + """tests that HTTPSPublicKeyRetriever supports https:// + base urls. """ self.create_retriever(self.base_url) - @mock.patch.object(requests.Session, 'get') + @mock.patch.object(requests.Session, "get") def test_retrieve(self, mock_get_method): - """ tests that the retrieve method works expected. """ - _setup_mock_response_for_retriever( - mock_get_method, self._public_key_pem) + """tests that the retrieve method works expected.""" + _setup_mock_response_for_retriever(mock_get_method, self._public_key_pem) retriever = self.create_retriever(self.base_url) - self.assertEqual( - retriever.retrieve('example/eg'), - self._public_key_pem) + self.assertEqual(retriever.retrieve("example/eg"), self._public_key_pem) - @mock.patch.object(requests.Session, 'get') + @mock.patch.object(requests.Session, "get") def test_retrieve_with_proxy(self, mock_get_method): - """ tests that the retrieve method works as expected when a proxy - should be used. + """tests that the retrieve method works as expected when a proxy + should be used. """ - proxy_location = 'https://example.proxy' - key_id = 'example/eg' - expected_proxies, proxy_dict = get_expected_and_os_proxies_dict( - proxy_location) - _setup_mock_response_for_retriever( - mock_get_method, self._public_key_pem) + proxy_location = "https://example.proxy" + key_id = "example/eg" + expected_proxies, proxy_dict = get_expected_and_os_proxies_dict(proxy_location) + _setup_mock_response_for_retriever(mock_get_method, self._public_key_pem) with mock.patch.dict(os.environ, proxy_dict, clear=True): retriever = self.create_retriever(self.base_url) retriever.retrieve(key_id) mock_get_method.assert_called_once_with( - '%s/%s' % (self.base_url, key_id), - headers={'accept': PEM_FILE_TYPE}, - proxies=expected_proxies + "%s/%s" % (self.base_url, key_id), + headers={"accept": PEM_FILE_TYPE}, + proxies=expected_proxies, ) - @mock.patch.object(requests.Session, 'get') + @mock.patch.object(requests.Session, "get") def test_retrieve_with_proxy_explicitly_set(self, mock_get_method): - """ tests that the retrieve method works as expected when a proxy - should be used and has been explicitly provided. + """tests that the retrieve method works as expected when a proxy + should be used and has been explicitly provided. """ - proxy_location = 'https://example.proxy' - explicit_proxy_location = 'https://explicit.proxy' - key_id = 'example/eg' + proxy_location = "https://example.proxy" + explicit_proxy_location = "https://explicit.proxy" + key_id = "example/eg" _, proxy_dict = get_expected_and_os_proxies_dict(proxy_location) - expected_proxies, _ = get_expected_and_os_proxies_dict( - explicit_proxy_location) - _setup_mock_response_for_retriever( - mock_get_method, self._public_key_pem) + expected_proxies, _ = get_expected_and_os_proxies_dict(explicit_proxy_location) + _setup_mock_response_for_retriever(mock_get_method, self._public_key_pem) with mock.patch.dict(os.environ, proxy_dict, clear=True): retriever = self.create_retriever(self.base_url) retriever.retrieve(key_id, proxies=expected_proxies) mock_get_method.assert_called_once_with( - '%s/%s' % (self.base_url, key_id), - headers={'accept': PEM_FILE_TYPE}, - proxies=expected_proxies + "%s/%s" % (self.base_url, key_id), + headers={"accept": PEM_FILE_TYPE}, + proxies=expected_proxies, ) - @mock.patch.object(requests.Session, 'get') + @mock.patch.object(requests.Session, "get") def test_retrieve_with_charset_in_content_type_h(self, mock_get_method): - """ tests that the retrieve method works expected when there is - a charset in the response content-type header. + """tests that the retrieve method works expected when there is + a charset in the response content-type header. """ - headers = {'content-type': 'application/x-pem-file;charset=UTF-8'} + headers = {"content-type": "application/x-pem-file;charset=UTF-8"} _setup_mock_response_for_retriever( - mock_get_method, self._public_key_pem, headers) + mock_get_method, self._public_key_pem, headers + ) retriever = self.create_retriever(self.base_url) - self.assertEqual( - retriever.retrieve('example/eg'), - self._public_key_pem) + self.assertEqual(retriever.retrieve("example/eg"), self._public_key_pem) - @mock.patch.object(requests.Session, 'get') + @mock.patch.object(requests.Session, "get") def test_retrieve_fails_with_different_content_type(self, mock_get_method): - """ tests that the retrieve method fails when the response is for a - media type that is not supported. + """tests that the retrieve method fails when the response is for a + media type that is not supported. """ - headers = {'content-type': 'different/not-supported'} + headers = {"content-type": "different/not-supported"} _setup_mock_response_for_retriever( - mock_get_method, self._public_key_pem, headers) + mock_get_method, self._public_key_pem, headers + ) retriever = self.create_retriever(self.base_url) with self.assertRaises(ValueError): - retriever.retrieve('example/eg') - - @mock.patch.object(requests.Session, 'get', - side_effect=requests.exceptions.HTTPError( - mock.Mock(response=mock.Mock(status_code=403)), - 'forbidden')) + retriever.retrieve("example/eg") + + @mock.patch.object( + requests.Session, + "get", + side_effect=requests.exceptions.HTTPError( + mock.Mock(response=mock.Mock(status_code=403)), "forbidden" + ), + ) def test_retrieve_fails_with_forbidden_error(self, mock_get_method): - """ tests that the retrieve method fails when the response is an + """tests that the retrieve method fails when the response is an 403 forbidden error. """ - _setup_mock_response_for_retriever( - mock_get_method, self._public_key_pem) + _setup_mock_response_for_retriever(mock_get_method, self._public_key_pem) retriever = self.create_retriever(self.base_url) with self.assertRaises(ValueError): - retriever.retrieve('example/eg') - + retriever.retrieve("example/eg") -class CachedHTTPPublicKeyRetrieverTest(utils.ES256KeyTestMixin, - unittest.TestCase): +class CachedHTTPPublicKeyRetrieverTest(utils.ES256KeyTestMixin, unittest.TestCase): class HTTPPublicKeyRetriever(HTTPSPublicKeyRetriever): """A subclass of HTTPSPublicKeyRetriever that allows us to use plain HTTP during testing so we don't have to run an actual SSL server. @@ -179,116 +169,117 @@ class HTTPPublicKeyRetriever(HTTPSPublicKeyRetriever): def __init__(self, base_url): # pretend to the super class that this is an HTTPS url - super(CachedHTTPPublicKeyRetrieverTest.HTTPPublicKeyRetriever, - self).__init__( - re.sub(r'^http', 'https', base_url, flags=re.IGNORECASE)) + super( + CachedHTTPPublicKeyRetrieverTest.HTTPPublicKeyRetriever, self + ).__init__(re.sub(r"^http", "https", base_url, flags=re.IGNORECASE)) self.base_url = base_url def setUp(self): super(CachedHTTPPublicKeyRetrieverTest, self).setUp() self._private_key_pem = self.get_new_private_key_in_pem_format() self._public_key_pem = utils.get_public_key_pem_for_private_key_pem( - self._private_key_pem) + self._private_key_pem + ) def test_http_caching(self): """Asserts that our use of requests properly caches keys between invocations across different `HTTPSPublicKeyRetriever` instances. """ + def wsgi(environ, start_response): - print(environ['PATH_INFO']) - start_response('200 OK', [ - ('content-type', 'application/x-pem-file;charset=UTF-8'), - ('Cache-Control', 'public,max-age=300,stale-while-revalidate=' - '300,stale-if-error=300'), - ('Last-Modified', 'Sun, 18 Jan 1970 18:14:21 GMT')]) + print(environ["PATH_INFO"]) + start_response( + "200 OK", + [ + ("content-type", "application/x-pem-file;charset=UTF-8"), + ( + "Cache-Control", + "public,max-age=300,stale-while-revalidate=" + "300,stale-if-error=300", + ), + ("Last-Modified", "Sun, 18 Jan 1970 18:14:21 GMT"), + ], + ) return [self._public_key_pem] with httptest.testserver(wsgi) as server: - retriever = self.HTTPPublicKeyRetriever(server.url()) - retriever.retrieve('example/eg') + retriever.retrieve("example/eg") retriever = self.HTTPPublicKeyRetriever(server.url()) - retriever.retrieve('example/eg') + retriever.retrieve("example/eg") - self.assertEqual(1, len(server.log()), - msg='HTTP caching should suppress second GET') + self.assertEqual( + 1, len(server.log()), msg="HTTP caching should suppress second GET" + ) -class BaseHTTPSMultiRepositoryPublicKeyRetrieverTest( - BaseHTTPSPublicKeyRetrieverTest): - """ tests for the HTTPSMultiRepositoryPublicKeyRetriever class. """ +class BaseHTTPSMultiRepositoryPublicKeyRetrieverTest(BaseHTTPSPublicKeyRetrieverTest): + """tests for the HTTPSMultiRepositoryPublicKeyRetriever class.""" def create_retriever(self, url): - """ returns a public key retriever created using the given url. """ + """returns a public key retriever created using the given url.""" return HTTPSMultiRepositoryPublicKeyRetriever([url]) def setUp(self): self._private_key_pem = self.get_new_private_key_in_pem_format() self._public_key_pem = utils.get_public_key_pem_for_private_key_pem( - self._private_key_pem) - self.keystore_urls = ['https://example.com', 'https://example.ly'] + self._private_key_pem + ) + self.keystore_urls = ["https://example.com", "https://example.ly"] self.base_url = self.keystore_urls[0] def test_https_multi_public_key_retriever_does_not_support_strings(self): - """ tests that HTTPSMultiRepositoryPublicKeyRetriever does not - support a string key repository url. + """tests that HTTPSMultiRepositoryPublicKeyRetriever does not + support a string key repository url. """ with self.assertRaises(TypeError): - HTTPSMultiRepositoryPublicKeyRetriever('https://example.com') + HTTPSMultiRepositoryPublicKeyRetriever("https://example.com") - @mock.patch.object(requests.Session, 'get') + @mock.patch.object(requests.Session, "get") def test_retrieve(self, mock_get_method): - """ tests that the retrieve method works expected. """ - _setup_mock_response_for_retriever( - mock_get_method, self._public_key_pem) + """tests that the retrieve method works expected.""" + _setup_mock_response_for_retriever(mock_get_method, self._public_key_pem) retriever = HTTPSMultiRepositoryPublicKeyRetriever(self.keystore_urls) - self.assertEqual( - retriever.retrieve('example/eg'), - self._public_key_pem) + self.assertEqual(retriever.retrieve("example/eg"), self._public_key_pem) - @mock.patch.object(requests.Session, 'get') + @mock.patch.object(requests.Session, "get") def test_retrieve_with_500_error(self, mock_get_method): - """ tests that the retrieve method works as expected - when the first key repository returns a server error response. + """tests that the retrieve method works as expected + when the first key repository returns a server error response. """ retriever = HTTPSMultiRepositoryPublicKeyRetriever(self.keystore_urls) - _setup_mock_response_for_retriever( - mock_get_method, self._public_key_pem) + _setup_mock_response_for_retriever(mock_get_method, self._public_key_pem) valid_response = mock_get_method.return_value del mock_get_method.return_value server_exception = requests.exceptions.HTTPError( - response=mock.Mock(status_code=500)) + response=mock.Mock(status_code=500) + ) mock_get_method.side_effect = [server_exception, valid_response] - self.assertEqual( - retriever.retrieve('example/eg'), - self._public_key_pem) + self.assertEqual(retriever.retrieve("example/eg"), self._public_key_pem) - @mock.patch.object(requests.Session, 'get') + @mock.patch.object(requests.Session, "get") def test_retrieve_with_connection_error(self, mock_get_method): - """ tests that the retrieve method works as expected - when the first key repository encounters a connection error. + """tests that the retrieve method works as expected + when the first key repository encounters a connection error. """ retriever = HTTPSMultiRepositoryPublicKeyRetriever(self.keystore_urls) - _setup_mock_response_for_retriever( - mock_get_method, self._public_key_pem) + _setup_mock_response_for_retriever(mock_get_method, self._public_key_pem) valid_response = mock_get_method.return_value del mock_get_method.return_value connection_exception = requests.exceptions.ConnectionError( - response=mock.Mock(status_code=None)) + response=mock.Mock(status_code=None) + ) mock_get_method.side_effect = [connection_exception, valid_response] - self.assertEqual( - retriever.retrieve('example/eg'), - self._public_key_pem) + self.assertEqual(retriever.retrieve("example/eg"), self._public_key_pem) -def _setup_mock_response_for_retriever( - mock_method, public_key_pem, headers=None): - """ returns a setup mock response for use with a https public key - retriever. +def _setup_mock_response_for_retriever(mock_method, public_key_pem, headers=None): + """returns a setup mock response for use with a https public key + retriever. """ if headers is None: - headers = {'content-type': 'application/x-pem-file'} + headers = {"content-type": "application/x-pem-file"} mock_response = mock.Mock() mock_response.headers = headers mock_response.text = public_key_pem @@ -296,27 +287,29 @@ def _setup_mock_response_for_retriever( return mock_method -class HTTPSPublicKeyRetrieverRS256Test(BaseHTTPSPublicKeyRetrieverTest, - utils.RS256KeyTestMixin, - unittest.TestCase): +class HTTPSPublicKeyRetrieverRS256Test( + BaseHTTPSPublicKeyRetrieverTest, utils.RS256KeyTestMixin, unittest.TestCase +): pass -class HTTPSPublicKeyRetrieverES256Test(BaseHTTPSPublicKeyRetrieverTest, - utils.ES256KeyTestMixin, - unittest.TestCase): +class HTTPSPublicKeyRetrieverES256Test( + BaseHTTPSPublicKeyRetrieverTest, utils.ES256KeyTestMixin, unittest.TestCase +): pass class HTTPSMultiRepositoryPublicKeyRetrieverRS256Test( - BaseHTTPSMultiRepositoryPublicKeyRetrieverTest, - utils.RS256KeyTestMixin, - unittest.TestCase): + BaseHTTPSMultiRepositoryPublicKeyRetrieverTest, + utils.RS256KeyTestMixin, + unittest.TestCase, +): pass class HTTPSMultiRepositoryPublicKeyRetrieverES256Test( - BaseHTTPSMultiRepositoryPublicKeyRetrieverTest, - utils.ES256KeyTestMixin, - unittest.TestCase): + BaseHTTPSMultiRepositoryPublicKeyRetrieverTest, + utils.ES256KeyTestMixin, + unittest.TestCase, +): pass diff --git a/atlassian_jwt_auth/tests/test_signer.py b/atlassian_jwt_auth/tests/test_signer.py index 5d83145..c577456 100644 --- a/atlassian_jwt_auth/tests/test_signer.py +++ b/atlassian_jwt_auth/tests/test_signer.py @@ -9,60 +9,59 @@ class BaseJWTAuthSignerTest(object): - - """ tests for the JWTAuthSigner class. """ + """tests for the JWTAuthSigner class.""" def setUp(self): self._private_key_pem = self.get_new_private_key_in_pem_format() def test__generate_claims(self): - """ tests that _generate_claims works as expected. """ + """tests that _generate_claims works as expected.""" expected_now = datetime.datetime(year=2001, day=1, month=1) - expected_audience = 'example_aud' - expected_iss = 'eg' - expected_key_id = 'eg/ex' + expected_audience = "example_aud" + expected_iss = "eg" + expected_key_id = "eg/ex" jwt_auth_signer = atlassian_jwt_auth.create_signer( - expected_iss, - expected_key_id, - self._private_key_pem) + expected_iss, expected_key_id, self._private_key_pem + ) jwt_auth_signer._now = lambda: expected_now - for additional_claims in [{}, {'extra': 'thing'}]: + for additional_claims in [{}, {"extra": "thing"}]: expected_claims = { - 'iss': expected_iss, - 'exp': expected_now + datetime.timedelta(minutes=1), - 'iat': expected_now, - 'aud': expected_audience, - 'nbf': expected_now, - 'sub': expected_iss, + "iss": expected_iss, + "exp": expected_now + datetime.timedelta(minutes=1), + "iat": expected_now, + "aud": expected_audience, + "nbf": expected_now, + "sub": expected_iss, } expected_claims.update(additional_claims) claims = jwt_auth_signer._generate_claims( - expected_audience, - additional_claims=additional_claims) - self.assertIsNotNone(claims['jti']) - del claims['jti'] + expected_audience, additional_claims=additional_claims + ) + self.assertIsNotNone(claims["jti"]) + del claims["jti"] self.assertEqual(claims, expected_claims) def test_jti_changes(self): - """ tests that the jti of a claim changes. """ + """tests that the jti of a claim changes.""" expected_now = datetime.datetime(year=2001, day=1, month=1) - aud = 'aud' + aud = "aud" jwt_auth_signer = utils.get_example_jwt_auth_signer( - algorithm=self.algorithm, private_key_pem=self._private_key_pem) + algorithm=self.algorithm, private_key_pem=self._private_key_pem + ) jwt_auth_signer._now = lambda: expected_now - first = jwt_auth_signer._generate_claims(aud)['jti'] - second = jwt_auth_signer._generate_claims(aud)['jti'] + first = jwt_auth_signer._generate_claims(aud)["jti"] + second = jwt_auth_signer._generate_claims(aud)["jti"] self.assertNotEqual(first, second) - self.assertTrue(str(expected_now.strftime('%s')) in first) - self.assertTrue(str(expected_now.strftime('%s')) in second) + self.assertTrue(str(expected_now.strftime("%s")) in first) + self.assertTrue(str(expected_now.strftime("%s")) in second) - @mock.patch('jwt.encode') + @mock.patch("jwt.encode") def test_generate_jwt(self, m_jwt_encode): - """ tests that generate_jwt works as expected. """ - expected_aud = 'aud_x' - expected_claims = {'eg': 'ex'} - expected_key_id = 'key_id' - expected_issuer = 'a_issuer' + """tests that generate_jwt works as expected.""" + expected_aud = "aud_x" + expected_claims = {"eg": "ex"} + expected_key_id = "key_id" + expected_issuer = "a_issuer" jwt_auth_signer = atlassian_jwt_auth.create_signer( expected_issuer, expected_key_id, @@ -75,28 +74,27 @@ def test_generate_jwt(self, m_jwt_encode): expected_claims, key=mock.ANY, algorithm=self.algorithm, - headers={'kid': expected_key_id}) + headers={"kid": expected_key_id}, + ) for name, args, kwargs in m_jwt_encode.mock_calls: if not kwargs: - self.assertEqual(args[0], 'utf-8') + self.assertEqual(args[0], "utf-8") continue - call_private_key = kwargs['key'].private_bytes( + call_private_key = kwargs["key"].private_bytes( encoding=serialization.Encoding.PEM, format=serialization.PrivateFormat.TraditionalOpenSSL, - encryption_algorithm=serialization.NoEncryption() + encryption_algorithm=serialization.NoEncryption(), ) self.assertEqual(call_private_key, self._private_key_pem) class JWTAuthSignerRS256Test( - BaseJWTAuthSignerTest, - utils.RS256KeyTestMixin, - unittest.TestCase): + BaseJWTAuthSignerTest, utils.RS256KeyTestMixin, unittest.TestCase +): pass class JWTAuthSignerES256Test( - BaseJWTAuthSignerTest, - utils.ES256KeyTestMixin, - unittest.TestCase): + BaseJWTAuthSignerTest, utils.ES256KeyTestMixin, unittest.TestCase +): pass diff --git a/atlassian_jwt_auth/tests/test_signer_private_key_repo.py b/atlassian_jwt_auth/tests/test_signer_private_key_repo.py index f1cfdb6..9bb519b 100644 --- a/atlassian_jwt_auth/tests/test_signer_private_key_repo.py +++ b/atlassian_jwt_auth/tests/test_signer_private_key_repo.py @@ -3,33 +3,30 @@ import tempfile import unittest - import atlassian_jwt_auth from atlassian_jwt_auth import key from atlassian_jwt_auth.tests import utils class BaseJWTAuthSignerWithFilePrivateKeyRetrieverTest(object): - - """ tests for the JWTAuthSigner using the FilePrivateKeyRetriever. """ + """tests for the JWTAuthSigner using the FilePrivateKeyRetriever.""" def setUp(self): - self.test_dir = tempfile.mkdtemp(prefix='atlassian-jwt-p-tests') - self.key_dir = os.path.join(self.test_dir, 'jwtprivatekeys') - for dir in ['invalid-issuer', 'issuer-with-many-keys', - 'valid-issuer']: + self.test_dir = tempfile.mkdtemp(prefix="atlassian-jwt-p-tests") + self.key_dir = os.path.join(self.test_dir, "jwtprivatekeys") + for dir in ["invalid-issuer", "issuer-with-many-keys", "valid-issuer"]: os.makedirs(os.path.join(self.key_dir, dir)) self._private_key_pem = self.get_new_private_key_in_pem_format() for file_loc in [ - 'invalid-issuer/key-tests-pem.new', - 'issuer-with-many-keys/key1.pem.new', - 'issuer-with-many-keys/key2.pem', - 'issuer-with-many-keys/key3.pem', - 'issuer-with-many-keys/key4.pem.new', - 'valid-issuer/key-for-tests.pem' + "invalid-issuer/key-tests-pem.new", + "issuer-with-many-keys/key1.pem.new", + "issuer-with-many-keys/key2.pem", + "issuer-with-many-keys/key3.pem", + "issuer-with-many-keys/key4.pem.new", + "valid-issuer/key-for-tests.pem", ]: file_location = os.path.join(self.key_dir, file_loc) - with open(file_location, 'wb') as f: + with open(file_location, "wb") as f: f.write(self._private_key_pem) def tearDown(self): @@ -37,42 +34,44 @@ def tearDown(self): shutil.rmtree(self.test_dir) def create_signer_for_issuer(self, issuer): - return \ - atlassian_jwt_auth.create_signer_from_file_private_key_repository( - issuer, self.key_dir, algorithm=self.algorithm) + return atlassian_jwt_auth.create_signer_from_file_private_key_repository( + issuer, self.key_dir, algorithm=self.algorithm + ) def test_succeeds_if_issuer_has_one_valid_key(self): - signer = self.create_signer_for_issuer('valid-issuer') - token = signer.generate_jwt('audience') + signer = self.create_signer_for_issuer("valid-issuer") + token = signer.generate_jwt("audience") self.assertIsNotNone(token) def test_picks_last_valid_key_id(self): - signer = self.create_signer_for_issuer('issuer-with-many-keys') - token = signer.generate_jwt('audience') + signer = self.create_signer_for_issuer("issuer-with-many-keys") + token = signer.generate_jwt("audience") key_identifier = key._get_key_id_from_jwt_header(token) - expected_key_id = 'issuer-with-many-keys/key3.pem' + expected_key_id = "issuer-with-many-keys/key3.pem" self.assertEqual(key_identifier.key_id, expected_key_id) def test_fails_if_issuer_has_no_valid_keys(self): - signer = self.create_signer_for_issuer('invalid-issuer') - with self.assertRaisesRegex(IOError, 'Issuer has no valid keys'): - signer.generate_jwt('audience') + signer = self.create_signer_for_issuer("invalid-issuer") + with self.assertRaisesRegex(IOError, "Issuer has no valid keys"): + signer.generate_jwt("audience") def test_fails_if_issuer_does_not_exist(self): - signer = self.create_signer_for_issuer('this-does-not-exist') - with self.assertRaisesRegex(OSError, 'No such file or directory'): - signer.generate_jwt('audience') + signer = self.create_signer_for_issuer("this-does-not-exist") + with self.assertRaisesRegex(OSError, "No such file or directory"): + signer.generate_jwt("audience") class JWTAuthSignerWithFilePrivateKeyRetrieverRS256Test( - BaseJWTAuthSignerWithFilePrivateKeyRetrieverTest, - utils.RS256KeyTestMixin, - unittest.TestCase): + BaseJWTAuthSignerWithFilePrivateKeyRetrieverTest, + utils.RS256KeyTestMixin, + unittest.TestCase, +): pass class JWTAuthSignerWithFilePrivateKeyRetrieverES256Test( - BaseJWTAuthSignerWithFilePrivateKeyRetrieverTest, - utils.ES256KeyTestMixin, - unittest.TestCase): + BaseJWTAuthSignerWithFilePrivateKeyRetrieverTest, + utils.ES256KeyTestMixin, + unittest.TestCase, +): pass diff --git a/atlassian_jwt_auth/tests/test_verifier.py b/atlassian_jwt_auth/tests/test_verifier.py index 670d044..87aa9f5 100644 --- a/atlassian_jwt_auth/tests/test_verifier.py +++ b/atlassian_jwt_auth/tests/test_verifier.py @@ -14,37 +14,37 @@ class NoneAlgorithmJwtAuthSigner(atlassian_jwt_auth.signer.JWTAuthSigner): - """ A JWTAuthSigner that generates JWTs using the none algorithm - and supports specifying arbitrary alg jwt header values. + """A JWTAuthSigner that generates JWTs using the none algorithm + and supports specifying arbitrary alg jwt header values. """ def generate_jwt(self, audience, **kwargs): - alg_header = kwargs.get('alg_header', 'none') - key_identifier, private_key_pem = self.private_key_retriever.load( - self.issuer) - return jwt.encode(self._generate_claims(audience, **kwargs), - algorithm=None, - key=None, - headers={'kid': key_identifier.key_id, - 'alg': alg_header}) + alg_header = kwargs.get("alg_header", "none") + key_identifier, private_key_pem = self.private_key_retriever.load(self.issuer) + return jwt.encode( + self._generate_claims(audience, **kwargs), + algorithm=None, + key=None, + headers={"kid": key_identifier.key_id, "alg": alg_header}, + ) class BaseJWTAuthVerifierTest(object): - - """ tests for the JWTAuthVerifier class. """ + """tests for the JWTAuthVerifier class.""" def setUp(self): self._private_key_pem = self.get_new_private_key_in_pem_format() self._public_key_pem = utils.get_public_key_pem_for_private_key_pem( - self._private_key_pem) - self._example_aud = 'aud_x' - self._example_issuer = 'egissuer' - self._example_key_id = '%s/a' % self._example_issuer + self._private_key_pem + ) + self._example_aud = "aud_x" + self._example_issuer = "egissuer" + self._example_key_id = "%s/a" % self._example_issuer self._jwt_auth_signer = atlassian_jwt_auth.create_signer( self._example_issuer, self._example_key_id, self._private_key_pem.decode(), - algorithm=self.algorithm + algorithm=self.algorithm, ) def _setup_mock_public_key_retriever(self, pub_key_pem): @@ -57,61 +57,61 @@ def _setup_jwt_auth_verifier(self, pub_key_pem, **kwargs): return atlassian_jwt_auth.JWTAuthVerifier(m_public_key_ret, **kwargs) def test_verify_jwt_with_valid_jwt(self): - """ test that verify_jwt verifies a valid jwt. """ + """test that verify_jwt verifies a valid jwt.""" verifier = self._setup_jwt_auth_verifier(self._public_key_pem) - signed_jwt = self._jwt_auth_signer.generate_jwt( - self._example_aud) + signed_jwt = self._jwt_auth_signer.generate_jwt(self._example_aud) v_claims = verifier.verify_jwt(signed_jwt, self._example_aud) self.assertIsNotNone(v_claims) - self.assertEqual(v_claims['aud'], self._example_aud) - self.assertEqual(v_claims['iss'], self._example_issuer) + self.assertEqual(v_claims["aud"], self._example_aud) + self.assertEqual(v_claims["iss"], self._example_issuer) def test_verify_jwt_with_none_algorithm(self): - """ tests that verify_jwt does not accept jwt that use the none - algorithm. + """tests that verify_jwt does not accept jwt that use the none + algorithm. """ verifier = self._setup_jwt_auth_verifier(self._public_key_pem) private_key_ret = atlassian_jwt_auth.key.StaticPrivateKeyRetriever( - self._example_key_id, self._private_key_pem.decode()) + self._example_key_id, self._private_key_pem.decode() + ) jwt_signer = NoneAlgorithmJwtAuthSigner( issuer=self._example_issuer, private_key_retriever=private_key_ret, ) - for algorithm in ['none', 'None', 'nOne', 'nonE', 'NONE']: - if algorithm != 'none': - jwt.register_algorithm( - algorithm, jwt.algorithms.NoneAlgorithm()) - jwt_token = jwt_signer.generate_jwt( - self._example_aud, alg_header=algorithm) - if algorithm != 'none': + for algorithm in ["none", "None", "nOne", "nonE", "NONE"]: + if algorithm != "none": + jwt.register_algorithm(algorithm, jwt.algorithms.NoneAlgorithm()) + jwt_token = jwt_signer.generate_jwt(self._example_aud, alg_header=algorithm) + if algorithm != "none": jwt.unregister_algorithm(algorithm) jwt_headers = jwt.get_unverified_header(jwt_token) - self.assertEqual(jwt_headers['alg'], algorithm) + self.assertEqual(jwt_headers["alg"], algorithm) with self.assertRaises(jwt.exceptions.InvalidAlgorithmError): verifier.verify_jwt(jwt_token, self._example_aud) def test_verify_jwt_with_key_identifier_not_starting_with_issuer(self): - """ tests that verify_jwt rejects a jwt if the key identifier does - not start with the claimed issuer. + """tests that verify_jwt rejects a jwt if the key identifier does + not start with the claimed issuer. """ verifier = self._setup_jwt_auth_verifier(self._public_key_pem) signer = atlassian_jwt_auth.create_signer( - 'issuer', 'issuerx', self._private_key_pem.decode(), + "issuer", + "issuerx", + self._private_key_pem.decode(), algorithm=self.algorithm, ) a_jwt = signer.generate_jwt(self._example_aud) - with self.assertRaisesRegex(ValueError, 'Issuer does not own'): + with self.assertRaisesRegex(ValueError, "Issuer does not own"): verifier.verify_jwt(a_jwt, self._example_aud) - @mock.patch('atlassian_jwt_auth.verifier.jwt.decode') + @mock.patch("atlassian_jwt_auth.verifier.jwt.decode") def test_verify_jwt_with_non_matching_sub_and_iss(self, m_j_decode): - """ tests that verify_jwt rejects a jwt if the claims - contains a subject which does not match the issuer. + """tests that verify_jwt rejects a jwt if the claims + contains a subject which does not match the issuer. """ - expected_msg = 'Issuer does not match the subject' + expected_msg = "Issuer does not match the subject" m_j_decode.return_value = { - 'iss': self._example_issuer, - 'sub': self._example_issuer[::-1] + "iss": self._example_issuer, + "sub": self._example_issuer[::-1], } a_jwt = self._jwt_auth_signer.generate_jwt(self._example_aud) verifier = self._setup_jwt_auth_verifier(self._public_key_pem) @@ -122,16 +122,16 @@ def test_verify_jwt_with_non_matching_sub_and_iss(self, m_j_decode): with self.assertRaisesRegex(exception, expected_msg): verifier.verify_jwt(a_jwt, self._example_aud) - @mock.patch('atlassian_jwt_auth.verifier.jwt.decode') + @mock.patch("atlassian_jwt_auth.verifier.jwt.decode") def test_verify_jwt_with_jwt_lasting_gt_max_time(self, m_j_decode): - """ tests that verify_jwt rejects a jwt if the claims - period of validity is greater than the allowed maximum. + """tests that verify_jwt rejects a jwt if the claims + period of validity is greater than the allowed maximum. """ - expected_msg = 'exceeds the maximum' + expected_msg = "exceeds the maximum" claims = self._jwt_auth_signer._generate_claims(self._example_aud) - claims['iat'] = claims['exp'] - datetime.timedelta(minutes=61) - for key in ['iat', 'exp']: - claims[key] = claims[key].strftime('%s') + claims["iat"] = claims["exp"] - datetime.timedelta(minutes=61) + for key in ["iat", "exp"]: + claims[key] = claims[key].strftime("%s") m_j_decode.return_value = claims a_jwt = self._jwt_auth_signer.generate_jwt(self._example_aud) verifier = self._setup_jwt_auth_verifier(self._public_key_pem) @@ -139,74 +139,74 @@ def test_verify_jwt_with_jwt_lasting_gt_max_time(self, m_j_decode): verifier.verify_jwt(a_jwt, self._example_aud) def test_verify_jwt_with_jwt_with_already_seen_jti(self): - """ tests that verify_jwt rejects a jwt if the jti - has already been seen. + """tests that verify_jwt rejects a jwt if the jti + has already been seen. """ verifier = self._setup_jwt_auth_verifier( - self._public_key_pem, check_jti_uniqueness=True) - a_jwt = self._jwt_auth_signer.generate_jwt( - self._example_aud) - self.assertIsNotNone(verifier.verify_jwt( - a_jwt, - self._example_aud)) + self._public_key_pem, check_jti_uniqueness=True + ) + a_jwt = self._jwt_auth_signer.generate_jwt(self._example_aud) + self.assertIsNotNone(verifier.verify_jwt(a_jwt, self._example_aud)) for exception in [ - ValueError, - atlassian_jwt_auth.exceptions.JtiUniquenessException]: - with self.assertRaisesRegex(exception, 'has already been used'): + ValueError, + atlassian_jwt_auth.exceptions.JtiUniquenessException, + ]: + with self.assertRaisesRegex(exception, "has already been used"): verifier.verify_jwt(a_jwt, self._example_aud) def assert_jwt_accepted_more_than_once(self, verifier, a_jwt): - """ asserts that the given jwt is accepted more than once. """ + """asserts that the given jwt is accepted more than once.""" for i in range(0, 3): - self.assertIsNotNone( - verifier.verify_jwt(a_jwt, self._example_aud)) + self.assertIsNotNone(verifier.verify_jwt(a_jwt, self._example_aud)) def test_verify_jwt_with_already_seen_jti_with_uniqueness_disabled(self): - """ tests that verify_jwt accepts a jwt if the jti - has already been seen and the verifier has been set - to not check the uniqueness of jti. + """tests that verify_jwt accepts a jwt if the jti + has already been seen and the verifier has been set + to not check the uniqueness of jti. """ verifier = self._setup_jwt_auth_verifier( - self._public_key_pem, check_jti_uniqueness=False) + self._public_key_pem, check_jti_uniqueness=False + ) a_jwt = self._jwt_auth_signer.generate_jwt(self._example_aud) self.assert_jwt_accepted_more_than_once(verifier, a_jwt) def test_verify_jwt_with_already_seen_jti_default(self): - """ tests that verify_jwt by default accepts a jwt if the jti - has already been seen. + """tests that verify_jwt by default accepts a jwt if the jti + has already been seen. """ - verifier = self._setup_jwt_auth_verifier( - self._public_key_pem) + verifier = self._setup_jwt_auth_verifier(self._public_key_pem) a_jwt = self._jwt_auth_signer.generate_jwt(self._example_aud) self.assert_jwt_accepted_more_than_once(verifier, a_jwt) def test_verify_jwt_subject_should_match_issuer(self): verifier = self._setup_jwt_auth_verifier( - self._public_key_pem, subject_should_match_issuer=True) + self._public_key_pem, subject_should_match_issuer=True + ) a_jwt = self._jwt_auth_signer.generate_jwt( - self._example_aud, - additional_claims={'sub': 'not-' + self._example_issuer}) - with self.assertRaisesRegex(ValueError, - 'Issuer does not match the subject.'): + self._example_aud, additional_claims={"sub": "not-" + self._example_issuer} + ) + with self.assertRaisesRegex(ValueError, "Issuer does not match the subject."): verifier.verify_jwt(a_jwt, self._example_aud) def test_verify_jwt_subject_does_not_need_to_match_issuer(self): verifier = self._setup_jwt_auth_verifier( - self._public_key_pem, subject_should_match_issuer=False) + self._public_key_pem, subject_should_match_issuer=False + ) a_jwt = self._jwt_auth_signer.generate_jwt( - self._example_aud, - additional_claims={'sub': 'not-' + self._example_issuer}) + self._example_aud, additional_claims={"sub": "not-" + self._example_issuer} + ) self.assertIsNotNone(verifier.verify_jwt(a_jwt, self._example_aud)) - @mock.patch('atlassian_jwt_auth.verifier.jwt.decode') + @mock.patch("atlassian_jwt_auth.verifier.jwt.decode") def test_verify_jwt_with_missing_aud_claim(self, m_j_decode): - """ tests that verify_jwt rejects jwt that do not have an aud - claim. + """tests that verify_jwt rejects jwt that do not have an aud + claim. """ - expected_msg = ('Claims validity, the aud claim must be provided and ' - 'cannot be empty.') + expected_msg = ( + "Claims validity, the aud claim must be provided and cannot be empty." + ) claims = self._jwt_auth_signer._generate_claims(self._example_aud) - del claims['aud'] + del claims["aud"] m_j_decode.return_value = claims a_jwt = self._jwt_auth_signer.generate_jwt(self._example_aud) verifier = self._setup_jwt_auth_verifier(self._public_key_pem) @@ -214,39 +214,40 @@ def test_verify_jwt_with_missing_aud_claim(self, m_j_decode): verifier.verify_jwt(a_jwt, self._example_aud) def test_verify_jwt_with_none_aud(self): - """ tests that verify_jwt rejects jwt that have a None aud claim. """ + """tests that verify_jwt rejects jwt that have a None aud claim.""" verifier = self._setup_jwt_auth_verifier(self._public_key_pem) a_jwt = self._jwt_auth_signer.generate_jwt( - self._example_aud, - additional_claims={'aud': None}) - exceptions = (jwt.exceptions.InvalidAudienceError, - jwt.exceptions.InvalidTokenError) + self._example_aud, additional_claims={"aud": None} + ) + exceptions = ( + jwt.exceptions.InvalidAudienceError, + jwt.exceptions.InvalidTokenError, + ) with self.assertRaises(exceptions) as cm: verifier.verify_jwt(a_jwt, self._example_aud) if not isinstance(cm.exception, jwt.exceptions.InvalidAudienceError): - self.assertIn('aud', str(cm.exception)) + self.assertIn("aud", str(cm.exception)) def test_verify_jwt_with_non_matching_aud(self): - """ tests that verify_jwt rejects a jwt if the aud claim does not - match the given & expected audience. + """tests that verify_jwt rejects a jwt if the aud claim does not + match the given & expected audience. """ verifier = self._setup_jwt_auth_verifier(self._public_key_pem) a_jwt = self._jwt_auth_signer.generate_jwt( self._example_aud, - additional_claims={'aud': self._example_aud + '-different'}) + additional_claims={"aud": self._example_aud + "-different"}, + ) with self.assertRaises(jwt.exceptions.InvalidAudienceError): verifier.verify_jwt(a_jwt, self._example_aud) class JWTAuthVerifierRS256Test( - BaseJWTAuthVerifierTest, - utils.RS256KeyTestMixin, - unittest.TestCase): + BaseJWTAuthVerifierTest, utils.RS256KeyTestMixin, unittest.TestCase +): pass class JWTAuthVerifierES256Test( - BaseJWTAuthVerifierTest, - utils.ES256KeyTestMixin, - unittest.TestCase): + BaseJWTAuthVerifierTest, utils.ES256KeyTestMixin, unittest.TestCase +): pass diff --git a/atlassian_jwt_auth/tests/utils.py b/atlassian_jwt_auth/tests/utils.py index 4e58caf..75bf629 100644 --- a/atlassian_jwt_auth/tests/utils.py +++ b/atlassian_jwt_auth/tests/utils.py @@ -1,89 +1,82 @@ from cryptography.hazmat.backends import default_backend -from cryptography.hazmat.primitives.asymmetric import rsa -from cryptography.hazmat.primitives.asymmetric import ec from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric import ec, rsa import atlassian_jwt_auth def get_new_rsa_private_key_in_pem_format(): - """ returns a new rsa key in pem format. """ + """returns a new rsa key in pem format.""" private_key = rsa.generate_private_key( - key_size=2048, backend=default_backend(), public_exponent=65537) + key_size=2048, backend=default_backend(), public_exponent=65537 + ) return private_key.private_bytes( encoding=serialization.Encoding.PEM, format=serialization.PrivateFormat.TraditionalOpenSSL, - encryption_algorithm=serialization.NoEncryption() + encryption_algorithm=serialization.NoEncryption(), ) def get_public_key_pem_for_private_key_pem(private_key_pem): private_key = serialization.load_pem_private_key( - private_key_pem, - password=None, - backend=default_backend() + private_key_pem, password=None, backend=default_backend() ) public_key = private_key.public_key() return public_key.public_bytes( encoding=serialization.Encoding.PEM, - format=serialization.PublicFormat.SubjectPublicKeyInfo + format=serialization.PublicFormat.SubjectPublicKeyInfo, ) def get_example_jwt_auth_signer(**kwargs): - """ returns an example jwt_auth_signer instance. """ - issuer = kwargs.get('issuer', 'egissuer') - key_id = kwargs.get('key_id', '%s/a' % issuer) - key = kwargs.get( - 'private_key_pem', get_new_rsa_private_key_in_pem_format()) - algorithm = kwargs.get('algorithm', 'RS256') - return atlassian_jwt_auth.create_signer( - issuer, key_id, key, algorithm=algorithm) + """returns an example jwt_auth_signer instance.""" + issuer = kwargs.get("issuer", "egissuer") + key_id = kwargs.get("key_id", "%s/a" % issuer) + key = kwargs.get("private_key_pem", get_new_rsa_private_key_in_pem_format()) + algorithm = kwargs.get("algorithm", "RS256") + return atlassian_jwt_auth.create_signer(issuer, key_id, key, algorithm=algorithm) def create_token(issuer, audience, key_id, private_key, subject=None): - """" returns a token based upon the supplied parameters. """ + """ " returns a token based upon the supplied parameters.""" signer = atlassian_jwt_auth.create_signer( - issuer, key_id, private_key, subject=subject) + issuer, key_id, private_key, subject=subject + ) return signer.generate_jwt(audience) class BaseJWTAlgorithmTestMixin(object): - - """ A mixin class to make testing different support for different - jwt algorithms easier. + """A mixin class to make testing different support for different + jwt algorithms easier. """ def get_new_private_key_in_pem_format(self): - """ returns a new private key in pem format. """ + """returns a new private key in pem format.""" raise NotImplementedError("not implemented.") class RS256KeyTestMixin(object): - - """ Private rs256 test mixin. """ + """Private rs256 test mixin.""" @property def algorithm(self): - return 'RS256' + return "RS256" def get_new_private_key_in_pem_format(self): return get_new_rsa_private_key_in_pem_format() class ES256KeyTestMixin(object): - - """ Private es256 test mixin. """ + """Private es256 test mixin.""" @property def algorithm(self): - return 'ES256' + return "ES256" def get_new_private_key_in_pem_format(self): - private_key = ec.generate_private_key( - ec.SECP256R1(), default_backend()) + private_key = ec.generate_private_key(ec.SECP256R1(), default_backend()) return private_key.private_bytes( encoding=serialization.Encoding.PEM, format=serialization.PrivateFormat.TraditionalOpenSSL, - encryption_algorithm=serialization.NoEncryption() + encryption_algorithm=serialization.NoEncryption(), ) diff --git a/atlassian_jwt_auth/verifier.py b/atlassian_jwt_auth/verifier.py index c34367f..ccae9ea 100644 --- a/atlassian_jwt_auth/verifier.py +++ b/atlassian_jwt_auth/verifier.py @@ -3,49 +3,41 @@ import jwt import jwt.api_jwt -from cryptography.hazmat.primitives.asymmetric.ec import ( - EllipticCurvePublicKey -) -from cryptography.hazmat.primitives.asymmetric.rsa import ( - RSAPublicKey -) +from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurvePublicKey +from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicKey from jwt.exceptions import InvalidAlgorithmError -from atlassian_jwt_auth import algorithms -from atlassian_jwt_auth import key -from atlassian_jwt_auth import exceptions +from atlassian_jwt_auth import algorithms, exceptions, key @lru_cache(maxsize=10) def _load_public_key(algorithms, public_key, algorithm): - """ Returns a public key object instance given the public key and - algorithm. + """Returns a public key object instance given the public key and + algorithm. - This has been extracted out of JWTAuthVerifier to avoid possible memory - leaks via retained instance references. + This has been extracted out of JWTAuthVerifier to avoid possible memory + leaks via retained instance references. """ if isinstance(public_key, (RSAPublicKey, EllipticCurvePublicKey)): return public_key if algorithm not in algorithms: - raise InvalidAlgorithmError( - 'The specified alg value is not allowed') + raise InvalidAlgorithmError("The specified alg value is not allowed") py_jws = jwt.api_jws.PyJWS(algorithms=algorithms) alg_obj = py_jws._algorithms[algorithm] return alg_obj.prepare_key(public_key) class JWTAuthVerifier(object): - - """ This class can be used to verify a JWT. """ + """This class can be used to verify a JWT.""" def __init__(self, public_key_retriever, **kwargs): self.public_key_retriever = public_key_retriever self.algorithms = algorithms.get_permitted_algorithm_names() self._seen_jti = OrderedDict() self._subject_should_match_issuer = kwargs.get( - 'subject_should_match_issuer', True) - self._check_jti_uniqueness = kwargs.get( - 'check_jti_uniqueness', False) + "subject_should_match_issuer", True + ) + self._check_jti_uniqueness = kwargs.get("check_jti_uniqueness", False) def verify_jwt(self, a_jwt, audience, leeway=0, **requests_kwargs): """Verify if the token is correct @@ -59,30 +51,28 @@ def verify_jwt(self, a_jwt, audience, leeway=0, **requests_kwargs): key_identifier = key._get_key_id_from_jwt_header(a_jwt) public_key = self._retrieve_pub_key(key_identifier, requests_kwargs) - alg = jwt.get_unverified_header(a_jwt).get('alg', None) + alg = jwt.get_unverified_header(a_jwt).get("alg", None) public_key_obj = self._load_public_key(public_key, alg) return self._decode_jwt( - a_jwt, key_identifier, public_key_obj, - audience=audience, leeway=leeway) + a_jwt, key_identifier, public_key_obj, audience=audience, leeway=leeway + ) def _retrieve_pub_key(self, key_identifier, requests_kwargs): - return self.public_key_retriever.retrieve( - key_identifier, **requests_kwargs) + return self.public_key_retriever.retrieve(key_identifier, **requests_kwargs) def _load_public_key(self, public_key, algorithm): - """ Returns a public key object instance given the public key and - algorithm. + """Returns a public key object instance given the public key and + algorithm. """ return _load_public_key(tuple(self.algorithms), public_key, algorithm) - def _decode_jwt(self, a_jwt, key_identifier, jwt_key, - audience=None, leeway=0): + def _decode_jwt(self, a_jwt, key_identifier, jwt_key, audience=None, leeway=0): """Decode JWT and check if it's valid""" options = { - 'verify_signature': True, - 'require': ['exp', 'iat'], - 'require_exp': True, - 'require_iat': True, + "verify_signature": True, + "require": ["exp", "iat"], + "require_exp": True, + "require_iat": True, } claims = jwt.decode( @@ -91,29 +81,34 @@ def _decode_jwt(self, a_jwt, key_identifier, jwt_key, algorithms=self.algorithms, options=options, audience=audience, - leeway=leeway) + leeway=leeway, + ) - if (not key_identifier.key_id.startswith('%s/' % claims['iss']) and - key_identifier.key_id != claims['iss']): - raise ValueError('Issuer does not own the supplied public key') + if ( + not key_identifier.key_id.startswith("%s/" % claims["iss"]) + and key_identifier.key_id != claims["iss"] + ): + raise ValueError("Issuer does not own the supplied public key") if self._subject_should_match_issuer and ( - claims.get('sub') and claims['iss'] != claims['sub']): + claims.get("sub") and claims["iss"] != claims["sub"] + ): raise exceptions.SubjectDoesNotMatchIssuerException( - 'Issuer does not match the subject.') + "Issuer does not match the subject." + ) - _aud = claims.get('aud', None) + _aud = claims.get("aud", None) if _aud is None: - _msg = ("Claims validity, the aud claim must be provided and " - "cannot be empty.") + _msg = ( + "Claims validity, the aud claim must be provided and cannot be empty." + ) raise KeyError(_msg) - _exp = int(claims['exp']) - _iat = int(claims['iat']) + _exp = int(claims["exp"]) + _iat = int(claims["iat"]) if _exp - _iat > 3600: - _msg = ("Claims validity, '%s', exceeds the maximum 1 hour." % - (_exp - _iat)) + _msg = "Claims validity, '%s', exceeds the maximum 1 hour." % (_exp - _iat) raise ValueError(_msg) - _jti = claims['jti'] + _jti = claims["jti"] if self._check_jti_uniqueness: self._check_jti(_jti) return claims @@ -122,7 +117,8 @@ def _check_jti(self, jti): """Checks that the given jti has not been already been used.""" if jti in self._seen_jti: raise exceptions.JtiUniquenessException( - "The jti, '%s', has already been used." % jti) + "The jti, '%s', has already been used." % jti + ) self._seen_jti[jti] = None while len(self._seen_jti) > 1000: self._seen_jti.popitem(last=False) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..5199082 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,17 @@ +[tool.ruff] +# Same as Black. +line-length = 88 +indent-width = 4 + +[tool.ruff.format] +# Apply Black like config +quote-style = "double" +indent-style = "space" +skip-magic-trailing-comma = false +docstring-code-format = false +docstring-code-line-length = "dynamic" +line-ending = "auto" + +[tool.ruff.lint] +extend-select = ["I"] + diff --git a/setup.py b/setup.py index 595f8a4..6e64fbe 100644 --- a/setup.py +++ b/setup.py @@ -1,10 +1,9 @@ #!/usr/bin/env python from setuptools import setup - setup( - setup_requires=['pbr<7.0.0'], + setup_requires=["pbr<7.0.0"], pbr=True, - platforms=['any'], + platforms=["any"], zip_safe=False, )