Skip to content

Commit 1abf8c7

Browse files
committed
chore: use pydantic to parse jwt, instead of json.loads
1 parent df911cb commit 1abf8c7

File tree

12 files changed

+127
-68
lines changed

12 files changed

+127
-68
lines changed

src/auth/Makefile

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,57 @@
1-
tests: pytest
1+
help::
2+
@echo "Available commands"
3+
@echo " help -- (default) print this message"
4+
5+
tests: mypy pytest
6+
help::
7+
@echo " tests -- run all tests for supabase_auth"
28

39
pytest: start-infra
410
uv run --package supabase_auth pytest --cov=./ --cov-report=xml --cov-report=html -vv
511

12+
mypy:
13+
uv run --package supabase_auth mypy src/supabase_auth tests
14+
help::
15+
@echo " mypy -- run mypy on supabase_auth"
16+
617
start-infra:
718
cd infra &&\
819
docker compose down &&\
920
docker compose up -d
1021
sleep 2
22+
help::
23+
@echo " start-infra -- start containers for tests"
1124

1225
clean-infra:
1326
cd infra &&\
1427
docker compose down --remove-orphans &&\
1528
docker system prune -a --volumes -f
29+
help::
30+
@echo " clean-infra -- delete all stored information about the containers"
1631

1732
stop-infra:
1833
cd infra &&\
1934
docker compose down --remove-orphans
35+
help::
36+
@echo " stop-infra -- stop containers for tests"
2037

2138
sync-infra:
2239
uv run --package supabase_auth scripts/gh-download.py --repo=supabase/gotrue-js --branch=master --folder=infra
40+
help::
41+
@echo " sync-infra -- update locked versions for test containers"
2342

2443
build-sync:
2544
uv run --package supabase_auth scripts/run-unasync.py
45+
help::
46+
@echo " build-sync -- generate _sync from _async code"
2647

2748
clean:
2849
rm -rf htmlcov .pytest_cache .mypy_cache .ruff_cache
2950
rm -f .coverage coverage.xml
51+
help::
52+
@echo " clean -- clean intermediary files"
3053

3154
build:
3255
uv build --package supabase_auth
56+
help::
57+
@echo " build -- invoke uv build on supabase_auth package"

