Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"""

from typing import Dict, Any, Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.orm import Session
from loguru import logger
from datetime import datetime
Expand All @@ -24,6 +24,7 @@
from ....utils.error_handlers import ContentPlanningErrorHandler
from ....utils.response_builders import ResponseBuilder
from ....utils.constants import ERROR_MESSAGES, SUCCESS_MESSAGES
from middleware.auth_middleware import get_current_user

router = APIRouter(tags=["AI Strategy Generation"])

Expand All @@ -38,15 +39,24 @@ def get_db():
# Global storage for latest strategies (more persistent than task status)
_latest_strategies = {}


def _get_authenticated_user_id(current_user: Dict[str, Any]) -> str:
"""Extract authenticated user id from token claims."""
user_id = str((current_user or {}).get("id") or (current_user or {}).get("clerk_user_id") or "").strip()
if not user_id:
raise HTTPException(status_code=401, detail="Authenticated user ID is required")
return user_id

@router.post("/generate-comprehensive-strategy")
async def generate_comprehensive_strategy(
user_id: int,
strategy_name: Optional[str] = None,
config: Optional[Dict[str, Any]] = None,
current_user: Dict[str, Any] = Depends(get_current_user),
db: Session = Depends(get_db)
) -> Dict[str, Any]:
"""Generate a comprehensive AI-powered content strategy."""
try:
user_id = _get_authenticated_user_id(current_user)
logger.info(f"🚀 Generating comprehensive AI strategy for user: {user_id}")

# Get user context and onboarding data
Expand Down Expand Up @@ -103,14 +113,15 @@ async def generate_comprehensive_strategy(

@router.post("/generate-strategy-component")
async def generate_strategy_component(
user_id: int,
component_type: str,
base_strategy: Optional[Dict[str, Any]] = None,
context: Optional[Dict[str, Any]] = None,
current_user: Dict[str, Any] = Depends(get_current_user),
db: Session = Depends(get_db)
) -> Dict[str, Any]:
"""Generate a specific strategy component using AI."""
try:
user_id = _get_authenticated_user_id(current_user)
logger.info(f"🚀 Generating strategy component '{component_type}' for user: {user_id}")

# Validate component type
Expand Down Expand Up @@ -187,11 +198,12 @@ async def generate_strategy_component(

@router.get("/strategy-generation-status")
async def get_strategy_generation_status(
user_id: int,
current_user: Dict[str, Any] = Depends(get_current_user),
db: Session = Depends(get_db)
) -> Dict[str, Any]:
"""Get the status of strategy generation for a user."""
try:
user_id = _get_authenticated_user_id(current_user)
logger.info(f"Getting strategy generation status for user: {user_id}")

# Get user's strategies
Expand Down Expand Up @@ -247,10 +259,12 @@ async def get_strategy_generation_status(
async def optimize_existing_strategy(
strategy_id: int,
optimization_type: str = "comprehensive",
current_user: Dict[str, Any] = Depends(get_current_user),
db: Session = Depends(get_db)
) -> Dict[str, Any]:
"""Optimize an existing strategy using AI."""
try:
user_id = _get_authenticated_user_id(current_user)
logger.info(f"🚀 Optimizing existing strategy {strategy_id} with type: {optimization_type}")

# Get existing strategy
Expand All @@ -266,7 +280,13 @@ async def optimize_existing_strategy(
)

existing_strategy = strategies_data["strategies"][0]
user_id = existing_strategy.get("user_id")
strategy_owner_id = str(existing_strategy.get("user_id", ""))

if strategy_owner_id != user_id:
raise HTTPException(
status_code=403,
detail="Not authorized to optimize this strategy"
)

# Get user context
onboarding_data = await enhanced_service._get_onboarding_data(user_id)
Expand Down Expand Up @@ -309,12 +329,13 @@ async def optimize_existing_strategy(
@router.post("/generate-comprehensive-strategy-polling")
async def generate_comprehensive_strategy_polling(
request: Dict[str, Any],
current_user: Dict[str, Any] = Depends(get_current_user),
db: Session = Depends(get_db)
) -> Dict[str, Any]:
"""Generate a comprehensive AI-powered content strategy using polling approach."""
try:
# Extract parameters from request body
user_id = request.get("user_id", 1)
user_id = _get_authenticated_user_id(current_user)
strategy_name = request.get("strategy_name")
config = request.get("config", {})

Expand Down Expand Up @@ -611,10 +632,12 @@ async def generate_strategy_background():
@router.get("/strategy-generation-status/{task_id}")
async def get_strategy_generation_status_by_task(
task_id: str,
current_user: Dict[str, Any] = Depends(get_current_user),
db: Session = Depends(get_db)
) -> Dict[str, Any]:
"""Get the status of strategy generation for a specific task."""
try:
user_id = _get_authenticated_user_id(current_user)
logger.info(f"Getting strategy generation status for task: {task_id}")

# Check if task status exists
Expand All @@ -631,6 +654,12 @@ async def get_strategy_generation_status_by_task(
status_code=404,
detail=f"Task {task_id} not found. It may have expired or never existed."
)

if str(task_status.get("user_id")) != user_id:
raise HTTPException(
status_code=403,
detail="Not authorized to access this task status"
)

logger.info(f"✅ Strategy generation status retrieved for task: {task_id}")

Expand All @@ -647,11 +676,12 @@ async def get_strategy_generation_status_by_task(

@router.get("/latest-strategy")
async def get_latest_generated_strategy(
user_id: int = Query(1, description="User ID"),
current_user: Dict[str, Any] = Depends(get_current_user),
db: Session = Depends(get_db)
) -> Dict[str, Any]:
"""Get the latest generated strategy from the polling system or database."""
try:
user_id = _get_authenticated_user_id(current_user)
logger.info(f"🔍 Getting latest generated strategy for user: {user_id}")

# First, try to get from database (most reliable)
Expand Down
57 changes: 57 additions & 0 deletions backend/security_test_ai_generation_endpoints_auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from pathlib import Path
import ast


SOURCE_PATH = Path("backend/api/content_planning/api/content_strategy/endpoints/ai_generation_endpoints.py")
SOURCE = SOURCE_PATH.read_text()
MODULE = ast.parse(SOURCE)


def _get_async_function(name: str) -> ast.AsyncFunctionDef:
for node in MODULE.body:
if isinstance(node, ast.AsyncFunctionDef) and node.name == name:
return node
raise AssertionError(f"Function {name} not found")


def _arg_names(node: ast.AsyncFunctionDef) -> list[str]:
return [arg.arg for arg in node.args.args]


def test_public_routes_use_current_user_not_client_user_id_param():
route_names = [
"generate_comprehensive_strategy",
"generate_strategy_component",
"get_strategy_generation_status",
"generate_comprehensive_strategy_polling",
"get_strategy_generation_status_by_task",
"get_latest_generated_strategy",
]

for name in route_names:
fn = _get_async_function(name)
arg_names = _arg_names(fn)
assert "current_user" in arg_names
assert "user_id" not in arg_names


def test_polling_route_derives_user_id_from_authenticated_claims_only():
fn = _get_async_function("generate_comprehensive_strategy_polling")
fn_source = ast.get_source_segment(SOURCE, fn)

assert 'user_id = _get_authenticated_user_id(current_user)' in fn_source
assert 'request.get("user_id"' not in fn_source
assert 'request.get("user_id", 1)' not in fn_source


def test_task_status_route_enforces_task_owner_authorization_check():
fn = _get_async_function("get_strategy_generation_status_by_task")
fn_source = ast.get_source_segment(SOURCE, fn)

assert 'if str(task_status.get("user_id")) != user_id:' in fn_source
assert 'status_code=403' in fn_source
assert 'Not authorized to access this task status' in fn_source


def test_missing_authenticated_user_id_returns_401():
assert 'raise HTTPException(status_code=401, detail="Authenticated user ID is required")' in SOURCE