Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion backend/app/application/errors/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,5 +32,5 @@ def __init__(self, msg: str = "Internal server error"):


class UnauthorizedError(AppException):
def __init__(self, msg: str = "Unauthorized"):
def __init__(self, msg: str = "Authentication required"):
super().__init__(code=401, msg=msg, status_code=401)
59 changes: 55 additions & 4 deletions backend/app/application/services/agent_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,13 @@ async def chat(
yield event
logger.info(f"Chat with session {session_id} completed")

async def get_session(self, session_id: str, user_id: str) -> Optional[Session]:
async def get_session(self, session_id: str, user_id: Optional[str] = None) -> Optional[Session]:
"""Get a session by ID, ensuring it belongs to the user"""
logger.info(f"Getting session {session_id} for user {user_id}")
session = await self._session_repository.find_by_id_and_user_id(session_id, user_id)
if not user_id:
session = await self._session_repository.find_by_id(session_id)
else:
session = await self._session_repository.find_by_id_and_user_id(session_id, user_id)
if not session:
logger.error(f"Session {session_id} not found for user {user_id}")
return session
Expand Down Expand Up @@ -209,12 +212,60 @@ async def file_view(self, session_id: str, file_path: str, user_id: str) -> File
return FileViewResponse(**result.data)
else:
raise RuntimeError(f"Failed to read file: {result.message}")

async def is_session_shared(self, session_id: str) -> bool:
"""Check if a session is shared"""
logger.info(f"Checking if session {session_id} is shared")
session = await self._session_repository.find_by_id(session_id)
if not session:
logger.error(f"Session {session_id} not found")
raise RuntimeError("Session not found")
return session.is_shared

async def get_session_files(self, session_id: str, user_id: str) -> List[FileInfo]:
async def get_session_files(self, session_id: str, user_id: Optional[str] = None) -> List[FileInfo]:
"""Get files for a session, ensuring it belongs to the user"""
logger.info(f"Getting files for session {session_id} for user {user_id}")
session = await self.get_session(session_id, user_id)
return session.files

async def get_shared_session_files(self, session_id: str) -> List[FileInfo]:
"""Get files for a shared session"""
logger.info(f"Getting files for shared session {session_id}")
session = await self._session_repository.find_by_id(session_id)
if not session or not session.is_shared:
logger.error(f"Shared session {session_id} not found or not shared")
raise RuntimeError("Session not found")
return session.files

async def share_session(self, session_id: str, user_id: str) -> None:
"""Share a session, ensuring it belongs to the user"""
logger.info(f"Sharing session {session_id} for user {user_id}")
# First verify the session belongs to the user
session = await self._session_repository.find_by_id_and_user_id(session_id, user_id)
if not session:
logger.error(f"Session {session_id} not found for user {user_id}")
raise RuntimeError("Session not found")
return session.files

await self._session_repository.update_shared_status(session_id, True)
logger.info(f"Session {session_id} shared successfully")

async def unshare_session(self, session_id: str, user_id: str) -> None:
"""Unshare a session, ensuring it belongs to the user"""
logger.info(f"Unsharing session {session_id} for user {user_id}")
# First verify the session belongs to the user
session = await self._session_repository.find_by_id_and_user_id(session_id, user_id)
if not session:
logger.error(f"Session {session_id} not found for user {user_id}")
raise RuntimeError("Session not found")

await self._session_repository.update_shared_status(session_id, False)
logger.info(f"Session {session_id} unshared successfully")

async def get_shared_session(self, session_id: str) -> Optional[Session]:
"""Get a shared session by ID (no user authentication required)"""
logger.info(f"Getting shared session {session_id}")
session = await self._session_repository.find_by_id(session_id)
if not session or not session.is_shared:
logger.error(f"Shared session {session_id} not found or not shared")
return None
return session
47 changes: 45 additions & 2 deletions backend/app/application/services/file_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
import logging
from app.domain.external.file import FileStorage
from app.domain.models.file import FileInfo
from app.application.services.token_service import TokenService

# Set up logger
logger = logging.getLogger(__name__)

class FileService:
def __init__(self, file_storage: Optional[FileStorage] = None):
def __init__(self, file_storage: Optional[FileStorage] = None, token_service: Optional[TokenService] = None):
self._file_storage = file_storage
self._token_service = token_service

async def upload_file(self, file_data: BinaryIO, filename: str, user_id: str, content_type: Optional[str] = None, metadata: Optional[Dict[str, Any]] = None) -> FileInfo:
"""Upload file"""
Expand Down Expand Up @@ -58,7 +60,7 @@ async def delete_file(self, file_id: str, user_id: str) -> bool:
logger.error(f"Failed to delete file {file_id} for user {user_id}: {str(e)}")
raise

async def get_file_info(self, file_id: str, user_id: str) -> Optional[FileInfo]:
async def get_file_info(self, file_id: str, user_id: Optional[str] = None) -> Optional[FileInfo]:
"""Get file information"""
logger.info(f"Get file info request: file_id={file_id}, user_id={user_id}")
if not self._file_storage:
Expand All @@ -75,3 +77,44 @@ async def get_file_info(self, file_id: str, user_id: str) -> Optional[FileInfo]:
except Exception as e:
logger.error(f"Failed to get file info {file_id} for user {user_id}: {str(e)}")
raise

async def enrich_with_file_url(self, file_info: FileInfo) -> FileInfo:
"""Enrich file information with file URL"""
logger.info(f"Enrich file info request: file_info={file_info}")

try:
signed_url = await self.create_signed_url(file_info.file_id, file_info.user_id)
file_info.file_url = signed_url
return file_info
except Exception as e:
logger.error(f"Failed to enrich file info {file_info.file_id} with file URL: {str(e)}")
raise

async def create_signed_url(self, file_id: str, user_id: Optional[str] = None, expire_minutes: int = 30) -> str:
"""Create signed URL for file download"""
logger.info(f"Create signed URL request: file_id={file_id}, user_id={user_id}, expire_minutes={expire_minutes}")

if not self._token_service:
logger.error("Token service not available")
raise RuntimeError("Token service not available")

# Validate expiration time (max 15 minutes)
if expire_minutes > 30:
expire_minutes = 30

# Check if file exists and user has access
file_info = await self.get_file_info(file_id, user_id)
if not file_info:
logger.warning(f"File not found or access denied for signed URL: file_id={file_id}, user_id={user_id}")
raise FileNotFoundError("File not found")

# Create signed URL for file download
base_url = f"/api/v1/files/{file_id}"
signed_url = self._token_service.create_signed_url(
base_url=base_url,
expire_minutes=expire_minutes
)

logger.info(f"Created signed URL for file download for user {user_id}, file {file_id}")

return signed_url
2 changes: 1 addition & 1 deletion backend/app/domain/external/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ async def delete_file(
async def get_file_info(
self,
file_id: str,
user_id: str
user_id: Optional[str] = None
) -> Optional[FileInfo]:
"""Get file metadata from storage

Expand Down
1 change: 1 addition & 0 deletions backend/app/domain/models/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ class FileInfo(BaseModel):
upload_date: Optional[datetime] = None
metadata: Optional[Dict[str, Any]] = None
user_id: Optional[str] = None
file_url: Optional[str] = None
1 change: 1 addition & 0 deletions backend/app/domain/models/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class Session(BaseModel):
events: List[AgentEvent] = []
files: List[FileInfo] = []
status: SessionStatus = SessionStatus.PENDING
is_shared: bool = False # Whether this session is shared publicly

def get_last_plan(self) -> Optional[Plan]:
"""Get the last plan from the events"""
Expand Down
4 changes: 4 additions & 0 deletions backend/app/domain/repositories/session_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,10 @@ async def decrement_unread_message_count(self, session_id: str) -> None:
"""Decrement the unread message count of a session"""
...

async def update_shared_status(self, session_id: str, is_shared: bool) -> None:
"""Update the shared status of a session"""
...

async def delete(self, session_id: str) -> None:
"""Delete a session"""
...
Expand Down
4 changes: 2 additions & 2 deletions backend/app/infrastructure/external/file/gridfsfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ async def delete_file(self, file_id: str, user_id: str) -> bool:
logger.error(f"Failed to delete file {file_id} for user {user_id}: {str(e)}")
return False

async def get_file_info(self, file_id: str, user_id: str) -> Optional[FileInfo]:
async def get_file_info(self, file_id: str, user_id: Optional[str] = None) -> Optional[FileInfo]:
"""Get file information"""
try:
files_collection = self._get_files_collection()
Expand All @@ -195,7 +195,7 @@ async def get_file_info(self, file_id: str, user_id: str) -> Optional[FileInfo]:

# Check if file belongs to the user
file_user_id = file_info.get('metadata', {}).get('user_id')
if file_user_id != user_id:
if user_id is not None and file_user_id != user_id:
logger.warning(f"Access denied: file {file_id} does not belong to user {user_id}")
return None

Expand Down
1 change: 1 addition & 0 deletions backend/app/infrastructure/models/documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ class SessionDocument(BaseDocument[Session], id_field="session_id", domain_model
events: List[AgentEvent]
status: SessionStatus
files: List[FileInfo] = []
is_shared: Optional[bool] = False
class Settings:
name = "sessions"
indexes = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,3 +167,13 @@ async def decrement_unread_message_count(self, session_id: str) -> None:
if not result:
raise ValueError(f"Session {session_id} not found")

async def update_shared_status(self, session_id: str, is_shared: bool) -> None:
"""Update the shared status of a session"""
result = await SessionDocument.find_one(
SessionDocument.session_id == session_id
).update(
{"$set": {"is_shared": is_shared, "updated_at": datetime.now(UTC)}}
)
if not result:
raise ValueError(f"Session {session_id} not found")

14 changes: 5 additions & 9 deletions backend/app/interfaces/api/auth_routes.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from fastapi import APIRouter, Depends, Request
from fastapi import APIRouter, Depends
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
import logging

from app.application.services.auth_service import AuthService
Expand Down Expand Up @@ -186,21 +187,16 @@ async def refresh_token(

@router.post("/logout", response_model=APIResponse[dict])
async def logout(
request: Request,
current_user: User = Depends(get_current_user),
bearer_credentials: HTTPAuthorizationCredentials = Depends(HTTPBearer()),
auth_service: AuthService = Depends(get_auth_service)
) -> APIResponse[dict]:
"""User logout endpoint"""
if get_settings().auth_provider == "none":
raise BadRequestError("Logout is not allowed")
# Extract token from Authorization header
auth_header = request.headers.get("Authorization", "")
if not auth_header.startswith("Bearer "):
raise UnauthorizedError("Authentication required")

token = auth_header.split(" ")[1]

# Revoke token
await auth_service.logout(token)
await auth_service.logout(bearer_credentials.credentials)

return APIResponse.success({})

Expand Down
Loading