src/auth/pyproject.toml

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,9 @@ tests = [
4444
lints = [
4545
"ruff >=0.12.1",
4646
"unasync >= 0.6.0",
47+
"python-lsp-server (>=1.12.2,<2.0.0)",
48+
"pylsp-mypy (>=0.7.0,<0.8.0)",
49+
"python-lsp-ruff (>=2.2.2,<3.0.0)",
4750
]
4851
dev = [{ include-group = "lints" }, {include-group = "tests" }]
4952

@@ -76,3 +79,15 @@ asyncio_mode = "auto"
7679
[build-system]
7780
requires = ["uv_build>=0.8.3,<0.9.0"]
7881
build-backend = "uv_build"
82+
83+
[tool.mypy]
84+
python_version = "3.9"
85+
check_untyped_defs = true
86+
allow_redefinition = true
87+
follow_untyped_imports = true # for deprecation module that does not have stubs
88+
89+
no_warn_no_return = true
90+
warn_return_any = true
91+
warn_unused_configs = true
92+
warn_redundant_casts = true
93+
warn_unused_ignores = true

src/auth/scripts/run-unasync.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import unasync
44

5-
paths = Path("../src/supabase").glob("**/*.py")
5+
paths = Path("src/supabase_auth").glob("**/*.py")
66
tests = Path("tests").glob("**/*.py")
77

88
rules = (unasync._DEFAULT_RULE,)

src/auth/src/supabase_auth/_async/gotrue_base_api.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22

33
from typing import Any, Callable, Dict, Optional, TypeVar, overload
44

5-
from httpx import Response
5+
from httpx import HTTPStatusError, Response
66
from pydantic import BaseModel
77
from typing_extensions import Literal, Self
88

9-
from ..constants import API_VERSION_HEADER_NAME, API_VERSIONS
9+
from ..constants import API_VERSION_HEADER_NAME, API_VERSIONS_2024_01_01_NAME
1010
from ..helpers import handle_exception, model_dump
1111
from ..http_clients import AsyncClient
1212

@@ -96,12 +96,12 @@ async def _request(
9696
query: Optional[Dict[str, str]] = None,
9797
body: Optional[Any] = None,
9898
no_resolve_json: bool = False,
99-
xform: Optional[Callable[[Any], T]] = None,
99+
xform: Optional[Callable[[Response], T]] = None,
100100
) -> Optional[T]:
101101
url = f"{self._url}/{path}"
102102
headers = {**self._headers, **(headers or {})}
103103
if API_VERSION_HEADER_NAME not in headers:
104-
headers[API_VERSION_HEADER_NAME] = API_VERSIONS["2024-01-01"].get("name")
104+
headers[API_VERSION_HEADER_NAME] = API_VERSIONS_2024_01_01_NAME
105105
if "Content-Type" not in headers:
106106
headers["Content-Type"] = "application/json;charset=UTF-8"
107107
if jwt:
@@ -121,5 +121,6 @@ async def _request(
121121
result = response if no_resolve_json else response.json()
122122
if xform:
123123
return xform(result)
124-
except Exception as e:
124+
return None
125+
except (HTTPStatusError, RuntimeError) as e:
125126
raise handle_exception(e)

src/auth/src/supabase_auth/_sync/gotrue_base_api.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22

33
from typing import Any, Callable, Dict, Optional, TypeVar, overload
44

5-
from httpx import Response
5+
from httpx import HTTPStatusError, Response
66
from pydantic import BaseModel
77
from typing_extensions import Literal, Self
88

9-
from ..constants import API_VERSION_HEADER_NAME, API_VERSIONS
9+
from ..constants import API_VERSION_HEADER_NAME, API_VERSIONS_2024_01_01_NAME
1010
from ..helpers import handle_exception, model_dump
1111
from ..http_clients import SyncClient
1212

@@ -96,12 +96,12 @@ def _request(
9696
query: Optional[Dict[str, str]] = None,
9797
body: Optional[Any] = None,
9898
no_resolve_json: bool = False,
99-
xform: Optional[Callable[[Any], T]] = None,
99+
xform: Optional[Callable[[Response], T]] = None,
100100
) -> Optional[T]:
101101
url = f"{self._url}/{path}"
102102
headers = {**self._headers, **(headers or {})}
103103
if API_VERSION_HEADER_NAME not in headers:
104-
headers[API_VERSION_HEADER_NAME] = API_VERSIONS["2024-01-01"].get("name")
104+
headers[API_VERSION_HEADER_NAME] = API_VERSIONS_2024_01_01_NAME
105105
if "Content-Type" not in headers:
106106
headers["Content-Type"] = "application/json;charset=UTF-8"
107107
if jwt:
@@ -121,5 +121,6 @@ def _request(
121121
result = response if no_resolve_json else response.json()
122122
if xform:
123123
return xform(result)
124-
except Exception as e:
124+
return None
125+
except (HTTPStatusError, RuntimeError) as e:
125126
raise handle_exception(e)

src/auth/src/supabase_auth/constants.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,8 @@
1515
STORAGE_KEY = "supabase.auth.token"
1616

1717
API_VERSION_HEADER_NAME = "X-Supabase-Api-Version"
18-
API_VERSIONS = {
19-
"2024-01-01": {
20-
"timestamp": datetime.timestamp(datetime.strptime("2024-01-01", "%Y-%m-%d")),
21-
"name": "2024-01-01",
22-
},
23-
}
18+
API_VERSIONS_2024_01_01_TIMESTAMP = datetime.timestamp(
19+
datetime.strptime("2024-01-01", "%Y-%m-%d")
20+
)
21+
API_VERSIONS_2024_01_01_NAME = "2024-01-01"
2422
BASE64URL_REGEX = r"^([a-z0-9_-]{4})*($|[a-z0-9_-]{3}$|[a-z0-9_-]{2}$)$"

src/auth/src/supabase_auth/errors.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -86,11 +86,12 @@
8686
"invalid_credentials",
8787
"email_address_not_authorized",
8888
"email_address_invalid",
89+
"invalid_jwt",
8990
]
9091

9192

9293
class AuthError(Exception):
93-
def __init__(self, message: str, code: ErrorCode) -> None:
94+
def __init__(self, message: str, code: ErrorCode | None) -> None:
9495
Exception.__init__(self, message)
9596
self.message = message
9697
self.name = "AuthError"
@@ -101,11 +102,11 @@ class AuthApiErrorDict(TypedDict):
101102
name: str
102103
message: str
103104
status: int
104-
code: ErrorCode
105+
code: ErrorCode | None
105106

106107

107108
class AuthApiError(AuthError):
108-
def __init__(self, message: str, status: int, code: ErrorCode) -> None:
109+
def __init__(self, message: str, status: int, code: Optional[ErrorCode]) -> None:
109110
AuthError.__init__(self, message, code)
110111
self.name = "AuthApiError"
111112
self.status = status
@@ -128,7 +129,9 @@ def __init__(self, message: str, original_error: Exception) -> None:
128129

129130

130131
class CustomAuthError(AuthError):
131-
def __init__(self, message: str, name: str, status: int, code: ErrorCode) -> None:
132+
def __init__(
133+
self, message: str, name: str, status: int, code: Optional[ErrorCode]
134+
) -> None:
132135
AuthError.__init__(self, message, code)
133136
self.name = name
134137
self.status = status
@@ -138,6 +141,7 @@ def to_dict(self) -> AuthApiErrorDict:
138141
"name": self.name,
139142
"message": self.message,
140143
"status": self.status,
144+
"code": self.code,
141145
}
142146

143147

@@ -193,6 +197,7 @@ def to_dict(self) -> AuthImplicitGrantRedirectErrorDict:
193197
"message": self.message,
194198
"status": self.status,
195199
"details": self.details,
200+
"code": self.code,
196201
}
197202

