Skip to content

Commit 2e7224c

Browse files
authored
Merge pull request #891 from liudonggalaxy/Dongliu/fix_invalid_key_id_error_handling
2 parents bf39e2f + 8e46995 commit 2e7224c

2 files changed

Lines changed: 99 additions & 2 deletions

File tree

authlib/oauth2/rfc9068/token_validator.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
"""
88

99
from joserfc import jwt
10-
from joserfc.errors import DecodeError
1110
from joserfc.errors import JoseError
1211

1312
from authlib._joserfc_helpers import import_any_key
@@ -115,7 +114,7 @@ def authenticate_token(self, token_string):
115114
try:
116115
token = jwt.decode(token_string, key=key)
117116
return JWTAccessTokenClaims(token.claims, token.header, claims_options)
118-
except DecodeError as exc:
117+
except JoseError as exc:
119118
raise InvalidTokenError(
120119
realm=self.realm, extra_attributes=self.extra_attributes
121120
) from exc
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
"""Tests for RFC 9068 JWTBearerTokenValidator.authenticate_token.
2+
3+
Verifies that all joserfc errors raised during JWT decoding — including
4+
InvalidKeyIdError — are converted to authlib's InvalidTokenError.
5+
"""
6+
7+
import time
8+
9+
import pytest
10+
from joserfc import jwt
11+
from joserfc.jwk import KeySet
12+
from joserfc.jwk import OctKey
13+
from joserfc.jwk import RSAKey
14+
15+
from authlib.oauth2.rfc6750.errors import InvalidTokenError
16+
from authlib.oauth2.rfc9068 import JWTBearerTokenValidator
17+
18+
ISSUER = "https://auth.example.com"
19+
RESOURCE_SERVER = "https://api.example.com"
20+
21+
22+
def _make_validator(key):
23+
class Validator(JWTBearerTokenValidator):
24+
def get_jwks(self):
25+
return key
26+
27+
return Validator(issuer=ISSUER, resource_server=RESOURCE_SERVER)
28+
29+
30+
def _encode_token(key, claims=None):
31+
base_claims = {
32+
"iss": ISSUER,
33+
"aud": RESOURCE_SERVER,
34+
"sub": "user-1",
35+
"client_id": "client-1",
36+
"iat": int(time.time()),
37+
"exp": int(time.time()) + 3600,
38+
"jti": "unique-jti",
39+
}
40+
if claims:
41+
base_claims.update(claims)
42+
return jwt.encode({"alg": "HS256"}, base_claims, key)
43+
44+
45+
def test_valid_token():
46+
key = OctKey.generate_key()
47+
validator = _make_validator(key)
48+
49+
token_string = _encode_token(key)
50+
token = validator.authenticate_token(token_string)
51+
52+
assert token is not None
53+
assert token["sub"] == "user-1"
54+
55+
56+
def test_invalid_signature_raises_invalid_token_error():
57+
signing_key = OctKey.generate_key()
58+
wrong_key = OctKey.generate_key()
59+
validator = _make_validator(wrong_key)
60+
61+
token_string = _encode_token(signing_key)
62+
63+
with pytest.raises(InvalidTokenError):
64+
validator.authenticate_token(token_string)
65+
66+
67+
def test_mismatched_kid_raises_invalid_token_error():
68+
"""InvalidKeyIdError (not a subclass of DecodeError) must be caught."""
69+
key = RSAKey.generate_key(2048, private=True, parameters={"kid": "key-1"})
70+
key_set = KeySet(keys=[key])
71+
validator = _make_validator(key_set)
72+
73+
# Encode a token with a kid that doesn't exist in the key set
74+
other_key = RSAKey.generate_key(2048, private=True, parameters={"kid": "key-999"})
75+
token_string = jwt.encode(
76+
{"alg": "RS256", "kid": "key-999"},
77+
{
78+
"iss": ISSUER,
79+
"aud": RESOURCE_SERVER,
80+
"sub": "user-1",
81+
"client_id": "client-1",
82+
"iat": int(time.time()),
83+
"exp": int(time.time()) + 3600,
84+
"jti": "unique-jti",
85+
},
86+
other_key,
87+
)
88+
89+
with pytest.raises(InvalidTokenError):
90+
validator.authenticate_token(token_string)
91+
92+
93+
def test_garbage_token_raises_invalid_token_error():
94+
key = OctKey.generate_key()
95+
validator = _make_validator(key)
96+
97+
with pytest.raises(InvalidTokenError):
98+
validator.authenticate_token("not.a.jwt")

0 commit comments

Comments
 (0)