11from __future__ import annotations
22
33import base64
4+ import binascii
45import hashlib
56import re
67import secrets
78import string
89import uuid
910from base64 import urlsafe_b64decode
1011from datetime import datetime
11- from json import loads
1212from typing import Any , Dict , Optional , Type , TypedDict , TypeVar , cast
1313from urllib .parse import urlparse
1414
1515from httpx import HTTPStatusError , Response
16- from pydantic import BaseModel
16+ from pydantic import BaseModel , TypeAdapter
1717
18- from .constants import API_VERSION_HEADER_NAME , API_VERSIONS , BASE64URL_REGEX
18+ from .constants import (
19+ API_VERSION_HEADER_NAME ,
20+ API_VERSIONS_2024_01_01_TIMESTAMP ,
21+ BASE64URL_REGEX ,
22+ )
1923from .errors import (
2024 AuthApiError ,
2125 AuthError ,
@@ -136,18 +140,9 @@ def get_error_message(error: Any) -> str:
136140 return next ((error [prop ] for prop in props if filter (prop )), str (error ))
137141
138142
139- def get_error_code (error : Any ) -> str :
140- return error .get ("error_code" , None ) if isinstance (error , dict ) else None
141-
142-
143- def looks_like_http_status_error (exception : Exception ) -> bool :
144- return isinstance (exception , HTTPStatusError )
145-
146-
147- def handle_exception (exception : Exception ) -> AuthError :
148- if not looks_like_http_status_error (exception ):
149- return AuthRetryableError (get_error_message (exception ), 0 )
150- error = cast (HTTPStatusError , exception )
143+ def handle_exception (error : HTTPStatusError | RuntimeError ) -> AuthError :
144+ if not isinstance (error , HTTPStatusError ):
145+ return AuthRetryableError (get_error_message (error ), 0 )
151146 try :
152147 network_error_codes = [502 , 503 , 504 ]
153148 if error .response .status_code in network_error_codes :
@@ -161,8 +156,10 @@ def handle_exception(exception: Exception) -> AuthError:
161156
162157 if (
163158 response_api_version
164- and datetime .timestamp (response_api_version )
165- >= API_VERSIONS .get ("2024-01-01" ).get ("timestamp" )
159+ and (
160+ datetime .timestamp (response_api_version )
161+ >= API_VERSIONS_2024_01_01_TIMESTAMP
162+ )
166163 and isinstance (data , dict )
167164 and data
168165 and isinstance (data .get ("code" ), str )
@@ -180,18 +177,18 @@ def handle_exception(exception: Exception) -> AuthError:
180177 and isinstance (data .get ("weak_password" ), dict )
181178 and data .get ("weak_password" )
182179 and isinstance (data .get ("weak_password" ), list )
183- and len (data . get ( "weak_password" ) )
180+ and len (data [ "weak_password" ] )
184181 ):
185182 return AuthWeakPasswordError (
186183 get_error_message (data ),
187184 error .response .status_code ,
188- data . get ( "weak_password" ) .get ("reasons" ),
185+ data [ "weak_password" ] .get ("reasons" ),
189186 )
190187 elif error_code == "weak_password" :
191188 return AuthWeakPasswordError (
192189 get_error_message (data ),
193190 error .response .status_code ,
194- data . get ( "weak_password" , {}) .get ("reasons" , {}),
191+ data [ "weak_password" ] .get ("reasons" , {}),
195192 )
196193
197194 return AuthApiError (
@@ -224,20 +221,26 @@ class DecodedJWT(TypedDict):
224221 raw : Dict [str , str ]
225222
226223
224+ JWTHeaderParser = TypeAdapter (JWTHeader )
225+ JWTPayloadParser = TypeAdapter (JWTPayload )
226+
227+
227228def decode_jwt (token : str ) -> DecodedJWT :
228229 parts = token .split ("." )
229230 if len (parts ) != 3 :
230231 raise AuthInvalidJwtError ("Invalid JWT structure" )
231232
232- # regex check for base64url
233- for part in parts :
234- if not re .match (BASE64URL_REGEX , part , re .IGNORECASE ):
235- raise AuthInvalidJwtError ("JWT not in base64url format" )
233+ try :
234+ header = base64url_to_bytes (parts [0 ])
235+ payload = base64url_to_bytes (parts [1 ])
236+ signature = base64url_to_bytes (parts [2 ])
237+ except binascii .Error :
238+ raise AuthInvalidJwtError ("Invalid JWT structure" )
236239
237240 return DecodedJWT (
238- header = JWTHeader ( ** loads ( str_from_base64url ( parts [ 0 ])) ),
239- payload = JWTPayload ( ** loads ( str_from_base64url ( parts [ 1 ])) ),
240- signature = base64url_to_bytes ( parts [ 2 ]) ,
241+ header = JWTHeaderParser . validate_json ( header ),
242+ payload = JWTPayloadParser . validate_json ( payload ),
243+ signature = signature ,
241244 raw = {
242245 "header" : parts [0 ],
243246 "payload" : parts [1 ],
0 commit comments