198203

@@ -207,6 +212,10 @@ def __init__(self, message: str, status: int) -> None:
207212
)
208213

209214

215+
class AuthApiErrorWithReasonsDict(AuthApiErrorDict):
216+
reasons: List[str]
217+
218+
210219
class AuthWeakPasswordError(CustomAuthError):
211220
def __init__(self, message: str, status: int, reasons: List[str]) -> None:
212221
CustomAuthError.__init__(
@@ -218,12 +227,13 @@ def __init__(self, message: str, status: int, reasons: List[str]) -> None:
218227
)
219228
self.reasons = reasons
220229

221-
def to_dict(self) -> AuthApiErrorDict:
230+
def to_dict(self) -> AuthApiErrorWithReasonsDict:
222231
return {
223232
"name": self.name,
224233
"message": self.message,
225234
"status": self.status,
226235
"reasons": self.reasons,
236+
"code": self.code,
227237
}
228238

229239

src/auth/src/supabase_auth/helpers.py

Lines changed: 30 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,25 @@
11
from __future__ import annotations
22

33
import base64
4+
import binascii
45
import hashlib
56
import re
67
import secrets
78
import string
89
import uuid
910
from base64 import urlsafe_b64decode
1011
from datetime import datetime
11-
from json import loads
1212
from typing import Any, Dict, Optional, Type, TypedDict, TypeVar, cast
1313
from urllib.parse import urlparse
1414

1515
from httpx import HTTPStatusError, Response
16-
from pydantic import BaseModel
16+
from pydantic import BaseModel, TypeAdapter
1717

