From 3bffa8c918b2f8ddcaf39e58f0762c8d21912bf8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=D9=8A?= Date: Thu, 12 Mar 2026 15:15:08 +0530 Subject: [PATCH] Harden AI generation endpoints to use authenticated user claims --- .../endpoints/ai_generation_endpoints.py | 44 +++++++++++--- ...urity_test_ai_generation_endpoints_auth.py | 57 +++++++++++++++++++ 2 files changed, 94 insertions(+), 7 deletions(-) create mode 100644 backend/security_test_ai_generation_endpoints_auth.py diff --git a/backend/api/content_planning/api/content_strategy/endpoints/ai_generation_endpoints.py b/backend/api/content_planning/api/content_strategy/endpoints/ai_generation_endpoints.py index 20e5d1ec..65fa2830 100644 --- a/backend/api/content_planning/api/content_strategy/endpoints/ai_generation_endpoints.py +++ b/backend/api/content_planning/api/content_strategy/endpoints/ai_generation_endpoints.py @@ -4,7 +4,7 @@ """ from typing import Dict, Any, Optional -from fastapi import APIRouter, Depends, HTTPException, Query +from fastapi import APIRouter, Depends, HTTPException from sqlalchemy.orm import Session from loguru import logger from datetime import datetime @@ -24,6 +24,7 @@ from ....utils.error_handlers import ContentPlanningErrorHandler from ....utils.response_builders import ResponseBuilder from ....utils.constants import ERROR_MESSAGES, SUCCESS_MESSAGES +from middleware.auth_middleware import get_current_user router = APIRouter(tags=["AI Strategy Generation"]) @@ -38,15 +39,24 @@ def get_db(): # Global storage for latest strategies (more persistent than task status) _latest_strategies = {} + +def _get_authenticated_user_id(current_user: Dict[str, Any]) -> str: + """Extract authenticated user id from token claims.""" + user_id = str((current_user or {}).get("id") or (current_user or {}).get("clerk_user_id") or "").strip() + if not user_id: + raise HTTPException(status_code=401, detail="Authenticated user ID is required") + return user_id + @router.post("/generate-comprehensive-strategy") async def generate_comprehensive_strategy( - user_id: int, strategy_name: Optional[str] = None, config: Optional[Dict[str, Any]] = None, + current_user: Dict[str, Any] = Depends(get_current_user), db: Session = Depends(get_db) ) -> Dict[str, Any]: """Generate a comprehensive AI-powered content strategy.""" try: + user_id = _get_authenticated_user_id(current_user) logger.info(f"🚀 Generating comprehensive AI strategy for user: {user_id}") # Get user context and onboarding data @@ -103,14 +113,15 @@ async def generate_comprehensive_strategy( @router.post("/generate-strategy-component") async def generate_strategy_component( - user_id: int, component_type: str, base_strategy: Optional[Dict[str, Any]] = None, context: Optional[Dict[str, Any]] = None, + current_user: Dict[str, Any] = Depends(get_current_user), db: Session = Depends(get_db) ) -> Dict[str, Any]: """Generate a specific strategy component using AI.""" try: + user_id = _get_authenticated_user_id(current_user) logger.info(f"🚀 Generating strategy component '{component_type}' for user: {user_id}") # Validate component type @@ -187,11 +198,12 @@ async def generate_strategy_component( @router.get("/strategy-generation-status") async def get_strategy_generation_status( - user_id: int, + current_user: Dict[str, Any] = Depends(get_current_user), db: Session = Depends(get_db) ) -> Dict[str, Any]: """Get the status of strategy generation for a user.""" try: + user_id = _get_authenticated_user_id(current_user) logger.info(f"Getting strategy generation status for user: {user_id}") # Get user's strategies @@ -247,10 +259,12 @@ async def get_strategy_generation_status( async def optimize_existing_strategy( strategy_id: int, optimization_type: str = "comprehensive", + current_user: Dict[str, Any] = Depends(get_current_user), db: Session = Depends(get_db) ) -> Dict[str, Any]: """Optimize an existing strategy using AI.""" try: + user_id = _get_authenticated_user_id(current_user) logger.info(f"🚀 Optimizing existing strategy {strategy_id} with type: {optimization_type}") # Get existing strategy @@ -266,7 +280,13 @@ async def optimize_existing_strategy( ) existing_strategy = strategies_data["strategies"][0] - user_id = existing_strategy.get("user_id") + strategy_owner_id = str(existing_strategy.get("user_id", "")) + + if strategy_owner_id != user_id: + raise HTTPException( + status_code=403, + detail="Not authorized to optimize this strategy" + ) # Get user context onboarding_data = await enhanced_service._get_onboarding_data(user_id) @@ -309,12 +329,13 @@ async def optimize_existing_strategy( @router.post("/generate-comprehensive-strategy-polling") async def generate_comprehensive_strategy_polling( request: Dict[str, Any], + current_user: Dict[str, Any] = Depends(get_current_user), db: Session = Depends(get_db) ) -> Dict[str, Any]: """Generate a comprehensive AI-powered content strategy using polling approach.""" try: # Extract parameters from request body - user_id = request.get("user_id", 1) + user_id = _get_authenticated_user_id(current_user) strategy_name = request.get("strategy_name") config = request.get("config", {}) @@ -611,10 +632,12 @@ async def generate_strategy_background(): @router.get("/strategy-generation-status/{task_id}") async def get_strategy_generation_status_by_task( task_id: str, + current_user: Dict[str, Any] = Depends(get_current_user), db: Session = Depends(get_db) ) -> Dict[str, Any]: """Get the status of strategy generation for a specific task.""" try: + user_id = _get_authenticated_user_id(current_user) logger.info(f"Getting strategy generation status for task: {task_id}") # Check if task status exists @@ -631,6 +654,12 @@ async def get_strategy_generation_status_by_task( status_code=404, detail=f"Task {task_id} not found. It may have expired or never existed." ) + + if str(task_status.get("user_id")) != user_id: + raise HTTPException( + status_code=403, + detail="Not authorized to access this task status" + ) logger.info(f"✅ Strategy generation status retrieved for task: {task_id}") @@ -647,11 +676,12 @@ async def get_strategy_generation_status_by_task( @router.get("/latest-strategy") async def get_latest_generated_strategy( - user_id: int = Query(1, description="User ID"), + current_user: Dict[str, Any] = Depends(get_current_user), db: Session = Depends(get_db) ) -> Dict[str, Any]: """Get the latest generated strategy from the polling system or database.""" try: + user_id = _get_authenticated_user_id(current_user) logger.info(f"🔍 Getting latest generated strategy for user: {user_id}") # First, try to get from database (most reliable) diff --git a/backend/security_test_ai_generation_endpoints_auth.py b/backend/security_test_ai_generation_endpoints_auth.py new file mode 100644 index 00000000..00ff5c74 --- /dev/null +++ b/backend/security_test_ai_generation_endpoints_auth.py @@ -0,0 +1,57 @@ +from pathlib import Path +import ast + + +SOURCE_PATH = Path("backend/api/content_planning/api/content_strategy/endpoints/ai_generation_endpoints.py") +SOURCE = SOURCE_PATH.read_text() +MODULE = ast.parse(SOURCE) + + +def _get_async_function(name: str) -> ast.AsyncFunctionDef: + for node in MODULE.body: + if isinstance(node, ast.AsyncFunctionDef) and node.name == name: + return node + raise AssertionError(f"Function {name} not found") + + +def _arg_names(node: ast.AsyncFunctionDef) -> list[str]: + return [arg.arg for arg in node.args.args] + + +def test_public_routes_use_current_user_not_client_user_id_param(): + route_names = [ + "generate_comprehensive_strategy", + "generate_strategy_component", + "get_strategy_generation_status", + "generate_comprehensive_strategy_polling", + "get_strategy_generation_status_by_task", + "get_latest_generated_strategy", + ] + + for name in route_names: + fn = _get_async_function(name) + arg_names = _arg_names(fn) + assert "current_user" in arg_names + assert "user_id" not in arg_names + + +def test_polling_route_derives_user_id_from_authenticated_claims_only(): + fn = _get_async_function("generate_comprehensive_strategy_polling") + fn_source = ast.get_source_segment(SOURCE, fn) + + assert 'user_id = _get_authenticated_user_id(current_user)' in fn_source + assert 'request.get("user_id"' not in fn_source + assert 'request.get("user_id", 1)' not in fn_source + + +def test_task_status_route_enforces_task_owner_authorization_check(): + fn = _get_async_function("get_strategy_generation_status_by_task") + fn_source = ast.get_source_segment(SOURCE, fn) + + assert 'if str(task_status.get("user_id")) != user_id:' in fn_source + assert 'status_code=403' in fn_source + assert 'Not authorized to access this task status' in fn_source + + +def test_missing_authenticated_user_id_returns_401(): + assert 'raise HTTPException(status_code=401, detail="Authenticated user ID is required")' in SOURCE