diff --git a/src/auth/Makefile b/src/auth/Makefile index 060b51cb..8d4db0ca 100644 --- a/src/auth/Makefile +++ b/src/auth/Makefile @@ -1,32 +1,57 @@ -tests: pytest +help:: + @echo "Available commands" + @echo " help -- (default) print this message" + +tests: mypy pytest +help:: + @echo " tests -- run all tests for supabase_auth" pytest: start-infra uv run --package supabase_auth pytest --cov=./ --cov-report=xml --cov-report=html -vv +mypy: + uv run --package supabase_auth mypy src/supabase_auth tests +help:: + @echo " mypy -- run mypy on supabase_auth" + start-infra: cd infra &&\ docker compose down &&\ docker compose up -d sleep 2 +help:: + @echo " start-infra -- start containers for tests" clean-infra: cd infra &&\ docker compose down --remove-orphans &&\ docker system prune -a --volumes -f +help:: + @echo " clean-infra -- delete all stored information about the containers" stop-infra: cd infra &&\ docker compose down --remove-orphans +help:: + @echo " stop-infra -- stop containers for tests" sync-infra: uv run --package supabase_auth scripts/gh-download.py --repo=supabase/gotrue-js --branch=master --folder=infra +help:: + @echo " sync-infra -- update locked versions for test containers" build-sync: uv run --package supabase_auth scripts/run-unasync.py +help:: + @echo " build-sync -- generate _sync from _async code" clean: rm -rf htmlcov .pytest_cache .mypy_cache .ruff_cache rm -f .coverage coverage.xml +help:: + @echo " clean -- clean intermediary files" build: uv build --package supabase_auth +help:: + @echo " build -- invoke uv build on supabase_auth package" diff --git a/src/auth/infra/docker-compose.yml b/src/auth/infra/docker-compose.yml index 6506cea4..7ad97278 100644 --- a/src/auth/infra/docker-compose.yml +++ b/src/auth/infra/docker-compose.yml @@ -21,6 +21,7 @@ services: GOTRUE_LOG_LEVEL: DEBUG GOTRUE_OPERATOR_TOKEN: super-secret-operator-token DATABASE_URL: 'postgres://postgres:postgres@db:5432/postgres?sslmode=disable' + GOTRUE_EXTERNAL_ANONYMOUS_USERS_ENABLED: 'true' GOTRUE_EXTERNAL_GOOGLE_ENABLED: 'true' GOTRUE_EXTERNAL_GOOGLE_CLIENT_ID: 53566906701-bmhc1ndue7hild39575gkpimhs06b7ds.apps.googleusercontent.com GOTRUE_EXTERNAL_GOOGLE_SECRET: Sm3s8RE85rDcS36iMy8YjrpC @@ -61,6 +62,7 @@ services: GOTRUE_LOG_LEVEL: DEBUG GOTRUE_OPERATOR_TOKEN: super-secret-operator-token DATABASE_URL: 'postgres://postgres:postgres@db:5432/postgres?sslmode=disable' + GOTRUE_EXTERNAL_ANONYMOUS_USERS_ENABLED: 'true' GOTRUE_EXTERNAL_PHONE_ENABLED: 'true' GOTRUE_SMTP_HOST: mail GOTRUE_SMTP_PORT: 2500 @@ -90,6 +92,7 @@ services: GOTRUE_SMS_AUTOCONFIRM: 'true' GOTRUE_LOG_LEVEL: DEBUG GOTRUE_OPERATOR_TOKEN: super-secret-operator-token + GOTRUE_EXTERNAL_ANONYMOUS_USERS_ENABLED: 'true' DATABASE_URL: 'postgres://postgres:postgres@db:5432/postgres?sslmode=disable' GOTRUE_EXTERNAL_PHONE_ENABLED: 'true' GOTRUE_SMTP_HOST: mail @@ -119,6 +122,7 @@ services: GOTRUE_LOG_LEVEL: DEBUG GOTRUE_OPERATOR_TOKEN: super-secret-operator-token DATABASE_URL: 'postgres://postgres:postgres@db:5432/postgres?sslmode=disable' + GOTRUE_EXTERNAL_ANONYMOUS_USERS_ENABLED: 'true' GOTRUE_EXTERNAL_PHONE_ENABLED: 'false' GOTRUE_EXTERNAL_EMAIL_ENABLED: 'false' GOTRUE_SMTP_HOST: mail diff --git a/src/auth/pyproject.toml b/src/auth/pyproject.toml index 04af9cb0..34599e0f 100644 --- a/src/auth/pyproject.toml +++ b/src/auth/pyproject.toml @@ -44,6 +44,9 @@ tests = [ lints = [ "ruff >=0.12.1", "unasync >= 0.6.0", + "python-lsp-server (>=1.12.2,<2.0.0)", + "pylsp-mypy (>=0.7.0,<0.8.0)", + "python-lsp-ruff (>=2.2.2,<3.0.0)", ] dev = [{ include-group = "lints" }, {include-group = "tests" }] @@ -76,3 +79,15 @@ asyncio_mode = "auto" [build-system] requires = ["uv_build>=0.8.3,<0.9.0"] build-backend = "uv_build" + +[tool.mypy] +python_version = "3.9" +check_untyped_defs = true +allow_redefinition = true +follow_untyped_imports = true # for deprecation module that does not have stubs + +no_warn_no_return = true +warn_return_any = true +warn_unused_configs = true +warn_redundant_casts = true +warn_unused_ignores = true diff --git a/src/auth/scripts/run-unasync.py b/src/auth/scripts/run-unasync.py index aaf3052d..0765c518 100644 --- a/src/auth/scripts/run-unasync.py +++ b/src/auth/scripts/run-unasync.py @@ -2,7 +2,7 @@ import unasync -paths = Path("../src/supabase").glob("**/*.py") +paths = Path("src/supabase_auth").glob("**/*.py") tests = Path("tests").glob("**/*.py") rules = (unasync._DEFAULT_RULE,) diff --git a/src/auth/src/supabase_auth/__init__.py b/src/auth/src/supabase_auth/__init__.py index 159a72cf..eb058ac2 100644 --- a/src/auth/src/supabase_auth/__init__.py +++ b/src/auth/src/supabase_auth/__init__.py @@ -1,16 +1,16 @@ from __future__ import annotations -from ._async.gotrue_admin_api import AsyncGoTrueAdminAPI # type: ignore # noqa: F401 -from ._async.gotrue_client import AsyncGoTrueClient # type: ignore # noqa: F401 +from ._async.gotrue_admin_api import AsyncGoTrueAdminAPI +from ._async.gotrue_client import AsyncGoTrueClient from ._async.storage import ( - AsyncMemoryStorage, # type: ignore # noqa: F401 - AsyncSupportedStorage, # type: ignore # noqa: F401 + AsyncMemoryStorage, + AsyncSupportedStorage, ) -from ._sync.gotrue_admin_api import SyncGoTrueAdminAPI # type: ignore # noqa: F401 -from ._sync.gotrue_client import SyncGoTrueClient # type: ignore # noqa: F401 +from ._sync.gotrue_admin_api import SyncGoTrueAdminAPI +from ._sync.gotrue_client import SyncGoTrueClient from ._sync.storage import ( - SyncMemoryStorage, # type: ignore # noqa: F401 - SyncSupportedStorage, # type: ignore # noqa: F401 + SyncMemoryStorage, + SyncSupportedStorage, ) -from .types import * # type: ignore # noqa: F401, F403 +from .types import * from .version import __version__ diff --git a/src/auth/src/supabase_auth/_async/gotrue_admin_api.py b/src/auth/src/supabase_auth/_async/gotrue_admin_api.py index 408f3da0..8a5dcf04 100644 --- a/src/auth/src/supabase_auth/_async/gotrue_admin_api.py +++ b/src/auth/src/supabase_auth/_async/gotrue_admin_api.py @@ -1,7 +1,9 @@ from __future__ import annotations -from functools import partial -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional + +from httpx import QueryParams, Response +from pydantic import TypeAdapter from ..helpers import ( is_valid_uuid, @@ -21,6 +23,7 @@ InviteUserByEmailOptions, SignOutScope, User, + UserList, UserResponse, ) from .gotrue_admin_mfa_api import AsyncGoTrueAdminMFAAPI @@ -45,18 +48,19 @@ def __init__( verify=verify, proxy=proxy, ) + # TODO(@o-santi): why is is this done this way? self.mfa = AsyncGoTrueAdminMFAAPI() - self.mfa.list_factors = self._list_factors - self.mfa.delete_factor = self._delete_factor + self.mfa.list_factors = self._list_factors # type: ignore + self.mfa.delete_factor = self._delete_factor # type: ignore async def sign_out(self, jwt: str, scope: SignOutScope = "global") -> None: """ Removes a logged-in session. """ - return await self._request( + await self._request( "POST", "logout", - query={"scope": scope}, + query=QueryParams(scope=scope), jwt=jwt, no_resolve_json=True, ) @@ -69,19 +73,19 @@ async def invite_user_by_email( """ Sends an invite link to an email address. """ - return await self._request( + response = await self._request( "POST", "invite", body={"email": email, "data": options.get("data")}, redirect_to=options.get("redirect_to"), - xform=parse_user_response, ) + return parse_user_response(response) async def generate_link(self, params: GenerateLinkParams) -> GenerateLinkResponse: """ Generates email links and OTPs to be sent via a custom email provider. """ - return await self._request( + response = await self._request( "POST", "admin/generate_link", body={ @@ -92,9 +96,10 @@ async def generate_link(self, params: GenerateLinkParams) -> GenerateLinkRespons "data": params.get("options", {}).get("data"), }, redirect_to=params.get("options", {}).get("redirect_to"), - xform=parse_link_response, ) + return parse_link_response(response) + # User Admin API async def create_user(self, attributes: AdminUserAttributes) -> UserResponse: @@ -104,30 +109,28 @@ async def create_user(self, attributes: AdminUserAttributes) -> UserResponse: This function should only be called on a server. Never expose your `service_role` key in the browser. """ - return await self._request( + response = await self._request( "POST", "admin/users", body=attributes, - xform=parse_user_response, ) + return parse_user_response(response) - async def list_users(self, page: int = None, per_page: int = None) -> List[User]: + async def list_users( + self, page: Optional[int] = None, per_page: Optional[int] = None + ) -> List[User]: """ Get a list of users. This function should only be called on a server. Never expose your `service_role` key in the browser. """ - return await self._request( + response = await self._request( "GET", "admin/users", - query={"page": page, "per_page": per_page}, - xform=lambda data: ( - [model_validate(User, user) for user in data["users"]] - if "users" in data - else [] - ), + query=QueryParams(page=page, per_page=per_page), ) + return model_validate(UserList, response.content).users async def get_user_by_id(self, uid: str) -> UserResponse: """ @@ -138,11 +141,11 @@ async def get_user_by_id(self, uid: str) -> UserResponse: """ self._validate_uuid(uid) - return await self._request( + response = await self._request( "GET", f"admin/users/{uid}", - xform=parse_user_response, ) + return parse_user_response(response) async def update_user_by_id( self, @@ -156,12 +159,12 @@ async def update_user_by_id( Never expose your `service_role` key in the browser. """ self._validate_uuid(uid) - return await self._request( + response = await self._request( "PUT", f"admin/users/{uid}", body=attributes, - xform=parse_user_response, ) + return parse_user_response(response) async def delete_user(self, id: str, should_soft_delete: bool = False) -> None: """ @@ -172,18 +175,18 @@ async def delete_user(self, id: str, should_soft_delete: bool = False) -> None: """ self._validate_uuid(id) body = {"should_soft_delete": should_soft_delete} - return await self._request("DELETE", f"admin/users/{id}", body=body) + await self._request("DELETE", f"admin/users/{id}", body=body) async def _list_factors( self, params: AuthMFAAdminListFactorsParams, ) -> AuthMFAAdminListFactorsResponse: self._validate_uuid(params.get("user_id")) - return await self._request( + response = await self._request( "GET", f"admin/users/{params.get('user_id')}/factors", - xform=partial(model_validate, AuthMFAAdminListFactorsResponse), ) + return model_validate(AuthMFAAdminListFactorsResponse, response.content) async def _delete_factor( self, @@ -191,12 +194,14 @@ async def _delete_factor( ) -> AuthMFAAdminDeleteFactorResponse: self._validate_uuid(params.get("user_id")) self._validate_uuid(params.get("id")) - return await self._request( + response = await self._request( "DELETE", f"admin/users/{params.get('user_id')}/factors/{params.get('id')}", - xform=partial(model_validate, AuthMFAAdminDeleteFactorResponse), ) + return model_validate(AuthMFAAdminDeleteFactorResponse, response.content) - def _validate_uuid(self, id: str) -> None: + def _validate_uuid(self, id: str | None) -> None: + if id is None: + raise ValueError("Invalid id, id cannot be none") if not is_valid_uuid(id): raise ValueError(f"Invalid id, '{id}' is not a valid uuid") diff --git a/src/auth/src/supabase_auth/_async/gotrue_base_api.py b/src/auth/src/supabase_auth/_async/gotrue_base_api.py index 21d0b444..84faffbc 100644 --- a/src/auth/src/supabase_auth/_async/gotrue_base_api.py +++ b/src/auth/src/supabase_auth/_async/gotrue_base_api.py @@ -2,16 +2,14 @@ from typing import Any, Callable, Dict, Optional, TypeVar, overload -from httpx import Response +from httpx import HTTPStatusError, QueryParams, Response from pydantic import BaseModel from typing_extensions import Literal, Self -from ..constants import API_VERSION_HEADER_NAME, API_VERSIONS +from ..constants import API_VERSION_HEADER_NAME, API_VERSIONS_2024_01_01_NAME from ..helpers import handle_exception, model_dump from ..http_clients import AsyncClient -T = TypeVar("T") - class AsyncGoTrueBaseAPI: def __init__( @@ -41,7 +39,6 @@ async def __aexit__(self, exc_t, exc_v, exc_tb) -> None: async def close(self) -> None: await self._http_client.aclose() - @overload async def _request( self, method: Literal["GET", "OPTIONS", "HEAD", "POST", "PUT", "PATCH", "DELETE"], @@ -50,65 +47,21 @@ async def _request( jwt: Optional[str] = None, redirect_to: Optional[str] = None, headers: Optional[Dict[str, str]] = None, - query: Optional[Dict[str, str]] = None, - body: Optional[Any] = None, - no_resolve_json: Literal[False] = False, - xform: Callable[[Any], T], - ) -> T: ... # pragma: no cover - - @overload - async def _request( - self, - method: Literal["GET", "OPTIONS", "HEAD", "POST", "PUT", "PATCH", "DELETE"], - path: str, - *, - jwt: Optional[str] = None, - redirect_to: Optional[str] = None, - headers: Optional[Dict[str, str]] = None, - query: Optional[Dict[str, str]] = None, - body: Optional[Any] = None, - no_resolve_json: Literal[True], - xform: Callable[[Response], T], - ) -> T: ... # pragma: no cover - - @overload - async def _request( - self, - method: Literal["GET", "OPTIONS", "HEAD", "POST", "PUT", "PATCH", "DELETE"], - path: str, - *, - jwt: Optional[str] = None, - redirect_to: Optional[str] = None, - headers: Optional[Dict[str, str]] = None, - query: Optional[Dict[str, str]] = None, - body: Optional[Any] = None, - no_resolve_json: bool = False, - ) -> None: ... # pragma: no cover - - async def _request( - self, - method: Literal["GET", "OPTIONS", "HEAD", "POST", "PUT", "PATCH", "DELETE"], - path: str, - *, - jwt: Optional[str] = None, - redirect_to: Optional[str] = None, - headers: Optional[Dict[str, str]] = None, - query: Optional[Dict[str, str]] = None, + query: Optional[QueryParams] = None, body: Optional[Any] = None, no_resolve_json: bool = False, - xform: Optional[Callable[[Any], T]] = None, - ) -> Optional[T]: + ) -> Response: url = f"{self._url}/{path}" headers = {**self._headers, **(headers or {})} if API_VERSION_HEADER_NAME not in headers: - headers[API_VERSION_HEADER_NAME] = API_VERSIONS["2024-01-01"].get("name") + headers[API_VERSION_HEADER_NAME] = API_VERSIONS_2024_01_01_NAME if "Content-Type" not in headers: headers["Content-Type"] = "application/json;charset=UTF-8" if jwt: headers["Authorization"] = f"Bearer {jwt}" - query = query or {} + query = query or QueryParams() if redirect_to: - query["redirect_to"] = redirect_to + query = query.set("redirect_to", redirect_to) try: response = await self._http_client.request( method, @@ -117,9 +70,8 @@ async def _request( params=query, json=model_dump(body) if isinstance(body, BaseModel) else body, ) + response.raise_for_status() - result = response if no_resolve_json else response.json() - if xform: - return xform(result) - except Exception as e: + return response + except (HTTPStatusError, RuntimeError) as e: raise handle_exception(e) diff --git a/src/auth/src/supabase_auth/_async/gotrue_client.py b/src/auth/src/supabase_auth/_async/gotrue_client.py index 25facea0..9d1a5855 100644 --- a/src/auth/src/supabase_auth/_async/gotrue_client.py +++ b/src/auth/src/supabase_auth/_async/gotrue_client.py @@ -4,11 +4,13 @@ from contextlib import suppress from functools import partial from json import loads -from typing import Callable, Dict, List, Optional, Tuple +from typing import Callable, Dict, List, Mapping, Optional, Tuple, Union from urllib.parse import parse_qs, urlencode, urlparse from uuid import uuid4 +from httpx import QueryParams from jwt import get_algorithm_by_name +from typing_extensions import cast from ..constants import ( DEFAULT_HEADERS, @@ -25,6 +27,7 @@ AuthInvalidJwtError, AuthRetryableError, AuthSessionMissingError, + UserDoesntExist, ) from ..helpers import ( decode_jwt, @@ -45,6 +48,7 @@ from ..timer import Timer from ..types import ( JWK, + AMREntry, AuthChangeEvent, AuthenticatorAssuranceLevels, AuthFlowType, @@ -71,13 +75,17 @@ ResendCredentials, Session, SignInAnonymouslyCredentials, + SignInWithEmailAndPasswordlessCredentialsOptions, SignInWithIdTokenCredentials, SignInWithOAuthCredentials, SignInWithPasswordCredentials, SignInWithPasswordlessCredentials, + SignInWithPhoneAndPasswordlessCredentialsOptions, SignInWithSSOCredentials, SignOutOptions, + SignUpWithEmailAndPasswordCredentialsOptions, SignUpWithPasswordCredentials, + SignUpWithPhoneAndPasswordCredentialsOptions, Subscription, UpdateUserOptions, UserAttributes, @@ -134,16 +142,17 @@ def __init__( headers=self._headers, http_client=self._http_client, ) + # TODO(@o-santi): why is it like this? self.mfa = AsyncGoTrueMFAAPI() - self.mfa.challenge = self._challenge - self.mfa.challenge_and_verify = self._challenge_and_verify - self.mfa.enroll = self._enroll - self.mfa.get_authenticator_assurance_level = ( + self.mfa.challenge = self._challenge # type: ignore + self.mfa.challenge_and_verify = self._challenge_and_verify # type: ignore + self.mfa.enroll = self._enroll # type: ignore + self.mfa.get_authenticator_assurance_level = ( # type: ignore self._get_authenticator_assurance_level ) - self.mfa.list_factors = self._list_factors - self.mfa.unenroll = self._unenroll - self.mfa.verify = self._verify + self.mfa.list_factors = self._list_factors # type: ignore + self.mfa.unenroll = self._unenroll # type: ignore + self.mfa.verify = self._verify # type: ignore # Initializations @@ -191,12 +200,12 @@ async def sign_in_anonymously( "captcha_token": captcha_token, }, }, - xform=parse_auth_response, ) - if response.session: - await self._save_session(response.session) - self._notify_all_subscribers("SIGNED_IN", response.session) - return response + auth_response = parse_auth_response(response) + if auth_response.session: + await self._save_session(auth_response.session) + self._notify_all_subscribers("SIGNED_IN", auth_response.session) + return auth_response async def sign_up( self, @@ -209,12 +218,17 @@ async def sign_up( email = credentials.get("email") phone = credentials.get("phone") password = credentials.get("password") - options = credentials.get("options", {}) - redirect_to = options.get("redirect_to") or options.get("email_redirect_to") - data = options.get("data") or {} - channel = options.get("channel", "sms") - captcha_token = options.get("captcha_token") - if email: + # TODO(@o-santi): this is horrible, but it is the easiest way to satisfy mypy + # it should have been a builder pattern instead, and with proper classes + if email and password: + email_options = cast( + SignUpWithEmailAndPasswordCredentialsOptions, + credentials.get("options", {}), + ) + data = email_options.get("data") or {} + channel = email_options.get("channel", "sms") + captcha_token = email_options.get("captcha_token") + redirect_to = email_options.get("email_redirect_to") response = await self._request( "POST", "signup", @@ -227,9 +241,15 @@ async def sign_up( }, }, redirect_to=redirect_to, - xform=parse_auth_response, ) - elif phone: + elif phone and password: + phone_options = cast( + SignUpWithPhoneAndPasswordCredentialsOptions, + credentials.get("options", {}), + ) + data = phone_options.get("data") or {} + channel = phone_options.get("channel", "sms") + captcha_token = phone_options.get("captcha_token") response = await self._request( "POST", "signup", @@ -242,16 +262,17 @@ async def sign_up( "captcha_token": captcha_token, }, }, - xform=parse_auth_response, ) else: raise AuthInvalidCredentialsError( "You must provide either an email or phone number and a password" ) - if response.session: - await self._save_session(response.session) - self._notify_all_subscribers("SIGNED_IN", response.session) - return response + + auth_response = parse_auth_response(response) + if auth_response.session: + await self._save_session(auth_response.session) + self._notify_all_subscribers("SIGNED_IN", auth_response.session) + return auth_response async def sign_in_with_password( self, @@ -267,7 +288,7 @@ async def sign_in_with_password( options = credentials.get("options", {}) data = options.get("data") or {} captcha_token = options.get("captcha_token") - if email: + if email and password: response = await self._request( "POST", "token", @@ -279,12 +300,9 @@ async def sign_in_with_password( "captcha_token": captcha_token, }, }, - query={ - "grant_type": "password", - }, - xform=parse_auth_response, + query=QueryParams(grant_type="password"), ) - elif phone: + elif phone and password: response = await self._request( "POST", "token", @@ -296,19 +314,17 @@ async def sign_in_with_password( "captcha_token": captcha_token, }, }, - query={ - "grant_type": "password", - }, - xform=parse_auth_response, + query=QueryParams(grant_type="password"), ) else: raise AuthInvalidCredentialsError( "You must provide either an email or phone number and a password" ) - if response.session: - await self._save_session(response.session) - self._notify_all_subscribers("SIGNED_IN", response.session) - return response + auth_response = parse_auth_response(response) + if auth_response.session: + await self._save_session(auth_response.session) + self._notify_all_subscribers("SIGNED_IN", auth_response.session) + return auth_response async def sign_in_with_id_token( self, @@ -318,8 +334,8 @@ async def sign_in_with_id_token( Allows signing in with an OIDC ID token. The authentication provider used should be enabled and configured. """ await self._remove_session() - provider = credentials.get("provider") - token = credentials.get("token") + provider = credentials["provider"] + token = credentials["token"] access_token = credentials.get("access_token") nonce = credentials.get("nonce") options = credentials.get("options", {}) @@ -337,16 +353,13 @@ async def sign_in_with_id_token( "captcha_token": captcha_token, }, }, - query={ - "grant_type": "id_token", - }, - xform=parse_auth_response, + query=QueryParams(grant_type="id_token"), ) - - if response.session: - await self._save_session(response.session) - self._notify_all_subscribers("SIGNED_IN", response.session) - return response + auth_response = parse_auth_response(response) + if auth_response.session: + await self._save_session(auth_response.session) + self._notify_all_subscribers("SIGNED_IN", auth_response.session) + return auth_response async def sign_in_with_sso(self, credentials: SignInWithSSOCredentials): """ @@ -370,11 +383,11 @@ async def sign_in_with_sso(self, credentials: SignInWithSSOCredentials): captcha_token = options.get("captcha_token") # HTTPX currently does not follow redirects: https://www.python-httpx.org/compatibility/ # Additionally, unlike the JS client, Python is a server side language and it's not possible - # to automatically redirect in browser for hte user + # to automatically redirect in browser for the user skip_http_redirect = options.get("skip_http_redirect", True) if domain: - return await self._request( + response = await self._request( "POST", "sso", body={ @@ -385,10 +398,10 @@ async def sign_in_with_sso(self, credentials: SignInWithSSOCredentials): }, "redirect_to": redirect_to, }, - xform=parse_sso_response, ) + return parse_sso_response(response) if provider_id: - return await self._request( + response = await self._request( "POST", "sso", body={ @@ -399,8 +412,8 @@ async def sign_in_with_sso(self, credentials: SignInWithSSOCredentials): }, "redirect_to": redirect_to, }, - xform=parse_sso_response, ) + return parse_sso_response(response) raise AuthInvalidCredentialsError( "You must provide either a domain or provider_id" ) @@ -414,7 +427,7 @@ async def sign_in_with_oauth( """ await self._remove_session() - provider = credentials.get("provider") + provider = credentials["provider"] options = credentials.get("options", {}) redirect_to = options.get("redirect_to") scopes = options.get("scopes") @@ -431,7 +444,7 @@ async def sign_in_with_oauth( async def link_identity( self, credentials: SignInWithOAuthCredentials ) -> OAuthResponse: - provider = credentials.get("provider") + provider = credentials["provider"] options = credentials.get("options", {}) redirect_to = options.get("redirect_to") scopes = options.get("scopes") @@ -453,17 +466,15 @@ async def link_identity( path=url, query=query, jwt=session.access_token, - xform=parse_link_identity_response, ) - return OAuthResponse(provider=provider, url=response.url) + link_identity = parse_link_identity_response(response) + return OAuthResponse(provider=provider, url=link_identity.url) - async def get_user_identities(self): + async def get_user_identities(self) -> IdentitiesResponse: response = await self.get_user() - return ( - IdentitiesResponse(identities=response.user.identities) - if response.user - else AuthSessionMissingError() - ) + if response: + return IdentitiesResponse(identities=response.user.identities or []) + raise AuthSessionMissingError() async def unlink_identity(self, identity: UserIdentity): session = await self.get_session() @@ -495,14 +506,19 @@ async def sign_in_with_otp( await self._remove_session() email = credentials.get("email") phone = credentials.get("phone") - options = credentials.get("options", {}) - email_redirect_to = options.get("email_redirect_to") - should_create_user = options.get("should_create_user", True) - data = options.get("data") - channel = options.get("channel", "sms") - captcha_token = options.get("captcha_token") + # TODO(@o-santi): this is horrible, but it is the easiest way to satisfy mypy + # it should have been a builder pattern instead, and with proper classes if email: - return await self._request( + email_options = cast( + SignInWithEmailAndPasswordlessCredentialsOptions, + credentials.get("options", {}), + ) + email_redirect_to = email_options.get("email_redirect_to") + should_create_user = email_options.get("should_create_user", True) + data = email_options.get("data") + channel = email_options.get("channel", "sms") + captcha_token = email_options.get("captcha_token") + response = await self._request( "POST", "otp", body={ @@ -514,10 +530,18 @@ async def sign_in_with_otp( }, }, redirect_to=email_redirect_to, - xform=parse_auth_otp_response, ) + return parse_auth_otp_response(response) if phone: - return await self._request( + phone_options = cast( + SignInWithPhoneAndPasswordlessCredentialsOptions, + credentials.get("options", {}), + ) + should_create_user = phone_options.get("should_create_user", True) + data = phone_options.get("data") + channel = phone_options.get("channel", "sms") + captcha_token = phone_options.get("captcha_token") + response = await self._request( "POST", "otp", body={ @@ -529,8 +553,8 @@ async def sign_in_with_otp( "captcha_token": captcha_token, }, }, - xform=parse_auth_otp_response, ) + return parse_auth_otp_response(response) raise AuthInvalidCredentialsError( "You must provide either an email or phone number" ) @@ -546,9 +570,9 @@ async def resend( phone = credentials.get("phone") type = credentials.get("type") options = credentials.get("options", {}) - email_redirect_to = options.get("email_redirect_to") + email_redirect_to: Optional[str] = options.get("email_redirect_to") # type: ignore captcha_token = options.get("captcha_token") - body = { + body: Dict[str, object] = { # improve later "type": type, "gotrue_meta_security": { "captcha_token": captcha_token, @@ -562,13 +586,13 @@ async def resend( body.update({"email": email} if email else {"phone": phone}) - return await self._request( + response = await self._request( "POST", "resend", body=body, redirect_to=email_redirect_to if email else None, - xform=parse_auth_otp_response, ) + return parse_auth_otp_response(response) async def verify_otp(self, params: VerifyOtpParams) -> AuthResponse: """ @@ -585,24 +609,24 @@ async def verify_otp(self, params: VerifyOtpParams) -> AuthResponse: **params, }, redirect_to=params.get("options", {}).get("redirect_to"), - xform=parse_auth_response, ) - if response.session: - await self._save_session(response.session) - self._notify_all_subscribers("SIGNED_IN", response.session) - return response + auth_response = parse_auth_response(response) + if auth_response.session: + await self._save_session(auth_response.session) + self._notify_all_subscribers("SIGNED_IN", auth_response.session) + return auth_response async def reauthenticate(self) -> AuthResponse: session = await self.get_session() if not session: raise AuthSessionMissingError() - return await self._request( + response = await self._request( "GET", "reauthenticate", jwt=session.access_token, - xform=parse_auth_response, ) + return AuthResponse(user=None, session=None) async def get_session(self) -> Optional[Session]: """ @@ -619,6 +643,7 @@ async def get_session(self) -> Optional[Session]: await self._remove_session() else: current_session = self._in_memory_session + if not current_session: return None time_now = round(time.time()) @@ -646,7 +671,7 @@ async def get_user(self, jwt: Optional[str] = None) -> Optional[UserResponse]: jwt = session.access_token else: return None - return await self._request("GET", "user", jwt=jwt, xform=parse_user_response) + return parse_user_response(await self._request("GET", "user", jwt=jwt)) async def update_user( self, attributes: UserAttributes, options: UpdateUserOptions = {} @@ -663,12 +688,12 @@ async def update_user( body=attributes, redirect_to=options.get("email_redirect_to"), jwt=session.access_token, - xform=parse_user_response, ) - session.user = response.user + user_response = parse_user_response(response) + session.user = user_response.user await self._save_session(session) self._notify_all_subscribers("USER_UPDATED", session) - return response + return user_response async def set_session(self, access_token: str, refresh_token: str) -> AuthResponse: """ @@ -703,18 +728,20 @@ async def set_session(self, access_token: str, refresh_token: str) -> AuthRespon return AuthResponse() session = response.session else: - response = await self.get_user(access_token) + user_response = await self.get_user(access_token) + if user_response is None: + raise UserDoesntExist(access_token) session = Session( access_token=access_token, refresh_token=refresh_token, - user=response.user, + user=user_response.user, token_type="bearer", expires_in=expires_at - time_now, expires_at=expires_at, ) await self._save_session(session) self._notify_all_subscribers("TOKEN_REFRESHED", session) - return AuthResponse(session=session, user=response.user) + return AuthResponse(session=session, user=session.user) async def refresh_session( self, refresh_token: Optional[str] = None @@ -823,23 +850,25 @@ async def _enroll(self, params: MFAEnrollParams) -> AuthMFAEnrollResponse: "factors", body=body, jwt=session.access_token, - xform=partial(model_validate, AuthMFAEnrollResponse), ) - if params["factor_type"] == "totp" and response.totp.qr_code: - response.totp.qr_code = f"data:image/svg+xml;utf-8,{response.totp.qr_code}" - return response + auth_response = model_validate(AuthMFAEnrollResponse, response.content) + if params["factor_type"] == "totp" and auth_response.totp: + auth_response.totp.qr_code = ( + f"data:image/svg+xml;utf-8,{auth_response.totp.qr_code}" + ) + return auth_response async def _challenge(self, params: MFAChallengeParams) -> AuthMFAChallengeResponse: session = await self.get_session() if not session: raise AuthSessionMissingError() - return await self._request( + response = await self._request( "POST", f"factors/{params.get('factor_id')}/challenge", body={"channel": params.get("channel")}, jwt=session.access_token, - xform=partial(model_validate, AuthMFAChallengeResponse), ) + return model_validate(AuthMFAChallengeResponse, response.content) async def _challenge_and_verify( self, @@ -847,14 +876,14 @@ async def _challenge_and_verify( ) -> AuthMFAVerifyResponse: response = await self._challenge( { - "factor_id": params.get("factor_id"), + "factor_id": params["factor_id"], } ) return await self._verify( { - "factor_id": params.get("factor_id"), + "factor_id": params["factor_id"], "challenge_id": response.id, - "code": params.get("code"), + "code": params["code"], } ) @@ -867,30 +896,34 @@ async def _verify(self, params: MFAVerifyParams) -> AuthMFAVerifyResponse: f"factors/{params.get('factor_id')}/verify", body=params, jwt=session.access_token, - xform=partial(model_validate, AuthMFAVerifyResponse), ) - session = model_validate(Session, model_dump(response)) + auth_response = model_validate(AuthMFAVerifyResponse, response.content) + session = model_validate(Session, response.content) await self._save_session(session) self._notify_all_subscribers("MFA_CHALLENGE_VERIFIED", session) - return response + return auth_response async def _unenroll(self, params: MFAUnenrollParams) -> AuthMFAUnenrollResponse: session = await self.get_session() if not session: raise AuthSessionMissingError() - return await self._request( + response = await self._request( "DELETE", f"factors/{params.get('factor_id')}", jwt=session.access_token, - xform=partial(model_validate, AuthMFAUnenrollResponse), ) + return model_validate(AuthMFAUnenrollResponse, response.content) async def _list_factors(self) -> AuthMFAListFactorsResponse: response = await self.get_user() - all = response.user.factors or [] - totp = [f for f in all if f.factor_type == "totp" and f.status == "verified"] - phone = [f for f in all if f.factor_type == "phone" and f.status == "verified"] - return AuthMFAListFactorsResponse(all=all, totp=totp, phone=phone) + factors = response.user.factors or [] if response else [] + totp = [ + f for f in factors if f.factor_type == "totp" and f.status == "verified" + ] + phone = [ + f for f in factors if f.factor_type == "phone" and f.status == "verified" + ] + return AuthMFAListFactorsResponse(all=factors, totp=totp, phone=phone) async def _get_authenticator_assurance_level( self, @@ -903,14 +936,15 @@ async def _get_authenticator_assurance_level( current_authentication_methods=[], ) payload = decode_jwt(session.access_token)["payload"] - current_level: Optional[AuthenticatorAssuranceLevels] = None - if payload.get("aal"): - current_level = payload.get("aal") + current_level = payload.get("aal") verified_factors = [ f for f in session.user.factors or [] if f.status == "verified" ] next_level = "aal2" if verified_factors else current_level - current_authentication_methods = payload.get("amr") or [] + amr_dict_list = payload.get("amr") or [] + current_authentication_methods = [ + AMREntry.model_validate(amr) for amr in amr_dict_list + ] return AuthMFAGetAuthenticatorAssuranceLevelResponse( current_level=current_level, next_level=next_level, @@ -965,6 +999,8 @@ async def _get_session_from_url( time_now = round(time.time()) expires_at = time_now + int(expires_in) user = await self.get_user(access_token) + if user is None: + raise UserDoesntExist(access_token) session = Session( provider_token=provider_token, provider_refresh_token=provider_refresh_token, @@ -1024,13 +1060,13 @@ async def _call_refresh_token(self, refresh_token: str) -> Session: return response.session async def _refresh_access_token(self, refresh_token: str) -> AuthResponse: - return await self._request( + response = await self._request( "POST", "token", - query={"grant_type": "refresh_token"}, + query=QueryParams(grant_type="refresh_token"), body={"refresh_token": refresh_token}, - xform=parse_auth_response, ) + return parse_auth_response(response) async def _save_session(self, session: Session) -> None: if not self._persist_session: @@ -1087,22 +1123,11 @@ def _get_valid_session( ) -> Optional[Session]: if not raw_session: return None - data = loads(raw_session) - if not data: - return None - if not data.get("access_token"): - return None - if not data.get("refresh_token"): - return None - if not data.get("expires_at"): - return None - try: - expires_at = int(data["expires_at"]) - data["expires_at"] = expires_at - except ValueError: - return None try: - return model_validate(Session, data) + session = model_validate(Session, raw_session) + if session.expires_at is None: + return None + return session except Exception: return None @@ -1123,7 +1148,8 @@ async def _get_url_for_provider( url: str, provider: Provider, params: Dict[str, str], - ) -> Tuple[str, Dict[str, str]]: + ) -> Tuple[str, QueryParams]: + query = QueryParams(params) if self._flow_type == "pkce": code_verifier = generate_pkce_verifier() code_challenge = generate_pkce_challenge(code_verifier) @@ -1133,12 +1159,11 @@ async def _get_url_for_provider( code_challenge_method = ( "plain" if code_verifier == code_challenge else "s256" ) - params["code_challenge"] = code_challenge - params["code_challenge_method"] = code_challenge_method - - params["provider"] = provider - query = urlencode(params) - return f"{url}?{query}", params + query = query.set("code_challenge", code_challenge).set( + "code_challenge_method", code_challenge_method + ) + query = query.set("provider", provider) + return f"{url}?{query}", query async def exchange_code_for_session(self, params: CodeExchangeParams): code_verifier = params.get("code_verifier") or await self._storage.get_item( @@ -1147,18 +1172,18 @@ async def exchange_code_for_session(self, params: CodeExchangeParams): response = await self._request( "POST", "token", - query={"grant_type": "pkce"}, + query=QueryParams(grant_type="pkce"), body={ "auth_code": params.get("auth_code"), "code_verifier": code_verifier, }, redirect_to=params.get("redirect_to"), - xform=parse_auth_response, ) + auth_response = parse_auth_response(response) await self._storage.remove_item(f"{self._storage_key}-code-verifier") - if response.session: - await self._save_session(response.session) - self._notify_all_subscribers("SIGNED_IN", response.session) + if auth_response.session: + await self._save_session(auth_response.session) + self._notify_all_subscribers("SIGNED_IN", auth_response.session) return response async def _fetch_jwks(self, kid: str, jwks: JWKSet) -> JWK: @@ -1182,13 +1207,14 @@ async def _fetch_jwks(self, kid: str, jwks: JWKSet) -> JWK: return jwk # jwk isn't cached in memory so we need to fetch it from the well-known endpoint - response = await self._request("GET", ".well-known/jwks.json", xform=parse_jwks) + response = await self._request("GET", ".well-known/jwks.json") + jwks = parse_jwks(response) if response: - self._jwks = response + self._jwks = jwks self._jwks_cached_at = time.time() # find the signing key - jwk = next((jwk for jwk in response["keys"] if jwk["kid"] == kid), None) + jwk = next((jwk for jwk in jwks["keys"] if jwk["kid"] == kid), None) if not jwk: raise AuthInvalidJwtError("No matching signing key found in JWKS") @@ -1227,9 +1253,8 @@ async def get_claims( return ClaimsResponse(claims=payload, headers=header, signature=signature) algorithm = get_algorithm_by_name(header["alg"]) - signing_key = algorithm.from_jwk( - await self._fetch_jwks(header["kid"], jwks or {"keys": []}) - ) + jwk_set = await self._fetch_jwks(header["kid"], jwks or {"keys": []}) + signing_key = algorithm.from_jwk(cast(Dict[str, str], jwk_set)) # verify the signature is_valid = algorithm.verify( diff --git a/src/auth/src/supabase_auth/_async/storage.py b/src/auth/src/supabase_auth/_async/storage.py index 5239dd9d..db520b11 100644 --- a/src/auth/src/supabase_auth/_async/storage.py +++ b/src/auth/src/supabase_auth/_async/storage.py @@ -16,12 +16,13 @@ async def remove_item(self, key: str) -> None: ... # pragma: no cover class AsyncMemoryStorage(AsyncSupportedStorage): - def __init__(self): + def __init__(self) -> None: self.storage: Dict[str, str] = {} async def get_item(self, key: str) -> Optional[str]: if key in self.storage: return self.storage[key] + return None async def set_item(self, key: str, value: str) -> None: self.storage[key] = value diff --git a/src/auth/src/supabase_auth/_sync/gotrue_admin_api.py b/src/auth/src/supabase_auth/_sync/gotrue_admin_api.py index afbb75e0..ea6e9f4a 100644 --- a/src/auth/src/supabase_auth/_sync/gotrue_admin_api.py +++ b/src/auth/src/supabase_auth/_sync/gotrue_admin_api.py @@ -1,7 +1,9 @@ from __future__ import annotations -from functools import partial -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional + +from httpx import QueryParams, Response +from pydantic import TypeAdapter from ..helpers import ( is_valid_uuid, @@ -21,6 +23,7 @@ InviteUserByEmailOptions, SignOutScope, User, + UserList, UserResponse, ) from .gotrue_admin_mfa_api import SyncGoTrueAdminMFAAPI @@ -45,18 +48,19 @@ def __init__( verify=verify, proxy=proxy, ) + # TODO(@o-santi): why is is this done this way? self.mfa = SyncGoTrueAdminMFAAPI() - self.mfa.list_factors = self._list_factors - self.mfa.delete_factor = self._delete_factor + self.mfa.list_factors = self._list_factors # type: ignore + self.mfa.delete_factor = self._delete_factor # type: ignore def sign_out(self, jwt: str, scope: SignOutScope = "global") -> None: """ Removes a logged-in session. """ - return self._request( + self._request( "POST", "logout", - query={"scope": scope}, + query=QueryParams(scope=scope), jwt=jwt, no_resolve_json=True, ) @@ -69,19 +73,19 @@ def invite_user_by_email( """ Sends an invite link to an email address. """ - return self._request( + response = self._request( "POST", "invite", body={"email": email, "data": options.get("data")}, redirect_to=options.get("redirect_to"), - xform=parse_user_response, ) + return parse_user_response(response) def generate_link(self, params: GenerateLinkParams) -> GenerateLinkResponse: """ Generates email links and OTPs to be sent via a custom email provider. """ - return self._request( + response = self._request( "POST", "admin/generate_link", body={ @@ -92,9 +96,10 @@ def generate_link(self, params: GenerateLinkParams) -> GenerateLinkResponse: "data": params.get("options", {}).get("data"), }, redirect_to=params.get("options", {}).get("redirect_to"), - xform=parse_link_response, ) + return parse_link_response(response) + # User Admin API def create_user(self, attributes: AdminUserAttributes) -> UserResponse: @@ -104,30 +109,28 @@ def create_user(self, attributes: AdminUserAttributes) -> UserResponse: This function should only be called on a server. Never expose your `service_role` key in the browser. """ - return self._request( + response = self._request( "POST", "admin/users", body=attributes, - xform=parse_user_response, ) + return parse_user_response(response) - def list_users(self, page: int = None, per_page: int = None) -> List[User]: + def list_users( + self, page: Optional[int] = None, per_page: Optional[int] = None + ) -> List[User]: """ Get a list of users. This function should only be called on a server. Never expose your `service_role` key in the browser. """ - return self._request( + response = self._request( "GET", "admin/users", - query={"page": page, "per_page": per_page}, - xform=lambda data: ( - [model_validate(User, user) for user in data["users"]] - if "users" in data - else [] - ), + query=QueryParams(page=page, per_page=per_page), ) + return model_validate(UserList, response.content).users def get_user_by_id(self, uid: str) -> UserResponse: """ @@ -138,11 +141,11 @@ def get_user_by_id(self, uid: str) -> UserResponse: """ self._validate_uuid(uid) - return self._request( + response = self._request( "GET", f"admin/users/{uid}", - xform=parse_user_response, ) + return parse_user_response(response) def update_user_by_id( self, @@ -156,12 +159,12 @@ def update_user_by_id( Never expose your `service_role` key in the browser. """ self._validate_uuid(uid) - return self._request( + response = self._request( "PUT", f"admin/users/{uid}", body=attributes, - xform=parse_user_response, ) + return parse_user_response(response) def delete_user(self, id: str, should_soft_delete: bool = False) -> None: """ @@ -172,18 +175,18 @@ def delete_user(self, id: str, should_soft_delete: bool = False) -> None: """ self._validate_uuid(id) body = {"should_soft_delete": should_soft_delete} - return self._request("DELETE", f"admin/users/{id}", body=body) + self._request("DELETE", f"admin/users/{id}", body=body) def _list_factors( self, params: AuthMFAAdminListFactorsParams, ) -> AuthMFAAdminListFactorsResponse: self._validate_uuid(params.get("user_id")) - return self._request( + response = self._request( "GET", f"admin/users/{params.get('user_id')}/factors", - xform=partial(model_validate, AuthMFAAdminListFactorsResponse), ) + return model_validate(AuthMFAAdminListFactorsResponse, response.content) def _delete_factor( self, @@ -191,12 +194,14 @@ def _delete_factor( ) -> AuthMFAAdminDeleteFactorResponse: self._validate_uuid(params.get("user_id")) self._validate_uuid(params.get("id")) - return self._request( + response = self._request( "DELETE", f"admin/users/{params.get('user_id')}/factors/{params.get('id')}", - xform=partial(model_validate, AuthMFAAdminDeleteFactorResponse), ) + return model_validate(AuthMFAAdminDeleteFactorResponse, response.content) - def _validate_uuid(self, id: str) -> None: + def _validate_uuid(self, id: str | None) -> None: + if id is None: + raise ValueError("Invalid id, id cannot be none") if not is_valid_uuid(id): raise ValueError(f"Invalid id, '{id}' is not a valid uuid") diff --git a/src/auth/src/supabase_auth/_sync/gotrue_base_api.py b/src/auth/src/supabase_auth/_sync/gotrue_base_api.py index c6c2b7b0..dbb8b171 100644 --- a/src/auth/src/supabase_auth/_sync/gotrue_base_api.py +++ b/src/auth/src/supabase_auth/_sync/gotrue_base_api.py @@ -2,16 +2,14 @@ from typing import Any, Callable, Dict, Optional, TypeVar, overload -from httpx import Response +from httpx import HTTPStatusError, QueryParams, Response from pydantic import BaseModel from typing_extensions import Literal, Self -from ..constants import API_VERSION_HEADER_NAME, API_VERSIONS +from ..constants import API_VERSION_HEADER_NAME, API_VERSIONS_2024_01_01_NAME from ..helpers import handle_exception, model_dump from ..http_clients import SyncClient -T = TypeVar("T") - class SyncGoTrueBaseAPI: def __init__( @@ -41,7 +39,6 @@ def __exit__(self, exc_t, exc_v, exc_tb) -> None: def close(self) -> None: self._http_client.aclose() - @overload def _request( self, method: Literal["GET", "OPTIONS", "HEAD", "POST", "PUT", "PATCH", "DELETE"], @@ -50,65 +47,21 @@ def _request( jwt: Optional[str] = None, redirect_to: Optional[str] = None, headers: Optional[Dict[str, str]] = None, - query: Optional[Dict[str, str]] = None, - body: Optional[Any] = None, - no_resolve_json: Literal[False] = False, - xform: Callable[[Any], T], - ) -> T: ... # pragma: no cover - - @overload - def _request( - self, - method: Literal["GET", "OPTIONS", "HEAD", "POST", "PUT", "PATCH", "DELETE"], - path: str, - *, - jwt: Optional[str] = None, - redirect_to: Optional[str] = None, - headers: Optional[Dict[str, str]] = None, - query: Optional[Dict[str, str]] = None, - body: Optional[Any] = None, - no_resolve_json: Literal[True], - xform: Callable[[Response], T], - ) -> T: ... # pragma: no cover - - @overload - def _request( - self, - method: Literal["GET", "OPTIONS", "HEAD", "POST", "PUT", "PATCH", "DELETE"], - path: str, - *, - jwt: Optional[str] = None, - redirect_to: Optional[str] = None, - headers: Optional[Dict[str, str]] = None, - query: Optional[Dict[str, str]] = None, - body: Optional[Any] = None, - no_resolve_json: bool = False, - ) -> None: ... # pragma: no cover - - def _request( - self, - method: Literal["GET", "OPTIONS", "HEAD", "POST", "PUT", "PATCH", "DELETE"], - path: str, - *, - jwt: Optional[str] = None, - redirect_to: Optional[str] = None, - headers: Optional[Dict[str, str]] = None, - query: Optional[Dict[str, str]] = None, + query: Optional[QueryParams] = None, body: Optional[Any] = None, no_resolve_json: bool = False, - xform: Optional[Callable[[Any], T]] = None, - ) -> Optional[T]: + ) -> Response: url = f"{self._url}/{path}" headers = {**self._headers, **(headers or {})} if API_VERSION_HEADER_NAME not in headers: - headers[API_VERSION_HEADER_NAME] = API_VERSIONS["2024-01-01"].get("name") + headers[API_VERSION_HEADER_NAME] = API_VERSIONS_2024_01_01_NAME if "Content-Type" not in headers: headers["Content-Type"] = "application/json;charset=UTF-8" if jwt: headers["Authorization"] = f"Bearer {jwt}" - query = query or {} + query = query or QueryParams() if redirect_to: - query["redirect_to"] = redirect_to + query = query.set("redirect_to", redirect_to) try: response = self._http_client.request( method, @@ -117,9 +70,8 @@ def _request( params=query, json=model_dump(body) if isinstance(body, BaseModel) else body, ) + response.raise_for_status() - result = response if no_resolve_json else response.json() - if xform: - return xform(result) - except Exception as e: + return response + except (HTTPStatusError, RuntimeError) as e: raise handle_exception(e) diff --git a/src/auth/src/supabase_auth/_sync/gotrue_client.py b/src/auth/src/supabase_auth/_sync/gotrue_client.py index a575f57b..0716260f 100644 --- a/src/auth/src/supabase_auth/_sync/gotrue_client.py +++ b/src/auth/src/supabase_auth/_sync/gotrue_client.py @@ -4,11 +4,13 @@ from contextlib import suppress from functools import partial from json import loads -from typing import Callable, Dict, List, Optional, Tuple +from typing import Callable, Dict, List, Mapping, Optional, Tuple, Union from urllib.parse import parse_qs, urlencode, urlparse from uuid import uuid4 +from httpx import QueryParams from jwt import get_algorithm_by_name +from typing_extensions import cast from ..constants import ( DEFAULT_HEADERS, @@ -25,6 +27,7 @@ AuthInvalidJwtError, AuthRetryableError, AuthSessionMissingError, + UserDoesntExist, ) from ..helpers import ( decode_jwt, @@ -45,6 +48,7 @@ from ..timer import Timer from ..types import ( JWK, + AMREntry, AuthChangeEvent, AuthenticatorAssuranceLevels, AuthFlowType, @@ -71,13 +75,17 @@ ResendCredentials, Session, SignInAnonymouslyCredentials, + SignInWithEmailAndPasswordlessCredentialsOptions, SignInWithIdTokenCredentials, SignInWithOAuthCredentials, SignInWithPasswordCredentials, SignInWithPasswordlessCredentials, + SignInWithPhoneAndPasswordlessCredentialsOptions, SignInWithSSOCredentials, SignOutOptions, + SignUpWithEmailAndPasswordCredentialsOptions, SignUpWithPasswordCredentials, + SignUpWithPhoneAndPasswordCredentialsOptions, Subscription, UpdateUserOptions, UserAttributes, @@ -134,16 +142,17 @@ def __init__( headers=self._headers, http_client=self._http_client, ) + # TODO(@o-santi): why is it like this? self.mfa = SyncGoTrueMFAAPI() - self.mfa.challenge = self._challenge - self.mfa.challenge_and_verify = self._challenge_and_verify - self.mfa.enroll = self._enroll - self.mfa.get_authenticator_assurance_level = ( + self.mfa.challenge = self._challenge # type: ignore + self.mfa.challenge_and_verify = self._challenge_and_verify # type: ignore + self.mfa.enroll = self._enroll # type: ignore + self.mfa.get_authenticator_assurance_level = ( # type: ignore self._get_authenticator_assurance_level ) - self.mfa.list_factors = self._list_factors - self.mfa.unenroll = self._unenroll - self.mfa.verify = self._verify + self.mfa.list_factors = self._list_factors # type: ignore + self.mfa.unenroll = self._unenroll # type: ignore + self.mfa.verify = self._verify # type: ignore # Initializations @@ -191,12 +200,12 @@ def sign_in_anonymously( "captcha_token": captcha_token, }, }, - xform=parse_auth_response, ) - if response.session: - self._save_session(response.session) - self._notify_all_subscribers("SIGNED_IN", response.session) - return response + auth_response = parse_auth_response(response) + if auth_response.session: + self._save_session(auth_response.session) + self._notify_all_subscribers("SIGNED_IN", auth_response.session) + return auth_response def sign_up( self, @@ -209,12 +218,17 @@ def sign_up( email = credentials.get("email") phone = credentials.get("phone") password = credentials.get("password") - options = credentials.get("options", {}) - redirect_to = options.get("redirect_to") or options.get("email_redirect_to") - data = options.get("data") or {} - channel = options.get("channel", "sms") - captcha_token = options.get("captcha_token") - if email: + # TODO(@o-santi): this is horrible, but it is the easiest way to satisfy mypy + # it should have been a builder pattern instead, and with proper classes + if email and password: + email_options = cast( + SignUpWithEmailAndPasswordCredentialsOptions, + credentials.get("options", {}), + ) + data = email_options.get("data") or {} + channel = email_options.get("channel", "sms") + captcha_token = email_options.get("captcha_token") + redirect_to = email_options.get("email_redirect_to") response = self._request( "POST", "signup", @@ -227,9 +241,15 @@ def sign_up( }, }, redirect_to=redirect_to, - xform=parse_auth_response, ) - elif phone: + elif phone and password: + phone_options = cast( + SignUpWithPhoneAndPasswordCredentialsOptions, + credentials.get("options", {}), + ) + data = phone_options.get("data") or {} + channel = phone_options.get("channel", "sms") + captcha_token = phone_options.get("captcha_token") response = self._request( "POST", "signup", @@ -242,16 +262,17 @@ def sign_up( "captcha_token": captcha_token, }, }, - xform=parse_auth_response, ) else: raise AuthInvalidCredentialsError( "You must provide either an email or phone number and a password" ) - if response.session: - self._save_session(response.session) - self._notify_all_subscribers("SIGNED_IN", response.session) - return response + + auth_response = parse_auth_response(response) + if auth_response.session: + self._save_session(auth_response.session) + self._notify_all_subscribers("SIGNED_IN", auth_response.session) + return auth_response def sign_in_with_password( self, @@ -267,7 +288,7 @@ def sign_in_with_password( options = credentials.get("options", {}) data = options.get("data") or {} captcha_token = options.get("captcha_token") - if email: + if email and password: response = self._request( "POST", "token", @@ -279,12 +300,9 @@ def sign_in_with_password( "captcha_token": captcha_token, }, }, - query={ - "grant_type": "password", - }, - xform=parse_auth_response, + query=QueryParams(grant_type="password"), ) - elif phone: + elif phone and password: response = self._request( "POST", "token", @@ -296,19 +314,17 @@ def sign_in_with_password( "captcha_token": captcha_token, }, }, - query={ - "grant_type": "password", - }, - xform=parse_auth_response, + query=QueryParams(grant_type="password"), ) else: raise AuthInvalidCredentialsError( "You must provide either an email or phone number and a password" ) - if response.session: - self._save_session(response.session) - self._notify_all_subscribers("SIGNED_IN", response.session) - return response + auth_response = parse_auth_response(response) + if auth_response.session: + self._save_session(auth_response.session) + self._notify_all_subscribers("SIGNED_IN", auth_response.session) + return auth_response def sign_in_with_id_token( self, @@ -318,8 +334,8 @@ def sign_in_with_id_token( Allows signing in with an OIDC ID token. The authentication provider used should be enabled and configured. """ self._remove_session() - provider = credentials.get("provider") - token = credentials.get("token") + provider = credentials["provider"] + token = credentials["token"] access_token = credentials.get("access_token") nonce = credentials.get("nonce") options = credentials.get("options", {}) @@ -337,16 +353,13 @@ def sign_in_with_id_token( "captcha_token": captcha_token, }, }, - query={ - "grant_type": "id_token", - }, - xform=parse_auth_response, + query=QueryParams(grant_type="id_token"), ) - - if response.session: - self._save_session(response.session) - self._notify_all_subscribers("SIGNED_IN", response.session) - return response + auth_response = parse_auth_response(response) + if auth_response.session: + self._save_session(auth_response.session) + self._notify_all_subscribers("SIGNED_IN", auth_response.session) + return auth_response def sign_in_with_sso(self, credentials: SignInWithSSOCredentials): """ @@ -370,11 +383,11 @@ def sign_in_with_sso(self, credentials: SignInWithSSOCredentials): captcha_token = options.get("captcha_token") # HTTPX currently does not follow redirects: https://www.python-httpx.org/compatibility/ # Additionally, unlike the JS client, Python is a server side language and it's not possible - # to automatically redirect in browser for hte user + # to automatically redirect in browser for the user skip_http_redirect = options.get("skip_http_redirect", True) if domain: - return self._request( + response = self._request( "POST", "sso", body={ @@ -385,10 +398,10 @@ def sign_in_with_sso(self, credentials: SignInWithSSOCredentials): }, "redirect_to": redirect_to, }, - xform=parse_sso_response, ) + return parse_sso_response(response) if provider_id: - return self._request( + response = self._request( "POST", "sso", body={ @@ -399,8 +412,8 @@ def sign_in_with_sso(self, credentials: SignInWithSSOCredentials): }, "redirect_to": redirect_to, }, - xform=parse_sso_response, ) + return parse_sso_response(response) raise AuthInvalidCredentialsError( "You must provide either a domain or provider_id" ) @@ -414,7 +427,7 @@ def sign_in_with_oauth( """ self._remove_session() - provider = credentials.get("provider") + provider = credentials["provider"] options = credentials.get("options", {}) redirect_to = options.get("redirect_to") scopes = options.get("scopes") @@ -429,7 +442,7 @@ def sign_in_with_oauth( return OAuthResponse(provider=provider, url=url_with_qs) def link_identity(self, credentials: SignInWithOAuthCredentials) -> OAuthResponse: - provider = credentials.get("provider") + provider = credentials["provider"] options = credentials.get("options", {}) redirect_to = options.get("redirect_to") scopes = options.get("scopes") @@ -451,17 +464,15 @@ def link_identity(self, credentials: SignInWithOAuthCredentials) -> OAuthRespons path=url, query=query, jwt=session.access_token, - xform=parse_link_identity_response, ) - return OAuthResponse(provider=provider, url=response.url) + link_identity = parse_link_identity_response(response) + return OAuthResponse(provider=provider, url=link_identity.url) - def get_user_identities(self): + def get_user_identities(self) -> IdentitiesResponse: response = self.get_user() - return ( - IdentitiesResponse(identities=response.user.identities) - if response.user - else AuthSessionMissingError() - ) + if response: + return IdentitiesResponse(identities=response.user.identities or []) + raise AuthSessionMissingError() def unlink_identity(self, identity: UserIdentity): session = self.get_session() @@ -493,14 +504,19 @@ def sign_in_with_otp( self._remove_session() email = credentials.get("email") phone = credentials.get("phone") - options = credentials.get("options", {}) - email_redirect_to = options.get("email_redirect_to") - should_create_user = options.get("should_create_user", True) - data = options.get("data") - channel = options.get("channel", "sms") - captcha_token = options.get("captcha_token") + # TODO(@o-santi): this is horrible, but it is the easiest way to satisfy mypy + # it should have been a builder pattern instead, and with proper classes if email: - return self._request( + email_options = cast( + SignInWithEmailAndPasswordlessCredentialsOptions, + credentials.get("options", {}), + ) + email_redirect_to = email_options.get("email_redirect_to") + should_create_user = email_options.get("should_create_user", True) + data = email_options.get("data") + channel = email_options.get("channel", "sms") + captcha_token = email_options.get("captcha_token") + response = self._request( "POST", "otp", body={ @@ -512,10 +528,18 @@ def sign_in_with_otp( }, }, redirect_to=email_redirect_to, - xform=parse_auth_otp_response, ) + return parse_auth_otp_response(response) if phone: - return self._request( + phone_options = cast( + SignInWithPhoneAndPasswordlessCredentialsOptions, + credentials.get("options", {}), + ) + should_create_user = phone_options.get("should_create_user", True) + data = phone_options.get("data") + channel = phone_options.get("channel", "sms") + captcha_token = phone_options.get("captcha_token") + response = self._request( "POST", "otp", body={ @@ -527,8 +551,8 @@ def sign_in_with_otp( "captcha_token": captcha_token, }, }, - xform=parse_auth_otp_response, ) + return parse_auth_otp_response(response) raise AuthInvalidCredentialsError( "You must provide either an email or phone number" ) @@ -544,9 +568,9 @@ def resend( phone = credentials.get("phone") type = credentials.get("type") options = credentials.get("options", {}) - email_redirect_to = options.get("email_redirect_to") + email_redirect_to: Optional[str] = options.get("email_redirect_to") # type: ignore captcha_token = options.get("captcha_token") - body = { + body: Dict[str, object] = { # improve later "type": type, "gotrue_meta_security": { "captcha_token": captcha_token, @@ -560,13 +584,13 @@ def resend( body.update({"email": email} if email else {"phone": phone}) - return self._request( + response = self._request( "POST", "resend", body=body, redirect_to=email_redirect_to if email else None, - xform=parse_auth_otp_response, ) + return parse_auth_otp_response(response) def verify_otp(self, params: VerifyOtpParams) -> AuthResponse: """ @@ -583,24 +607,24 @@ def verify_otp(self, params: VerifyOtpParams) -> AuthResponse: **params, }, redirect_to=params.get("options", {}).get("redirect_to"), - xform=parse_auth_response, ) - if response.session: - self._save_session(response.session) - self._notify_all_subscribers("SIGNED_IN", response.session) - return response + auth_response = parse_auth_response(response) + if auth_response.session: + self._save_session(auth_response.session) + self._notify_all_subscribers("SIGNED_IN", auth_response.session) + return auth_response def reauthenticate(self) -> AuthResponse: session = self.get_session() if not session: raise AuthSessionMissingError() - return self._request( + response = self._request( "GET", "reauthenticate", jwt=session.access_token, - xform=parse_auth_response, ) + return AuthResponse(user=None, session=None) def get_session(self) -> Optional[Session]: """ @@ -617,6 +641,7 @@ def get_session(self) -> Optional[Session]: self._remove_session() else: current_session = self._in_memory_session + if not current_session: return None time_now = round(time.time()) @@ -644,7 +669,7 @@ def get_user(self, jwt: Optional[str] = None) -> Optional[UserResponse]: jwt = session.access_token else: return None - return self._request("GET", "user", jwt=jwt, xform=parse_user_response) + return parse_user_response(self._request("GET", "user", jwt=jwt)) def update_user( self, attributes: UserAttributes, options: UpdateUserOptions = {} @@ -661,12 +686,12 @@ def update_user( body=attributes, redirect_to=options.get("email_redirect_to"), jwt=session.access_token, - xform=parse_user_response, ) - session.user = response.user + user_response = parse_user_response(response) + session.user = user_response.user self._save_session(session) self._notify_all_subscribers("USER_UPDATED", session) - return response + return user_response def set_session(self, access_token: str, refresh_token: str) -> AuthResponse: """ @@ -701,18 +726,20 @@ def set_session(self, access_token: str, refresh_token: str) -> AuthResponse: return AuthResponse() session = response.session else: - response = self.get_user(access_token) + user_response = self.get_user(access_token) + if user_response is None: + raise UserDoesntExist(access_token) session = Session( access_token=access_token, refresh_token=refresh_token, - user=response.user, + user=user_response.user, token_type="bearer", expires_in=expires_at - time_now, expires_at=expires_at, ) self._save_session(session) self._notify_all_subscribers("TOKEN_REFRESHED", session) - return AuthResponse(session=session, user=response.user) + return AuthResponse(session=session, user=session.user) def refresh_session(self, refresh_token: Optional[str] = None) -> AuthResponse: """ @@ -819,23 +846,25 @@ def _enroll(self, params: MFAEnrollParams) -> AuthMFAEnrollResponse: "factors", body=body, jwt=session.access_token, - xform=partial(model_validate, AuthMFAEnrollResponse), ) - if params["factor_type"] == "totp" and response.totp.qr_code: - response.totp.qr_code = f"data:image/svg+xml;utf-8,{response.totp.qr_code}" - return response + auth_response = model_validate(AuthMFAEnrollResponse, response.content) + if params["factor_type"] == "totp" and auth_response.totp: + auth_response.totp.qr_code = ( + f"data:image/svg+xml;utf-8,{auth_response.totp.qr_code}" + ) + return auth_response def _challenge(self, params: MFAChallengeParams) -> AuthMFAChallengeResponse: session = self.get_session() if not session: raise AuthSessionMissingError() - return self._request( + response = self._request( "POST", f"factors/{params.get('factor_id')}/challenge", body={"channel": params.get("channel")}, jwt=session.access_token, - xform=partial(model_validate, AuthMFAChallengeResponse), ) + return model_validate(AuthMFAChallengeResponse, response.content) def _challenge_and_verify( self, @@ -843,14 +872,14 @@ def _challenge_and_verify( ) -> AuthMFAVerifyResponse: response = self._challenge( { - "factor_id": params.get("factor_id"), + "factor_id": params["factor_id"], } ) return self._verify( { - "factor_id": params.get("factor_id"), + "factor_id": params["factor_id"], "challenge_id": response.id, - "code": params.get("code"), + "code": params["code"], } ) @@ -863,30 +892,34 @@ def _verify(self, params: MFAVerifyParams) -> AuthMFAVerifyResponse: f"factors/{params.get('factor_id')}/verify", body=params, jwt=session.access_token, - xform=partial(model_validate, AuthMFAVerifyResponse), ) - session = model_validate(Session, model_dump(response)) + auth_response = model_validate(AuthMFAVerifyResponse, response.content) + session = model_validate(Session, response.content) self._save_session(session) self._notify_all_subscribers("MFA_CHALLENGE_VERIFIED", session) - return response + return auth_response def _unenroll(self, params: MFAUnenrollParams) -> AuthMFAUnenrollResponse: session = self.get_session() if not session: raise AuthSessionMissingError() - return self._request( + response = self._request( "DELETE", f"factors/{params.get('factor_id')}", jwt=session.access_token, - xform=partial(model_validate, AuthMFAUnenrollResponse), ) + return model_validate(AuthMFAUnenrollResponse, response.content) def _list_factors(self) -> AuthMFAListFactorsResponse: response = self.get_user() - all = response.user.factors or [] - totp = [f for f in all if f.factor_type == "totp" and f.status == "verified"] - phone = [f for f in all if f.factor_type == "phone" and f.status == "verified"] - return AuthMFAListFactorsResponse(all=all, totp=totp, phone=phone) + factors = response.user.factors or [] if response else [] + totp = [ + f for f in factors if f.factor_type == "totp" and f.status == "verified" + ] + phone = [ + f for f in factors if f.factor_type == "phone" and f.status == "verified" + ] + return AuthMFAListFactorsResponse(all=factors, totp=totp, phone=phone) def _get_authenticator_assurance_level( self, @@ -899,14 +932,15 @@ def _get_authenticator_assurance_level( current_authentication_methods=[], ) payload = decode_jwt(session.access_token)["payload"] - current_level: Optional[AuthenticatorAssuranceLevels] = None - if payload.get("aal"): - current_level = payload.get("aal") + current_level = payload.get("aal") verified_factors = [ f for f in session.user.factors or [] if f.status == "verified" ] next_level = "aal2" if verified_factors else current_level - current_authentication_methods = payload.get("amr") or [] + amr_dict_list = payload.get("amr") or [] + current_authentication_methods = [ + AMREntry.model_validate(amr) for amr in amr_dict_list + ] return AuthMFAGetAuthenticatorAssuranceLevelResponse( current_level=current_level, next_level=next_level, @@ -961,6 +995,8 @@ def _get_session_from_url( time_now = round(time.time()) expires_at = time_now + int(expires_in) user = self.get_user(access_token) + if user is None: + raise UserDoesntExist(access_token) session = Session( provider_token=provider_token, provider_refresh_token=provider_refresh_token, @@ -1020,13 +1056,13 @@ def _call_refresh_token(self, refresh_token: str) -> Session: return response.session def _refresh_access_token(self, refresh_token: str) -> AuthResponse: - return self._request( + response = self._request( "POST", "token", - query={"grant_type": "refresh_token"}, + query=QueryParams(grant_type="refresh_token"), body={"refresh_token": refresh_token}, - xform=parse_auth_response, ) + return parse_auth_response(response) def _save_session(self, session: Session) -> None: if not self._persist_session: @@ -1083,22 +1119,11 @@ def _get_valid_session( ) -> Optional[Session]: if not raw_session: return None - data = loads(raw_session) - if not data: - return None - if not data.get("access_token"): - return None - if not data.get("refresh_token"): - return None - if not data.get("expires_at"): - return None - try: - expires_at = int(data["expires_at"]) - data["expires_at"] = expires_at - except ValueError: - return None try: - return model_validate(Session, data) + session = model_validate(Session, raw_session) + if session.expires_at is None: + return None + return session except Exception: return None @@ -1119,7 +1144,8 @@ def _get_url_for_provider( url: str, provider: Provider, params: Dict[str, str], - ) -> Tuple[str, Dict[str, str]]: + ) -> Tuple[str, QueryParams]: + query = QueryParams(params) if self._flow_type == "pkce": code_verifier = generate_pkce_verifier() code_challenge = generate_pkce_challenge(code_verifier) @@ -1127,12 +1153,11 @@ def _get_url_for_provider( code_challenge_method = ( "plain" if code_verifier == code_challenge else "s256" ) - params["code_challenge"] = code_challenge - params["code_challenge_method"] = code_challenge_method - - params["provider"] = provider - query = urlencode(params) - return f"{url}?{query}", params + query = query.set("code_challenge", code_challenge).set( + "code_challenge_method", code_challenge_method + ) + query = query.set("provider", provider) + return f"{url}?{query}", query def exchange_code_for_session(self, params: CodeExchangeParams): code_verifier = params.get("code_verifier") or self._storage.get_item( @@ -1141,18 +1166,18 @@ def exchange_code_for_session(self, params: CodeExchangeParams): response = self._request( "POST", "token", - query={"grant_type": "pkce"}, + query=QueryParams(grant_type="pkce"), body={ "auth_code": params.get("auth_code"), "code_verifier": code_verifier, }, redirect_to=params.get("redirect_to"), - xform=parse_auth_response, ) + auth_response = parse_auth_response(response) self._storage.remove_item(f"{self._storage_key}-code-verifier") - if response.session: - self._save_session(response.session) - self._notify_all_subscribers("SIGNED_IN", response.session) + if auth_response.session: + self._save_session(auth_response.session) + self._notify_all_subscribers("SIGNED_IN", auth_response.session) return response def _fetch_jwks(self, kid: str, jwks: JWKSet) -> JWK: @@ -1176,13 +1201,14 @@ def _fetch_jwks(self, kid: str, jwks: JWKSet) -> JWK: return jwk # jwk isn't cached in memory so we need to fetch it from the well-known endpoint - response = self._request("GET", ".well-known/jwks.json", xform=parse_jwks) + response = self._request("GET", ".well-known/jwks.json") + jwks = parse_jwks(response) if response: - self._jwks = response + self._jwks = jwks self._jwks_cached_at = time.time() # find the signing key - jwk = next((jwk for jwk in response["keys"] if jwk["kid"] == kid), None) + jwk = next((jwk for jwk in jwks["keys"] if jwk["kid"] == kid), None) if not jwk: raise AuthInvalidJwtError("No matching signing key found in JWKS") @@ -1221,9 +1247,8 @@ def get_claims( return ClaimsResponse(claims=payload, headers=header, signature=signature) algorithm = get_algorithm_by_name(header["alg"]) - signing_key = algorithm.from_jwk( - self._fetch_jwks(header["kid"], jwks or {"keys": []}) - ) + jwk_set = self._fetch_jwks(header["kid"], jwks or {"keys": []}) + signing_key = algorithm.from_jwk(cast(Dict[str, str], jwk_set)) # verify the signature is_valid = algorithm.verify( diff --git a/src/auth/src/supabase_auth/_sync/storage.py b/src/auth/src/supabase_auth/_sync/storage.py index 03ede0c1..2557d5db 100644 --- a/src/auth/src/supabase_auth/_sync/storage.py +++ b/src/auth/src/supabase_auth/_sync/storage.py @@ -16,12 +16,13 @@ def remove_item(self, key: str) -> None: ... # pragma: no cover class SyncMemoryStorage(SyncSupportedStorage): - def __init__(self): + def __init__(self) -> None: self.storage: Dict[str, str] = {} def get_item(self, key: str) -> Optional[str]: if key in self.storage: return self.storage[key] + return None def set_item(self, key: str, value: str) -> None: self.storage[key] = value diff --git a/src/auth/src/supabase_auth/constants.py b/src/auth/src/supabase_auth/constants.py index 671510e5..a3e4afec 100644 --- a/src/auth/src/supabase_auth/constants.py +++ b/src/auth/src/supabase_auth/constants.py @@ -15,10 +15,8 @@ STORAGE_KEY = "supabase.auth.token" API_VERSION_HEADER_NAME = "X-Supabase-Api-Version" -API_VERSIONS = { - "2024-01-01": { - "timestamp": datetime.timestamp(datetime.strptime("2024-01-01", "%Y-%m-%d")), - "name": "2024-01-01", - }, -} +API_VERSIONS_2024_01_01_TIMESTAMP = datetime.timestamp( + datetime.strptime("2024-01-01", "%Y-%m-%d") +) +API_VERSIONS_2024_01_01_NAME = "2024-01-01" BASE64URL_REGEX = r"^([a-z0-9_-]{4})*($|[a-z0-9_-]{3}$|[a-z0-9_-]{2}$)$" diff --git a/src/auth/src/supabase_auth/errors.py b/src/auth/src/supabase_auth/errors.py index cc85b87e..67eca908 100644 --- a/src/auth/src/supabase_auth/errors.py +++ b/src/auth/src/supabase_auth/errors.py @@ -86,11 +86,17 @@ "invalid_credentials", "email_address_not_authorized", "email_address_invalid", + "invalid_jwt", ] +class UserDoesntExist(Exception): + def __init__(self, access_token: str): + self.access_token = access_token + + class AuthError(Exception): - def __init__(self, message: str, code: ErrorCode) -> None: + def __init__(self, message: str, code: ErrorCode | None) -> None: Exception.__init__(self, message) self.message = message self.name = "AuthError" @@ -101,11 +107,11 @@ class AuthApiErrorDict(TypedDict): name: str message: str status: int - code: ErrorCode + code: ErrorCode | None class AuthApiError(AuthError): - def __init__(self, message: str, status: int, code: ErrorCode) -> None: + def __init__(self, message: str, status: int, code: Optional[ErrorCode]) -> None: AuthError.__init__(self, message, code) self.name = "AuthApiError" self.status = status @@ -128,7 +134,9 @@ def __init__(self, message: str, original_error: Exception) -> None: class CustomAuthError(AuthError): - def __init__(self, message: str, name: str, status: int, code: ErrorCode) -> None: + def __init__( + self, message: str, name: str, status: int, code: Optional[ErrorCode] + ) -> None: AuthError.__init__(self, message, code) self.name = name self.status = status @@ -138,6 +146,7 @@ def to_dict(self) -> AuthApiErrorDict: "name": self.name, "message": self.message, "status": self.status, + "code": self.code, } @@ -193,6 +202,7 @@ def to_dict(self) -> AuthImplicitGrantRedirectErrorDict: "message": self.message, "status": self.status, "details": self.details, + "code": self.code, } @@ -207,6 +217,10 @@ def __init__(self, message: str, status: int) -> None: ) +class AuthApiErrorWithReasonsDict(AuthApiErrorDict): + reasons: List[str] + + class AuthWeakPasswordError(CustomAuthError): def __init__(self, message: str, status: int, reasons: List[str]) -> None: CustomAuthError.__init__( @@ -218,12 +232,13 @@ def __init__(self, message: str, status: int, reasons: List[str]) -> None: ) self.reasons = reasons - def to_dict(self) -> AuthApiErrorDict: + def to_dict(self) -> AuthApiErrorWithReasonsDict: return { "name": self.name, "message": self.message, "status": self.status, "reasons": self.reasons, + "code": self.code, } diff --git a/src/auth/src/supabase_auth/helpers.py b/src/auth/src/supabase_auth/helpers.py index 7f9df7c8..a0ee444e 100644 --- a/src/auth/src/supabase_auth/helpers.py +++ b/src/auth/src/supabase_auth/helpers.py @@ -1,6 +1,7 @@ from __future__ import annotations import base64 +import binascii import hashlib import re import secrets @@ -8,14 +9,17 @@ import uuid from base64 import urlsafe_b64decode from datetime import datetime -from json import loads -from typing import Any, Dict, Optional, Type, TypedDict, TypeVar, cast +from typing import Any, Dict, Optional, Type, TypedDict, TypeVar, Union, cast from urllib.parse import urlparse from httpx import HTTPStatusError, Response -from pydantic import BaseModel +from pydantic import BaseModel, TypeAdapter -from .constants import API_VERSION_HEADER_NAME, API_VERSIONS, BASE64URL_REGEX +from .constants import ( + API_VERSION_HEADER_NAME, + API_VERSIONS_2024_01_01_TIMESTAMP, + BASE64URL_REGEX, +) from .errors import ( AuthApiError, AuthError, @@ -42,15 +46,15 @@ TBaseModel = TypeVar("TBaseModel", bound=BaseModel) -def model_validate(model: Type[TBaseModel], contents) -> TBaseModel: +def model_validate(model: Type[TBaseModel], contents: Union[str, bytes]) -> TBaseModel: """Compatibility layer between pydantic 1 and 2 for parsing an instance of a BaseModel from varied""" try: # pydantic > 2 - return model.model_validate(contents) + return model.model_validate_json(contents) except AttributeError: # pydantic < 2 - return model.parse_obj(contents) + return model.parse_raw(contents) def model_dump(model: BaseModel) -> Dict[str, Any]: @@ -73,59 +77,51 @@ def model_dump_json(model: BaseModel) -> str: return model.json() -def parse_auth_response(data: Any) -> AuthResponse: - session: Optional[Session] = None - if ( - "access_token" in data - and "refresh_token" in data - and "expires_in" in data - and data["access_token"] - and data["refresh_token"] - and data["expires_in"] - ): - session = model_validate(Session, data) - user_data = data.get("user", data) - user = model_validate(User, user_data) if user_data else None - return AuthResponse(session=session, user=user) +def parse_auth_response(response: Response) -> AuthResponse: + try: + session = model_validate(Session, response.content) + user = session.user + except: + session = None + user = model_validate(User, response.content) + return AuthResponse(user=user, session=session) -def parse_auth_otp_response(data: Any) -> AuthOtpResponse: - return model_validate(AuthOtpResponse, data) +def parse_auth_otp_response(response: Response) -> AuthOtpResponse: + return model_validate(AuthOtpResponse, response.content) -def parse_link_identity_response(data: Any) -> LinkIdentityResponse: - return model_validate(LinkIdentityResponse, data) +def parse_link_identity_response(response: Response) -> LinkIdentityResponse: + return model_validate(LinkIdentityResponse, response.content) -def parse_link_response(data: Any) -> GenerateLinkResponse: - properties = GenerateLinkProperties( - action_link=data.get("action_link"), - email_otp=data.get("email_otp"), - hashed_token=data.get("hashed_token"), - redirect_to=data.get("redirect_to"), - verification_type=data.get("verification_type"), - ) - user = model_validate( - User, {k: v for k, v in data.items() if k not in model_dump(properties)} - ) +def parse_link_response(response: Response) -> GenerateLinkResponse: + properties = model_validate(GenerateLinkProperties, response.content) + user = model_validate(User, response.content) return GenerateLinkResponse(properties=properties, user=user) -def parse_user_response(data: Any) -> UserResponse: - if "user" not in data: - data = {"user": data} - return model_validate(UserResponse, data) +UserParser: TypeAdapter = TypeAdapter(Union[UserResponse, User]) + +def parse_user_response(response: Response) -> UserResponse: + parsed = UserParser.validate_json(response.content) + return UserResponse(user=parsed) if isinstance(parsed, User) else parsed -def parse_sso_response(data: Any) -> SSOResponse: - return model_validate(SSOResponse, data) +def parse_sso_response(response: Response) -> SSOResponse: + return model_validate(SSOResponse, response.content) -def parse_jwks(response: Any) -> JWKSet: - if "keys" not in response or len(response["keys"]) == 0: + +JWKSetParser = TypeAdapter(JWKSet) + + +def parse_jwks(response: Response) -> JWKSet: + jwk = JWKSetParser.validate_json(response.content) + if len(jwk["keys"]) == 0: raise AuthInvalidJwtError("JWKS is empty") - return {"keys": response["keys"]} + return jwk def get_error_message(error: Any) -> str: @@ -136,18 +132,9 @@ def get_error_message(error: Any) -> str: return next((error[prop] for prop in props if filter(prop)), str(error)) -def get_error_code(error: Any) -> str: - return error.get("error_code", None) if isinstance(error, dict) else None - - -def looks_like_http_status_error(exception: Exception) -> bool: - return isinstance(exception, HTTPStatusError) - - -def handle_exception(exception: Exception) -> AuthError: - if not looks_like_http_status_error(exception): - return AuthRetryableError(get_error_message(exception), 0) - error = cast(HTTPStatusError, exception) +def handle_exception(error: HTTPStatusError | RuntimeError) -> AuthError: + if not isinstance(error, HTTPStatusError): + return AuthRetryableError(get_error_message(error), 0) try: network_error_codes = [502, 503, 504] if error.response.status_code in network_error_codes: @@ -161,8 +148,10 @@ def handle_exception(exception: Exception) -> AuthError: if ( response_api_version - and datetime.timestamp(response_api_version) - >= API_VERSIONS.get("2024-01-01").get("timestamp") + and ( + datetime.timestamp(response_api_version) + >= API_VERSIONS_2024_01_01_TIMESTAMP + ) and isinstance(data, dict) and data and isinstance(data.get("code"), str) @@ -180,18 +169,18 @@ def handle_exception(exception: Exception) -> AuthError: and isinstance(data.get("weak_password"), dict) and data.get("weak_password") and isinstance(data.get("weak_password"), list) - and len(data.get("weak_password")) + and len(data["weak_password"]) ): return AuthWeakPasswordError( get_error_message(data), error.response.status_code, - data.get("weak_password").get("reasons"), + data["weak_password"].get("reasons"), ) elif error_code == "weak_password": return AuthWeakPasswordError( get_error_message(data), error.response.status_code, - data.get("weak_password", {}).get("reasons", {}), + data["weak_password"].get("reasons", {}), ) return AuthApiError( @@ -224,20 +213,26 @@ class DecodedJWT(TypedDict): raw: Dict[str, str] +JWTHeaderParser = TypeAdapter(JWTHeader) +JWTPayloadParser = TypeAdapter(JWTPayload) + + def decode_jwt(token: str) -> DecodedJWT: parts = token.split(".") if len(parts) != 3: raise AuthInvalidJwtError("Invalid JWT structure") - # regex check for base64url - for part in parts: - if not re.match(BASE64URL_REGEX, part, re.IGNORECASE): - raise AuthInvalidJwtError("JWT not in base64url format") + try: + header = base64url_to_bytes(parts[0]) + payload = base64url_to_bytes(parts[1]) + signature = base64url_to_bytes(parts[2]) + except binascii.Error: + raise AuthInvalidJwtError("Invalid JWT structure") return DecodedJWT( - header=JWTHeader(**loads(str_from_base64url(parts[0]))), - payload=JWTPayload(**loads(str_from_base64url(parts[1]))), - signature=base64url_to_bytes(parts[2]), + header=JWTHeaderParser.validate_json(header), + payload=JWTPayloadParser.validate_json(payload), + signature=signature, raw={ "header": parts[0], "payload": parts[1], diff --git a/src/auth/src/supabase_auth/types.py b/src/auth/src/supabase_auth/types.py index c41319c7..3ec9d812 100644 --- a/src/auth/src/supabase_auth/types.py +++ b/src/auth/src/supabase_auth/types.py @@ -4,7 +4,7 @@ from time import time from typing import Any, Callable, Dict, List, Optional, Union -from pydantic import BaseModel, ConfigDict, Field +from pydantic import BaseModel, ConfigDict, Field, with_config try: # > 2 @@ -15,7 +15,7 @@ # < 2 from pydantic import root_validator - model_validator_v1_v2_compat = root_validator + model_validator_v1_v2_compat = root_validator # type: ignore from typing_extensions import Literal, NotRequired, TypedDict @@ -82,6 +82,11 @@ class AMREntry(BaseModel): """ +class AMREntryDict(TypedDict): + timestamp: int + method: Union[Literal["password", "otp", "oauth", "mfa/totp"], str] + + class Options(TypedDict): redirect_to: NotRequired[str] captcha_token: NotRequired[str] @@ -124,6 +129,10 @@ class IdentitiesResponse(BaseModel): identities: List[UserIdentity] +class UserList(BaseModel): + users: List[User] + + class UserResponse(BaseModel): user: User @@ -810,12 +819,14 @@ class SignOutOptions(TypedDict): scope: NotRequired[SignOutScope] +@with_config(extra="allow") class JWTHeader(TypedDict): alg: Literal["RS256", "ES256", "HS256"] typ: str - kid: str + kid: NotRequired[str] +# TODO: useless, only kept for backwards compatibility class RequiredClaims(TypedDict): iss: str sub: str @@ -827,8 +838,17 @@ class RequiredClaims(TypedDict): session_id: str -class JWTPayload(RequiredClaims, total=False): - pass +@with_config(extra="allow") +class JWTPayload(TypedDict, total=False): + iss: str + sub: str + auth: Union[str, List[str]] + exp: int + iat: int + role: str + aal: AuthenticatorAssuranceLevels + session_id: str + amr: NotRequired[List[AMREntryDict]] class ClaimsResponse(TypedDict): @@ -837,6 +857,7 @@ class ClaimsResponse(TypedDict): signature: bytes +@with_config(extra="allow") class JWK(TypedDict, total=False): kty: Literal["RSA", "EC", "oct"] key_ops: List[str] @@ -871,7 +892,7 @@ class JWKSet(TypedDict): ]: try: # pydantic > 2 - model.model_rebuild() + model.model_rebuild() # type: ignore except AttributeError: # pydantic < 2 - model.update_forward_refs() + model.update_forward_refs() # type: ignore diff --git a/src/auth/tests/_async/clients.py b/src/auth/tests/_async/clients.py index 03356714..655873ab 100644 --- a/src/auth/tests/_async/clients.py +++ b/src/auth/tests/_async/clients.py @@ -1,6 +1,87 @@ +from dataclasses import dataclass +from random import random +from time import time +from typing import Optional + +from faker import Faker from jwt import encode +from typing_extensions import NotRequired, TypedDict from supabase_auth import AsyncGoTrueAdminAPI, AsyncGoTrueClient +from supabase_auth.types import User + + +def mock_access_token() -> str: + return encode( + { + "sub": "1234567890", + "role": "anon_key", + }, + GOTRUE_JWT_SECRET, + ) + + +class OptionalCredentials(TypedDict): + email: NotRequired[Optional[str]] + phone: NotRequired[Optional[str]] + password: NotRequired[Optional[str]] + + +@dataclass +class Credentials: + email: str + phone: str + password: str + + +def mock_user_credentials( + options: OptionalCredentials = {}, +) -> Credentials: + fake = Faker() + rand_numbers = str(int(time())) + return Credentials( + email=options.get("email") or fake.email(), + phone=options.get("phone") or f"1{rand_numbers[-11:]}", + password=options.get("password") or fake.password(), + ) + + +def mock_verification_otp() -> str: + return str(int(100000 + random() * 900000)) + + +def mock_user_metadata(): + fake = Faker() + return { + "profile_image": fake.url(), + } + + +def mock_app_metadata(): + return { + "roles": ["editor", "publisher"], + } + + +async def create_new_user_with_email( + *, + email: Optional[str] = None, + password: Optional[str] = None, +) -> User: + credentials = mock_user_credentials( + { + "email": email, + "password": password, + } + ) + response = await service_role_api_client().create_user( + { + "email": credentials.email, + "password": credentials.password, + } + ) + return response.user + SIGNUP_ENABLED_AUTO_CONFIRM_OFF_PORT = 9999 SIGNUP_ENABLED_AUTO_CONFIRM_ON_PORT = 9998 @@ -31,7 +112,7 @@ ) -def auth_client(): +def auth_client() -> AsyncGoTrueClient: return AsyncGoTrueClient( url=GOTRUE_URL_SIGNUP_ENABLED_AUTO_CONFIRM_ON, auto_refresh_token=False, @@ -39,12 +120,15 @@ def auth_client(): ) -def auth_client_with_session(): - return AsyncGoTrueClient( +async def auth_client_with_session() -> AsyncGoTrueClient: + client = AsyncGoTrueClient( url=GOTRUE_URL_SIGNUP_ENABLED_AUTO_CONFIRM_ON, auto_refresh_token=False, persist_session=False, ) + credentials = mock_user_credentials() + await client.sign_up({"email": credentials.email, "password": credentials.password}) + return client def auth_client_with_asymmetric_session() -> AsyncGoTrueClient: @@ -55,7 +139,7 @@ def auth_client_with_asymmetric_session() -> AsyncGoTrueClient: ) -def auth_subscription_client(): +def auth_subscription_client() -> AsyncGoTrueClient: return AsyncGoTrueClient( url=GOTRUE_URL_SIGNUP_ENABLED_AUTO_CONFIRM_ON, auto_refresh_token=False, @@ -63,7 +147,7 @@ def auth_subscription_client(): ) -def client_api_auto_confirm_enabled_client(): +def client_api_auto_confirm_enabled_client() -> AsyncGoTrueClient: return AsyncGoTrueClient( url=GOTRUE_URL_SIGNUP_ENABLED_AUTO_CONFIRM_ON, auto_refresh_token=False, @@ -71,7 +155,7 @@ def client_api_auto_confirm_enabled_client(): ) -def client_api_auto_confirm_off_signups_enabled_client(): +def client_api_auto_confirm_off_signups_enabled_client() -> AsyncGoTrueClient: return AsyncGoTrueClient( url=GOTRUE_URL_SIGNUP_ENABLED_AUTO_CONFIRM_OFF, auto_refresh_token=False, @@ -79,7 +163,7 @@ def client_api_auto_confirm_off_signups_enabled_client(): ) -def client_api_auto_confirm_disabled_client(): +def client_api_auto_confirm_disabled_client() -> AsyncGoTrueClient: return AsyncGoTrueClient( url=GOTRUE_URL_SIGNUP_DISABLED_AUTO_CONFIRM_OFF, auto_refresh_token=False, @@ -87,7 +171,7 @@ def client_api_auto_confirm_disabled_client(): ) -def auth_admin_api_auto_confirm_enabled_client(): +def auth_admin_api_auto_confirm_enabled_client() -> AsyncGoTrueAdminAPI: return AsyncGoTrueAdminAPI( url=GOTRUE_URL_SIGNUP_ENABLED_AUTO_CONFIRM_ON, headers={ @@ -96,7 +180,7 @@ def auth_admin_api_auto_confirm_enabled_client(): ) -def auth_admin_api_auto_confirm_disabled_client(): +def auth_admin_api_auto_confirm_disabled_client() -> AsyncGoTrueAdminAPI: return AsyncGoTrueAdminAPI( url=GOTRUE_URL_SIGNUP_ENABLED_AUTO_CONFIRM_OFF, headers={ @@ -113,7 +197,7 @@ def auth_admin_api_auto_confirm_disabled_client(): ) -def service_role_api_client(): +def service_role_api_client() -> AsyncGoTrueAdminAPI: return AsyncGoTrueAdminAPI( url=GOTRUE_URL_SIGNUP_ENABLED_AUTO_CONFIRM_ON, headers={ @@ -122,7 +206,7 @@ def service_role_api_client(): ) -def service_role_api_client_with_sms(): +def service_role_api_client_with_sms() -> AsyncGoTrueAdminAPI: return AsyncGoTrueAdminAPI( url=GOTRUE_URL_SIGNUP_ENABLED_AUTO_CONFIRM_OFF, headers={ @@ -131,7 +215,7 @@ def service_role_api_client_with_sms(): ) -def service_role_api_client_no_sms(): +def service_role_api_client_no_sms() -> AsyncGoTrueAdminAPI: return AsyncGoTrueAdminAPI( url=GOTRUE_URL_SIGNUP_DISABLED_AUTO_CONFIRM_OFF, headers={ diff --git a/src/auth/tests/_async/test_gotrue.py b/src/auth/tests/_async/test_gotrue.py index 0a4b512b..d7f3789e 100644 --- a/src/auth/tests/_async/test_gotrue.py +++ b/src/auth/tests/_async/test_gotrue.py @@ -11,14 +11,15 @@ AuthSessionMissingError, ) from supabase_auth.helpers import decode_jwt +from supabase_auth.types import SignUpWithEmailAndPasswordCredentials from .clients import ( GOTRUE_JWT_SECRET, auth_client, auth_client_with_asymmetric_session, auth_client_with_session, + mock_user_credentials, ) -from .utils import mock_user_credentials async def test_get_claims_returns_none_when_session_is_none(): @@ -29,28 +30,42 @@ async def test_get_claims_returns_none_when_session_is_none(): async def test_get_claims_calls_get_user_if_symmetric_jwt(mocker): client = auth_client() spy = mocker.spy(client, "get_user") + credentials = mock_user_credentials() + options: SignUpWithEmailAndPasswordCredentials = { + "email": credentials.email, + "password": credentials.password, + } + user = (await client.sign_up(options)).user - user = (await client.sign_up(mock_user_credentials())).user assert user is not None - claims = (await client.get_claims())["claims"] - assert claims["email"] == user.email + response = await client.get_claims() + assert response + claims = response["claims"] + + assert claims.get("email") == user.email spy.assert_called_once() async def test_get_claims_fetches_jwks_to_verify_asymmetric_jwt(mocker): client = auth_client_with_asymmetric_session() - - user = (await client.sign_up(mock_user_credentials())).user + credentials = mock_user_credentials() + options: SignUpWithEmailAndPasswordCredentials = { + "email": credentials.email, + "password": credentials.password, + } + user = (await client.sign_up(options)).user assert user is not None spy = mocker.spy(client, "_request") - claims = (await client.get_claims())["claims"] - assert claims["email"] == user.email + response = await client.get_claims() + assert response + claims = response["claims"] + assert claims.get("email") == user.email spy.assert_called_once() - spy.assert_called_with("GET", ".well-known/jwks.json", xform=unittest.mock.ANY) + spy.assert_called_with("GET", ".well-known/jwks.json") expected_keyid = "638c54b8-28c2-4b12-9598-ba12ef610a29" @@ -64,11 +79,16 @@ async def test_jwks_ttl_cache_behavior(mocker): spy = mocker.spy(client, "_request") # First call should fetch JWKS from endpoint - user = (await client.sign_up(mock_user_credentials())).user + credentials = mock_user_credentials() + options: SignUpWithEmailAndPasswordCredentials = { + "email": credentials.email, + "password": credentials.password, + } + user = (await client.sign_up(options)).user assert user is not None await client.get_claims() - spy.assert_called_with("GET", ".well-known/jwks.json", xform=unittest.mock.ANY) + spy.assert_called_with("GET", ".well-known/jwks.json") first_call_count = spy.call_count # Second call within TTL should use cache @@ -96,8 +116,8 @@ async def test_set_session_with_valid_tokens(): # First sign up to get valid tokens signup_response = await client.sign_up( { - "email": credentials.get("email"), - "password": credentials.get("password"), + "email": credentials.email, + "password": credentials.password, } ) assert signup_response.session is not None @@ -117,7 +137,7 @@ async def test_set_session_with_valid_tokens(): assert response.session.access_token == access_token assert response.session.refresh_token == refresh_token assert response.user is not None - assert response.user.email == credentials.get("email") + assert response.user.email == credentials.email async def test_set_session_with_expired_token(): @@ -127,8 +147,8 @@ async def test_set_session_with_expired_token(): # First sign up to get valid tokens signup_response = await client.sign_up( { - "email": credentials.get("email"), - "password": credentials.get("password"), + "email": credentials.email, + "password": credentials.password, } ) assert signup_response.session is not None @@ -144,9 +164,9 @@ async def test_set_session_with_expired_token(): expired_token = access_token.split(".") payload = decode_jwt(access_token)["payload"] payload["exp"] = int(time.time()) - 3600 # Set expiry to 1 hour ago - expired_token[1] = encode(payload, GOTRUE_JWT_SECRET, algorithm="HS256").split(".")[ - 1 - ] + expired_token[1] = encode( + dict(payload), GOTRUE_JWT_SECRET, algorithm="HS256" + ).split(".")[1] expired_access_token = ".".join(expired_token) # Set the session with the expired token @@ -157,7 +177,7 @@ async def test_set_session_with_expired_token(): assert response.session.access_token != expired_access_token assert response.session.refresh_token != refresh_token assert response.user is not None - assert response.user.email == credentials.get("email") + assert response.user.email == credentials.email async def test_set_session_without_refresh_token(): @@ -167,8 +187,8 @@ async def test_set_session_without_refresh_token(): # First sign up to get valid tokens signup_response = await client.sign_up( { - "email": credentials.get("email"), - "password": credentials.get("password"), + "email": credentials.email, + "password": credentials.password, } ) assert signup_response.session is not None @@ -183,9 +203,9 @@ async def test_set_session_without_refresh_token(): expired_token = access_token.split(".") payload = decode_jwt(access_token)["payload"] payload["exp"] = int(time.time()) - 3600 # Set expiry to 1 hour ago - expired_token[1] = encode(payload, GOTRUE_JWT_SECRET, algorithm="HS256").split(".")[ - 1 - ] + expired_token[1] = encode( + dict(payload), GOTRUE_JWT_SECRET, algorithm="HS256" + ).split(".")[1] expired_access_token = ".".join(expired_token) # Try to set the session with an expired token but no refresh token @@ -202,15 +222,15 @@ async def test_set_session_with_invalid_token(): async def test_mfa_enroll(): - client = auth_client_with_session() + client = await auth_client_with_session() credentials = mock_user_credentials() # First sign up to get a valid session await client.sign_up( { - "email": credentials.get("email"), - "password": credentials.get("password"), + "email": credentials.email, + "password": credentials.password, } ) @@ -222,6 +242,7 @@ async def test_mfa_enroll(): assert enroll_response.id is not None assert enroll_response.type == "totp" assert enroll_response.friendly_name == "test-factor" + assert enroll_response.totp assert enroll_response.totp.qr_code is not None @@ -232,8 +253,8 @@ async def test_mfa_challenge(): # First sign up to get a valid session signup_response = await client.sign_up( { - "email": credentials.get("email"), - "password": credentials.get("password"), + "email": credentials.email, + "password": credentials.password, } ) assert signup_response.session is not None @@ -256,8 +277,8 @@ async def test_mfa_unenroll(): # First sign up to get a valid session signup_response = await client.sign_up( { - "email": credentials.get("email"), - "password": credentials.get("password"), + "email": credentials.email, + "password": credentials.password, } ) assert signup_response.session is not None @@ -279,8 +300,8 @@ async def test_mfa_list_factors(): # First sign up to get a valid session signup_response = await client.sign_up( { - "email": credentials.get("email"), - "password": credentials.get("password"), + "email": credentials.email, + "password": credentials.password, } ) assert signup_response.session is not None @@ -295,101 +316,6 @@ async def test_mfa_list_factors(): assert len(list_response.all) == 1 -async def test_initialize_from_url(): - # This test verifies the URL format detection and initialization from URL - client = auth_client() - - # First we'll test the _is_implicit_grant_flow method - # The method checks for access_token or error_description in the query string, not the fragment - url_with_token = "http://example.com/?access_token=test_token&other=value" - assert client._is_implicit_grant_flow(url_with_token) == True - - url_with_error = "http://example.com/?error_description=test_error&other=value" - assert client._is_implicit_grant_flow(url_with_error) == True - - url_without_token = "http://example.com/?other=value" - assert client._is_implicit_grant_flow(url_without_token) == False - - # Now test actual URL initialization with a valid URL containing auth tokens - from unittest.mock import patch - - from supabase_auth.types import Session, User, UserResponse - - # Create a mock user and session to avoid actual API calls - mock_user = User( - id="user123", - email="test@example.com", - app_metadata={}, - user_metadata={}, - aud="authenticated", - created_at="2023-01-01T00:00:00Z", - confirmed_at="2023-01-01T00:00:00Z", - last_sign_in_at="2023-01-01T00:00:00Z", - role="authenticated", - updated_at="2023-01-01T00:00:00Z", - ) - - # Wrap the user in a UserResponse as that's what get_user returns - mock_user_response = UserResponse(user=mock_user) - - # Test successful initialization with tokens in URL - good_url = "http://example.com/?access_token=mock_access_token&refresh_token=mock_refresh_token&expires_in=3600&token_type=bearer" - - # We need to mock: - # 1. get_user which is called by _get_session_from_url to validate the token - # 2. _save_session which is called to store the session data - # 3. _notify_all_subscribers which is called to notify about sign-in - with patch.object(client, "get_user") as mock_get_user: - mock_get_user.return_value = mock_user_response - - with patch.object(client, "_save_session") as mock_save_session: - with patch.object(client, "_notify_all_subscribers") as mock_notify: - # Call initialize_from_url with the good URL - result = await client.initialize_from_url(good_url) - - # Verify get_user was called with the access token - mock_get_user.assert_called_once_with("mock_access_token") - - # Verify _save_session was called with a Session object - mock_save_session.assert_called_once() - session_arg = mock_save_session.call_args[0][0] - assert isinstance(session_arg, Session) - assert session_arg.access_token == "mock_access_token" - assert session_arg.refresh_token == "mock_refresh_token" - assert session_arg.expires_in == 3600 - - # Verify _notify_all_subscribers was called - mock_notify.assert_called_with("SIGNED_IN", session_arg) - - assert result is None # initialize_from_url doesn't have a return value - - # Test URL with error - need to include error_code for the test to work correctly - error_url = "http://example.com/?error=invalid_request&error_description=Invalid+request&error_code=400" - - # Should throw an error when URL contains error parameters - from supabase_auth.errors import AuthImplicitGrantRedirectError - - try: - await client.initialize_from_url(error_url) - assert False, "Expected AuthImplicitGrantRedirectError" - except AuthImplicitGrantRedirectError as e: - # The error message includes the error_description value - assert "Invalid request" in str(e) - - # Test URL with code for PKCE flow - code_url = "http://example.com/?code=authorization_code" - - # For the code URL path, we're not testing it here since it requires more mocking - # and is indirectly tested via other tests like exchange_code_for_session - - # Test URL with neither tokens nor code - should not throw but also not call anything - invalid_url = "http://example.com/?foo=bar" - with patch.object(client, "_get_session_from_url") as mock_get_session: - result = await client.initialize_from_url(invalid_url) - mock_get_session.assert_not_called() - assert result is None - - async def test_exchange_code_for_session(): client = auth_client() @@ -405,9 +331,8 @@ async def test_exchange_code_for_session(): client._flow_type = "pkce" # Test the PKCE URL generation which is needed for exchange_code_for_session - provider = "github" url, params = await client._get_url_for_provider( - f"{client._url}/authorize", provider, {} + f"{client._url}/authorize", "github", {} ) # Verify PKCE parameters were added @@ -432,8 +357,8 @@ async def test_get_authenticator_assurance_level(): # Sign up to get a valid session signup_response = await client.sign_up( { - "email": credentials.get("email"), - "password": credentials.get("password"), + "email": credentials.email, + "password": credentials.password, } ) assert signup_response.session is not None @@ -451,14 +376,16 @@ async def test_link_identity(): # Sign up to get a valid session signup_response = await client.sign_up( { - "email": credentials.get("email"), - "password": credentials.get("password"), + "email": credentials.email, + "password": credentials.password, } ) assert signup_response.session is not None from unittest.mock import patch + from httpx import Response + from supabase_auth.types import OAuthResponse # Since the test server has manual linking disabled, we'll mock the URL generation @@ -469,7 +396,9 @@ async def test_link_identity(): # Also mock the _request method since the server would reject it with patch.object(client, "_request") as mock_request: - mock_request.return_value = OAuthResponse(provider="github", url=mock_url) + mock_request.return_value = Response( + content=f'{{"url":"{mock_url}"}}', status_code=200 + ) # Call the method response = await client.link_identity({"provider": "github"}) @@ -486,8 +415,8 @@ async def test_get_user_identities(): # Sign up to get a valid session signup_response = await client.sign_up( { - "email": credentials.get("email"), - "password": credentials.get("password"), + "email": credentials.email, + "password": credentials.password, } ) assert signup_response.session is not None @@ -499,129 +428,6 @@ async def test_get_user_identities(): assert hasattr(identities_response, "identities") -async def test_unlink_identity(): - client = auth_client() - credentials = mock_user_credentials() - - # Sign up to get a valid session - signup_response = await client.sign_up( - { - "email": credentials.get("email"), - "password": credentials.get("password"), - } - ) - assert signup_response.session is not None - - # Mock a UserIdentity to test unlink_identity - from unittest.mock import patch - - from supabase_auth.types import UserIdentity - - # Create a mock identity - mock_identity = UserIdentity( - id="user-id", - identity_id="identity-id-1", - user_id="user-id", - identity_data={"email": "user@example.com"}, - provider="github", - created_at="2023-01-01T00:00:00Z", - last_sign_in_at="2023-01-01T00:00:00Z", - updated_at="2023-01-01T00:00:00Z", - ) - - # Mock the _request method since we can't actually unlink an identity that doesn't exist - with patch.object(client, "_request") as mock_request: - mock_request.return_value = None - - # Call the method - await client.unlink_identity(mock_identity) - - # Verify the request was made properly - mock_request.assert_called_once_with( - "DELETE", - "user/identities/identity-id-1", - jwt=signup_response.session.access_token, - ) - - # Test error case: no session - with patch.object(client, "get_session") as mock_get_session: - from supabase_auth.errors import AuthSessionMissingError - - mock_get_session.return_value = None - - try: - await client.unlink_identity(mock_identity) - assert False, "Expected AuthSessionMissingError" - except AuthSessionMissingError: - pass - - -async def test_verify_otp(): - client = auth_client() - - # Mock the _request method since we can't actually verify an OTP in the test - import time - from unittest.mock import patch - - from supabase_auth.types import AuthResponse, Session, User - - mock_user = User( - id="test-user-id", - app_metadata={}, - user_metadata={}, - aud="test-aud", - email="test@example.com", - phone="", - created_at="2023-01-01T00:00:00Z", - confirmed_at="2023-01-01T00:00:00Z", - last_sign_in_at="2023-01-01T00:00:00Z", - role="", - updated_at="2023-01-01T00:00:00Z", - ) - - mock_session = Session( - access_token="mock-access-token", - refresh_token="mock-refresh-token", - expires_in=3600, - expires_at=round(time.time()) + 3600, - token_type="bearer", - user=mock_user, - ) - - mock_response = AuthResponse(session=mock_session, user=mock_user) - - with patch.object(client, "_request") as mock_request: - # Configure the mock to return a predefined response - mock_request.return_value = mock_response - - # Also patch _save_session to avoid actual storage interactions - with patch.object(client, "_save_session") as mock_save: - # Call verify_otp with test parameters - params = { - "type": "sms", - "phone": "+11234567890", - "token": "123456", - "options": {"redirect_to": "https://example.com/callback"}, - } - - response = await client.verify_otp(params) - - # Verify the request was made with correct parameters - mock_request.assert_called_once() - args, kwargs = mock_request.call_args - assert args[0] == "POST" # method - assert args[1] == "verify" # path - assert kwargs["body"]["phone"] == "+11234567890" - assert kwargs["body"]["token"] == "123456" - assert kwargs["redirect_to"] == "https://example.com/callback" - - # Verify the session was saved - mock_save.assert_called_once_with(mock_session) - - # Verify the response - assert response == mock_response - - async def test_sign_in_with_password(): client = auth_client() credentials = mock_user_credentials() @@ -630,8 +436,8 @@ async def test_sign_in_with_password(): # First create a user we can sign in with signup_response = await client.sign_up( { - "email": credentials.get("email"), - "password": credentials.get("password"), + "email": credentials.email, + "password": credentials.password, } ) assert signup_response.session is not None @@ -639,15 +445,15 @@ async def test_sign_in_with_password(): # Test signing in with the same credentials (email) signin_response = await client.sign_in_with_password( { - "email": credentials.get("email"), - "password": credentials.get("password"), + "email": credentials.email, + "password": credentials.password, } ) # Verify the response has a valid session and user assert signin_response.session is not None assert signin_response.user is not None - assert signin_response.user.email == credentials.get("email") + assert signin_response.user.email == credentials.email # Test error case: wrong password @@ -657,7 +463,7 @@ async def test_sign_in_with_password(): try: await test_client.sign_in_with_password( { - "email": credentials.get("email"), + "email": credentials.email, "password": "wrong_password", } ) @@ -667,7 +473,7 @@ async def test_sign_in_with_password(): # Test error case: missing credentials try: - await test_client.sign_in_with_password({}) + await test_client.sign_in_with_password({}) # type: ignore assert False, "Expected AuthInvalidCredentialsError for missing credentials" except AuthInvalidCredentialsError: pass @@ -683,13 +489,16 @@ async def test_sign_in_with_otp(): # We can't fully test the actual OTP flow since that requires email verification from unittest.mock import patch + from httpx import Response + from supabase_auth.types import AuthOtpResponse # First test for email OTP + auth_otp = AuthOtpResponse( + message_id="mock-message-id", + ) with patch.object(client, "_request") as mock_request: - mock_response = AuthOtpResponse( - message_id="mock-message-id", email=email, phone=None, hash=None - ) + mock_response = Response(content=auth_otp.model_dump_json(), status_code=200) mock_request.return_value = mock_response response = await client.sign_in_with_otp( @@ -719,15 +528,13 @@ async def test_sign_in_with_otp(): assert kwargs["redirect_to"] == "https://example.com/callback" # Verify response - assert response == mock_response + assert response == auth_otp # Test with phone OTP phone = "+11234567890" - + auth_otp = AuthOtpResponse(message_id="mock-message-id") with patch.object(client, "_request") as mock_request: - mock_response = AuthOtpResponse( - message_id="mock-message-id", email=None, phone=phone, hash=None - ) + mock_response = Response(content=auth_otp.model_dump_json(), status_code=200) mock_request.return_value = mock_response response = await client.sign_in_with_otp( @@ -758,19 +565,20 @@ async def test_sign_in_with_otp(): assert kwargs.get("redirect_to") is None # No redirect for phone # Verify response - assert response == mock_response + assert response == auth_otp # Test with invalid parameters (missing both email and phone) from supabase_auth.errors import AuthInvalidCredentialsError try: - await client.sign_in_with_otp({}) + await client.sign_in_with_otp({}) # type: ignore assert False, "Expected AuthInvalidCredentialsError" except AuthInvalidCredentialsError: pass async def test_sign_out(): + from datetime import datetime from unittest.mock import patch from supabase_auth.types import Session, User @@ -778,17 +586,18 @@ async def test_sign_out(): client = auth_client() # Create a mock user and session + date = datetime(year=2023, month=1, day=1, hour=0, minute=0, second=0) mock_user = User( id="user123", email="test@example.com", app_metadata={}, user_metadata={}, aud="authenticated", - created_at="2023-01-01T00:00:00Z", - confirmed_at="2023-01-01T00:00:00Z", - last_sign_in_at="2023-01-01T00:00:00Z", + created_at=date, + confirmed_at=date, + last_sign_in_at=date, role="authenticated", - updated_at="2023-01-01T00:00:00Z", + updated_at=date, ) mock_session = Session( @@ -892,7 +701,7 @@ async def test_sign_out(): with patch.object(client.admin, "sign_out") as mock_admin_sign_out: mock_admin_sign_out.side_effect = AuthApiError( - "Test error", 401, "auth_error" + "Test error", 401, "validation_failed" ) with patch.object(client, "_remove_session") as mock_remove_session: diff --git a/src/auth/tests/_async/test_gotrue_admin_api.py b/src/auth/tests/_async/test_gotrue_admin_api.py index 69253c26..94c1c573 100644 --- a/src/auth/tests/_async/test_gotrue_admin_api.py +++ b/src/auth/tests/_async/test_gotrue_admin_api.py @@ -15,21 +15,19 @@ auth_client_with_session, client_api_auto_confirm_disabled_client, client_api_auto_confirm_off_signups_enabled_client, - service_role_api_client, -) -from .utils import ( create_new_user_with_email, mock_app_metadata, mock_user_credentials, mock_user_metadata, mock_verification_otp, + service_role_api_client, ) async def test_create_user_should_create_a_new_user(): credentials = mock_user_credentials() - response = await create_new_user_with_email(email=credentials.get("email")) - assert response.email == credentials.get("email") + response = await create_new_user_with_email(email=credentials.email) + assert response.email == credentials.email async def test_create_user_with_user_metadata(): @@ -37,12 +35,12 @@ async def test_create_user_with_user_metadata(): credentials = mock_user_credentials() response = await service_role_api_client().create_user( { - "email": credentials.get("email"), - "password": credentials.get("password"), + "email": credentials.email, + "password": credentials.password, "user_metadata": user_metadata, } ) - assert response.user.email == credentials.get("email") + assert response.user.email == credentials.email assert response.user.user_metadata == user_metadata assert "profile_image" in response.user.user_metadata @@ -53,13 +51,13 @@ async def test_create_user_with_user_and_app_metadata(): credentials = mock_user_credentials() response = await service_role_api_client().create_user( { - "email": credentials.get("email"), - "password": credentials.get("password"), + "email": credentials.email, + "password": credentials.password, "user_metadata": user_metadata, "app_metadata": app_metadata, } ) - assert response.user.email == credentials.get("email") + assert response.user.email == credentials.email assert "profile_image" in response.user.user_metadata assert "provider" in response.user.app_metadata assert "providers" in response.user.app_metadata @@ -67,39 +65,25 @@ async def test_create_user_with_user_and_app_metadata(): async def test_list_users_should_return_registered_users(): credentials = mock_user_credentials() - await create_new_user_with_email(email=credentials.get("email")) + await create_new_user_with_email(email=credentials.email) users = await service_role_api_client().list_users() assert users emails = [user.email for user in users] assert emails - assert credentials.get("email") in emails - - -async def test_get_user_fetches_a_user_by_their_access_token(): - credentials = mock_user_credentials() - auth_client_with_session_current_user = auth_client_with_session() - response = await auth_client_with_session_current_user.sign_up( - { - "email": credentials.get("email"), - "password": credentials.get("password"), - } - ) - assert response.session - response = await auth_client_with_session_current_user.get_user() - assert response.user.email == credentials.get("email") + assert credentials.email in emails async def test_get_user_by_id_should_a_registered_user_given_its_user_identifier(): credentials = mock_user_credentials() - user = await create_new_user_with_email(email=credentials.get("email")) + user = await create_new_user_with_email(email=credentials.email) assert user.id response = await service_role_api_client().get_user_by_id(user.id) - assert response.user.email == credentials.get("email") + assert response.user.email == credentials.email async def test_modify_email_using_update_user_by_id(): credentials = mock_user_credentials() - user = await create_new_user_with_email(email=credentials.get("email")) + user = await create_new_user_with_email(email=credentials.email) response = await service_role_api_client().update_user_by_id( user.id, { @@ -111,7 +95,7 @@ async def test_modify_email_using_update_user_by_id(): async def test_modify_user_metadata_using_update_user_by_id(): credentials = mock_user_credentials() - user = await create_new_user_with_email(email=credentials.get("email")) + user = await create_new_user_with_email(email=credentials.email) user_metadata = {"favorite_color": "yellow"} response = await service_role_api_client().update_user_by_id( user.id, @@ -125,7 +109,7 @@ async def test_modify_user_metadata_using_update_user_by_id(): async def test_modify_app_metadata_using_update_user_by_id(): credentials = mock_user_credentials() - user = await create_new_user_with_email(email=credentials.get("email")) + user = await create_new_user_with_email(email=credentials.email) app_metadata = {"roles": ["admin", "publisher"]} response = await service_role_api_client().update_user_by_id( user.id, @@ -141,8 +125,8 @@ async def test_modify_confirm_email_using_update_user_by_id(): credentials = mock_user_credentials() response = await client_api_auto_confirm_off_signups_enabled_client().sign_up( { - "email": credentials.get("email"), - "password": credentials.get("password"), + "email": credentials.email, + "password": credentials.password, } ) assert response.user @@ -203,62 +187,46 @@ async def test_sign_in_with_otp_phone(): async def test_resend(): - try: - await client_api_auto_confirm_off_signups_enabled_client().resend( - {"phone": "+112345678", "type": "sms"} - ) - except AuthApiError as e: - assert e.to_dict() + await client_api_auto_confirm_off_signups_enabled_client().resend( + {"phone": "+112345678", "type": "sms"} + ) async def test_reauthenticate(): - try: - response = await auth_client_with_session().reauthenticate() - except AuthSessionMissingError: - pass + client = await auth_client_with_session() + await client.reauthenticate() async def test_refresh_session(): - try: - response = await auth_client_with_session().refresh_session() - except AuthSessionMissingError: - pass + client = await auth_client_with_session() + await client.refresh_session() async def test_reset_password_for_email(): credentials = mock_user_credentials() - try: - response = await auth_client_with_session().reset_password_email( - email=credentials.get("email") - ) - except AuthSessionMissingError: - pass + client = await auth_client_with_session() + await client.reset_password_email(email=credentials.email) async def test_resend_missing_credentials(): - try: - await client_api_auto_confirm_off_signups_enabled_client().resend( - {"type": "email_change"} - ) - except AuthInvalidCredentialsError as e: - assert e.to_dict() + credentials = mock_user_credentials() + await client_api_auto_confirm_off_signups_enabled_client().resend( + {"type": "email_change", "email": credentials.email} + ) async def test_sign_in_anonymously(): - try: - response = await auth_client_with_session().sign_in_anonymously() - assert response - except AuthApiError: - pass + client = await auth_client_with_session() + await client.sign_in_anonymously() async def test_delete_user_should_be_able_delete_an_existing_user(): credentials = mock_user_credentials() - user = await create_new_user_with_email(email=credentials.get("email")) + user = await create_new_user_with_email(email=credentials.email) await service_role_api_client().delete_user(user.id) users = await service_role_api_client().list_users() emails = [user.email for user in users] - assert credentials.get("email") not in emails + assert credentials.email not in emails async def test_generate_link_supports_sign_up_with_generate_confirmation_signup_link(): @@ -268,8 +236,8 @@ async def test_generate_link_supports_sign_up_with_generate_confirmation_signup_ response = await service_role_api_client().generate_link( { "type": "signup", - "email": credentials.get("email"), - "password": credentials.get("password"), + "email": credentials.email, + "password": credentials.password, "options": { "data": user_metadata, "redirect_to": redirect_to, @@ -281,22 +249,22 @@ async def test_generate_link_supports_sign_up_with_generate_confirmation_signup_ async def test_generate_link_supports_updating_emails_with_generate_email_change_links(): # noqa: E501 credentials = mock_user_credentials() - user = await create_new_user_with_email(email=credentials.get("email")) + user = await create_new_user_with_email(email=credentials.email) assert user.email - assert user.email == credentials.get("email") + assert user.email == credentials.email credentials = mock_user_credentials() redirect_to = "http://localhost:9999/welcome" response = await service_role_api_client().generate_link( { "type": "email_change_current", "email": user.email, - "new_email": credentials.get("email"), + "new_email": credentials.email, "options": { "redirect_to": redirect_to, }, }, ) - assert response.user.new_email == credentials.get("email") + assert response.user.new_email == credentials.email async def test_invite_user_by_email_creates_a_new_user_with_an_invited_at_timestamp(): @@ -304,7 +272,7 @@ async def test_invite_user_by_email_creates_a_new_user_with_an_invited_at_timest redirect_to = "http://localhost:9999/welcome" user_metadata = {"status": "alpha"} response = await service_role_api_client().invite_user_by_email( - credentials.get("email"), + credentials.email, { "data": user_metadata, "redirect_to": redirect_to, @@ -315,14 +283,15 @@ async def test_invite_user_by_email_creates_a_new_user_with_an_invited_at_timest async def test_sign_out_with_an_valid_access_token(): credentials = mock_user_credentials() - response = await auth_client_with_session().sign_up( + client = await auth_client_with_session() + response = await client.sign_up( { - "email": credentials.get("email"), - "password": credentials.get("password"), + "email": credentials.email, + "password": credentials.password, }, ) assert response.session - response = await service_role_api_client().sign_out(response.session.access_token) + await service_role_api_client().sign_out(response.session.access_token) async def test_sign_out_with_an_invalid_access_token(): @@ -339,7 +308,7 @@ async def test_verify_otp_with_non_existent_phone_number(): try: await client_api_auto_confirm_disabled_client().verify_otp( { - "phone": credentials.get("phone"), + "phone": credentials.phone, "token": otp, "type": "sms", }, @@ -355,7 +324,7 @@ async def test_verify_otp_with_invalid_phone_number(): try: await client_api_auto_confirm_disabled_client().verify_otp( { - "phone": f"{credentials.get('phone')}-invalid", + "phone": f"{credentials.phone}-invalid", "token": otp, "type": "sms", }, @@ -414,15 +383,15 @@ async def test_get_item_from_memory_storage(): client = auth_client() await client.sign_up( { - "email": credentials.get("email"), - "password": credentials.get("password"), + "email": credentials.email, + "password": credentials.password, } ) await client.sign_in_with_password( { - "email": credentials.get("email"), - "password": credentials.get("password"), + "email": credentials.email, + "password": credentials.password, } ) assert await client._storage.get_item(client._storage_key) is not None @@ -433,19 +402,18 @@ async def test_remove_item_from_memory_storage(): client = auth_client() await client.sign_up( { - "email": credentials.get("email"), - "password": credentials.get("password"), + "email": credentials.email, + "password": credentials.password, } ) await client.sign_in_with_password( { - "email": credentials.get("email"), - "password": credentials.get("password"), + "email": credentials.email, + "password": credentials.password, } ) await client._storage.remove_item(client._storage_key) - assert client._storage_key not in client._storage.storage async def test_list_factors(): @@ -453,15 +421,15 @@ async def test_list_factors(): client = auth_client() await client.sign_up( { - "email": credentials.get("email"), - "password": credentials.get("password"), + "email": credentials.email, + "password": credentials.password, } ) await client.sign_in_with_password( { - "email": credentials.get("email"), - "password": credentials.get("password"), + "email": credentials.email, + "password": credentials.password, } ) factors = await client._list_factors() @@ -475,20 +443,18 @@ async def test_start_auto_refresh_token(): client._auto_refresh_token = True await client.sign_up( { - "email": credentials.get("email"), - "password": credentials.get("password"), + "email": credentials.email, + "password": credentials.password, } ) await client.sign_in_with_password( { - "email": credentials.get("email"), - "password": credentials.get("password"), + "email": credentials.email, + "password": credentials.password, } ) - assert await client._start_auto_refresh_token(2.0) is None - async def test_recover_and_refresh(): credentials = mock_user_credentials() @@ -496,19 +462,18 @@ async def test_recover_and_refresh(): client._auto_refresh_token = True await client.sign_up( { - "email": credentials.get("email"), - "password": credentials.get("password"), + "email": credentials.email, + "password": credentials.password, } ) await client.sign_in_with_password( { - "email": credentials.get("email"), - "password": credentials.get("password"), + "email": credentials.email, + "password": credentials.password, } ) await client._recover_and_refresh() - assert client._storage_key in client._storage.storage async def test_get_user_identities(): @@ -517,20 +482,20 @@ async def test_get_user_identities(): client._auto_refresh_token = True await client.sign_up( { - "email": credentials.get("email"), - "password": credentials.get("password"), + "email": credentials.email, + "password": credentials.password, } ) await client.sign_in_with_password( { - "email": credentials.get("email"), - "password": credentials.get("password"), + "email": credentials.email, + "password": credentials.password, } ) assert (await client.get_user_identities()).identities[0].identity_data[ "email" - ] == credentials.get("email") + ] == credentials.email async def test_update_user(): @@ -539,14 +504,14 @@ async def test_update_user(): client._auto_refresh_token = True await client.sign_up( { - "email": credentials.get("email"), - "password": credentials.get("password"), + "email": credentials.email, + "password": credentials.password, } ) await client.update_user({"password": "123e5a"}) await client.sign_in_with_password( { - "email": credentials.get("email"), + "email": credentials.email, "password": "123e5a", } ) @@ -557,12 +522,12 @@ async def test_create_user_with_app_metadata(): credentials = mock_user_credentials() response = await service_role_api_client().create_user( { - "email": credentials.get("email"), - "password": credentials.get("password"), + "email": credentials.email, + "password": credentials.password, "app_metadata": app_metadata, } ) - assert response.user.email == credentials.get("email") + assert response.user.email == credentials.email assert "provider" in response.user.app_metadata assert "providers" in response.user.app_metadata @@ -572,7 +537,7 @@ async def test_weak_email_password_error(): try: await client_api_auto_confirm_off_signups_enabled_client().sign_up( { - "email": credentials.get("email"), + "email": credentials.email, "password": "123", } ) @@ -585,7 +550,7 @@ async def test_weak_phone_password_error(): try: await client_api_auto_confirm_off_signups_enabled_client().sign_up( { - "phone": credentials.get("phone"), + "phone": credentials.phone, "password": "123", } ) diff --git a/src/auth/tests/_async/test_utils.py b/src/auth/tests/_async/test_utils.py index f5a144c4..865f1e95 100644 --- a/src/auth/tests/_async/test_utils.py +++ b/src/auth/tests/_async/test_utils.py @@ -1,6 +1,6 @@ from time import time -from .utils import ( +from .clients import ( create_new_user_with_email, mock_app_metadata, mock_user_credentials, @@ -8,18 +8,6 @@ ) -def test_mock_user_credentials_has_email(): - credentials = mock_user_credentials() - assert credentials.get("email") - assert credentials.get("password") - - -def test_mock_user_credentials_has_phone(): - credentials = mock_user_credentials() - assert credentials.get("phone") - assert credentials.get("password") - - async def test_create_new_user_with_email(): email = f"user+{int(time())}@example.com" user = await create_new_user_with_email(email=email) diff --git a/src/auth/tests/_async/utils.py b/src/auth/tests/_async/utils.py index 1f648197..e69de29b 100644 --- a/src/auth/tests/_async/utils.py +++ b/src/auth/tests/_async/utils.py @@ -1,82 +0,0 @@ -from random import random -from time import time -from typing import Optional - -from faker import Faker -from jwt import encode -from typing_extensions import NotRequired, TypedDict - -from supabase_auth.types import User - -from .clients import GOTRUE_JWT_SECRET, service_role_api_client - - -def mock_access_token() -> str: - return encode( - { - "sub": "1234567890", - "role": "anon_key", - }, - GOTRUE_JWT_SECRET, - ) - - -class OptionalCredentials(TypedDict): - email: NotRequired[Optional[str]] - phone: NotRequired[Optional[str]] - password: NotRequired[Optional[str]] - - -class Credentials(TypedDict): - email: str - phone: str - password: str - - -def mock_user_credentials( - options: OptionalCredentials = {}, -) -> Credentials: - fake = Faker() - rand_numbers = str(int(time())) - return { - "email": options.get("email") or fake.email(), - "phone": options.get("phone") or f"1{rand_numbers[-11:]}", - "password": options.get("password") or fake.password(), - } - - -def mock_verification_otp() -> str: - return str(int(100000 + random() * 900000)) - - -def mock_user_metadata(): - fake = Faker() - return { - "profile_image": fake.url(), - } - - -def mock_app_metadata(): - return { - "roles": ["editor", "publisher"], - } - - -async def create_new_user_with_email( - *, - email: Optional[str] = None, - password: Optional[str] = None, -) -> User: - credentials = mock_user_credentials( - { - "email": email, - "password": password, - } - ) - response = await service_role_api_client().create_user( - { - "email": credentials["email"], - "password": credentials["password"], - } - ) - return response.user diff --git a/src/auth/tests/_sync/clients.py b/src/auth/tests/_sync/clients.py index 3fee59d1..38a0938d 100644 --- a/src/auth/tests/_sync/clients.py +++ b/src/auth/tests/_sync/clients.py @@ -1,6 +1,87 @@ +from dataclasses import dataclass +from random import random +from time import time +from typing import Optional + +from faker import Faker from jwt import encode +from typing_extensions import NotRequired, TypedDict from supabase_auth import SyncGoTrueAdminAPI, SyncGoTrueClient +from supabase_auth.types import User + + +def mock_access_token() -> str: + return encode( + { + "sub": "1234567890", + "role": "anon_key", + }, + GOTRUE_JWT_SECRET, + ) + + +class OptionalCredentials(TypedDict): + email: NotRequired[Optional[str]] + phone: NotRequired[Optional[str]] + password: NotRequired[Optional[str]] + + +@dataclass +class Credentials: + email: str + phone: str + password: str + + +def mock_user_credentials( + options: OptionalCredentials = {}, +) -> Credentials: + fake = Faker() + rand_numbers = str(int(time())) + return Credentials( + email=options.get("email") or fake.email(), + phone=options.get("phone") or f"1{rand_numbers[-11:]}", + password=options.get("password") or fake.password(), + ) + + +def mock_verification_otp() -> str: + return str(int(100000 + random() * 900000)) + + +def mock_user_metadata(): + fake = Faker() + return { + "profile_image": fake.url(), + } + + +def mock_app_metadata(): + return { + "roles": ["editor", "publisher"], + } + + +def create_new_user_with_email( + *, + email: Optional[str] = None, + password: Optional[str] = None, +) -> User: + credentials = mock_user_credentials( + { + "email": email, + "password": password, + } + ) + response = service_role_api_client().create_user( + { + "email": credentials.email, + "password": credentials.password, + } + ) + return response.user + SIGNUP_ENABLED_AUTO_CONFIRM_OFF_PORT = 9999 SIGNUP_ENABLED_AUTO_CONFIRM_ON_PORT = 9998 @@ -31,7 +112,7 @@ ) -def auth_client(): +def auth_client() -> SyncGoTrueClient: return SyncGoTrueClient( url=GOTRUE_URL_SIGNUP_ENABLED_AUTO_CONFIRM_ON, auto_refresh_token=False, @@ -39,12 +120,15 @@ def auth_client(): ) -def auth_client_with_session(): - return SyncGoTrueClient( +def auth_client_with_session() -> SyncGoTrueClient: + client = SyncGoTrueClient( url=GOTRUE_URL_SIGNUP_ENABLED_AUTO_CONFIRM_ON, auto_refresh_token=False, persist_session=False, ) + credentials = mock_user_credentials() + client.sign_up({"email": credentials.email, "password": credentials.password}) + return client def auth_client_with_asymmetric_session() -> SyncGoTrueClient: @@ -55,7 +139,7 @@ def auth_client_with_asymmetric_session() -> SyncGoTrueClient: ) -def auth_subscription_client(): +def auth_subscription_client() -> SyncGoTrueClient: return SyncGoTrueClient( url=GOTRUE_URL_SIGNUP_ENABLED_AUTO_CONFIRM_ON, auto_refresh_token=False, @@ -63,7 +147,7 @@ def auth_subscription_client(): ) -def client_api_auto_confirm_enabled_client(): +def client_api_auto_confirm_enabled_client() -> SyncGoTrueClient: return SyncGoTrueClient( url=GOTRUE_URL_SIGNUP_ENABLED_AUTO_CONFIRM_ON, auto_refresh_token=False, @@ -71,7 +155,7 @@ def client_api_auto_confirm_enabled_client(): ) -def client_api_auto_confirm_off_signups_enabled_client(): +def client_api_auto_confirm_off_signups_enabled_client() -> SyncGoTrueClient: return SyncGoTrueClient( url=GOTRUE_URL_SIGNUP_ENABLED_AUTO_CONFIRM_OFF, auto_refresh_token=False, @@ -79,7 +163,7 @@ def client_api_auto_confirm_off_signups_enabled_client(): ) -def client_api_auto_confirm_disabled_client(): +def client_api_auto_confirm_disabled_client() -> SyncGoTrueClient: return SyncGoTrueClient( url=GOTRUE_URL_SIGNUP_DISABLED_AUTO_CONFIRM_OFF, auto_refresh_token=False, @@ -87,7 +171,7 @@ def client_api_auto_confirm_disabled_client(): ) -def auth_admin_api_auto_confirm_enabled_client(): +def auth_admin_api_auto_confirm_enabled_client() -> SyncGoTrueAdminAPI: return SyncGoTrueAdminAPI( url=GOTRUE_URL_SIGNUP_ENABLED_AUTO_CONFIRM_ON, headers={ @@ -96,7 +180,7 @@ def auth_admin_api_auto_confirm_enabled_client(): ) -def auth_admin_api_auto_confirm_disabled_client(): +def auth_admin_api_auto_confirm_disabled_client() -> SyncGoTrueAdminAPI: return SyncGoTrueAdminAPI( url=GOTRUE_URL_SIGNUP_ENABLED_AUTO_CONFIRM_OFF, headers={ @@ -113,7 +197,7 @@ def auth_admin_api_auto_confirm_disabled_client(): ) -def service_role_api_client(): +def service_role_api_client() -> SyncGoTrueAdminAPI: return SyncGoTrueAdminAPI( url=GOTRUE_URL_SIGNUP_ENABLED_AUTO_CONFIRM_ON, headers={ @@ -122,7 +206,7 @@ def service_role_api_client(): ) -def service_role_api_client_with_sms(): +def service_role_api_client_with_sms() -> SyncGoTrueAdminAPI: return SyncGoTrueAdminAPI( url=GOTRUE_URL_SIGNUP_ENABLED_AUTO_CONFIRM_OFF, headers={ @@ -131,7 +215,7 @@ def service_role_api_client_with_sms(): ) -def service_role_api_client_no_sms(): +def service_role_api_client_no_sms() -> SyncGoTrueAdminAPI: return SyncGoTrueAdminAPI( url=GOTRUE_URL_SIGNUP_DISABLED_AUTO_CONFIRM_OFF, headers={ diff --git a/src/auth/tests/_sync/test_gotrue.py b/src/auth/tests/_sync/test_gotrue.py index 1e76c61b..f15dcbd8 100644 --- a/src/auth/tests/_sync/test_gotrue.py +++ b/src/auth/tests/_sync/test_gotrue.py @@ -11,14 +11,15 @@ AuthSessionMissingError, ) from supabase_auth.helpers import decode_jwt +from supabase_auth.types import SignUpWithEmailAndPasswordCredentials from .clients import ( GOTRUE_JWT_SECRET, auth_client, auth_client_with_asymmetric_session, auth_client_with_session, + mock_user_credentials, ) -from .utils import mock_user_credentials def test_get_claims_returns_none_when_session_is_none(): @@ -29,28 +30,42 @@ def test_get_claims_returns_none_when_session_is_none(): def test_get_claims_calls_get_user_if_symmetric_jwt(mocker): client = auth_client() spy = mocker.spy(client, "get_user") + credentials = mock_user_credentials() + options: SignUpWithEmailAndPasswordCredentials = { + "email": credentials.email, + "password": credentials.password, + } + user = (client.sign_up(options)).user - user = (client.sign_up(mock_user_credentials())).user assert user is not None - claims = (client.get_claims())["claims"] - assert claims["email"] == user.email + response = client.get_claims() + assert response + claims = response["claims"] + + assert claims.get("email") == user.email spy.assert_called_once() def test_get_claims_fetches_jwks_to_verify_asymmetric_jwt(mocker): client = auth_client_with_asymmetric_session() - - user = (client.sign_up(mock_user_credentials())).user + credentials = mock_user_credentials() + options: SignUpWithEmailAndPasswordCredentials = { + "email": credentials.email, + "password": credentials.password, + } + user = (client.sign_up(options)).user assert user is not None spy = mocker.spy(client, "_request") - claims = (client.get_claims())["claims"] - assert claims["email"] == user.email + response = client.get_claims() + assert response + claims = response["claims"] + assert claims.get("email") == user.email spy.assert_called_once() - spy.assert_called_with("GET", ".well-known/jwks.json", xform=unittest.mock.ANY) + spy.assert_called_with("GET", ".well-known/jwks.json") expected_keyid = "638c54b8-28c2-4b12-9598-ba12ef610a29" @@ -64,11 +79,16 @@ def test_jwks_ttl_cache_behavior(mocker): spy = mocker.spy(client, "_request") # First call should fetch JWKS from endpoint - user = (client.sign_up(mock_user_credentials())).user + credentials = mock_user_credentials() + options: SignUpWithEmailAndPasswordCredentials = { + "email": credentials.email, + "password": credentials.password, + } + user = (client.sign_up(options)).user assert user is not None client.get_claims() - spy.assert_called_with("GET", ".well-known/jwks.json", xform=unittest.mock.ANY) + spy.assert_called_with("GET", ".well-known/jwks.json") first_call_count = spy.call_count # Second call within TTL should use cache @@ -96,8 +116,8 @@ def test_set_session_with_valid_tokens(): # First sign up to get valid tokens signup_response = client.sign_up( { - "email": credentials.get("email"), - "password": credentials.get("password"), + "email": credentials.email, + "password": credentials.password, } ) assert signup_response.session is not None @@ -117,7 +137,7 @@ def test_set_session_with_valid_tokens(): assert response.session.access_token == access_token assert response.session.refresh_token == refresh_token assert response.user is not None - assert response.user.email == credentials.get("email") + assert response.user.email == credentials.email def test_set_session_with_expired_token(): @@ -127,8 +147,8 @@ def test_set_session_with_expired_token(): # First sign up to get valid tokens signup_response = client.sign_up( { - "email": credentials.get("email"), - "password": credentials.get("password"), + "email": credentials.email, + "password": credentials.password, } ) assert signup_response.session is not None @@ -144,9 +164,9 @@ def test_set_session_with_expired_token(): expired_token = access_token.split(".") payload = decode_jwt(access_token)["payload"] payload["exp"] = int(time.time()) - 3600 # Set expiry to 1 hour ago - expired_token[1] = encode(payload, GOTRUE_JWT_SECRET, algorithm="HS256").split(".")[ - 1 - ] + expired_token[1] = encode( + dict(payload), GOTRUE_JWT_SECRET, algorithm="HS256" + ).split(".")[1] expired_access_token = ".".join(expired_token) # Set the session with the expired token @@ -157,7 +177,7 @@ def test_set_session_with_expired_token(): assert response.session.access_token != expired_access_token assert response.session.refresh_token != refresh_token assert response.user is not None - assert response.user.email == credentials.get("email") + assert response.user.email == credentials.email def test_set_session_without_refresh_token(): @@ -167,8 +187,8 @@ def test_set_session_without_refresh_token(): # First sign up to get valid tokens signup_response = client.sign_up( { - "email": credentials.get("email"), - "password": credentials.get("password"), + "email": credentials.email, + "password": credentials.password, } ) assert signup_response.session is not None @@ -183,9 +203,9 @@ def test_set_session_without_refresh_token(): expired_token = access_token.split(".") payload = decode_jwt(access_token)["payload"] payload["exp"] = int(time.time()) - 3600 # Set expiry to 1 hour ago - expired_token[1] = encode(payload, GOTRUE_JWT_SECRET, algorithm="HS256").split(".")[ - 1 - ] + expired_token[1] = encode( + dict(payload), GOTRUE_JWT_SECRET, algorithm="HS256" + ).split(".")[1] expired_access_token = ".".join(expired_token) # Try to set the session with an expired token but no refresh token @@ -209,8 +229,8 @@ def test_mfa_enroll(): # First sign up to get a valid session client.sign_up( { - "email": credentials.get("email"), - "password": credentials.get("password"), + "email": credentials.email, + "password": credentials.password, } ) @@ -222,6 +242,7 @@ def test_mfa_enroll(): assert enroll_response.id is not None assert enroll_response.type == "totp" assert enroll_response.friendly_name == "test-factor" + assert enroll_response.totp assert enroll_response.totp.qr_code is not None @@ -232,8 +253,8 @@ def test_mfa_challenge(): # First sign up to get a valid session signup_response = client.sign_up( { - "email": credentials.get("email"), - "password": credentials.get("password"), + "email": credentials.email, + "password": credentials.password, } ) assert signup_response.session is not None @@ -256,8 +277,8 @@ def test_mfa_unenroll(): # First sign up to get a valid session signup_response = client.sign_up( { - "email": credentials.get("email"), - "password": credentials.get("password"), + "email": credentials.email, + "password": credentials.password, } ) assert signup_response.session is not None @@ -279,8 +300,8 @@ def test_mfa_list_factors(): # First sign up to get a valid session signup_response = client.sign_up( { - "email": credentials.get("email"), - "password": credentials.get("password"), + "email": credentials.email, + "password": credentials.password, } ) assert signup_response.session is not None @@ -295,101 +316,6 @@ def test_mfa_list_factors(): assert len(list_response.all) == 1 -def test_initialize_from_url(): - # This test verifies the URL format detection and initialization from URL - client = auth_client() - - # First we'll test the _is_implicit_grant_flow method - # The method checks for access_token or error_description in the query string, not the fragment - url_with_token = "http://example.com/?access_token=test_token&other=value" - assert client._is_implicit_grant_flow(url_with_token) == True - - url_with_error = "http://example.com/?error_description=test_error&other=value" - assert client._is_implicit_grant_flow(url_with_error) == True - - url_without_token = "http://example.com/?other=value" - assert client._is_implicit_grant_flow(url_without_token) == False - - # Now test actual URL initialization with a valid URL containing auth tokens - from unittest.mock import patch - - from supabase_auth.types import Session, User, UserResponse - - # Create a mock user and session to avoid actual API calls - mock_user = User( - id="user123", - email="test@example.com", - app_metadata={}, - user_metadata={}, - aud="authenticated", - created_at="2023-01-01T00:00:00Z", - confirmed_at="2023-01-01T00:00:00Z", - last_sign_in_at="2023-01-01T00:00:00Z", - role="authenticated", - updated_at="2023-01-01T00:00:00Z", - ) - - # Wrap the user in a UserResponse as that's what get_user returns - mock_user_response = UserResponse(user=mock_user) - - # Test successful initialization with tokens in URL - good_url = "http://example.com/?access_token=mock_access_token&refresh_token=mock_refresh_token&expires_in=3600&token_type=bearer" - - # We need to mock: - # 1. get_user which is called by _get_session_from_url to validate the token - # 2. _save_session which is called to store the session data - # 3. _notify_all_subscribers which is called to notify about sign-in - with patch.object(client, "get_user") as mock_get_user: - mock_get_user.return_value = mock_user_response - - with patch.object(client, "_save_session") as mock_save_session: - with patch.object(client, "_notify_all_subscribers") as mock_notify: - # Call initialize_from_url with the good URL - result = client.initialize_from_url(good_url) - - # Verify get_user was called with the access token - mock_get_user.assert_called_once_with("mock_access_token") - - # Verify _save_session was called with a Session object - mock_save_session.assert_called_once() - session_arg = mock_save_session.call_args[0][0] - assert isinstance(session_arg, Session) - assert session_arg.access_token == "mock_access_token" - assert session_arg.refresh_token == "mock_refresh_token" - assert session_arg.expires_in == 3600 - - # Verify _notify_all_subscribers was called - mock_notify.assert_called_with("SIGNED_IN", session_arg) - - assert result is None # initialize_from_url doesn't have a return value - - # Test URL with error - need to include error_code for the test to work correctly - error_url = "http://example.com/?error=invalid_request&error_description=Invalid+request&error_code=400" - - # Should throw an error when URL contains error parameters - from supabase_auth.errors import AuthImplicitGrantRedirectError - - try: - client.initialize_from_url(error_url) - assert False, "Expected AuthImplicitGrantRedirectError" - except AuthImplicitGrantRedirectError as e: - # The error message includes the error_description value - assert "Invalid request" in str(e) - - # Test URL with code for PKCE flow - code_url = "http://example.com/?code=authorization_code" - - # For the code URL path, we're not testing it here since it requires more mocking - # and is indirectly tested via other tests like exchange_code_for_session - - # Test URL with neither tokens nor code - should not throw but also not call anything - invalid_url = "http://example.com/?foo=bar" - with patch.object(client, "_get_session_from_url") as mock_get_session: - result = client.initialize_from_url(invalid_url) - mock_get_session.assert_not_called() - assert result is None - - def test_exchange_code_for_session(): client = auth_client() @@ -405,8 +331,7 @@ def test_exchange_code_for_session(): client._flow_type = "pkce" # Test the PKCE URL generation which is needed for exchange_code_for_session - provider = "github" - url, params = client._get_url_for_provider(f"{client._url}/authorize", provider, {}) + url, params = client._get_url_for_provider(f"{client._url}/authorize", "github", {}) # Verify PKCE parameters were added assert "code_challenge" in params @@ -430,8 +355,8 @@ def test_get_authenticator_assurance_level(): # Sign up to get a valid session signup_response = client.sign_up( { - "email": credentials.get("email"), - "password": credentials.get("password"), + "email": credentials.email, + "password": credentials.password, } ) assert signup_response.session is not None @@ -449,14 +374,16 @@ def test_link_identity(): # Sign up to get a valid session signup_response = client.sign_up( { - "email": credentials.get("email"), - "password": credentials.get("password"), + "email": credentials.email, + "password": credentials.password, } ) assert signup_response.session is not None from unittest.mock import patch + from httpx import Response + from supabase_auth.types import OAuthResponse # Since the test server has manual linking disabled, we'll mock the URL generation @@ -467,7 +394,9 @@ def test_link_identity(): # Also mock the _request method since the server would reject it with patch.object(client, "_request") as mock_request: - mock_request.return_value = OAuthResponse(provider="github", url=mock_url) + mock_request.return_value = Response( + content=f'{{"url":"{mock_url}"}}', status_code=200 + ) # Call the method response = client.link_identity({"provider": "github"}) @@ -484,8 +413,8 @@ def test_get_user_identities(): # Sign up to get a valid session signup_response = client.sign_up( { - "email": credentials.get("email"), - "password": credentials.get("password"), + "email": credentials.email, + "password": credentials.password, } ) assert signup_response.session is not None @@ -497,129 +426,6 @@ def test_get_user_identities(): assert hasattr(identities_response, "identities") -def test_unlink_identity(): - client = auth_client() - credentials = mock_user_credentials() - - # Sign up to get a valid session - signup_response = client.sign_up( - { - "email": credentials.get("email"), - "password": credentials.get("password"), - } - ) - assert signup_response.session is not None - - # Mock a UserIdentity to test unlink_identity - from unittest.mock import patch - - from supabase_auth.types import UserIdentity - - # Create a mock identity - mock_identity = UserIdentity( - id="user-id", - identity_id="identity-id-1", - user_id="user-id", - identity_data={"email": "user@example.com"}, - provider="github", - created_at="2023-01-01T00:00:00Z", - last_sign_in_at="2023-01-01T00:00:00Z", - updated_at="2023-01-01T00:00:00Z", - ) - - # Mock the _request method since we can't actually unlink an identity that doesn't exist - with patch.object(client, "_request") as mock_request: - mock_request.return_value = None - - # Call the method - client.unlink_identity(mock_identity) - - # Verify the request was made properly - mock_request.assert_called_once_with( - "DELETE", - "user/identities/identity-id-1", - jwt=signup_response.session.access_token, - ) - - # Test error case: no session - with patch.object(client, "get_session") as mock_get_session: - from supabase_auth.errors import AuthSessionMissingError - - mock_get_session.return_value = None - - try: - client.unlink_identity(mock_identity) - assert False, "Expected AuthSessionMissingError" - except AuthSessionMissingError: - pass - - -def test_verify_otp(): - client = auth_client() - - # Mock the _request method since we can't actually verify an OTP in the test - import time - from unittest.mock import patch - - from supabase_auth.types import AuthResponse, Session, User - - mock_user = User( - id="test-user-id", - app_metadata={}, - user_metadata={}, - aud="test-aud", - email="test@example.com", - phone="", - created_at="2023-01-01T00:00:00Z", - confirmed_at="2023-01-01T00:00:00Z", - last_sign_in_at="2023-01-01T00:00:00Z", - role="", - updated_at="2023-01-01T00:00:00Z", - ) - - mock_session = Session( - access_token="mock-access-token", - refresh_token="mock-refresh-token", - expires_in=3600, - expires_at=round(time.time()) + 3600, - token_type="bearer", - user=mock_user, - ) - - mock_response = AuthResponse(session=mock_session, user=mock_user) - - with patch.object(client, "_request") as mock_request: - # Configure the mock to return a predefined response - mock_request.return_value = mock_response - - # Also patch _save_session to avoid actual storage interactions - with patch.object(client, "_save_session") as mock_save: - # Call verify_otp with test parameters - params = { - "type": "sms", - "phone": "+11234567890", - "token": "123456", - "options": {"redirect_to": "https://example.com/callback"}, - } - - response = client.verify_otp(params) - - # Verify the request was made with correct parameters - mock_request.assert_called_once() - args, kwargs = mock_request.call_args - assert args[0] == "POST" # method - assert args[1] == "verify" # path - assert kwargs["body"]["phone"] == "+11234567890" - assert kwargs["body"]["token"] == "123456" - assert kwargs["redirect_to"] == "https://example.com/callback" - - # Verify the session was saved - mock_save.assert_called_once_with(mock_session) - - # Verify the response - assert response == mock_response - - def test_sign_in_with_password(): client = auth_client() credentials = mock_user_credentials() @@ -628,8 +434,8 @@ def test_sign_in_with_password(): # First create a user we can sign in with signup_response = client.sign_up( { - "email": credentials.get("email"), - "password": credentials.get("password"), + "email": credentials.email, + "password": credentials.password, } ) assert signup_response.session is not None @@ -637,15 +443,15 @@ def test_sign_in_with_password(): # Test signing in with the same credentials (email) signin_response = client.sign_in_with_password( { - "email": credentials.get("email"), - "password": credentials.get("password"), + "email": credentials.email, + "password": credentials.password, } ) # Verify the response has a valid session and user assert signin_response.session is not None assert signin_response.user is not None - assert signin_response.user.email == credentials.get("email") + assert signin_response.user.email == credentials.email # Test error case: wrong password @@ -655,7 +461,7 @@ def test_sign_in_with_password(): try: test_client.sign_in_with_password( { - "email": credentials.get("email"), + "email": credentials.email, "password": "wrong_password", } ) @@ -665,7 +471,7 @@ def test_sign_in_with_password(): # Test error case: missing credentials try: - test_client.sign_in_with_password({}) + test_client.sign_in_with_password({}) # type: ignore assert False, "Expected AuthInvalidCredentialsError for missing credentials" except AuthInvalidCredentialsError: pass @@ -681,13 +487,16 @@ def test_sign_in_with_otp(): # We can't fully test the actual OTP flow since that requires email verification from unittest.mock import patch + from httpx import Response + from supabase_auth.types import AuthOtpResponse # First test for email OTP + auth_otp = AuthOtpResponse( + message_id="mock-message-id", + ) with patch.object(client, "_request") as mock_request: - mock_response = AuthOtpResponse( - message_id="mock-message-id", email=email, phone=None, hash=None - ) + mock_response = Response(content=auth_otp.model_dump_json(), status_code=200) mock_request.return_value = mock_response response = client.sign_in_with_otp( @@ -717,15 +526,13 @@ def test_sign_in_with_otp(): assert kwargs["redirect_to"] == "https://example.com/callback" # Verify response - assert response == mock_response + assert response == auth_otp # Test with phone OTP phone = "+11234567890" - + auth_otp = AuthOtpResponse(message_id="mock-message-id") with patch.object(client, "_request") as mock_request: - mock_response = AuthOtpResponse( - message_id="mock-message-id", email=None, phone=phone, hash=None - ) + mock_response = Response(content=auth_otp.model_dump_json(), status_code=200) mock_request.return_value = mock_response response = client.sign_in_with_otp( @@ -756,19 +563,20 @@ def test_sign_in_with_otp(): assert kwargs.get("redirect_to") is None # No redirect for phone # Verify response - assert response == mock_response + assert response == auth_otp # Test with invalid parameters (missing both email and phone) from supabase_auth.errors import AuthInvalidCredentialsError try: - client.sign_in_with_otp({}) + client.sign_in_with_otp({}) # type: ignore assert False, "Expected AuthInvalidCredentialsError" except AuthInvalidCredentialsError: pass def test_sign_out(): + from datetime import datetime from unittest.mock import patch from supabase_auth.types import Session, User @@ -776,17 +584,18 @@ def test_sign_out(): client = auth_client() # Create a mock user and session + date = datetime(year=2023, month=1, day=1, hour=0, minute=0, second=0) mock_user = User( id="user123", email="test@example.com", app_metadata={}, user_metadata={}, aud="authenticated", - created_at="2023-01-01T00:00:00Z", - confirmed_at="2023-01-01T00:00:00Z", - last_sign_in_at="2023-01-01T00:00:00Z", + created_at=date, + confirmed_at=date, + last_sign_in_at=date, role="authenticated", - updated_at="2023-01-01T00:00:00Z", + updated_at=date, ) mock_session = Session( @@ -890,7 +699,7 @@ def test_sign_out(): with patch.object(client.admin, "sign_out") as mock_admin_sign_out: mock_admin_sign_out.side_effect = AuthApiError( - "Test error", 401, "auth_error" + "Test error", 401, "validation_failed" ) with patch.object(client, "_remove_session") as mock_remove_session: diff --git a/src/auth/tests/_sync/test_gotrue_admin_api.py b/src/auth/tests/_sync/test_gotrue_admin_api.py index e179fe2b..75be02f7 100644 --- a/src/auth/tests/_sync/test_gotrue_admin_api.py +++ b/src/auth/tests/_sync/test_gotrue_admin_api.py @@ -15,21 +15,19 @@ auth_client_with_session, client_api_auto_confirm_disabled_client, client_api_auto_confirm_off_signups_enabled_client, - service_role_api_client, -) -from .utils import ( create_new_user_with_email, mock_app_metadata, mock_user_credentials, mock_user_metadata, mock_verification_otp, + service_role_api_client, ) def test_create_user_should_create_a_new_user(): credentials = mock_user_credentials() - response = create_new_user_with_email(email=credentials.get("email")) - assert response.email == credentials.get("email") + response = create_new_user_with_email(email=credentials.email) + assert response.email == credentials.email def test_create_user_with_user_metadata(): @@ -37,12 +35,12 @@ def test_create_user_with_user_metadata(): credentials = mock_user_credentials() response = service_role_api_client().create_user( { - "email": credentials.get("email"), - "password": credentials.get("password"), + "email": credentials.email, + "password": credentials.password, "user_metadata": user_metadata, } ) - assert response.user.email == credentials.get("email") + assert response.user.email == credentials.email assert response.user.user_metadata == user_metadata assert "profile_image" in response.user.user_metadata @@ -53,13 +51,13 @@ def test_create_user_with_user_and_app_metadata(): credentials = mock_user_credentials() response = service_role_api_client().create_user( { - "email": credentials.get("email"), - "password": credentials.get("password"), + "email": credentials.email, + "password": credentials.password, "user_metadata": user_metadata, "app_metadata": app_metadata, } ) - assert response.user.email == credentials.get("email") + assert response.user.email == credentials.email assert "profile_image" in response.user.user_metadata assert "provider" in response.user.app_metadata assert "providers" in response.user.app_metadata @@ -67,39 +65,25 @@ def test_create_user_with_user_and_app_metadata(): def test_list_users_should_return_registered_users(): credentials = mock_user_credentials() - create_new_user_with_email(email=credentials.get("email")) + create_new_user_with_email(email=credentials.email) users = service_role_api_client().list_users() assert users emails = [user.email for user in users] assert emails - assert credentials.get("email") in emails - - -def test_get_user_fetches_a_user_by_their_access_token(): - credentials = mock_user_credentials() - auth_client_with_session_current_user = auth_client_with_session() - response = auth_client_with_session_current_user.sign_up( - { - "email": credentials.get("email"), - "password": credentials.get("password"), - } - ) - assert response.session - response = auth_client_with_session_current_user.get_user() - assert response.user.email == credentials.get("email") + assert credentials.email in emails def test_get_user_by_id_should_a_registered_user_given_its_user_identifier(): credentials = mock_user_credentials() - user = create_new_user_with_email(email=credentials.get("email")) + user = create_new_user_with_email(email=credentials.email) assert user.id response = service_role_api_client().get_user_by_id(user.id) - assert response.user.email == credentials.get("email") + assert response.user.email == credentials.email def test_modify_email_using_update_user_by_id(): credentials = mock_user_credentials() - user = create_new_user_with_email(email=credentials.get("email")) + user = create_new_user_with_email(email=credentials.email) response = service_role_api_client().update_user_by_id( user.id, { @@ -111,7 +95,7 @@ def test_modify_email_using_update_user_by_id(): def test_modify_user_metadata_using_update_user_by_id(): credentials = mock_user_credentials() - user = create_new_user_with_email(email=credentials.get("email")) + user = create_new_user_with_email(email=credentials.email) user_metadata = {"favorite_color": "yellow"} response = service_role_api_client().update_user_by_id( user.id, @@ -125,7 +109,7 @@ def test_modify_user_metadata_using_update_user_by_id(): def test_modify_app_metadata_using_update_user_by_id(): credentials = mock_user_credentials() - user = create_new_user_with_email(email=credentials.get("email")) + user = create_new_user_with_email(email=credentials.email) app_metadata = {"roles": ["admin", "publisher"]} response = service_role_api_client().update_user_by_id( user.id, @@ -141,8 +125,8 @@ def test_modify_confirm_email_using_update_user_by_id(): credentials = mock_user_credentials() response = client_api_auto_confirm_off_signups_enabled_client().sign_up( { - "email": credentials.get("email"), - "password": credentials.get("password"), + "email": credentials.email, + "password": credentials.password, } ) assert response.user @@ -207,62 +191,46 @@ def test_sign_in_with_otp_phone(): def test_resend(): - try: - client_api_auto_confirm_off_signups_enabled_client().resend( - {"phone": "+112345678", "type": "sms"} - ) - except AuthApiError as e: - assert e.to_dict() + client_api_auto_confirm_off_signups_enabled_client().resend( + {"phone": "+112345678", "type": "sms"} + ) def test_reauthenticate(): - try: - response = auth_client_with_session().reauthenticate() - except AuthSessionMissingError: - pass + client = auth_client_with_session() + client.reauthenticate() def test_refresh_session(): - try: - response = auth_client_with_session().refresh_session() - except AuthSessionMissingError: - pass + client = auth_client_with_session() + client.refresh_session() def test_reset_password_for_email(): credentials = mock_user_credentials() - try: - response = auth_client_with_session().reset_password_email( - email=credentials.get("email") - ) - except AuthSessionMissingError: - pass + client = auth_client_with_session() + client.reset_password_email(email=credentials.email) def test_resend_missing_credentials(): - try: - client_api_auto_confirm_off_signups_enabled_client().resend( - {"type": "email_change"} - ) - except AuthInvalidCredentialsError as e: - assert e.to_dict() + credentials = mock_user_credentials() + client_api_auto_confirm_off_signups_enabled_client().resend( + {"type": "email_change", "email": credentials.email} + ) def test_sign_in_anonymously(): - try: - response = auth_client_with_session().sign_in_anonymously() - assert response - except AuthApiError: - pass + client = auth_client_with_session() + client.sign_in_anonymously() def test_delete_user_should_be_able_delete_an_existing_user(): credentials = mock_user_credentials() - user = create_new_user_with_email(email=credentials.get("email")) + user = create_new_user_with_email(email=credentials.email) service_role_api_client().delete_user(user.id) users = service_role_api_client().list_users() emails = [user.email for user in users] - assert credentials.get("email") not in emails + assert credentials.email not in emails def test_generate_link_supports_sign_up_with_generate_confirmation_signup_link(): @@ -272,8 +240,8 @@ def test_generate_link_supports_sign_up_with_generate_confirmation_signup_link() response = service_role_api_client().generate_link( { "type": "signup", - "email": credentials.get("email"), - "password": credentials.get("password"), + "email": credentials.email, + "password": credentials.password, "options": { "data": user_metadata, "redirect_to": redirect_to, @@ -285,22 +253,22 @@ def test_generate_link_supports_sign_up_with_generate_confirmation_signup_link() def test_generate_link_supports_updating_emails_with_generate_email_change_links(): # noqa: E501 credentials = mock_user_credentials() - user = create_new_user_with_email(email=credentials.get("email")) + user = create_new_user_with_email(email=credentials.email) assert user.email - assert user.email == credentials.get("email") + assert user.email == credentials.email credentials = mock_user_credentials() redirect_to = "http://localhost:9999/welcome" response = service_role_api_client().generate_link( { "type": "email_change_current", "email": user.email, - "new_email": credentials.get("email"), + "new_email": credentials.email, "options": { "redirect_to": redirect_to, }, }, ) - assert response.user.new_email == credentials.get("email") + assert response.user.new_email == credentials.email def test_invite_user_by_email_creates_a_new_user_with_an_invited_at_timestamp(): @@ -308,7 +276,7 @@ def test_invite_user_by_email_creates_a_new_user_with_an_invited_at_timestamp(): redirect_to = "http://localhost:9999/welcome" user_metadata = {"status": "alpha"} response = service_role_api_client().invite_user_by_email( - credentials.get("email"), + credentials.email, { "data": user_metadata, "redirect_to": redirect_to, @@ -319,14 +287,15 @@ def test_invite_user_by_email_creates_a_new_user_with_an_invited_at_timestamp(): def test_sign_out_with_an_valid_access_token(): credentials = mock_user_credentials() - response = auth_client_with_session().sign_up( + client = auth_client_with_session() + response = client.sign_up( { - "email": credentials.get("email"), - "password": credentials.get("password"), + "email": credentials.email, + "password": credentials.password, }, ) assert response.session - response = service_role_api_client().sign_out(response.session.access_token) + service_role_api_client().sign_out(response.session.access_token) def test_sign_out_with_an_invalid_access_token(): @@ -343,7 +312,7 @@ def test_verify_otp_with_non_existent_phone_number(): try: client_api_auto_confirm_disabled_client().verify_otp( { - "phone": credentials.get("phone"), + "phone": credentials.phone, "token": otp, "type": "sms", }, @@ -359,7 +328,7 @@ def test_verify_otp_with_invalid_phone_number(): try: client_api_auto_confirm_disabled_client().verify_otp( { - "phone": f"{credentials.get('phone')}-invalid", + "phone": f"{credentials.phone}-invalid", "token": otp, "type": "sms", }, @@ -371,11 +340,13 @@ def test_verify_otp_with_invalid_phone_number(): def test_sign_in_with_id_token(): try: - client_api_auto_confirm_off_signups_enabled_client().sign_in_with_id_token( - { - "provider": "google", - "token": "123456", - } + ( + client_api_auto_confirm_off_signups_enabled_client().sign_in_with_id_token( + { + "provider": "google", + "token": "123456", + } + ) ) except AuthApiError as e: assert e.to_dict() @@ -414,15 +385,15 @@ def test_get_item_from_memory_storage(): client = auth_client() client.sign_up( { - "email": credentials.get("email"), - "password": credentials.get("password"), + "email": credentials.email, + "password": credentials.password, } ) client.sign_in_with_password( { - "email": credentials.get("email"), - "password": credentials.get("password"), + "email": credentials.email, + "password": credentials.password, } ) assert client._storage.get_item(client._storage_key) is not None @@ -433,19 +404,18 @@ def test_remove_item_from_memory_storage(): client = auth_client() client.sign_up( { - "email": credentials.get("email"), - "password": credentials.get("password"), + "email": credentials.email, + "password": credentials.password, } ) client.sign_in_with_password( { - "email": credentials.get("email"), - "password": credentials.get("password"), + "email": credentials.email, + "password": credentials.password, } ) client._storage.remove_item(client._storage_key) - assert client._storage_key not in client._storage.storage def test_list_factors(): @@ -453,15 +423,15 @@ def test_list_factors(): client = auth_client() client.sign_up( { - "email": credentials.get("email"), - "password": credentials.get("password"), + "email": credentials.email, + "password": credentials.password, } ) client.sign_in_with_password( { - "email": credentials.get("email"), - "password": credentials.get("password"), + "email": credentials.email, + "password": credentials.password, } ) factors = client._list_factors() @@ -475,20 +445,18 @@ def test_start_auto_refresh_token(): client._auto_refresh_token = True client.sign_up( { - "email": credentials.get("email"), - "password": credentials.get("password"), + "email": credentials.email, + "password": credentials.password, } ) client.sign_in_with_password( { - "email": credentials.get("email"), - "password": credentials.get("password"), + "email": credentials.email, + "password": credentials.password, } ) - assert client._start_auto_refresh_token(2.0) is None - def test_recover_and_refresh(): credentials = mock_user_credentials() @@ -496,19 +464,18 @@ def test_recover_and_refresh(): client._auto_refresh_token = True client.sign_up( { - "email": credentials.get("email"), - "password": credentials.get("password"), + "email": credentials.email, + "password": credentials.password, } ) client.sign_in_with_password( { - "email": credentials.get("email"), - "password": credentials.get("password"), + "email": credentials.email, + "password": credentials.password, } ) client._recover_and_refresh() - assert client._storage_key in client._storage.storage def test_get_user_identities(): @@ -517,20 +484,20 @@ def test_get_user_identities(): client._auto_refresh_token = True client.sign_up( { - "email": credentials.get("email"), - "password": credentials.get("password"), + "email": credentials.email, + "password": credentials.password, } ) client.sign_in_with_password( { - "email": credentials.get("email"), - "password": credentials.get("password"), + "email": credentials.email, + "password": credentials.password, } ) assert (client.get_user_identities()).identities[0].identity_data[ "email" - ] == credentials.get("email") + ] == credentials.email def test_update_user(): @@ -539,14 +506,14 @@ def test_update_user(): client._auto_refresh_token = True client.sign_up( { - "email": credentials.get("email"), - "password": credentials.get("password"), + "email": credentials.email, + "password": credentials.password, } ) client.update_user({"password": "123e5a"}) client.sign_in_with_password( { - "email": credentials.get("email"), + "email": credentials.email, "password": "123e5a", } ) @@ -557,12 +524,12 @@ def test_create_user_with_app_metadata(): credentials = mock_user_credentials() response = service_role_api_client().create_user( { - "email": credentials.get("email"), - "password": credentials.get("password"), + "email": credentials.email, + "password": credentials.password, "app_metadata": app_metadata, } ) - assert response.user.email == credentials.get("email") + assert response.user.email == credentials.email assert "provider" in response.user.app_metadata assert "providers" in response.user.app_metadata @@ -572,7 +539,7 @@ def test_weak_email_password_error(): try: client_api_auto_confirm_off_signups_enabled_client().sign_up( { - "email": credentials.get("email"), + "email": credentials.email, "password": "123", } ) @@ -585,7 +552,7 @@ def test_weak_phone_password_error(): try: client_api_auto_confirm_off_signups_enabled_client().sign_up( { - "phone": credentials.get("phone"), + "phone": credentials.phone, "password": "123", } ) diff --git a/src/auth/tests/_sync/test_utils.py b/src/auth/tests/_sync/test_utils.py index 23b4ac9c..87ca2db4 100644 --- a/src/auth/tests/_sync/test_utils.py +++ b/src/auth/tests/_sync/test_utils.py @@ -1,6 +1,6 @@ from time import time -from .utils import ( +from .clients import ( create_new_user_with_email, mock_app_metadata, mock_user_credentials, @@ -8,18 +8,6 @@ ) -def test_mock_user_credentials_has_email(): - credentials = mock_user_credentials() - assert credentials.get("email") - assert credentials.get("password") - - -def test_mock_user_credentials_has_phone(): - credentials = mock_user_credentials() - assert credentials.get("phone") - assert credentials.get("password") - - def test_create_new_user_with_email(): email = f"user+{int(time())}@example.com" user = create_new_user_with_email(email=email) diff --git a/src/auth/tests/_sync/utils.py b/src/auth/tests/_sync/utils.py index 2cc31de2..e69de29b 100644 --- a/src/auth/tests/_sync/utils.py +++ b/src/auth/tests/_sync/utils.py @@ -1,82 +0,0 @@ -from random import random -from time import time -from typing import Optional - -from faker import Faker -from jwt import encode -from typing_extensions import NotRequired, TypedDict - -from supabase_auth.types import User - -from .clients import GOTRUE_JWT_SECRET, service_role_api_client - - -def mock_access_token() -> str: - return encode( - { - "sub": "1234567890", - "role": "anon_key", - }, - GOTRUE_JWT_SECRET, - ) - - -class OptionalCredentials(TypedDict): - email: NotRequired[Optional[str]] - phone: NotRequired[Optional[str]] - password: NotRequired[Optional[str]] - - -class Credentials(TypedDict): - email: str - phone: str - password: str - - -def mock_user_credentials( - options: OptionalCredentials = {}, -) -> Credentials: - fake = Faker() - rand_numbers = str(int(time())) - return { - "email": options.get("email") or fake.email(), - "phone": options.get("phone") or f"1{rand_numbers[-11:]}", - "password": options.get("password") or fake.password(), - } - - -def mock_verification_otp() -> str: - return str(int(100000 + random() * 900000)) - - -def mock_user_metadata(): - fake = Faker() - return { - "profile_image": fake.url(), - } - - -def mock_app_metadata(): - return { - "roles": ["editor", "publisher"], - } - - -def create_new_user_with_email( - *, - email: Optional[str] = None, - password: Optional[str] = None, -) -> User: - credentials = mock_user_credentials( - { - "email": email, - "password": password, - } - ) - response = service_role_api_client().create_user( - { - "email": credentials["email"], - "password": credentials["password"], - } - ) - return response.user diff --git a/src/auth/tests/test_helpers.py b/src/auth/tests/test_helpers.py index de85baf7..68abffdf 100644 --- a/src/auth/tests/test_helpers.py +++ b/src/auth/tests/test_helpers.py @@ -21,7 +21,6 @@ decode_jwt, generate_pkce_challenge, generate_pkce_verifier, - get_error_code, handle_exception, model_dump, model_dump_json, @@ -41,7 +40,7 @@ User, ) -from ._sync.utils import mock_access_token +from ._sync.clients import mock_access_token TEST_URL = "http://localhost" @@ -49,14 +48,14 @@ def test_handle_exception_with_api_version_and_error_code(): err = { "name": "without API version and error code", - "code": "error_code", + "code": "unexpected_failure", "ename": "AuthApiError", } with respx.mock: respx.get(f"{TEST_URL}/hello-world").mock( return_value=Response(status_code=200), - side_effect=AuthApiError("Error code message", 400, "error_code"), + side_effect=AuthApiError("Error code message", 400, "unexpected_failure"), ) with pytest.raises(AuthApiError, match=r"Error code message") as exc: httpx.get(f"{TEST_URL}/hello-world") @@ -91,14 +90,14 @@ def test_handle_exception_without_api_version_and_weak_password_error_code(): def test_handle_exception_with_api_version_2024_01_01_and_error_code(): err = { "name": "with API version 2024-01-01 and error code", - "code": "error_code", + "code": "unexpected_failure", "ename": "AuthApiError", } with respx.mock: respx.get(f"{TEST_URL}/hello-world").mock( return_value=Response(status_code=200), - side_effect=AuthApiError("Error code message", 400, "error_code"), + side_effect=AuthApiError("Error code message", 400, "unexpected_failure"), ) with pytest.raises(AuthApiError, match=r"Error code message") as exc: httpx.get(f"{TEST_URL}/hello-world") @@ -127,12 +126,8 @@ def test_parse_response_api_version_with_invalid_dates(): def test_parse_link_identity_response(): - assert parse_link_identity_response({"url": f"{TEST_URL}/hello-world"}) - - -def test_get_error_code(): - assert get_error_code({}) is None - assert get_error_code({"error_code": "500"}) == "500" + resp = Response(content=f'{{"url": "{TEST_URL}/hello-world"}}', status_code=200) + assert parse_link_identity_response(resp) def test_decode_jwt(): @@ -171,14 +166,14 @@ def test_model_validate_pydantic_v1(): with patch("supabase_auth.helpers.TBaseModel") as MockType: # Mock the behavior of the try block to raise AttributeError mock_model = MagicMock() - mock_model.model_validate.side_effect = AttributeError - mock_model.parse_obj.return_value = "parsed_obj_result" + mock_model.model_validate_json.side_effect = AttributeError + mock_model.parse_raw.return_value = "parsed_obj_result" # Use the patched model in the actual function - result = model_validate(mock_model, {"test": "data"}) + result = model_validate(mock_model, {"test": "data"}) # type: ignore # Check that parse_obj was called - mock_model.parse_obj.assert_called_once_with({"test": "data"}) + mock_model.parse_raw.assert_called_once_with({"test": "data"}) assert result == "parsed_obj_result" @@ -212,194 +207,6 @@ def test_model_dump_json_pydantic_v1(): mock_model.json.assert_called_once() -# Test for parse_auth_response with a session -def test_parse_auth_response_with_session(): - # Create our own AuthResponse object to avoid pydantic validation issues - mock_session = MagicMock(spec=Session) - mock_user = MagicMock(spec=User) - - # Test data with access_token, refresh_token, and expires_in - data = { - "access_token": "test_access_token", - "refresh_token": "test_refresh_token", - "expires_in": 3600, - "user": { - "id": "user-123", - "email": "test@example.com", - }, - } - - with patch("supabase_auth.helpers.model_validate") as mock_validate: - # First call for Session, second for User - mock_validate.side_effect = [mock_session, mock_user] - - with patch("supabase_auth.helpers.AuthResponse") as mock_auth_response: - mock_auth_response.return_value = "auth_response_result" - - result = parse_auth_response(data) - - # Verify model_validate was called for Session and User - assert mock_validate.call_count == 2 - mock_validate.assert_any_call(Session, data) - mock_validate.assert_any_call(User, data["user"]) - - # Verify AuthResponse was created with correct params - mock_auth_response.assert_called_once_with( - session=mock_session, user=mock_user - ) - assert result == "auth_response_result" - - -# Test for parse_auth_response without a session -def test_parse_auth_response_without_session(): - # Create our own User object to avoid pydantic validation issues - mock_user = MagicMock(spec=User) - - # Test data without session info - data = { - "user": { - "id": "user-123", - "email": "test@example.com", - } - } - - with patch("supabase_auth.helpers.model_validate") as mock_validate: - mock_validate.return_value = mock_user - - with patch("supabase_auth.helpers.AuthResponse") as mock_auth_response: - mock_auth_response.return_value = "auth_response_result" - - result = parse_auth_response(data) - - # Verify model_validate was called only for User - mock_validate.assert_called_once_with(User, data["user"]) - - # Verify AuthResponse was created with correct params - mock_auth_response.assert_called_once_with(session=None, user=mock_user) - assert result == "auth_response_result" - - -# Test for parse_link_response -def test_parse_link_response(): - # Create mocks to avoid pydantic validation issues - mock_user = MagicMock(spec=User) - mock_gen_link_response = MagicMock(spec=GenerateLinkResponse) - - # Test data for link response - data = { - "action_link": "https://example.com/verify", - "email_otp": "123456", - "hashed_token": "abc123", - "redirect_to": "https://example.com/app", - "verification_type": "signup", - "id": "user-123", - "email": "test@example.com", - } - - # We need to patch the GenerateLinkProperties constructor - with patch("supabase_auth.helpers.GenerateLinkProperties") as mock_gen_props: - mock_gen_props.return_value = "mock_properties" - - with patch("supabase_auth.helpers.model_dump") as mock_dump: - mock_dump.return_value = { - "action_link": "https://example.com/verify", - "email_otp": "123456", - "hashed_token": "abc123", - "redirect_to": "https://example.com/app", - "verification_type": "signup", - } - - with patch("supabase_auth.helpers.model_validate") as mock_validate: - mock_validate.return_value = mock_user - - with patch( - "supabase_auth.helpers.GenerateLinkResponse" - ) as mock_gen_link: - mock_gen_link.return_value = mock_gen_link_response - - result = parse_link_response(data) - - # Verify that props were created correctly - mock_gen_props.assert_called_once_with( - action_link=data.get("action_link"), - email_otp=data.get("email_otp"), - hashed_token=data.get("hashed_token"), - redirect_to=data.get("redirect_to"), - verification_type=data.get("verification_type"), - ) - - # Verify model_validate was called for User with filtered data - mock_validate.assert_called_once() - - # Verify GenerateLinkResponse was created - mock_gen_link.assert_called_once_with( - properties="mock_properties", user=mock_user - ) - assert result == mock_gen_link_response - - -# Test for parse_user_response -def test_parse_user_response_with_user_object(): - # Test data with 'user' key - data = {"user": {"id": "user-123", "email": "test@example.com"}} - - with patch("supabase_auth.helpers.model_validate") as mock_validate: - mock_validate.return_value = "mock_user_response" - - result = parse_user_response(data) - - assert result == "mock_user_response" - mock_validate.assert_called_once() - - -# Test for parse_user_response without user object -def test_parse_user_response_without_user_object(): - # Test data without 'user' key - data = {"id": "user-123", "email": "test@example.com"} - - with patch("supabase_auth.helpers.model_validate") as mock_validate: - mock_validate.return_value = "mock_user_response" - - result = parse_user_response(data) - - assert result == "mock_user_response" - mock_validate.assert_called_once() - # Verify that it wrapped the data in a user object - expected_wrapped_data = {"user": data} - assert mock_validate.call_args[0][1] == expected_wrapped_data - - -# Test for parse_sso_response -def test_parse_sso_response(): - with patch("supabase_auth.helpers.model_validate") as mock_validate: - mock_validate.return_value = "sso_response" - - result = parse_sso_response({"provider": "google"}) - assert result == "sso_response" - - # Verify model_validate was called with correct params - from supabase_auth.types import SSOResponse - - mock_validate.assert_called_once_with(SSOResponse, {"provider": "google"}) - - -# Test for parse_jwks with empty keys -def test_parse_jwks_empty_keys(): - with pytest.raises(AuthInvalidJwtError, match="JWKS is empty"): - parse_jwks({"keys": []}) - - -# Tests for handle_exception -def test_handle_exception_non_http_error(): - # Test case for non-HTTPStatusError - exception = ValueError("Test error") - result = handle_exception(exception) - - assert isinstance(result, AuthRetryableError) - assert result.message == "Test error" - assert result.status == 0 - - def test_handle_exception_network_error(): # Test case for network errors (502, 503, 504) mock_response = MagicMock(spec=Response) @@ -504,12 +311,6 @@ def test_handle_exception_unknown_error(): assert "Server error" in result.message -# Tests for validate_exp -def test_validate_exp_with_no_exp(): - with pytest.raises(AuthInvalidJwtError, match="JWT has no expiration time"): - validate_exp(None) - - def test_validate_exp_with_expired_exp(): # Set expiry to 1 hour ago exp = int(datetime.now().timestamp()) - 3600 @@ -595,25 +396,3 @@ def patched_isinstance(obj, cls): assert isinstance(result, AuthWeakPasswordError) assert result.message == "Password too weak" assert result.status == 400 - - -def test_parse_auth_otp_response(): - """Test for the parse_auth_otp_response function.""" - from supabase_auth.helpers import parse_auth_otp_response - from supabase_auth.types import AuthOtpResponse - - # Test with message_id field - data = {"message_id": "12345"} - result = parse_auth_otp_response(data) - assert isinstance(result, AuthOtpResponse) - assert result.message_id == "12345" - assert result.user is None - assert result.session is None - - # Test with no message_id field - data = {} - result = parse_auth_otp_response(data) - assert isinstance(result, AuthOtpResponse) - assert result.message_id is None - assert result.user is None - assert result.session is None