18-
from .constants import API_VERSION_HEADER_NAME, API_VERSIONS, BASE64URL_REGEX
18+
from .constants import (
19+
API_VERSION_HEADER_NAME,
20+
API_VERSIONS_2024_01_01_TIMESTAMP,
21+
BASE64URL_REGEX,
22+
)
1923
from .errors import (
2024
AuthApiError,
2125
AuthError,
@@ -136,18 +140,9 @@ def get_error_message(error: Any) -> str:
136140
return next((error[prop] for prop in props if filter(prop)), str(error))
137141

138142

139-
def get_error_code(error: Any) -> str:
140-
return error.get("error_code", None) if isinstance(error, dict) else None
141-
142-
143-
def looks_like_http_status_error(exception: Exception) -> bool:
144-
return isinstance(exception, HTTPStatusError)
145-
146-
147-
def handle_exception(exception: Exception) -> AuthError:
148-
if not looks_like_http_status_error(exception):
149-
return AuthRetryableError(get_error_message(exception), 0)
150-
error = cast(HTTPStatusError, exception)
143+
def handle_exception(error: HTTPStatusError | RuntimeError) -> AuthError:
144+
if not isinstance(error, HTTPStatusError):
145+
return AuthRetryableError(get_error_message(error), 0)
151146
try:
152147
network_error_codes = [502, 503, 504]
153148
if error.response.status_code in network_error_codes:
@@ -161,8 +156,10 @@ def handle_exception(exception: Exception) -> AuthError:
161156

162157
if (
163158
response_api_version
164-
and datetime.timestamp(response_api_version)
165-
>= API_VERSIONS.get("2024-01-01").get("timestamp")
159+
and (
160+
datetime.timestamp(response_api_version)
161+
>= API_VERSIONS_2024_01_01_TIMESTAMP
162+
)
166163
and isinstance(data, dict)
167164
and data
168165
and isinstance(data.get("code"), str)
@@ -180,18 +177,18 @@ def handle_exception(exception: Exception) -> AuthError:
180177
and isinstance(data.get("weak_password"), dict)
181178
and data.get("weak_password")
182179
and isinstance(data.get("weak_password"), list)
183-
and len(data.get("weak_password"))
180+
and len(data["weak_password"])
184181
):
185182
return AuthWeakPasswordError(
186183
get_error_message(data),
187184
error.response.status_code,
188-
data.get("weak_password").get("reasons"),
185+
data["weak_password"].get("reasons"),
189186
)
190187
elif error_code == "weak_password":
191188
return AuthWeakPasswordError(
192189
get_error_message(data),
193190
error.response.status_code,
194-
data.get("weak_password", {}).get("reasons", {}),
191+
data["weak_password"].get("reasons", {}),
195192
)
196193

197194
return AuthApiError(
@@ -224,20 +221,26 @@ class DecodedJWT(TypedDict):
224221
raw: Dict[str, str]
225222

226223

224+
JWTHeaderParser = TypeAdapter(JWTHeader)
225+
JWTPayloadParser = TypeAdapter(JWTPayload)
226+
227+
227228
def decode_jwt(token: str) -> DecodedJWT:
228229
parts = token.split(".")
229230
if len(parts) != 3:
230231
raise AuthInvalidJwtError("Invalid JWT structure")
231232

232-
# regex check for base64url
233-
for part in parts:
234-
if not re.match(BASE64URL_REGEX, part, re.IGNORECASE):
235-
raise AuthInvalidJwtError("JWT not in base64url format")
233+
try:
234+
header = base64url_to_bytes(parts[0])
235+
payload = base64url_to_bytes(parts[1])
236+
signature = base64url_to_bytes(parts[2])
237+
except binascii.Error:
238+
raise AuthInvalidJwtError("Invalid JWT structure")
236239

237240
return DecodedJWT(
238-
header=JWTHeader(**loads(str_from_base64url(parts[0]))),
239-
payload=JWTPayload(**loads(str_from_base64url(parts[1]))),
240-
signature=base64url_to_bytes(parts[2]),
241+
header=JWTHeaderParser.validate_json(header),
242+
payload=JWTPayloadParser.validate_json(payload),
243+
signature=signature,
241244
raw={
242245
"header": parts[0],
243246
"payload": parts[1],

0 commit comments

Comments
 (0)