Skip to content

Commit cd0cffa

Browse files
committed
chore: partial commit, more fixes
1 parent 1abf8c7 commit cd0cffa

File tree

12 files changed

+350
-433
lines changed

12 files changed

+350
-433
lines changed
Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
from __future__ import annotations
22

3-
from ._async.gotrue_admin_api import AsyncGoTrueAdminAPI # type: ignore # noqa: F401
4-
from ._async.gotrue_client import AsyncGoTrueClient # type: ignore # noqa: F401
3+
from ._async.gotrue_admin_api import AsyncGoTrueAdminAPI
4+
from ._async.gotrue_client import AsyncGoTrueClient
55
from ._async.storage import (
6-
AsyncMemoryStorage, # type: ignore # noqa: F401
7-
AsyncSupportedStorage, # type: ignore # noqa: F401
6+
AsyncMemoryStorage,
7+
AsyncSupportedStorage,
88
)
9-
from ._sync.gotrue_admin_api import SyncGoTrueAdminAPI # type: ignore # noqa: F401
10-
from ._sync.gotrue_client import SyncGoTrueClient # type: ignore # noqa: F401
9+
from ._sync.gotrue_admin_api import SyncGoTrueAdminAPI
10+
from ._sync.gotrue_client import SyncGoTrueClient
1111
from ._sync.storage import (
12-
SyncMemoryStorage, # type: ignore # noqa: F401
13-
SyncSupportedStorage, # type: ignore # noqa: F401
12+
SyncMemoryStorage,
13+
SyncSupportedStorage,
1414
)
15-
from .types import * # type: ignore # noqa: F401, F403
15+
from .types import *
1616
from .version import __version__

src/auth/src/supabase_auth/_async/gotrue_admin_api.py

Lines changed: 34 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
from __future__ import annotations
22

3-
from functools import partial
4-
from typing import Dict, List, Optional
3+
from typing import Any, Dict, List, Optional
4+
5+
from httpx import Response
6+
from pydantic import TypeAdapter
57

68
from ..helpers import (
79
is_valid_uuid,
@@ -26,6 +28,8 @@
2628
from .gotrue_admin_mfa_api import AsyncGoTrueAdminMFAAPI
2729
from .gotrue_base_api import AsyncGoTrueBaseAPI
2830

31+
UserList = TypeAdapter(List[User])
32+
2933

3034
class AsyncGoTrueAdminAPI(AsyncGoTrueBaseAPI):
3135
def __init__(
@@ -45,15 +49,16 @@ def __init__(
4549
verify=verify,
4650
proxy=proxy,
4751
)
52+
# TODO(@o-santi): why is is this done this way?
4853
self.mfa = AsyncGoTrueAdminMFAAPI()
49-
self.mfa.list_factors = self._list_factors
50-
self.mfa.delete_factor = self._delete_factor
54+
self.mfa.list_factors = self._list_factors # type: ignore
55+
self.mfa.delete_factor = self._delete_factor # type: ignore
5156

5257
async def sign_out(self, jwt: str, scope: SignOutScope = "global") -> None:
5358
"""
5459
Removes a logged-in session.
5560
"""
56-
return await self._request(
61+
await self._request(
5762
"POST",
5863
"logout",
5964
query={"scope": scope},
@@ -69,19 +74,19 @@ async def invite_user_by_email(
6974
"""
7075
Sends an invite link to an email address.
7176
"""
72-
return await self._request(
77+
response = await self._request(
7378
"POST",
7479
"invite",
7580
body={"email": email, "data": options.get("data")},
7681
redirect_to=options.get("redirect_to"),
77-
xform=parse_user_response,
7882
)
83+
return parse_user_response(response)
7984

8085
async def generate_link(self, params: GenerateLinkParams) -> GenerateLinkResponse:
8186
"""
8287
Generates email links and OTPs to be sent via a custom email provider.
8388
"""
84-
return await self._request(
89+
response = await self._request(
8590
"POST",
8691
"admin/generate_link",
8792
body={
@@ -92,8 +97,8 @@ async def generate_link(self, params: GenerateLinkParams) -> GenerateLinkRespons
9297
"data": params.get("options", {}).get("data"),
9398
},
9499
redirect_to=params.get("options", {}).get("redirect_to"),
95-
xform=parse_link_response,
96100
)
101+
return parse_link_response(response)
97102

98103
# User Admin API
99104

@@ -104,30 +109,28 @@ async def create_user(self, attributes: AdminUserAttributes) -> UserResponse:
104109
This function should only be called on a server.
105110
Never expose your `service_role` key in the browser.
106111
"""
107-
return await self._request(
112+
response = await self._request(
108113
"POST",
109114
"admin/users",
110115
body=attributes,
111-
xform=parse_user_response,
112116
)
117+
return parse_user_response(response)
113118

114-
async def list_users(self, page: int = None, per_page: int = None) -> List[User]:
119+
async def list_users(
120+
self, page: Optional[int] = None, per_page: Optional[int] = None
121+
) -> List[User]:
115122
"""
116123
Get a list of users.
117124
118125
This function should only be called on a server.
119126
Never expose your `service_role` key in the browser.
120127
"""
121-
return await self._request(
128+
response = await self._request(
122129
"GET",
123130
"admin/users",
124-
query={"page": page, "per_page": per_page},
125-
xform=lambda data: (
126-
[model_validate(User, user) for user in data["users"]]
127-
if "users" in data
128-
else []
129-
),
131+
query={"page": str(page), "per_page": str(per_page)},
130132
)
133+
return UserList.validate_json(response.content)
131134

132135
async def get_user_by_id(self, uid: str) -> UserResponse:
133136
"""
@@ -138,11 +141,11 @@ async def get_user_by_id(self, uid: str) -> UserResponse:
138141
"""
139142
self._validate_uuid(uid)
140143

141-
return await self._request(
144+
response = await self._request(
142145
"GET",
143146
f"admin/users/{uid}",
144-
xform=parse_user_response,
145147
)
148+
return parse_user_response(response)
146149

147150
async def update_user_by_id(
148151
self,
@@ -156,12 +159,12 @@ async def update_user_by_id(
156159
Never expose your `service_role` key in the browser.
157160
"""
158161
self._validate_uuid(uid)
159-
return await self._request(
162+
response = await self._request(
160163
"PUT",
161164
f"admin/users/{uid}",
162165
body=attributes,
163-
xform=parse_user_response,
164166
)
167+
return parse_user_response(response)
165168

166169
async def delete_user(self, id: str, should_soft_delete: bool = False) -> None:
167170
"""
@@ -172,31 +175,33 @@ async def delete_user(self, id: str, should_soft_delete: bool = False) -> None:
172175
"""
173176
self._validate_uuid(id)
174177
body = {"should_soft_delete": should_soft_delete}
175-
return await self._request("DELETE", f"admin/users/{id}", body=body)
178+
await self._request("DELETE", f"admin/users/{id}", body=body)
176179

177180
async def _list_factors(
178181
self,
179182
params: AuthMFAAdminListFactorsParams,
180183
) -> AuthMFAAdminListFactorsResponse:
181184
self._validate_uuid(params.get("user_id"))
182-
return await self._request(
185+
response = await self._request(
183186
"GET",
184187
f"admin/users/{params.get('user_id')}/factors",
185-
xform=partial(model_validate, AuthMFAAdminListFactorsResponse),
186188
)
189+
return model_validate(AuthMFAAdminListFactorsResponse, response.content)
187190

188191
async def _delete_factor(
189192
self,
190193
params: AuthMFAAdminDeleteFactorParams,
191194
) -> AuthMFAAdminDeleteFactorResponse:
192195
self._validate_uuid(params.get("user_id"))
193196
self._validate_uuid(params.get("id"))
194-
return await self._request(
197+
response = await self._request(
195198
"DELETE",
196199
f"admin/users/{params.get('user_id')}/factors/{params.get('id')}",
197-
xform=partial(model_validate, AuthMFAAdminDeleteFactorResponse),
198200
)
201+
return model_validate(AuthMFAAdminDeleteFactorResponse, response.content)
199202

200-
def _validate_uuid(self, id: str) -> None:
203+
def _validate_uuid(self, id: str | None) -> None:
204+
if id is None:
205+
raise ValueError("Invalid id, id cannot be none")
201206
if not is_valid_uuid(id):
202207
raise ValueError(f"Invalid id, '{id}' is not a valid uuid")

src/auth/src/supabase_auth/_async/gotrue_base_api.py

Lines changed: 3 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,6 @@
1010
from ..helpers import handle_exception, model_dump
1111
from ..http_clients import AsyncClient
1212

13-
T = TypeVar("T")
14-
1513

1614
class AsyncGoTrueBaseAPI:
1715
def __init__(
@@ -41,50 +39,6 @@ async def __aexit__(self, exc_t, exc_v, exc_tb) -> None:
4139
async def close(self) -> None:
4240
await self._http_client.aclose()
4341

44-
@overload
45-
async def _request(
46-
self,
47-
method: Literal["GET", "OPTIONS", "HEAD", "POST", "PUT", "PATCH", "DELETE"],
48-
path: str,
49-
*,
50-
jwt: Optional[str] = None,
51-
redirect_to: Optional[str] = None,
52-
headers: Optional[Dict[str, str]] = None,
53-
query: Optional[Dict[str, str]] = None,
54-
body: Optional[Any] = None,
55-
no_resolve_json: Literal[False] = False,
56-
xform: Callable[[Any], T],
57-
) -> T: ... # pragma: no cover
58-
59-
@overload
60-
async def _request(
61-
self,
62-
method: Literal["GET", "OPTIONS", "HEAD", "POST", "PUT", "PATCH", "DELETE"],
63-
path: str,
64-
*,
65-
jwt: Optional[str] = None,
66-
redirect_to: Optional[str] = None,
67-
headers: Optional[Dict[str, str]] = None,
68-
query: Optional[Dict[str, str]] = None,
69-
body: Optional[Any] = None,
70-
no_resolve_json: Literal[True],
71-
xform: Callable[[Response], T],
72-
) -> T: ... # pragma: no cover
73-
74-
@overload
75-
async def _request(
76-
self,
77-
method: Literal["GET", "OPTIONS", "HEAD", "POST", "PUT", "PATCH", "DELETE"],
78-
path: str,
79-
*,
80-
jwt: Optional[str] = None,
81-
redirect_to: Optional[str] = None,
82-
headers: Optional[Dict[str, str]] = None,
83-
query: Optional[Dict[str, str]] = None,
84-
body: Optional[Any] = None,
85-
no_resolve_json: bool = False,
86-
) -> None: ... # pragma: no cover
87-
8842
async def _request(
8943
self,
9044
method: Literal["GET", "OPTIONS", "HEAD", "POST", "PUT", "PATCH", "DELETE"],
@@ -93,11 +47,10 @@ async def _request(
9347
jwt: Optional[str] = None,
9448
redirect_to: Optional[str] = None,
9549
headers: Optional[Dict[str, str]] = None,
96-
query: Optional[Dict[str, str]] = None,
50+
query: Optional[Dict[str, str | None]] = None,
9751
body: Optional[Any] = None,
9852
no_resolve_json: bool = False,
99-
xform: Optional[Callable[[Response], T]] = None,
100-
) -> Optional[T]:
53+
) -> Response:
10154
url = f"{self._url}/{path}"
10255
headers = {**self._headers, **(headers or {})}
10356
if API_VERSION_HEADER_NAME not in headers:
@@ -118,9 +71,6 @@ async def _request(
11871
json=model_dump(body) if isinstance(body, BaseModel) else body,
11972
)
12073
response.raise_for_status()
121-
result = response if no_resolve_json else response.json()
122-
if xform:
123-
return xform(result)
124-
return None
74+
return response
12575
except (HTTPStatusError, RuntimeError) as e:
12676
raise handle_exception(e)

0 commit comments

Comments
 (0)