Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 64 additions & 6 deletions be/app/routers/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from fastapi.responses import RedirectResponse
from authlib.integrations.starlette_client import OAuth
from starlette.config import Config as AuthlibConfig
from fastapi.responses import JSONResponse

from app import config
from app.models import TokenRequest
Expand Down Expand Up @@ -33,6 +34,10 @@ async def login(request: Request):

@router.post("/api/auth/token")
def get_token(request: Request, token_req: TokenRequest):
"""
Exchanges the SSO code for an Access Token, sets the token in a secure
HTTP-only cookie, and returns minimal user details.
"""
data = {
"grant_type": "authorization_code",
"code": token_req.code,
Expand All @@ -41,17 +46,70 @@ def get_token(request: Request, token_req: TokenRequest):
"redirect_uri": config.REDIRECT_URI,
}
response = requests.post(config.TOKEN_URL, data=data)

if response.status_code != 200:
raise HTTPException(status_code=400, detail="Failed to exchange code for token")

token_data = response.json()
user_info_resp = requests.get(f"{config.OPENID_PROVIDER_URL}/userinfo",
headers={'Authorization': f'Bearer {token_data["access_token"]}'})
user_email = user_info_resp.json()['email']
request.session['user'] = user_email
return user_email
access_token = token_data["access_token"]

user_info_resp = requests.get(
f"{config.OPENID_PROVIDER_URL}/userinfo",
headers={'Authorization': f'Bearer {access_token}'}
)
user_data = user_info_resp.json()

user_return_data = {
"id": user_data.get('sub', user_data.get('email', 'unknown')),
"email": user_data.get('email', 'unknown'),
}

response_to_client = JSONResponse(content=user_return_data)

@router.get("/api/logout")
request.session['user'] = user_data.get('email')
request.session['access_token'] = access_token

return response_to_client

@router.get("/api/auth/session")
def check_session(request: Request):
# 1. Get the token from the cookie sent by the browser
access_token = request.session.get("access_token")

if not access_token:
# If no cookie exists, return 401 Unauthorized
raise HTTPException(status_code=401, detail="No session token found")

# 2. Use the token to get the user info/validate it against the SSO provider
# NOTE: Your BE may need to check the token's expiration itself before calling the SSO provider
user_info_resp = requests.get(
f"{config.OPENID_PROVIDER_URL}/userinfo",
headers={'Authorization': f'Bearer {access_token}'}
)

if user_info_resp.status_code != 200:
# If the SSO provider says the token is invalid/expired
raise HTTPException(status_code=401, detail="Token validation failed or expired")

# 3. Token is valid. Return the user data to the frontend.
user_data = user_info_resp.json()
return {
"id": user_data.get('sub', user_data.get('email', 'unknown')),
"email": user_data.get('email', 'unknown'),
}

@router.post("/api/logout")
def logout(request: Request):
"""
Clears the server-side session.
"""
request.session.pop("user", None)
request.session.pop("access_token", None)
return JSONResponse(content={"message": "Logged out successfully"})

@router.get("/api/logout")
def logout_get(request: Request):
"""Fallback GET logout endpoint (original behavior)"""
request.session.pop("user", None)
request.session.pop("access_token", None)
return RedirectResponse(config.REDIRECT_URI)
6 changes: 3 additions & 3 deletions be/tests/routes/test_api_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def test_get_token_success(mock_requests_get, mock_requests_post, client: TestCl
response = client.post("/api/auth/token", json={"code": "test-code"})

assert response.status_code == 200
assert response.json() == "user@example.com"
assert response.json() == {'email': 'user@example.com', 'id': 'user@example.com'}
# Check that the session was set
assert client.cookies.get("session") is not None

Expand All @@ -67,7 +67,7 @@ def test_get_token_failure(mock_requests_post, client: TestClient):
def test_logout(client: TestClient):
"""Test that the logout endpoint returns a redirect."""
# We test the direct outcome (a redirect response) rather than inspect session state
response = client.get("/api/logout", follow_redirects=False)
response = client.post("/api/logout", follow_redirects=False)
# Check for a redirect status code (307 is used by FastAPI for temporary redirects)
assert response.status_code == 307
assert client.cookies.get("session") is None

2 changes: 2 additions & 0 deletions fe/src/lib/stores.js
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,7 @@ export const selectedNamespce = writable(null);
export const selectedTable = writable(null);
export const sample_limit = writable(100);

export const user = writable(null);

export const healthEnabled = writable(false);
export const HEALTH_DISABLED_MESSAGE = 'Feature is disabled. Please contact your app administrator or <a href="https://github.com/lakevision-project/lakevision?tab=readme-ov-file#configuration" target="_blank">read the documentation</a> to enable DB connection.';
Loading