diff --git a/rest_framework_jwt/utils.py b/rest_framework_jwt/utils.py index c72197bc..e935fa54 100644 --- a/rest_framework_jwt/utils.py +++ b/rest_framework_jwt/utils.py @@ -23,7 +23,7 @@ def jwt_get_secret_key(payload=None): """ if api_settings.JWT_GET_USER_SECRET_KEY: User = get_user_model() # noqa: N806 - user = User.objects.get(pk=payload.get('user_id')) + user = User.objects.get(pk=api_settings.JWT_PAYLOAD_GET_USER_ID_HANDLER(payload)) key = str(api_settings.JWT_GET_USER_SECRET_KEY(user)) return key return api_settings.JWT_SECRET_KEY diff --git a/tests/test_utils.py b/tests/test_utils.py index 393c30ab..17589e85 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -4,11 +4,13 @@ import jwt.exceptions from django.test import TestCase +from django.conf import settings from rest_framework_jwt import utils from rest_framework_jwt.compat import get_user_model from rest_framework_jwt.settings import api_settings, DEFAULTS from tests.models import CustomUserWithoutEmail +from tests.utils import custom_get_user_id, custom_get_user_secret User = get_user_model() @@ -82,6 +84,27 @@ def test_jwt_decode_verify_exp(self): api_settings.JWT_VERIFY_EXPIRATION = True + def test_jwt_get_secret_key(self): + secret = utils.jwt_get_secret_key({'user_id': self.user.pk}) + self.assertEqual(secret, settings.SECRET_KEY) + + def test_jwt_get_secret_key_customer_secret_getter(self): + old = api_settings.JWT_GET_USER_SECRET_KEY + api_settings.JWT_GET_USER_SECRET_KEY = custom_get_user_secret + secret = utils.jwt_get_secret_key({'user_id': self.user.pk}) + api_settings.JWT_GET_USER_SECRET_KEY = old + self.assertEqual(secret, str(self.user.pk)) + + def test_jwt_get_secret_key_customer_id_and_secret_getter(self): + old = api_settings.JWT_GET_USER_SECRET_KEY + api_settings.JWT_GET_USER_SECRET_KEY = custom_get_user_secret + old2 = api_settings.JWT_PAYLOAD_GET_USER_ID_HANDLER + api_settings.JWT_PAYLOAD_GET_USER_ID_HANDLER = custom_get_user_id + secret = utils.jwt_get_secret_key({'custom_uid': self.user.pk}) + api_settings.JWT_PAYLOAD_GET_USER_ID_HANDLER = old2 + api_settings.JWT_GET_USER_SECRET_KEY = old + self.assertEqual(secret, str(self.user.pk)) + class TestAudience(TestCase): def setUp(self): diff --git a/tests/utils.py b/tests/utils.py index a72529bb..6f742e6f 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -24,3 +24,11 @@ def jwt_response_payload_handler(token, user=None, request=None): def get_jwt_secret(user): return user.jwt_secret + + +def custom_get_user_secret(user): + return user.pk + + +def custom_get_user_id(payload): + return payload.get('custom_uid', None)