From 007ef80f1dd42ad7de47abea280e6592a097d768 Mon Sep 17 00:00:00 2001 From: patrykkotlowski-dsstream Date: Fri, 30 Aug 2024 14:51:50 +0200 Subject: [PATCH] Add option for custom auth --- backend/chainlit/auth.py | 8 +++++--- backend/chainlit/config.py | 4 ++++ backend/chainlit/oauth_providers.py | 29 ++++++++++++++++++----------- 3 files changed, 27 insertions(+), 14 deletions(-) diff --git a/backend/chainlit/auth.py b/backend/chainlit/auth.py index 981d33abe4..9bd2073b55 100644 --- a/backend/chainlit/auth.py +++ b/backend/chainlit/auth.py @@ -42,9 +42,9 @@ def get_configuration(): "requireLogin": require_login(), "passwordAuth": config.code.password_auth_callback is not None, "headerAuth": config.code.header_auth_callback is not None, - "oauthProviders": get_configured_oauth_providers() - if is_oauth_enabled() - else [], + "oauthProviders": ( + get_configured_oauth_providers() if is_oauth_enabled() else [] + ), } @@ -88,4 +88,6 @@ async def get_current_user(token: str = Depends(reuseable_oauth)): if not require_login(): return None + if config.code.custom_authenticate_user: + return await config.code.custom_authenticate_user(token) return await authenticate_user(token) diff --git a/backend/chainlit/config.py b/backend/chainlit/config.py index e2523a9781..dd6b7c6fdf 100644 --- a/backend/chainlit/config.py +++ b/backend/chainlit/config.py @@ -8,6 +8,7 @@ import tomli from chainlit.logger import logger +from chainlit.oauth_providers import OAuthProvider from chainlit.translations import lint_translation_json from chainlit.version import __version__ from dataclasses_json import DataClassJsonMixin @@ -275,6 +276,9 @@ class CodeSettings: oauth_callback: Optional[ Callable[[str, str, Dict[str, str], "User"], Optional["User"]] ] = None + # Callbacks for authenticate mechanism + custom_authenticate_user: Optional[Callable[[str], "User"]] + custom_oauth_provider: Optional[OAuthProvider] on_logout: Optional[Callable[["Request", "Response"], Any]] = None on_stop: Optional[Callable[[], Any]] = None on_chat_start: Optional[Callable[[], Any]] = None diff --git a/backend/chainlit/oauth_providers.py b/backend/chainlit/oauth_providers.py index fe019859b1..89839e20b0 100644 --- a/backend/chainlit/oauth_providers.py +++ b/backend/chainlit/oauth_providers.py @@ -8,6 +8,8 @@ from chainlit.user import User from fastapi import HTTPException +from chainlit import config + class OAuthProvider: id: str @@ -621,17 +623,22 @@ async def get_user_info(self, token: str): return (gitlab_user, user) -providers = [ - GithubOAuthProvider(), - GoogleOAuthProvider(), - AzureADOAuthProvider(), - AzureADHybridOAuthProvider(), - OktaOAuthProvider(), - Auth0OAuthProvider(), - DescopeOAuthProvider(), - AWSCognitoOAuthProvider(), - GitlabOAuthProvider(), -] +providers = ( + [ + GithubOAuthProvider(), + GoogleOAuthProvider(), + AzureADOAuthProvider(), + AzureADHybridOAuthProvider(), + OktaOAuthProvider(), + Auth0OAuthProvider(), + DescopeOAuthProvider(), + AWSCognitoOAuthProvider(), + GitlabOAuthProvider(), + ] + + [config.code.custom_oauth_provider()] + if config.code.custom_oauth_provider + else [] +) def get_oauth_provider(provider: str) -> Optional[OAuthProvider]: