diff --git a/.github/README.md b/.github/README.md index 35cca027..8f3727f7 100644 --- a/.github/README.md +++ b/.github/README.md @@ -3,8 +3,7 @@ > **You are reading the Github README!** > > - πŸ“š **Documentation**: See our [technical documentation](https://deepcritical.github.io/GradioDemo/) for detailed information -> - πŸ“– **Demo README**: Check out the [Demo README](..README.md) for for more information about our MCP Hackathon submission -> - πŸ† **Hackathon Submission**: Keep reading below for more information about our MCP Hackathon submission +> - πŸ“– **Demo README**: Check out the [Demo README](..README.md) for more information > - πŸ† **Demo**: Kindly consider using our [Free Demo](https://hf.co/DataQuests/GradioDemo)
diff --git a/.github/scripts/deploy_to_hf_space.py b/.github/scripts/deploy_to_hf_space.py index 839a38dd..ba4f81ac 100644 --- a/.github/scripts/deploy_to_hf_space.py +++ b/.github/scripts/deploy_to_hf_space.py @@ -3,13 +3,22 @@ import os import shutil import subprocess +<<<<<<< HEAD +import tempfile +from pathlib import Path +======= from pathlib import Path from typing import Set +>>>>>>> origin/dev from huggingface_hub import HfApi +<<<<<<< HEAD +def get_excluded_dirs() -> set[str]: +======= def get_excluded_dirs() -> Set[str]: +>>>>>>> origin/dev """Get set of directory names to exclude from deployment.""" return { "docs", @@ -38,17 +47,28 @@ def get_excluded_dirs() -> Set[str]: "dist", ".eggs", "htmlcov", +<<<<<<< HEAD + "hf_space", # Exclude the cloned HF Space directory itself + } + + +def get_excluded_files() -> set[str]: +======= } def get_excluded_files() -> Set[str]: +>>>>>>> origin/dev """Get set of file names to exclude from deployment.""" return { ".pre-commit-config.yaml", "mkdocs.yml", "uv.lock", "AGENTS.txt", +<<<<<<< HEAD +======= "CONTRIBUTING.md", +>>>>>>> origin/dev ".env", ".env.local", "*.local", @@ -60,17 +80,29 @@ def get_excluded_files() -> Set[str]: } +<<<<<<< HEAD +def should_exclude(path: Path, excluded_dirs: set[str], excluded_files: set[str]) -> bool: +======= def should_exclude(path: Path, excluded_dirs: Set[str], excluded_files: Set[str]) -> bool: +>>>>>>> origin/dev """Check if a path should be excluded from deployment.""" # Check if any parent directory is excluded for parent in path.parents: if parent.name in excluded_dirs: return True +<<<<<<< HEAD + + # Check if the path itself is a directory that should be excluded + if path.is_dir() and path.name in excluded_dirs: + return True + +======= # Check if the path itself is a directory that should be excluded if path.is_dir() and path.name in excluded_dirs: return True +>>>>>>> origin/dev # Check if the file name matches excluded patterns if path.is_file(): # Check exact match @@ -83,39 +115,82 @@ def should_exclude(path: Path, excluded_dirs: Set[str], excluded_files: Set[str] suffix = pattern.replace("*", "") if path.name.endswith(suffix): return True +<<<<<<< HEAD + +======= +>>>>>>> origin/dev return False def deploy_to_hf_space() -> None: """Deploy repository to Hugging Face Space. +<<<<<<< HEAD + + Supports both user and organization Spaces: + - User Space: username/space-name + - Organization Space: organization-name/space-name + +======= Supports both user and organization Spaces: - User Space: username/space-name - Organization Space: organization-name/space-name +>>>>>>> origin/dev Works with both classic tokens and fine-grained tokens. """ # Get configuration from environment variables hf_token = os.getenv("HF_TOKEN") hf_username = os.getenv("HF_USERNAME") # Can be username or organization name space_name = os.getenv("HF_SPACE_NAME") +<<<<<<< HEAD + + # Check which variables are missing and provide helpful error message + missing = [] + if not hf_token: + missing.append("HF_TOKEN (should be in repository secrets)") + if not hf_username: + missing.append("HF_USERNAME (should be in repository variables)") + if not space_name: + missing.append("HF_SPACE_NAME (should be in repository variables)") + + if missing: + raise ValueError( + f"Missing required environment variables: {', '.join(missing)}\n" + f"Please configure:\n" + f" - HF_TOKEN in Settings > Secrets and variables > Actions > Secrets\n" + f" - HF_USERNAME in Settings > Secrets and variables > Actions > Variables\n" + f" - HF_SPACE_NAME in Settings > Secrets and variables > Actions > Variables" + ) + +======= if not all([hf_token, hf_username, space_name]): raise ValueError( "Missing required environment variables: HF_TOKEN, HF_USERNAME, HF_SPACE_NAME" ) +>>>>>>> origin/dev # HF_USERNAME can be either a username or organization name # Format: {username|organization}/{space_name} repo_id = f"{hf_username}/{space_name}" local_dir = "hf_space" +<<<<<<< HEAD + + print(f"πŸš€ Deploying to Hugging Face Space: {repo_id}") + + # Initialize HF API + api = HfApi(token=hf_token) + +======= print(f"πŸš€ Deploying to Hugging Face Space: {repo_id}") # Initialize HF API api = HfApi(token=hf_token) +>>>>>>> origin/dev # Create Space if it doesn't exist try: api.repo_info(repo_id=repo_id, repo_type="space", token=hf_token) @@ -133,6 +208,47 @@ def deploy_to_hf_space() -> None: exist_ok=True, ) print(f"βœ… Created new Space: {repo_id}") +<<<<<<< HEAD + + # Configure Git credential helper for authentication + # This is needed for Git LFS to work properly with fine-grained tokens + print("πŸ” Configuring Git credentials...") + + # Use Git credential store to store the token + # This allows Git LFS to authenticate properly + temp_dir = Path(tempfile.gettempdir()) + credential_store = temp_dir / ".git-credentials-hf" + + # Write credentials in the format: https://username:token@huggingface.co + credential_store.write_text( + f"https://{hf_username}:{hf_token}@huggingface.co\n", encoding="utf-8" + ) + try: + credential_store.chmod(0o600) # Secure permissions (Unix only) + except OSError: + # Windows doesn't support chmod, skip + pass + + # Configure Git to use the credential store + subprocess.run( + ["git", "config", "--global", "credential.helper", f"store --file={credential_store}"], + check=True, + capture_output=True, + ) + + # Also set environment variable for Git LFS + os.environ["GIT_CREDENTIAL_HELPER"] = f"store --file={credential_store}" + + # Clone repository using git + # Use the token in the URL for initial clone, but LFS will use credential store + space_url = f"https://{hf_username}:{hf_token}@huggingface.co/spaces/{repo_id}" + + if Path(local_dir).exists(): + print(f"🧹 Removing existing {local_dir} directory...") + shutil.rmtree(local_dir) + + print("πŸ“₯ Cloning Space repository...") +======= # Clone repository using git space_url = f"https://{hf_token}@huggingface.co/spaces/{repo_id}" @@ -142,6 +258,7 @@ def deploy_to_hf_space() -> None: shutil.rmtree(local_dir) print(f"πŸ“₯ Cloning Space repository...") +>>>>>>> origin/dev try: result = subprocess.run( ["git", "clone", space_url, local_dir], @@ -149,6 +266,66 @@ def deploy_to_hf_space() -> None: capture_output=True, text=True, ) +<<<<<<< HEAD + print("βœ… Cloned Space repository") + + # After clone, configure the remote to use credential helper + # This ensures future operations (like push) use the credential store + os.chdir(local_dir) + subprocess.run( + ["git", "remote", "set-url", "origin", f"https://huggingface.co/spaces/{repo_id}"], + check=True, + capture_output=True, + ) + os.chdir("..") + + except subprocess.CalledProcessError as e: + error_msg = e.stderr if e.stderr else e.stdout if e.stdout else "Unknown error" + print(f"❌ Failed to clone Space repository: {error_msg}") + + # Try alternative: clone with LFS skip, then fetch LFS files separately + print("πŸ”„ Trying alternative clone method (skip LFS during clone)...") + try: + env = os.environ.copy() + env["GIT_LFS_SKIP_SMUDGE"] = "1" # Skip LFS during clone + + subprocess.run( + ["git", "clone", space_url, local_dir], + check=True, + capture_output=True, + text=True, + env=env, + ) + print("βœ… Cloned Space repository (LFS skipped)") + + # Configure remote + os.chdir(local_dir) + subprocess.run( + ["git", "remote", "set-url", "origin", f"https://huggingface.co/spaces/{repo_id}"], + check=True, + capture_output=True, + ) + + # Try to fetch LFS files with proper authentication + print("πŸ“₯ Fetching LFS files...") + subprocess.run( + ["git", "lfs", "pull"], + check=False, # Don't fail if LFS pull fails - we'll continue without LFS files + capture_output=True, + text=True, + ) + os.chdir("..") + print("βœ… Repository cloned (LFS files may be incomplete, but deployment can continue)") + except subprocess.CalledProcessError as e2: + error_msg2 = e2.stderr if e2.stderr else e2.stdout if e2.stdout else "Unknown error" + print(f"❌ Alternative clone method also failed: {error_msg2}") + raise RuntimeError(f"Git clone failed: {error_msg}") from e + + # Get exclusion sets + excluded_dirs = get_excluded_dirs() + excluded_files = get_excluded_files() + +======= print(f"βœ… Cloned Space repository") except subprocess.CalledProcessError as e: error_msg = e.stderr if e.stderr else e.stdout if e.stdout else "Unknown error" @@ -159,6 +336,7 @@ def deploy_to_hf_space() -> None: excluded_dirs = get_excluded_dirs() excluded_files = get_excluded_files() +>>>>>>> origin/dev # Remove all existing files in HF Space (except .git) print("🧹 Cleaning existing files...") for item in Path(local_dir).iterdir(): @@ -168,28 +346,61 @@ def deploy_to_hf_space() -> None: shutil.rmtree(item) else: item.unlink() +<<<<<<< HEAD + +======= +>>>>>>> origin/dev # Copy files from repository root print("πŸ“¦ Copying files...") repo_root = Path(".") files_copied = 0 dirs_copied = 0 +<<<<<<< HEAD + +======= +>>>>>>> origin/dev for item in repo_root.rglob("*"): # Skip if in .git directory if ".git" in item.parts: continue +<<<<<<< HEAD + + # Skip if in hf_space directory (the cloned Space directory) + if "hf_space" in item.parts: + continue + + # Skip if should be excluded + if should_exclude(item, excluded_dirs, excluded_files): + continue + +======= # Skip if should be excluded if should_exclude(item, excluded_dirs, excluded_files): continue +>>>>>>> origin/dev # Calculate relative path try: rel_path = item.relative_to(repo_root) except ValueError: # Item is outside repo root, skip continue +<<<<<<< HEAD + + # Skip if in excluded directory + if any(part in excluded_dirs for part in rel_path.parts): + continue + + # Destination path + dest_path = Path(local_dir) / rel_path + + # Create parent directories + dest_path.parent.mkdir(parents=True, exist_ok=True) + +======= # Skip if in excluded directory if any(part in excluded_dirs for part in rel_path.parts): @@ -201,6 +412,7 @@ def deploy_to_hf_space() -> None: # Create parent directories dest_path.parent.mkdir(parents=True, exist_ok=True) +>>>>>>> origin/dev # Copy file or directory if item.is_file(): shutil.copy2(item, dest_path) @@ -208,6 +420,18 @@ def deploy_to_hf_space() -> None: elif item.is_dir(): # Directory will be created by parent mkdir, but we track it dirs_copied += 1 +<<<<<<< HEAD + + print(f"βœ… Copied {files_copied} files and {dirs_copied} directories") + + # Commit and push changes using git + print("πŸ’Ύ Committing changes...") + + # Change to the Space directory + original_cwd = os.getcwd() + os.chdir(local_dir) + +======= print(f"βœ… Copied {files_copied} files and {dirs_copied} directories") @@ -218,6 +442,7 @@ def deploy_to_hf_space() -> None: original_cwd = os.getcwd() os.chdir(local_dir) +>>>>>>> origin/dev try: # Configure git user (required for commit) subprocess.run( @@ -230,13 +455,28 @@ def deploy_to_hf_space() -> None: check=True, capture_output=True, ) +<<<<<<< HEAD + +======= +>>>>>>> origin/dev # Add all files subprocess.run( ["git", "add", "."], check=True, capture_output=True, ) +<<<<<<< HEAD + + # Check if there are changes to commit + result = subprocess.run( + ["git", "status", "--porcelain"], + check=False, + capture_output=True, + text=True, + ) + +======= # Check if there are changes to commit result = subprocess.run( @@ -245,6 +485,7 @@ def deploy_to_hf_space() -> None: text=True, ) +>>>>>>> origin/dev if result.stdout.strip(): # There are changes, commit and push subprocess.run( @@ -253,6 +494,15 @@ def deploy_to_hf_space() -> None: capture_output=True, ) print("πŸ“€ Pushing to Hugging Face Space...") +<<<<<<< HEAD + # Ensure remote URL uses credential helper (not token in URL) + subprocess.run( + ["git", "remote", "set-url", "origin", f"https://huggingface.co/spaces/{repo_id}"], + check=True, + capture_output=True, + ) +======= +>>>>>>> origin/dev subprocess.run( ["git", "push"], check=True, @@ -273,10 +523,25 @@ def deploy_to_hf_space() -> None: finally: # Return to original directory os.chdir(original_cwd) +<<<<<<< HEAD + + # Clean up credential store for security + try: + if credential_store.exists(): + credential_store.unlink() + except Exception: + # Ignore cleanup errors + pass + +======= +>>>>>>> origin/dev print(f"πŸŽ‰ Successfully deployed to: https://huggingface.co/spaces/{repo_id}") if __name__ == "__main__": deploy_to_hf_space() +<<<<<<< HEAD +======= +>>>>>>> origin/dev diff --git a/.github/workflows/deploy-hf-space.yml b/.github/workflows/deploy-hf-space.yml index 5e788686..2561b279 100644 --- a/.github/workflows/deploy-hf-space.yml +++ b/.github/workflows/deploy-hf-space.yml @@ -30,9 +30,18 @@ jobs: - name: Deploy to Hugging Face Space env: +<<<<<<< HEAD + # Token from secrets (sensitive data) + HF_TOKEN: ${{ secrets.HF_TOKEN }} + # Username/Organization from repository variables (non-sensitive) + HF_USERNAME: ${{ vars.HF_USERNAME }} + # Space name from repository variables (non-sensitive) + HF_SPACE_NAME: ${{ vars.HF_SPACE_NAME }} +======= HF_TOKEN: ${{ secrets.HF_TOKEN }} HF_USERNAME: ${{ secrets.HF_USERNAME }} HF_SPACE_NAME: ${{ secrets.HF_SPACE_NAME }} +>>>>>>> origin/dev run: | python .github/scripts/deploy_to_hf_space.py @@ -40,5 +49,9 @@ jobs: if: success() run: | echo "βœ… Deployment completed successfully!" +<<<<<<< HEAD + echo "Space URL: https://huggingface.co/spaces/${{ vars.HF_USERNAME }}/${{ vars.HF_SPACE_NAME }}" +======= echo "Space URL: https://huggingface.co/spaces/${{ secrets.HF_USERNAME }}/${{ secrets.HF_SPACE_NAME }}" +>>>>>>> origin/dev diff --git a/.gitignore b/.gitignore index 84c2e770..9a982ecb 100644 --- a/.gitignore +++ b/.gitignore @@ -57,6 +57,9 @@ reference_repos/DeepCritical/ # Keep the README in reference_repos !reference_repos/README.md +# Development directory +dev/ + # OS .DS_Store Thumbs.db @@ -70,6 +73,7 @@ logs/ .mypy_cache/ .coverage htmlcov/ +test_output*.txt # Database files chroma_db/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 12a77673..8b1184f5 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -14,7 +14,7 @@ repos: hooks: - id: mypy files: ^src/ - exclude: ^folder + exclude: ^folder|^src/app.py additional_dependencies: - pydantic>=2.7 - pydantic-settings>=2.2 diff --git a/README.md b/README.md index aedb9d90..a4c0e90e 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,14 @@ app_file: src/app.py hf_oauth: true hf_oauth_expiration_minutes: 480 hf_oauth_scopes: - - inference-api + # Required for HuggingFace Inference API (includes all third-party providers) + # This scope grants access to: + # - HuggingFace's own Inference API + # - Third-party inference providers (nebius, together, scaleway, hyperbolic, novita, nscale, sambanova, ovh, fireworks, etc.) + # - All models available through the Inference Providers API + - inference-api + # Optional: Uncomment if you need to access user's billing information + # - read-billing pinned: true license: mit tags: diff --git a/deployments/README.md b/deployments/README.md new file mode 100644 index 00000000..3a4f4a4d --- /dev/null +++ b/deployments/README.md @@ -0,0 +1,46 @@ +# Deployments + +This directory contains infrastructure deployment scripts for DeepCritical services. + +## Modal Deployments + +### TTS Service (`modal_tts.py`) + +Deploys the Kokoro TTS (Text-to-Speech) function to Modal's GPU infrastructure. + +**Deploy:** +```bash +modal deploy deployments/modal_tts.py +``` + +**Features:** +- Kokoro 82M TTS model +- GPU-accelerated (T4) +- Voice options: af_heart, af_bella, am_michael, etc. +- Configurable speech speed + +**Requirements:** +- Modal account and credentials (`MODAL_TOKEN_ID`, `MODAL_TOKEN_SECRET` in `.env`) +- GPU quota on Modal + +**After Deployment:** +The function will be available at: +- App: `deepcritical-tts` +- Function: `kokoro_tts_function` + +The main application (`src/services/tts_modal.py`) will call this deployed function. + +--- + +## Adding New Deployments + +When adding new deployment scripts: + +1. Create a new file: `deployments/.py` +2. Use Modal's app pattern: + ```python + import modal + app = modal.App("deepcritical-") + ``` +3. Document in this README +4. Test deployment: `modal deploy deployments/.py` diff --git a/deployments/modal_tts.py b/deployments/modal_tts.py new file mode 100644 index 00000000..9987a339 --- /dev/null +++ b/deployments/modal_tts.py @@ -0,0 +1,97 @@ +"""Deploy Kokoro TTS function to Modal. + +This script deploys the TTS function to Modal so it can be called +from the main DeepCritical application. + +Usage: + modal deploy deploy_modal_tts.py + +After deployment, the function will be available at: + App: deepcritical-tts + Function: kokoro_tts_function +""" + +import modal +import numpy as np + +# Create Modal app +app = modal.App("deepcritical-tts") + +# Define Kokoro TTS dependencies +KOKORO_DEPENDENCIES = [ + "torch>=2.0.0", + "transformers>=4.30.0", + "numpy<2.0", +] + +# Create Modal image with Kokoro +tts_image = ( + modal.Image.debian_slim(python_version="3.11") + .apt_install("git") # Install git first for pip install from github + .pip_install(*KOKORO_DEPENDENCIES) + .pip_install("git+https://github.com/hexgrad/kokoro.git") +) + + +@app.function( + image=tts_image, + gpu="T4", + timeout=60, +) +def kokoro_tts_function(text: str, voice: str, speed: float) -> tuple[int, np.ndarray]: + """Modal GPU function for Kokoro TTS. + + This function runs on Modal's GPU infrastructure. + Based on: https://huggingface.co/spaces/hexgrad/Kokoro-TTS + + Args: + text: Text to synthesize + voice: Voice ID (e.g., af_heart, af_bella, am_michael) + speed: Speech speed multiplier (0.5-2.0) + + Returns: + Tuple of (sample_rate, audio_array) + """ + import numpy as np + + try: + import torch + from kokoro import KModel, KPipeline + + # Initialize model (cached on GPU) + model = KModel().to("cuda").eval() + pipeline = KPipeline(lang_code=voice[0]) + pack = pipeline.load_voice(voice) + + # Generate audio - accumulate all chunks + audio_chunks = [] + for _, ps, _ in pipeline(text, voice, speed): + ref_s = pack[len(ps) - 1] + audio = model(ps, ref_s, speed) + audio_chunks.append(audio.numpy()) + + # Concatenate all audio chunks + if audio_chunks: + full_audio = np.concatenate(audio_chunks) + return (24000, full_audio) + + # If no audio generated, return empty + return (24000, np.zeros(1, dtype=np.float32)) + + except ImportError as e: + raise RuntimeError( + f"Kokoro not installed: {e}. " + "Install with: pip install git+https://github.com/hexgrad/kokoro.git" + ) from e + except Exception as e: + raise RuntimeError(f"TTS synthesis failed: {e}") from e + + +# Optional: Add a test entrypoint +@app.local_entrypoint() +def test(): + """Test the TTS function.""" + print("Testing Modal TTS function...") + sample_rate, audio = kokoro_tts_function.remote("Hello, this is a test.", "af_heart", 1.0) + print(f"Generated audio: {sample_rate}Hz, shape={audio.shape}") + print("βœ“ TTS function works!") diff --git a/docs/LICENSE.md b/docs/LICENSE.md index 7a0c1fad..f44ce878 100644 --- a/docs/LICENSE.md +++ b/docs/LICENSE.md @@ -25,3 +25,14 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +<<<<<<< HEAD + + + + + + + + +======= +>>>>>>> origin/dev diff --git a/docs/api/orchestrators.md b/docs/api/orchestrators.md index 9e3d22df..1f157036 100644 --- a/docs/api/orchestrators.md +++ b/docs/api/orchestrators.md @@ -23,9 +23,18 @@ Runs iterative research flow. - `background_context`: Background context (default: "") - `output_length`: Optional description of desired output length (default: "") - `output_instructions`: Optional additional instructions for report generation (default: "") +<<<<<<< HEAD +- `message_history`: Optional user conversation history in Pydantic AI `ModelMessage` format (default: None) **Returns**: Final report string. +**Note**: The `message_history` parameter enables multi-turn conversations by providing context from previous interactions. + +======= + +**Returns**: Final report string. + +>>>>>>> origin/dev **Note**: `max_iterations`, `max_time_minutes`, and `token_budget` are constructor parameters, not `run()` parameters. ## DeepResearchFlow @@ -46,9 +55,18 @@ Runs deep research flow. **Parameters**: - `query`: Research query string +<<<<<<< HEAD +- `message_history`: Optional user conversation history in Pydantic AI `ModelMessage` format (default: None) **Returns**: Final report string. +**Note**: The `message_history` parameter enables multi-turn conversations by providing context from previous interactions. + +======= + +**Returns**: Final report string. + +>>>>>>> origin/dev **Note**: `max_iterations_per_section`, `max_time_minutes`, and `token_budget` are constructor parameters, not `run()` parameters. ## GraphOrchestrator @@ -69,10 +87,20 @@ Runs graph-based research orchestration. **Parameters**: - `query`: Research query string +<<<<<<< HEAD +- `message_history`: Optional user conversation history in Pydantic AI `ModelMessage` format (default: None) + +**Yields**: `AgentEvent` objects during graph execution. + +**Note**: +- `research_mode` and `use_graph` are constructor parameters, not `run()` parameters. +- The `message_history` parameter enables multi-turn conversations by providing context from previous interactions. Message history is stored in `GraphExecutionContext` and passed to agents during execution. +======= **Yields**: `AgentEvent` objects during graph execution. **Note**: `research_mode` and `use_graph` are constructor parameters, not `run()` parameters. +>>>>>>> origin/dev ## Orchestrator Factory diff --git a/docs/architecture/graph_orchestration.md b/docs/architecture/graph_orchestration.md index cf8c4dbd..5cb48dd9 100644 --- a/docs/architecture/graph_orchestration.md +++ b/docs/architecture/graph_orchestration.md @@ -4,6 +4,47 @@ DeepCritical implements a graph-based orchestration system for research workflows using Pydantic AI agents as nodes. This enables better parallel execution, conditional routing, and state management compared to simple agent chains. +<<<<<<< HEAD +## Conversation History + +DeepCritical supports multi-turn conversations through Pydantic AI's native message history format. The system maintains two types of history: + +1. **User Conversation History**: Multi-turn user interactions (from Gradio chat interface) stored as `list[ModelMessage]` +2. **Research Iteration History**: Internal research process state (existing `Conversation` model) + +### Message History Flow + +``` +Gradio Chat History β†’ convert_gradio_to_message_history() β†’ GraphOrchestrator.run(message_history) + ↓ +GraphExecutionContext (stores message_history) + ↓ +Agent Nodes (receive message_history via agent.run()) + ↓ +WorkflowState (persists user_message_history) +``` + +### Usage + +Message history is automatically converted from Gradio format and passed through the orchestrator: + +```python +# In app.py - automatic conversion +message_history = convert_gradio_to_message_history(history) if history else None +async for event in orchestrator.run(query, message_history=message_history): + yield event +``` + +Agents receive message history through their `run()` methods: + +```python +# In agent execution +if message_history: + result = await agent.run(input_data, message_history=message_history) +``` + +======= +>>>>>>> origin/dev ## Graph Patterns ### Iterative Research Graph diff --git a/pyproject.toml b/pyproject.toml index 6bef4f10..6f59cac5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -127,6 +127,7 @@ ignore = [ "PLR0913", # Too many arguments (agents need many params) "PLR0912", # Too many branches (complex orchestrator logic) "PLR0911", # Too many return statements (complex agent logic) + "PLR0915", # Too many statements (Gradio UI setup functions) "PLR2004", # Magic values (statistical constants like p-values) "PLW0603", # Global statement (singleton pattern for Modal) "PLC0415", # Lazy imports for optional dependencies @@ -152,6 +153,7 @@ exclude = [ "^reference_repos/", "^examples/", "^folder/", + "^src/app.py", ] # ============== PYTEST CONFIG ============== diff --git a/src/agent_factory/agents.py b/src/agent_factory/agents.py index 69d01380..c676140e 100644 --- a/src/agent_factory/agents.py +++ b/src/agent_factory/agents.py @@ -27,7 +27,9 @@ logger = structlog.get_logger() -def create_input_parser_agent(model: Any | None = None, oauth_token: str | None = None) -> "InputParserAgent": +def create_input_parser_agent( + model: Any | None = None, oauth_token: str | None = None +) -> "InputParserAgent": """ Create input parser agent for query analysis and research mode detection. @@ -51,7 +53,9 @@ def create_input_parser_agent(model: Any | None = None, oauth_token: str | None raise ConfigurationError(f"Failed to create input parser agent: {e}") from e -def create_planner_agent(model: Any | None = None, oauth_token: str | None = None) -> "PlannerAgent": +def create_planner_agent( + model: Any | None = None, oauth_token: str | None = None +) -> "PlannerAgent": """ Create planner agent with web search and crawl tools. @@ -76,7 +80,9 @@ def create_planner_agent(model: Any | None = None, oauth_token: str | None = Non raise ConfigurationError(f"Failed to create planner agent: {e}") from e -def create_knowledge_gap_agent(model: Any | None = None, oauth_token: str | None = None) -> "KnowledgeGapAgent": +def create_knowledge_gap_agent( + model: Any | None = None, oauth_token: str | None = None +) -> "KnowledgeGapAgent": """ Create knowledge gap agent for evaluating research completeness. @@ -100,7 +106,9 @@ def create_knowledge_gap_agent(model: Any | None = None, oauth_token: str | None raise ConfigurationError(f"Failed to create knowledge gap agent: {e}") from e -def create_tool_selector_agent(model: Any | None = None, oauth_token: str | None = None) -> "ToolSelectorAgent": +def create_tool_selector_agent( + model: Any | None = None, oauth_token: str | None = None +) -> "ToolSelectorAgent": """ Create tool selector agent for choosing tools to address gaps. @@ -124,7 +132,9 @@ def create_tool_selector_agent(model: Any | None = None, oauth_token: str | None raise ConfigurationError(f"Failed to create tool selector agent: {e}") from e -def create_thinking_agent(model: Any | None = None, oauth_token: str | None = None) -> "ThinkingAgent": +def create_thinking_agent( + model: Any | None = None, oauth_token: str | None = None +) -> "ThinkingAgent": """ Create thinking agent for generating observations. @@ -172,7 +182,9 @@ def create_writer_agent(model: Any | None = None, oauth_token: str | None = None raise ConfigurationError(f"Failed to create writer agent: {e}") from e -def create_long_writer_agent(model: Any | None = None, oauth_token: str | None = None) -> "LongWriterAgent": +def create_long_writer_agent( + model: Any | None = None, oauth_token: str | None = None +) -> "LongWriterAgent": """ Create long writer agent for iteratively writing report sections. @@ -196,7 +208,9 @@ def create_long_writer_agent(model: Any | None = None, oauth_token: str | None = raise ConfigurationError(f"Failed to create long writer agent: {e}") from e -def create_proofreader_agent(model: Any | None = None, oauth_token: str | None = None) -> "ProofreaderAgent": +def create_proofreader_agent( + model: Any | None = None, oauth_token: str | None = None +) -> "ProofreaderAgent": """ Create proofreader agent for finalizing report drafts. diff --git a/src/agent_factory/graph_builder.py b/src/agent_factory/graph_builder.py index 0c03b89c..c06c1b0a 100644 --- a/src/agent_factory/graph_builder.py +++ b/src/agent_factory/graph_builder.py @@ -487,12 +487,13 @@ def create_iterative_graph( # Add nodes builder.add_agent_node("thinking", thinking_agent, "Generate observations") builder.add_agent_node("knowledge_gap", knowledge_gap_agent, "Evaluate knowledge gaps") + def _decision_function(result: Any) -> str: """Decision function for continue_decision node. - + Args: result: Result from knowledge_gap node (KnowledgeGapOutput or tuple) - + Returns: Next node ID: "writer" if research complete, "tool_selector" otherwise """ @@ -510,11 +511,11 @@ def _decision_function(result: Any) -> str: return "writer" if item["research_complete"] else "tool_selector" # Default to continuing research if we can't determine return "tool_selector" - + # Normal case: result is KnowledgeGapOutput object research_complete = getattr(result, "research_complete", False) return "writer" if research_complete else "tool_selector" - + builder.add_decision_node( "continue_decision", decision_function=_decision_function, diff --git a/src/agent_factory/judges.py b/src/agent_factory/judges.py index 92fe3b1e..59ccdc08 100644 --- a/src/agent_factory/judges.py +++ b/src/agent_factory/judges.py @@ -37,7 +37,7 @@ def get_model(oauth_token: str | None = None) -> Any: 1. HuggingFace (if OAuth token or API key available - preferred for free tier) 2. OpenAI (if API key available) 3. Anthropic (if API key available) - + If OAuth token is available, prefer HuggingFace (even if provider is set to OpenAI). This ensures users logged in via HuggingFace Spaces get the free tier. @@ -50,9 +50,23 @@ def get_model(oauth_token: str | None = None) -> Any: Raises: ConfigurationError: If no LLM provider is available """ + from src.utils.hf_error_handler import log_token_info, validate_hf_token + # Priority: oauth_token > settings.hf_token > settings.huggingface_api_key effective_hf_token = oauth_token or settings.hf_token or settings.huggingface_api_key + # Validate and log token information + if effective_hf_token: + log_token_info(effective_hf_token, context="get_model") + is_valid, error_msg = validate_hf_token(effective_hf_token) + if not is_valid: + logger.warning( + "Token validation failed", + error=error_msg, + has_oauth=bool(oauth_token), + ) + # Continue anyway - let the API call fail with a clear error + # Try HuggingFace first (preferred for free tier) if effective_hf_token: model_name = settings.huggingface_model or "meta-llama/Llama-3.1-8B-Instruct" @@ -157,7 +171,27 @@ async def assess( return assessment except Exception as e: - logger.error("Assessment failed", error=str(e)) + # Extract error details for better logging and handling + from src.utils.hf_error_handler import ( + extract_error_details, + get_user_friendly_error_message, + ) + + error_details = extract_error_details(e) + logger.error( + "Assessment failed", + error=str(e), + status_code=error_details.get("status_code"), + model_name=error_details.get("model_name"), + is_auth_error=error_details.get("is_auth_error"), + is_model_error=error_details.get("is_model_error"), + ) + + # Log user-friendly message for debugging + if error_details.get("is_auth_error") or error_details.get("is_model_error"): + user_msg = get_user_friendly_error_message(e, error_details.get("model_name")) + logger.warning("API error details", user_message=user_msg[:200]) + # Return a safe default assessment on failure return self._create_fallback_assessment(question, str(e)) @@ -209,9 +243,7 @@ class HFInferenceJudgeHandler: "HuggingFaceH4/zephyr-7b-beta", # Fallback (Ungated) ] - def __init__( - self, model_id: str | None = None, api_key: str | None = None - ) -> None: + def __init__(self, model_id: str | None = None, api_key: str | None = None) -> None: """ Initialize with HF Inference client. diff --git a/src/agents/audio_refiner.py b/src/agents/audio_refiner.py new file mode 100644 index 00000000..257c6f5e --- /dev/null +++ b/src/agents/audio_refiner.py @@ -0,0 +1,402 @@ +"""Audio Refiner Agent - Cleans markdown reports for TTS audio clarity. + +This agent transforms markdown-formatted research reports into clean, +audio-friendly plain text suitable for text-to-speech synthesis. +""" + +import re + +import structlog +from pydantic_ai import Agent + +from src.utils.llm_factory import get_pydantic_ai_model + +logger = structlog.get_logger(__name__) + + +class AudioRefiner: + """Refines markdown reports for optimal TTS audio output. + + Handles common formatting issues that make text difficult to listen to: + - Markdown syntax (headers, bold, italic, links) + - Citations and reference markers + - Roman numerals in medical contexts + - Multiple References sections + - Special characters and formatting artifacts + """ + + # Roman numeral to integer mapping + ROMAN_VALUES = {"I": 1, "V": 5, "X": 10, "L": 50, "C": 100, "D": 500, "M": 1000} + + # Number to word mapping (1-20, common in medical literature) + NUMBER_TO_WORD = { + 1: "One", + 2: "Two", + 3: "Three", + 4: "Four", + 5: "Five", + 6: "Six", + 7: "Seven", + 8: "Eight", + 9: "Nine", + 10: "Ten", + 11: "Eleven", + 12: "Twelve", + 13: "Thirteen", + 14: "Fourteen", + 15: "Fifteen", + 16: "Sixteen", + 17: "Seventeen", + 18: "Eighteen", + 19: "Nineteen", + 20: "Twenty", + } + + async def refine_for_audio(self, markdown_text: str, use_llm_polish: bool = False) -> str: + """Transform markdown report into audio-friendly plain text. + + Args: + markdown_text: Markdown-formatted research report + use_llm_polish: If True, apply LLM-based final polish (optional) + + Returns: + Clean plain text optimized for TTS audio + """ + logger.info("Refining report for audio output", use_llm_polish=use_llm_polish) + + text = markdown_text + + # Step 1: Remove References sections first (before other processing) + text = self._remove_references_sections(text) + + # Step 2: Remove markdown formatting + text = self._remove_markdown_syntax(text) + + # Step 3: Convert roman numerals to words + text = self._convert_roman_numerals(text) + + # Step 4: Remove citations + text = self._remove_citations(text) + + # Step 5: Clean up special characters and artifacts + text = self._clean_special_characters(text) + + # Step 6: Normalize whitespace + text = self._normalize_whitespace(text) + + # Step 7 (Optional): LLM polish for edge cases + if use_llm_polish: + text = await self._llm_polish(text) + + logger.info( + "Audio refinement complete", + original_length=len(markdown_text), + refined_length=len(text), + llm_polish_applied=use_llm_polish, + ) + + return text.strip() + + def _remove_references_sections(self, text: str) -> str: + """Remove References sections while preserving other content. + + Removes the References section and its content until the next section + heading or end of document. Handles multiple References sections. + + Matches various References heading formats: + - # References + - ## References + - **References:** + - **Additional References:** + - References: (plain text) + """ + # Pattern to match References section heading (case-insensitive) + # Matches: markdown headers (# References), bold (**References:**), or plain text (References:) + references_pattern = r"\n(?:#+\s*References?:?\s*\n|\*\*\s*(?:Additional\s+)?References?:?\s*\*\*\s*\n|References?:?\s*\n)" + + # Find all References sections + while True: + match = re.search(references_pattern, text, re.IGNORECASE) + if not match: + break + + # Find the start of the References section + section_start = match.start() + + # Find the next section (markdown header or bold heading) or end of document + # Match: "# Header", "## Header", or "**Header**" + next_section_patterns = [ + r"\n#+\s+\w+", # Markdown headers (# Section, ## Section) + r"\n\*\*[A-Z][^*]+\*\*", # Bold headings (**Section Name**) + ] + + remaining_text = text[match.end() :] + next_section_match = None + + # Try all patterns and find the earliest match + earliest_match = None + for pattern in next_section_patterns: + m = re.search(pattern, remaining_text) + if m and (earliest_match is None or m.start() < earliest_match.start()): + earliest_match = m + + next_section_match = earliest_match + + if next_section_match: + # Remove from References heading to next section + section_end = match.end() + next_section_match.start() + else: + # No next section - remove to end of document + section_end = len(text) + + # Remove the References section + text = text[:section_start] + text[section_end:] + logger.debug("Removed References section", removed_chars=section_end - section_start) + + return text + + def _remove_markdown_syntax(self, text: str) -> str: + """Remove markdown formatting syntax.""" + + # Headers (# ## ###) + text = re.sub(r"^\s*#+\s+", "", text, flags=re.MULTILINE) + + # Bold (**text** or __text__) + text = re.sub(r"\*\*([^*]+)\*\*", r"\1", text) + text = re.sub(r"__([^_]+)__", r"\1", text) + + # Italic (*text* or _text_) + text = re.sub(r"\*([^*]+)\*", r"\1", text) + text = re.sub(r"_([^_]+)_", r"\1", text) + + # Links [text](url) β†’ text + text = re.sub(r"\[([^\]]+)\]\([^)]+\)", r"\1", text) + + # Inline code `code` β†’ code + text = re.sub(r"`([^`]+)`", r"\1", text) + + # Strikethrough ~~text~~ + text = re.sub(r"~~([^~]+)~~", r"\1", text) + + # Blockquotes (> text) + text = re.sub(r"^\s*>\s+", "", text, flags=re.MULTILINE) + + # Horizontal rules (---, ***, ___) + text = re.sub(r"^\s*[-*_]{3,}\s*$", "", text, flags=re.MULTILINE) + + # List markers (-, *, 1., 2.) + text = re.sub(r"^\s*[-*]\s+", "", text, flags=re.MULTILINE) + text = re.sub(r"^\s*\d+\.\s+", "", text, flags=re.MULTILINE) + + return text + + def _roman_to_int(self, roman: str) -> int | None: + """Convert roman numeral string to integer. + + Args: + roman: Roman numeral string (e.g., 'IV', 'XII') + + Returns: + Integer value, or None if invalid roman numeral + """ + roman = roman.upper() + result = 0 + prev_value = 0 + + for char in reversed(roman): + if char not in self.ROMAN_VALUES: + return None + + value = self.ROMAN_VALUES[char] + + # Subtractive notation (IV = 4, IX = 9) + if value < prev_value: + result -= value + else: + result += value + + prev_value = value + + return result + + def _int_to_word(self, num: int) -> str: + """Convert integer to word representation. + + Args: + num: Integer to convert (1-20 supported) + + Returns: + Word representation (e.g., 'One', 'Twelve') + """ + if num in self.NUMBER_TO_WORD: + return self.NUMBER_TO_WORD[num] + else: + # For numbers > 20, just return the digit + return str(num) + + def _convert_roman_numerals(self, text: str) -> str: + """Convert roman numerals to words for better TTS pronunciation. + + Handles patterns like: + - Phase I, Phase II, Phase III + - Trial I, Trial II + - Type I, Type II + - Stage I, Stage II + - Standalone I, II, III (with word boundaries) + """ + + def replace_roman(match: re.Match[str]) -> str: + """Callback to replace matched roman numeral.""" + prefix = match.group(1) # Word before roman numeral (if any) + roman = match.group(2) # The roman numeral + + # Convert to integer + num = self._roman_to_int(roman) + if num is None: + return match.group(0) # Return original if invalid + + # Convert to word + word = self._int_to_word(num) + + # Return with prefix if present + if prefix: + return f"{prefix} {word}" + else: + return word + + # Pattern: Optional word + space + roman numeral + # Matches: "Phase I", "Trial II", standalone "I", "II" + # Uses word boundaries to avoid matching "I" in "INVALID" + pattern = r"\b(Phase|Trial|Type|Stage|Class|Group|Arm|Cohort)?\s*([IVXLCDM]+)\b" + + text = re.sub(pattern, replace_roman, text) + + return text + + def _remove_citations(self, text: str) -> str: + """Remove citation markers and references.""" + + # Numbered citations [1], [2], [1,2], [1-3] + text = re.sub(r"\[\d+(?:[-,]\d+)*\]", "", text) + + # Author citations (Smith et al., 2023) or (Smith et al. 2023) + text = re.sub(r"\([A-Z][a-z]+\s+et\s+al\.?,?\s+\d{4}\)", "", text) + + # Simple year citations (2023) + text = re.sub(r"\(\d{4}\)", "", text) + + # Author-year (Smith, 2023) + text = re.sub(r"\([A-Z][a-z]+,?\s+\d{4}\)", "", text) + + # Footnote markers (ΒΉ, Β², Β³) + text = re.sub(r"[¹²³⁴⁡⁢⁷⁸⁹⁰]+", "", text) + + return text + + def _clean_special_characters(self, text: str) -> str: + """Clean up special characters and formatting artifacts.""" + + # Replace em dashes with regular dashes + text = text.replace("\u2014", "-") # em dash + text = text.replace("\u2013", "-") # en dash + + # Replace smart quotes with regular quotes + text = text.replace("\u201c", '"') # left double quote + text = text.replace("\u201d", '"') # right double quote + text = text.replace("\u2018", "'") # left single quote + text = text.replace("\u2019", "'") # right single quote + + # Remove excessive punctuation (!!!, ???) + text = re.sub(r"([!?]){2,}", r"\1", text) + + # Remove asterisks used for footnotes + text = re.sub(r"\*+", "", text) + + # Remove hash symbols (from headers) + text = text.replace("#", "") + + # Remove excessive dots (...) + text = re.sub(r"\.{4,}", "...", text) + + return text + + def _normalize_whitespace(self, text: str) -> str: + """Normalize whitespace for clean audio output.""" + + # Replace multiple spaces with single space + text = re.sub(r" {2,}", " ", text) + + # Replace multiple newlines with double newline (paragraph break) + text = re.sub(r"\n{3,}", "\n\n", text) + + # Remove trailing/leading whitespace from lines + text = "\n".join(line.strip() for line in text.split("\n")) + + # Remove empty lines at start/end + text = text.strip() + + return text + + async def _llm_polish(self, text: str) -> str: + """Apply LLM-based final polish to catch edge cases. + + This is a lightweight pass that removes any remaining formatting + artifacts the rule-based methods might have missed. + + Args: + text: Pre-cleaned text from rule-based methods + + Returns: + Final polished text ready for TTS + """ + try: + # Create a simple agent for text cleanup + model = get_pydantic_ai_model() + polish_agent = Agent( + model=model, + system_prompt=( + "You are a text cleanup assistant. Your ONLY job is to remove " + "any remaining formatting artifacts (markdown, citations, special " + "characters) that make text unsuitable for text-to-speech audio. " + "DO NOT rewrite, improve, or change the content. " + "DO NOT add explanations. " + "ONLY output the cleaned text." + ), + ) + + # Run asynchronously + result = await polish_agent.run( + f"Clean this text for audio (remove any formatting artifacts):\n\n{text}" + ) + + polished_text = result.output.strip() + + logger.info( + "llm_polish_applied", original_length=len(text), polished_length=len(polished_text) + ) + + return polished_text + + except Exception as e: + logger.warning( + "llm_polish_failed", error=str(e), message="Falling back to rule-based output" + ) + # Graceful fallback: return original text if LLM fails + return text + + +# Singleton instance for easy import +audio_refiner = AudioRefiner() + + +async def refine_text_for_audio(markdown_text: str, use_llm_polish: bool = False) -> str: + """Convenience function to refine markdown text for audio. + + Args: + markdown_text: Markdown-formatted text + use_llm_polish: If True, apply LLM-based final polish (optional) + + Returns: + Audio-friendly plain text + """ + return await audio_refiner.refine_for_audio(markdown_text, use_llm_polish=use_llm_polish) diff --git a/src/agents/knowledge_gap.py b/src/agents/knowledge_gap.py index 92114eaf..4ceab6cc 100644 --- a/src/agents/knowledge_gap.py +++ b/src/agents/knowledge_gap.py @@ -9,6 +9,11 @@ import structlog from pydantic_ai import Agent +try: + from pydantic_ai import ModelMessage +except ImportError: + ModelMessage = Any # type: ignore[assignment, misc] + from src.agent_factory.judges import get_model from src.utils.exceptions import ConfigurationError from src.utils.models import KnowledgeGapOutput @@ -68,6 +73,7 @@ async def evaluate( query: str, background_context: str = "", conversation_history: str = "", + message_history: list[ModelMessage] | None = None, iteration: int = 0, time_elapsed_minutes: float = 0.0, max_time_minutes: int = 10, @@ -78,7 +84,8 @@ async def evaluate( Args: query: The original research query background_context: Optional background context - conversation_history: History of actions, findings, and thoughts + conversation_history: History of actions, findings, and thoughts (backward compat) + message_history: Optional user conversation history (Pydantic AI format) iteration: Current iteration number time_elapsed_minutes: Time elapsed so far max_time_minutes: Maximum time allowed @@ -111,8 +118,11 @@ async def evaluate( """ try: - # Run the agent - result = await self.agent.run(user_message) + # Run the agent with message_history if provided + if message_history: + result = await self.agent.run(user_message, message_history=message_history) + else: + result = await self.agent.run(user_message) evaluation = result.output self.logger.info( @@ -132,7 +142,9 @@ async def evaluate( ) -def create_knowledge_gap_agent(model: Any | None = None, oauth_token: str | None = None) -> KnowledgeGapAgent: +def create_knowledge_gap_agent( + model: Any | None = None, oauth_token: str | None = None +) -> KnowledgeGapAgent: """ Factory function to create a knowledge gap agent. diff --git a/src/agents/long_writer.py b/src/agents/long_writer.py index 0e5f0804..5b0553aa 100644 --- a/src/agents/long_writer.py +++ b/src/agents/long_writer.py @@ -225,25 +225,27 @@ async def write_next_section( "Section writing failed after all attempts", error=str(last_exception) if last_exception else "Unknown error", ) - + # Try to enhance fallback with evidence if available try: from src.middleware.state_machine import get_workflow_state - + state = get_workflow_state() if state and state.evidence: # Include evidence citations in fallback evidence_refs: list[str] = [] for i, ev in enumerate(state.evidence[:10], 1): # Limit to 10 - authors = ", ".join(ev.citation.authors[:2]) if ev.citation.authors else "Unknown" + authors = ( + ", ".join(ev.citation.authors[:2]) if ev.citation.authors else "Unknown" + ) evidence_refs.append( f"[{i}] {authors}. *{ev.citation.title}*. {ev.citation.url}" ) - + enhanced_draft = f"## {next_section_title}\n\n{next_section_draft}" if evidence_refs: enhanced_draft += "\n\n### Sources\n\n" + "\n".join(evidence_refs) - + return LongWriterOutput( next_section_markdown=enhanced_draft, references=evidence_refs, @@ -253,7 +255,7 @@ async def write_next_section( "Failed to enhance fallback with evidence", error=str(e), ) - + # Basic fallback return LongWriterOutput( next_section_markdown=f"## {next_section_title}\n\n{next_section_draft}", @@ -437,7 +439,9 @@ def adjust_heading_level(match: re.Match[str]) -> str: return re.sub(r"^(#+)\s(.+)$", adjust_heading_level, section_markdown, flags=re.MULTILINE) -def create_long_writer_agent(model: Any | None = None, oauth_token: str | None = None) -> LongWriterAgent: +def create_long_writer_agent( + model: Any | None = None, oauth_token: str | None = None +) -> LongWriterAgent: """ Factory function to create a long writer agent. diff --git a/src/agents/proofreader.py b/src/agents/proofreader.py index 5fdf15e3..7a209c36 100644 --- a/src/agents/proofreader.py +++ b/src/agents/proofreader.py @@ -181,7 +181,9 @@ async def proofread( return f"# Research Report\n\n## Query\n{query}\n\n" + "\n\n".join(sections) -def create_proofreader_agent(model: Any | None = None, oauth_token: str | None = None) -> ProofreaderAgent: +def create_proofreader_agent( + model: Any | None = None, oauth_token: str | None = None +) -> ProofreaderAgent: """ Factory function to create a proofreader agent. diff --git a/src/agents/thinking.py b/src/agents/thinking.py index 225543b7..eff08e51 100644 --- a/src/agents/thinking.py +++ b/src/agents/thinking.py @@ -9,6 +9,11 @@ import structlog from pydantic_ai import Agent +try: + from pydantic_ai import ModelMessage +except ImportError: + ModelMessage = Any # type: ignore[assignment, misc] + from src.agent_factory.judges import get_model from src.utils.exceptions import ConfigurationError @@ -72,6 +77,7 @@ async def generate_observations( query: str, background_context: str = "", conversation_history: str = "", + message_history: list[ModelMessage] | None = None, iteration: int = 1, ) -> str: """ @@ -80,7 +86,8 @@ async def generate_observations( Args: query: The original research query background_context: Optional background context - conversation_history: History of actions, findings, and thoughts + conversation_history: History of actions, findings, and thoughts (backward compat) + message_history: Optional user conversation history (Pydantic AI format) iteration: Current iteration number Returns: @@ -110,8 +117,11 @@ async def generate_observations( """ try: - # Run the agent - result = await self.agent.run(user_message) + # Run the agent with message_history if provided + if message_history: + result = await self.agent.run(user_message, message_history=message_history) + else: + result = await self.agent.run(user_message) observations = result.output self.logger.info("Observations generated", length=len(observations)) @@ -124,7 +134,9 @@ async def generate_observations( return f"Starting iteration {iteration}. Need to gather information about: {query}" -def create_thinking_agent(model: Any | None = None, oauth_token: str | None = None) -> ThinkingAgent: +def create_thinking_agent( + model: Any | None = None, oauth_token: str | None = None +) -> ThinkingAgent: """ Factory function to create a thinking agent. diff --git a/src/agents/tool_selector.py b/src/agents/tool_selector.py index 3f06fe92..0da35f43 100644 --- a/src/agents/tool_selector.py +++ b/src/agents/tool_selector.py @@ -9,6 +9,11 @@ import structlog from pydantic_ai import Agent +try: + from pydantic_ai import ModelMessage +except ImportError: + ModelMessage = Any # type: ignore[assignment, misc] + from src.agent_factory.judges import get_model from src.utils.exceptions import ConfigurationError from src.utils.models import AgentSelectionPlan @@ -81,6 +86,7 @@ async def select_tools( query: str, background_context: str = "", conversation_history: str = "", + message_history: list[ModelMessage] | None = None, ) -> AgentSelectionPlan: """ Select tools to address a knowledge gap. @@ -89,7 +95,8 @@ async def select_tools( gap: The knowledge gap to address query: The original research query background_context: Optional background context - conversation_history: History of actions, findings, and thoughts + conversation_history: History of actions, findings, and thoughts (backward compat) + message_history: Optional user conversation history (Pydantic AI format) Returns: AgentSelectionPlan with tasks for selected agents @@ -115,8 +122,11 @@ async def select_tools( """ try: - # Run the agent - result = await self.agent.run(user_message) + # Run the agent with message_history if provided + if message_history: + result = await self.agent.run(user_message, message_history=message_history) + else: + result = await self.agent.run(user_message) selection_plan = result.output self.logger.info( @@ -144,7 +154,9 @@ async def select_tools( ) -def create_tool_selector_agent(model: Any | None = None, oauth_token: str | None = None) -> ToolSelectorAgent: +def create_tool_selector_agent( + model: Any | None = None, oauth_token: str | None = None +) -> ToolSelectorAgent: """ Factory function to create a tool selector agent. diff --git a/src/agents/writer.py b/src/agents/writer.py index 3bfe7ee0..9fc2edff 100644 --- a/src/agents/writer.py +++ b/src/agents/writer.py @@ -175,12 +175,12 @@ async def write_report( "Report writing failed after all attempts", error=str(last_exception) if last_exception else "Unknown error", ) - + # Try to use evidence-based report generator for better fallback try: from src.middleware.state_machine import get_workflow_state from src.utils.report_generator import generate_report_from_evidence - + state = get_workflow_state() if state and state.evidence: self.logger.info( @@ -197,7 +197,7 @@ async def write_report( "Failed to use evidence-based report generator", error=str(e), ) - + # Fallback to simple report if evidence generator fails # Truncate findings in fallback if too long fallback_findings = findings[:500] + "..." if len(findings) > 500 else findings diff --git a/src/app.py b/src/app.py index ecb13412..e700eaec 100644 --- a/src/app.py +++ b/src/app.py @@ -1,4 +1,12 @@ -"""Gradio UI for The DETERMINATOR agent with MCP server support.""" +"""Main Gradio application for DeepCritical research agent. + +This module provides the Gradio interface with: +- OAuth authentication via HuggingFace +- Multimodal input support (text, images, audio) +- Research agent orchestration +- Real-time event streaming +- MCP server integration +""" import os from collections.abc import AsyncGenerator @@ -6,38 +14,43 @@ import gradio as gr import numpy as np -from gradio.components.multimodal_textbox import MultimodalPostprocess - -# Try to import HuggingFace support (may not be available in all pydantic-ai versions) -# According to https://ai.pydantic.dev/models/huggingface/, HuggingFace support requires -# pydantic-ai with huggingface extra or pydantic-ai-slim[huggingface] -# There are two ways to use HuggingFace: -# 1. Inference API: HuggingFaceModel with HuggingFaceProvider (uses AsyncInferenceClient internally) -# 2. Local models: Would use transformers directly (not via pydantic-ai) +import structlog + +from src.agent_factory.judges import HFInferenceJudgeHandler, JudgeHandler, MockJudgeHandler +from src.orchestrator_factory import create_orchestrator +from src.services.multimodal_processing import get_multimodal_service +from src.utils.config import settings +from src.utils.models import AgentEvent, OrchestratorConfig + +# Import ModelMessage from pydantic_ai with fallback +try: + from pydantic_ai import ModelMessage +except ImportError: + from typing import Any + + ModelMessage = Any # type: ignore[assignment, misc] + +# Type alias for Gradio multimodal input +MultimodalPostprocess = dict[str, Any] | str + +# Import HuggingFace components with graceful fallback try: - from huggingface_hub import AsyncInferenceClient from pydantic_ai.models.huggingface import HuggingFaceModel from pydantic_ai.providers.huggingface import HuggingFaceProvider _HUGGINGFACE_AVAILABLE = True except ImportError: + _HUGGINGFACE_AVAILABLE = False HuggingFaceModel = None # type: ignore[assignment, misc] HuggingFaceProvider = None # type: ignore[assignment, misc] - AsyncInferenceClient = None # type: ignore[assignment, misc] - _HUGGINGFACE_AVAILABLE = False -from src.agent_factory.judges import HFInferenceJudgeHandler, JudgeHandler, MockJudgeHandler -from src.orchestrator_factory import create_orchestrator -from src.services.audio_processing import get_audio_service -from src.services.multimodal_processing import get_multimodal_service -import structlog -from src.tools.clinicaltrials import ClinicalTrialsTool -from src.tools.europepmc import EuropePMCTool -from src.tools.pubmed import PubMedTool -from src.tools.search_handler import SearchHandler -from src.tools.neo4j_search import Neo4jSearchTool -from src.utils.config import settings -from src.utils.models import AgentEvent, OrchestratorConfig +try: + from huggingface_hub import AsyncInferenceClient + + _ASYNC_INFERENCE_AVAILABLE = True +except ImportError: + _ASYNC_INFERENCE_AVAILABLE = False + AsyncInferenceClient = None # type: ignore[assignment, misc] logger = structlog.get_logger() @@ -50,40 +63,39 @@ def configure_orchestrator( hf_provider: str | None = None, graph_mode: str | None = None, use_graph: bool = True, + web_search_provider: str | None = None, ) -> tuple[Any, str]: """ - Create an orchestrator instance. + Configure and create the research orchestrator. Args: - use_mock: If True, use MockJudgeHandler (no API key needed) - mode: Orchestrator mode ("simple", "advanced", "iterative", "deep", or "auto") - oauth_token: Optional OAuth token from HuggingFace login - hf_model: Selected HuggingFace model ID - hf_provider: Selected inference provider - graph_mode: Graph research mode ("iterative", "deep", or "auto") - used when mode is graph-based - use_graph: Whether to use graph execution (True) or agent chains (False) + use_mock: Force mock judge handler (for testing) + mode: Orchestrator mode ("simple", "iterative", "deep", "auto", "advanced") + oauth_token: Optional OAuth token from HuggingFace login (takes priority over env vars) + hf_model: Optional HuggingFace model ID (overrides settings) + hf_provider: Optional inference provider (currently not used by HuggingFaceProvider) + graph_mode: Optional graph execution mode + use_graph: Whether to use graph execution + web_search_provider: Optional web search provider ("auto", "serper", "duckduckgo") Returns: - Tuple of (Orchestrator instance, backend_name) + Tuple of (orchestrator, backend_info_string) """ - # Create orchestrator config - config = OrchestratorConfig( - max_iterations=10, - max_results_per_tool=10, - ) - - # Create search tools with RAG enabled - # Pass OAuth token to SearchHandler so it can be used by RAG service - tools = [Neo4jSearchTool(),PubMedTool(), ClinicalTrialsTool(), EuropePMCTool()] - - # Add web search tool if available + from src.tools.search_handler import SearchHandler from src.tools.web_search_factory import create_web_search_tool - web_search_tool = create_web_search_tool() - if web_search_tool is not None: + # Create search handler with tools + tools = [] + + # Add web search tool + web_search_tool = create_web_search_tool(provider=web_search_provider or "auto") + if web_search_tool: tools.append(web_search_tool) logger.info("Web search tool added to search handler", provider=web_search_tool.name) + # Create config if not provided + config = OrchestratorConfig() + search_handler = SearchHandler( tools=tools, timeout=config.search_timeout, @@ -104,11 +116,37 @@ def configure_orchestrator( # 2. API Key (OAuth or Env) - HuggingFace only (OAuth provides HF token) # Priority: oauth_token > env vars # On HuggingFace Spaces, OAuth token is available via request.oauth_token + # + # OAuth Scope Requirements: + # - 'inference-api': Required for HuggingFace Inference API access + # This scope grants access to: + # * HuggingFace's own Inference API + # * All third-party inference providers (nebius, together, scaleway, hyperbolic, novita, nscale, sambanova, ovh, fireworks, etc.) + # * All models available through the Inference Providers API + # See: https://huggingface.co/docs/hub/oauth#currently-supported-scopes + # + # Note: The hf_provider parameter is accepted but not used here because HuggingFaceProvider + # from pydantic-ai doesn't support provider selection. Provider selection happens at the + # InferenceClient level (used in HuggingFaceChatClient for advanced mode). effective_api_key = oauth_token or os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_API_KEY") + # Log which authentication source is being used + if effective_api_key: + auth_source = ( + "OAuth token" + if oauth_token + else ("HF_TOKEN env var" if os.getenv("HF_TOKEN") else "HUGGINGFACE_API_KEY env var") + ) + logger.info( + "Using HuggingFace authentication", + source=auth_source, + has_token=bool(effective_api_key), + ) + if effective_api_key: # We have an API key (OAuth or env) - use pydantic-ai with JudgeHandler - # This uses HuggingFace's own inference API, not third-party providers + # This uses HuggingFace Inference API, which includes access to all third-party providers + # via the Inference Providers API (router.huggingface.co) model: Any | None = None # Use selected model or fall back to env var/settings model_name = ( @@ -126,6 +164,7 @@ def configure_orchestrator( # Per https://ai.pydantic.dev/models/huggingface/#configure-the-provider # HuggingFaceProvider accepts api_key parameter directly # This is consistent with usage in src/utils/llm_factory.py and src/agent_factory/judges.py + # The OAuth token with 'inference-api' scope provides access to all inference providers provider = HuggingFaceProvider(api_key=effective_api_key) # type: ignore[misc] model = HuggingFaceModel(model_name, provider=provider) # type: ignore[misc] backend_info = "API (HuggingFace OAuth)" if oauth_token else "API (Env Config)" @@ -167,203 +206,44 @@ def configure_orchestrator( def _is_file_path(text: str) -> bool: """Check if text appears to be a file path. - + Args: text: Text to check - + Returns: True if text looks like a file path """ - import os - # Check for common file extensions - file_extensions = ['.md', '.pdf', '.txt', '.json', '.csv', '.xlsx', '.docx', '.html'] - text_lower = text.lower().strip() - - # Check if it ends with a file extension - if any(text_lower.endswith(ext) for ext in file_extensions): - # Check if it's a valid path (absolute or relative) - if os.path.sep in text or '/' in text or '\\' in text: - return True - # Or if it's just a filename with extension - if '.' in text and len(text.split('.')) == 2: - return True - - # Check if it's an absolute path - if os.path.isabs(text): - return True - - return False - - -def _get_file_name(file_path: str) -> str: - """Extract filename from file path. - - Args: - file_path: Full file path - - Returns: - Filename with extension - """ - import os - return os.path.basename(file_path) + return ("/" in text or "\\" in text) and ( + "." in text.split("/")[-1] or "." in text.split("\\")[-1] + ) def event_to_chat_message(event: AgentEvent) -> dict[str, Any]: - """ - Convert AgentEvent to gr.ChatMessage with metadata for accordion display. + """Convert AgentEvent to Gradio chat message format. Args: - event: The AgentEvent to convert + event: AgentEvent to convert Returns: - ChatMessage with metadata for collapsible accordion + Dictionary with 'role' and 'content' keys for Gradio Chatbot """ - # Map event types to accordion titles and determine if pending - event_configs: dict[str, dict[str, Any]] = { - "started": {"title": "πŸš€ Starting Research", "status": "done", "icon": "πŸš€"}, - "searching": {"title": "πŸ” Searching Literature", "status": "pending", "icon": "πŸ”"}, - "search_complete": {"title": "πŸ“š Search Results", "status": "done", "icon": "πŸ“š"}, - "judging": {"title": "🧠 Evaluating Evidence", "status": "pending", "icon": "🧠"}, - "judge_complete": {"title": "βœ… Evidence Assessment", "status": "done", "icon": "βœ…"}, - "looping": {"title": "πŸ”„ Research Iteration", "status": "pending", "icon": "πŸ”„"}, - "synthesizing": {"title": "πŸ“ Synthesizing Report", "status": "pending", "icon": "πŸ“"}, - "hypothesizing": {"title": "πŸ”¬ Generating Hypothesis", "status": "pending", "icon": "πŸ”¬"}, - "analyzing": {"title": "πŸ“Š Statistical Analysis", "status": "pending", "icon": "πŸ“Š"}, - "analysis_complete": {"title": "πŸ“ˆ Analysis Results", "status": "done", "icon": "πŸ“ˆ"}, - "streaming": {"title": "πŸ“‘ Processing", "status": "pending", "icon": "πŸ“‘"}, - "complete": {"title": None, "status": "done", "icon": "πŸŽ‰"}, # Main response, no accordion - "error": {"title": "❌ Error", "status": "done", "icon": "❌"}, - } - - config = event_configs.get( - event.type, {"title": f"β€’ {event.type}", "status": "done", "icon": "β€’"} - ) - - # For complete events, return main response without accordion - if event.type == "complete": - # Check if event contains file information - content = event.message - files: list[str] | None = None - - # Check event.data for file paths - if event.data and isinstance(event.data, dict): - # Support both "files" (list) and "file" (single path) keys - if "files" in event.data: - files = event.data["files"] - if isinstance(files, str): - files = [files] - elif not isinstance(files, list): - files = None - else: - # Filter to only valid file paths - files = [f for f in files if isinstance(f, str) and _is_file_path(f)] - elif "file" in event.data: - file_path = event.data["file"] - if isinstance(file_path, str) and _is_file_path(file_path): - files = [file_path] - - # Also check if message itself is a file path (less common, but possible) - if not files and isinstance(event.message, str) and _is_file_path(event.message): - files = [event.message] - # Keep message as text description - content = "Report generated. Download available below." - - # Return as dict format for Gradio Chatbot compatibility - result: dict[str, Any] = { - "role": "assistant", - "content": content, - } - - # Add files if present - # Gradio Chatbot supports file paths in content as markdown links - # The links will be clickable and downloadable - if files: - # Validate files exist before including them - import os - valid_files = [f for f in files if os.path.exists(f)] - - if valid_files: - # Format files for Gradio: include as markdown download links - # Gradio ChatInterface automatically renders file links as downloadable files - import os - file_links = [] - for f in valid_files: - file_name = _get_file_name(f) - try: - file_size = os.path.getsize(f) - # Format file size (bytes to KB/MB) - if file_size < 1024: - size_str = f"{file_size} B" - elif file_size < 1024 * 1024: - size_str = f"{file_size / 1024:.1f} KB" - else: - size_str = f"{file_size / (1024 * 1024):.1f} MB" - file_links.append(f"πŸ“Ž [Download: {file_name} ({size_str})]({f})") - except OSError: - # If we can't get file size, just show the name - file_links.append(f"πŸ“Ž [Download: {file_name}]({f})") - - result["content"] = f"{content}\n\n" + "\n\n".join(file_links) - - # Also store in metadata for potential future use - if "metadata" not in result: - result["metadata"] = {} - result["metadata"]["files"] = valid_files - - return result - - # Build metadata for accordion according to Gradio ChatMessage spec - # Metadata keys: title (str), status ("pending"|"done"), log (str), duration (float) - # See: https://www.gradio.app/guides/agents-and-tool-usage - metadata: dict[str, Any] = {} - - # Title is required for accordion display - must be string - if config["title"]: - metadata["title"] = str(config["title"]) - - # Set status (pending shows spinner, done is collapsed) - # Must be exactly "pending" or "done" per Gradio spec - if config["status"] == "pending": - metadata["status"] = "pending" - elif config["status"] == "done": - metadata["status"] = "done" - - # Add duration if available in data (must be float) - if event.data and isinstance(event.data, dict) and "duration" in event.data: - duration = event.data["duration"] - if isinstance(duration, int | float): - metadata["duration"] = float(duration) - - # Add log info (iteration number, etc.) - must be string - log_parts: list[str] = [] - if event.iteration > 0: - log_parts.append(f"Iteration {event.iteration}") - if event.data and isinstance(event.data, dict): - if "tool" in event.data: - log_parts.append(f"Tool: {event.data['tool']}") - if "results_count" in event.data: - log_parts.append(f"Results: {event.data['results_count']}") - if log_parts: - metadata["log"] = " | ".join(log_parts) - - # Return as dict format for Gradio Chatbot compatibility - # According to Gradio docs: https://www.gradio.app/guides/agents-and-tool-usage - # ChatMessage format: {"role": "assistant", "content": "...", "metadata": {...}} - # Metadata must have "title" key for accordion display - # Valid metadata keys: title (str), status ("pending"|"done"), log (str), duration (float) result: dict[str, Any] = { "role": "assistant", - "content": event.message, + "content": event.to_markdown(), } - # Only add metadata if it has a title (required for accordion display) - # Ensure metadata values match Gradio's expected types - if metadata and metadata.get("title"): - # Ensure status is valid if present - if "status" in metadata: - status = metadata["status"] - if status not in ("pending", "done"): - metadata["status"] = "done" # Default to "done" if invalid - result["metadata"] = metadata + + # Add metadata if available + if event.data: + metadata: dict[str, Any] = {} + + # Extract file path if present + if isinstance(event.data, dict): + file_path = event.data.get("file_path") + if file_path: + metadata["file_path"] = file_path + + if metadata: + result["metadata"] = metadata return result @@ -402,9 +282,9 @@ def extract_oauth_info(request: gr.Request | None) -> tuple[str | None, str | No oauth_username = request.username # Also try accessing via oauth_profile if available elif hasattr(request, "oauth_profile") and request.oauth_profile is not None: - if hasattr(request.oauth_profile, "username"): + if hasattr(request.oauth_profile, "username") and request.oauth_profile.username: oauth_username = request.oauth_profile.username - elif hasattr(request.oauth_profile, "name"): + elif hasattr(request.oauth_profile, "name") and request.oauth_profile.name: oauth_username = request.oauth_profile.name return oauth_token, oauth_username @@ -417,134 +297,141 @@ async def yield_auth_messages( mode: str, ) -> AsyncGenerator[dict[str, Any], None]: """ - Yield authentication and mode status messages. + Yield authentication status messages. Args: oauth_username: OAuth username if available oauth_token: OAuth token if available - has_huggingface: Whether HuggingFace credentials are available - mode: Orchestrator mode + has_huggingface: Whether HuggingFace authentication is available + mode: Research mode Yields: - ChatMessage objects with authentication status + Chat message dictionaries """ - # Show user greeting if logged in via OAuth if oauth_username: yield { "role": "assistant", - "content": f"πŸ‘‹ **Welcome, {oauth_username}!** Using your HuggingFace account.\n\n", + "content": f"πŸ‘‹ **Welcome, {oauth_username}!**\n\nAuthenticated via HuggingFace OAuth.", } - # Advanced mode is not currently supported with HuggingFace inference - # For now, we only support simple mode with HuggingFace - if mode == "advanced": + if oauth_token: yield { "role": "assistant", "content": ( - "⚠️ **Note**: Advanced mode is not available with HuggingFace inference providers. " - "Falling back to simple mode.\n\n" + "πŸ” **Authentication Status**: βœ… Authenticated\n\n" + "Your OAuth token has been validated. You can now use all AI models and research tools." ), } - - # Inform user about authentication status - if oauth_token: + elif has_huggingface: yield { "role": "assistant", "content": ( - "πŸ” **Using HuggingFace OAuth token** - " - "Authenticated via your HuggingFace account.\n\n" + "πŸ” **Authentication Status**: βœ… Using environment token\n\n" + "Using HF_TOKEN from environment variables." ), } - elif not has_huggingface: - # No keys at all - will use FREE HuggingFace Inference (public models) + else: yield { "role": "assistant", "content": ( - "πŸ€— **Free Tier**: Using HuggingFace Inference (Llama 3.1 / Mistral) for AI analysis.\n" - "For premium models or higher rate limits, sign in with HuggingFace above.\n\n" + "⚠️ **Authentication Status**: ❌ No authentication\n\n" + "Please sign in with HuggingFace or set HF_TOKEN environment variable." ), } + yield { + "role": "assistant", + "content": f"πŸš€ **Mode**: {mode.upper()}\n\nStarting research agent...", + } -async def handle_orchestrator_events( - orchestrator: Any, - message: str, -) -> AsyncGenerator[dict[str, Any], None]: - """ - Handle orchestrator events and yield ChatMessages. - Args: - orchestrator: The orchestrator instance - message: The research question +def _extract_oauth_token(oauth_token: gr.OAuthToken | None) -> str | None: + """Extract token value from OAuth token object.""" + if oauth_token is None: + return None - Yields: - ChatMessage objects from orchestrator events - """ - # Track pending accordions for real-time updates - pending_accordions: dict[str, str] = {} # title -> accumulated content - - async for event in orchestrator.run(message): - # Convert event to ChatMessage with metadata - chat_msg = event_to_chat_message(event) - - # Handle complete events (main response) - if event.type == "complete": - # Close any pending accordions first - if pending_accordions: - for title, content in pending_accordions.items(): - yield { - "role": "assistant", - "content": content.strip(), - "metadata": {"title": title, "status": "done"}, - } - pending_accordions.clear() - - # Yield final response (no accordion for main response) - # chat_msg is already a dict from event_to_chat_message - yield chat_msg - continue + if hasattr(oauth_token, "token"): + token_value: str | None = getattr(oauth_token, "token", None) # type: ignore[assignment] + if token_value is None: + return None + logger.debug("OAuth token extracted from oauth_token.token attribute") - # Handle events with metadata (accordions) - # chat_msg is always a dict from event_to_chat_message - metadata: dict[str, Any] = chat_msg.get("metadata", {}) - if metadata: - msg_title: str | None = metadata.get("title") - msg_status: str | None = metadata.get("status") - - if msg_title: - # For pending operations, accumulate content and show spinner - if msg_status == "pending": - if msg_title not in pending_accordions: - pending_accordions[msg_title] = "" - # chat_msg is always a dict, so access content via key - content = chat_msg.get("content", "") - pending_accordions[msg_title] += content + "\n" - # Yield updated accordion with accumulated content - yield { - "role": "assistant", - "content": pending_accordions[msg_title].strip(), - "metadata": chat_msg.get("metadata", {}), - } - elif msg_title in pending_accordions: - # Combine pending content with final content - # chat_msg is always a dict, so access content via key - content = chat_msg.get("content", "") - final_content = pending_accordions[msg_title] + content - del pending_accordions[msg_title] - yield { - "role": "assistant", - "content": final_content.strip(), - "metadata": {"title": msg_title, "status": "done"}, - } - else: - # New done accordion (no pending state) - yield chat_msg - else: - # No title, yield as-is - yield chat_msg - else: - # No metadata, yield as plain message - yield chat_msg + # Validate token format + from src.utils.hf_error_handler import log_token_info, validate_hf_token + + log_token_info(token_value, context="research_agent") + is_valid, error_msg = validate_hf_token(token_value) + if not is_valid: + logger.warning( + "OAuth token validation failed", + error=error_msg, + oauth_token_type=type(oauth_token).__name__, + ) + return token_value + + if isinstance(oauth_token, str): + logger.debug("OAuth token extracted as string") + + # Validate token format + from src.utils.hf_error_handler import log_token_info, validate_hf_token + + log_token_info(oauth_token, context="research_agent") + return oauth_token + + logger.warning( + "OAuth token object present but token extraction failed", + oauth_token_type=type(oauth_token).__name__, + ) + return None + + +def _extract_username(oauth_profile: gr.OAuthProfile | None) -> str | None: + """Extract username from OAuth profile.""" + if oauth_profile is None: + return None + + username: str | None = None + if hasattr(oauth_profile, "username") and oauth_profile.username: + username = str(oauth_profile.username) + elif hasattr(oauth_profile, "name") and oauth_profile.name: + username = str(oauth_profile.name) + + if username: + logger.info("OAuth user authenticated", username=username) + return username + + +async def _process_multimodal_input( + message: str | MultimodalPostprocess, + enable_image_input: bool, + enable_audio_input: bool, + token_value: str | None, +) -> tuple[str, tuple[int, np.ndarray[Any, Any]] | None]: # type: ignore[type-arg] + """Process multimodal input and return processed text and audio data.""" + processed_text = "" + audio_input_data: tuple[int, np.ndarray[Any, Any]] | None = None # type: ignore[type-arg] + + if isinstance(message, dict): + processed_text = message.get("text", "") or "" + files = message.get("files", []) or [] + audio_input_data = message.get("audio") or None + + if (files and enable_image_input) or (audio_input_data is not None and enable_audio_input): + try: + multimodal_service = get_multimodal_service() + processed_text = await multimodal_service.process_multimodal_input( + processed_text, + files=files if enable_image_input else [], + audio_input=audio_input_data if enable_audio_input else None, + hf_token=token_value, + prepend_multimodal=True, + ) + except Exception as e: + logger.warning("multimodal_processing_failed", error=str(e)) + else: + processed_text = str(message) if message else "" + + return processed_text, audio_input_data async def research_agent( @@ -557,57 +444,33 @@ async def research_agent( use_graph: bool = True, enable_image_input: bool = True, enable_audio_input: bool = True, - tts_voice: str = "af_heart", - tts_speed: float = 1.0, + web_search_provider: str = "auto", oauth_token: gr.OAuthToken | None = None, oauth_profile: gr.OAuthProfile | None = None, -) -> AsyncGenerator[dict[str, Any] | tuple[dict[str, Any], tuple[int, np.ndarray] | None], None]: +) -> AsyncGenerator[dict[str, Any], None]: """ - Gradio chat function that runs the research agent. + Main research agent function that processes queries and streams results. Args: - message: User's research question (str or MultimodalPostprocess with text/files) - history: Chat history (Gradio format) - mode: Orchestrator mode ("simple" or "advanced") - hf_model: Selected HuggingFace model ID (from dropdown) - hf_provider: Selected inference provider (from dropdown) + message: User message (text, image, or audio) + history: Conversation history + mode: Orchestrator mode + hf_model: Optional HuggingFace model ID + hf_provider: Optional inference provider + graph_mode: Graph execution mode + use_graph: Whether to use graph execution + enable_image_input: Whether to process image inputs + enable_audio_input: Whether to process audio inputs + web_search_provider: Web search provider selection oauth_token: Gradio OAuth token (None if user not logged in) oauth_profile: Gradio OAuth profile (None if user not logged in) Yields: - ChatMessage objects with metadata for accordion display, optionally with audio output + Chat message dictionaries """ - import structlog - - logger = structlog.get_logger() - - # REQUIRE LOGIN BEFORE USE - # Extract OAuth token and username using Gradio's OAuth types - # According to Gradio docs: OAuthToken and OAuthProfile are None if user not logged in - token_value: str | None = None - username: str | None = None - - if oauth_token is not None: - # OAuthToken has a .token attribute containing the access token - if hasattr(oauth_token, "token"): - token_value = oauth_token.token - elif isinstance(oauth_token, str): - # Handle case where oauth_token is already a string (shouldn't happen but defensive) - token_value = oauth_token - else: - token_value = None - - if oauth_profile is not None: - # OAuthProfile has .username, .name, .profile_image attributes - username = ( - oauth_profile.username - if hasattr(oauth_profile, "username") and oauth_profile.username - else ( - oauth_profile.name - if hasattr(oauth_profile, "name") and oauth_profile.name - else None - ) - ) + # Extract OAuth token and username + token_value = _extract_oauth_token(oauth_token) + username = _extract_username(oauth_profile) # Check if user is logged in (OAuth token or env var) # Fallback to env vars for local development or Spaces with HF_TOKEN secret @@ -624,47 +487,19 @@ async def research_agent( "before using this application.\n\n" "The login button is required to access the AI models and research tools." ), - }, None + } return - # Process multimodal input (text + images + audio) - processed_text = "" - audio_input_data: tuple[int, np.ndarray] | None = None - - if isinstance(message, dict): - # MultimodalPostprocess format: {"text": str, "files": list[FileData], "audio": tuple | None} - processed_text = message.get("text", "") or "" - files = message.get("files", []) - # Check for audio input in message (Gradio may include it as a separate field) - audio_input_data = message.get("audio") or None - - # Process multimodal input (images, audio files, audio input) - # Process if we have files (and image input enabled) or audio input (and audio input enabled) - # Use UI settings from function parameters - if (files and enable_image_input) or (audio_input_data is not None and enable_audio_input): - try: - multimodal_service = get_multimodal_service() - # Prepend audio/image text to original text (prepend_multimodal=True) - # Filter files and audio based on UI settings - processed_text = await multimodal_service.process_multimodal_input( - processed_text, - files=files if enable_image_input else [], - audio_input=audio_input_data if enable_audio_input else None, - hf_token=token_value, - prepend_multimodal=True, # Prepend audio/image text to text input - ) - except Exception as e: - logger.warning("multimodal_processing_failed", error=str(e)) - # Continue with text-only input - else: - # Plain string message - processed_text = str(message) if message else "" + # Process multimodal input + processed_text, audio_input_data = await _process_multimodal_input( + message, enable_image_input, enable_audio_input, token_value + ) if not processed_text.strip(): yield { "role": "assistant", "content": "Please enter a research question or provide an image/audio input.", - }, None + } return # Check available keys (use token_value instead of oauth_token) @@ -687,56 +522,64 @@ async def research_agent( model_id = hf_model if hf_model and hf_model.strip() else None provider_name = hf_provider if hf_provider and hf_provider.strip() else None + # Log authentication source for debugging + auth_source = ( + "OAuth" + if token_value + else ( + "Env (HF_TOKEN)" + if os.getenv("HF_TOKEN") + else ("Env (HUGGINGFACE_API_KEY)" if os.getenv("HUGGINGFACE_API_KEY") else "None") + ) + ) + logger.info( + "Configuring orchestrator", + mode=effective_mode, + auth_source=auth_source, + has_oauth_token=bool(token_value), + model=model_id or "default", + provider=provider_name or "auto", + ) + + # Convert empty string to None for web_search_provider + web_search_provider_value = ( + web_search_provider if web_search_provider and web_search_provider.strip() else None + ) + orchestrator, backend_name = configure_orchestrator( use_mock=False, # Never use mock in production - HF Inference is the free fallback mode=effective_mode, - oauth_token=token_value, # Use extracted token value + oauth_token=token_value, # Use extracted token value - passed to all agents and services hf_model=model_id, # None will use defaults in configure_orchestrator hf_provider=provider_name, # None will use defaults in configure_orchestrator graph_mode=graph_mode if graph_mode else None, use_graph=use_graph, + web_search_provider=web_search_provider_value, # None will use settings default ) yield { "role": "assistant", - "content": f"🧠 **Backend**: {backend_name}\n\n", + "content": f"πŸ”§ **Backend**: {backend_name}\n\nProcessing your query...", } - # Handle orchestrator events and generate audio output - audio_output_data: tuple[int, np.ndarray] | None = None - final_message = "" - - async for msg in handle_orchestrator_events(orchestrator, processed_text): - # Track final message for TTS - if isinstance(msg, dict) and msg.get("role") == "assistant": + # Convert history to ModelMessage format if needed + message_history: list[ModelMessage] = [] + if history: + for msg in history: + role = msg.get("role", "user") content = msg.get("content", "") - metadata = msg.get("metadata", {}) - # This is the main response (not an accordion) if no title in metadata - if content and not metadata.get("title"): - final_message = content - - # Yield without audio for intermediate messages - yield msg, None - - # Generate audio output for final response - if final_message and settings.enable_audio_output: - try: - audio_service = get_audio_service() - # Use UI-configured voice and speed, fallback to settings defaults - audio_output_data = await audio_service.generate_audio_output( - final_message, - voice=tts_voice or settings.tts_voice, - speed=tts_speed if tts_speed else settings.tts_speed, - ) - except Exception as e: - logger.warning("audio_synthesis_failed", error=str(e)) - # Continue without audio output + if isinstance(content, str) and content.strip(): + message_history.append(ModelMessage(role=role, content=content)) # type: ignore[operator] + + # Run orchestrator and stream events + async for event in orchestrator.run( + processed_text, message_history=message_history if message_history else None + ): + chat_msg = event_to_chat_message(event) + yield chat_msg - # If we have audio output, we need to yield it with the final message - # Note: The final message was already yielded above, so we yield None, audio_output_data - # This will update the audio output component - if audio_output_data is not None: - yield None, audio_output_data + # Note: Audio output is now handled via on-demand TTS button + # Users click "Generate Audio" button to create TTS for the last response except Exception as e: # Return error message without metadata to avoid issues during example caching @@ -747,7 +590,110 @@ async def research_agent( yield { "role": "assistant", "content": f"Error: {error_msg}. Please check your configuration and try again.", - }, None + } + + +async def update_model_provider_dropdowns( + oauth_token: gr.OAuthToken | None = None, + oauth_profile: gr.OAuthProfile | None = None, +) -> tuple[dict[str, Any], dict[str, Any], str]: + """Update model and provider dropdowns based on OAuth token. + + This function is called when OAuth token/profile changes (user logs in/out). + It queries HuggingFace API to get available models and providers. + + Args: + oauth_token: Gradio OAuth token + oauth_profile: Gradio OAuth profile + + Returns: + Tuple of (model_dropdown_update, provider_dropdown_update, status_message) + """ + from src.utils.hf_model_validator import ( + get_available_models, + get_available_providers, + validate_oauth_token, + ) + + # Extract token value + token_value: str | None = None + if oauth_token is not None: + if hasattr(oauth_token, "token"): + token_value = oauth_token.token + elif isinstance(oauth_token, str): + token_value = oauth_token + + # Default values (empty = use default) + default_models = [""] + default_providers = [""] + status_msg = "⚠️ Not authenticated - using default models" + + if not token_value: + # No token - return defaults + return ( + gr.update(choices=default_models, value=""), + gr.update(choices=default_providers, value=""), + status_msg, + ) + + try: + # Validate token and get available resources + validation_result = await validate_oauth_token(token_value) + + if not validation_result["is_valid"]: + status_msg = ( + f"❌ Token validation failed: {validation_result.get('error', 'Unknown error')}" + ) + return ( + gr.update(choices=default_models, value=""), + gr.update(choices=default_providers, value=""), + status_msg, + ) + + # Get available models and providers + models = await get_available_models(token=token_value, limit=50) + providers = await get_available_providers(token=token_value) + + # Combine with defaults + model_choices = ["", *models[:49]] # Keep first 49 + empty option + provider_choices = providers # Already includes "auto" + + username = validation_result.get("username", "User") + + # Build status message with warning if scope is missing + scope_warning = "" + if not validation_result["has_inference_api_scope"]: + scope_warning = ( + "⚠️ Token may not have 'inference-api' scope - some models may not work\n\n" + ) + + status_msg = ( + f"{scope_warning}βœ… Authenticated as {username}\n\n" + f"πŸ“Š Found {len(models)} available models\n" + f"πŸ”§ Found {len(providers)} available providers" + ) + + logger.info( + "Updated model/provider dropdowns", + model_count=len(model_choices), + provider_count=len(provider_choices), + username=username, + ) + + return ( + gr.update(choices=model_choices, value=""), + gr.update(choices=provider_choices, value=""), + status_msg, + ) + + except Exception as e: + logger.error("Failed to update dropdowns", error=str(e)) + status_msg = f"⚠️ Failed to load models: {e!s}" + return ( + gr.update(choices=default_models, value=""), + gr.update(choices=default_providers, value=""), + status_msg, + ) def create_demo() -> gr.Blocks: @@ -768,24 +714,31 @@ def create_demo() -> gr.Blocks: ) gr.LoginButton("Sign in with Hugging Face") gr.Markdown("---") - gr.Markdown("### ℹ️ About") # noqa: RUF001 - gr.Markdown( - "**The DETERMINATOR** - Generalist Deep Research Agent\n\n" - "A powerful research agent that stops at nothing until finding precise answers to complex questions.\n\n" - "**Available Sources**:\n" - "- Web Search (general knowledge)\n" - "- PubMed (biomedical literature)\n" - "- ClinicalTrials.gov (clinical trials)\n" - "- Europe PMC (preprints & papers)\n" - "- RAG (semantic search)\n\n" - "**Automatic Detection**: Automatically determines if medical knowledge sources are needed for your query.\n\n" - "⚠️ **Research tool only** - Synthesizes evidence but cannot provide medical advice." - ) + + # About Section - Collapsible with details + with gr.Accordion("ℹ️ About", open=False): + gr.Markdown( + "**The DETERMINATOR** - Generalist Deep Research Agent\n\n" + "Stops at nothing until finding precise answers to complex questions.\n\n" + "**How It Works**:\n" + "- πŸ” Multi-source search (Web, PubMed, ClinicalTrials.gov, Europe PMC, RAG)\n" + "- 🧠 Automatic medical knowledge detection\n" + "- πŸ”„ Iterative refinement with search-judge loops\n" + "- ⏹️ Continues until budget/time/iteration limits\n" + "- πŸ“Š Evidence synthesis with citations\n\n" + "**Multimodal Input**:\n" + "- πŸ“· **Images**: Click image icon in textbox (OCR)\n" + "- 🎀 **Audio**: Click microphone icon (speech-to-text)\n" + "- πŸ“„ **Files**: Drag & drop or click to upload\n\n" + "**MCP Server**: Connect Claude Desktop to `/gradio_api/mcp/`\n\n" + "⚠️ **Research tool only** - Synthesizes evidence but cannot provide medical advice." + ) + gr.Markdown("---") - + # Settings Section - Organized in Accordions gr.Markdown("## βš™οΈ Settings") - + # Research Configuration Accordion with gr.Accordion("πŸ”¬ Research Configuration", open=True): mode_radio = gr.Radio( @@ -800,24 +753,30 @@ def create_demo() -> gr.Blocks: "Auto: Smart routing" ), ) - + graph_mode_radio = gr.Radio( choices=["iterative", "deep", "auto"], value="auto", label="Graph Research Mode", info="Iterative: Single loop | Deep: Parallel sections | Auto: Detect from query", ) - + use_graph_checkbox = gr.Checkbox( value=True, label="Use Graph Execution", info="Enable graph-based workflow execution", ) - + # Model and Provider selection gr.Markdown("### πŸ€– Model & Provider") - - # Popular models list + + # Status message for model/provider loading + model_provider_status = gr.Markdown( + value="⚠️ Sign in to see available models and providers", + visible=True, + ) + + # Popular models list (will be updated by validator) popular_models = [ "", # Empty = use default "Qwen/Qwen3-Next-80B-A3B-Thinking", @@ -828,16 +787,16 @@ def create_demo() -> gr.Blocks: "mistralai/Mistral-7B-Instruct-v0.2", "google/gemma-2-9b-it", ] - + hf_model_dropdown = gr.Dropdown( choices=popular_models, value="", # Empty string - will be converted to None in research_agent label="Reasoning Model", - info="Select a HuggingFace model (leave empty for default)", + info="Select a HuggingFace model (leave empty for default). Sign in to see all available models.", allow_custom_value=True, # Allow users to type custom model IDs ) - # Provider list from README + # Provider list from README (will be updated by validator) providers = [ "", # Empty string = auto-select "nebius", @@ -850,131 +809,411 @@ def create_demo() -> gr.Blocks: "ovh", "fireworks", ] - + hf_provider_dropdown = gr.Dropdown( choices=providers, value="", # Empty string - will be converted to None in research_agent label="Inference Provider", - info="Select inference provider (leave empty for auto-select)", + info="Select inference provider (leave empty for auto-select). Sign in to see all available providers.", + ) + + # Refresh button for updating models/providers after login + def refresh_models_and_providers( + request: gr.Request, + ) -> tuple[dict[str, Any], dict[str, Any], str]: + """Handle refresh button click and update dropdowns.""" + import asyncio + + # Extract OAuth token and profile from request + oauth_token: gr.OAuthToken | None = None + oauth_profile: gr.OAuthProfile | None = None + + if request is not None: + # Try to get OAuth token from request + if hasattr(request, "oauth_token"): + oauth_token = request.oauth_token + if hasattr(request, "oauth_profile"): + oauth_profile = request.oauth_profile + + # Run async function in sync context + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + result = loop.run_until_complete( + update_model_provider_dropdowns(oauth_token, oauth_profile) + ) + return result + finally: + loop.close() + + refresh_models_btn = gr.Button( + value="πŸ”„ Refresh Available Models", + visible=True, + size="sm", ) - - # Multimodal Input Configuration Accordion - with gr.Accordion("πŸ“· Multimodal Input", open=False): + + # Pass request to get OAuth token from Gradio context + refresh_models_btn.click( + fn=refresh_models_and_providers, + inputs=[], # Request is automatically available in Gradio context + outputs=[hf_model_dropdown, hf_provider_dropdown, model_provider_status], + ) + + # Web Search Provider selection + gr.Markdown("### πŸ” Web Search Provider") + + # Available providers with labels indicating availability + # Format: (display_label, value) - Gradio Dropdown supports tuples + web_search_provider_options = [ + ("Auto-detect (Recommended)", "auto"), + ("Serper (Google Search + Full Content)", "serper"), + ("DuckDuckGo (Free, Snippets Only)", "duckduckgo"), + ("SearchXNG (Self-hosted) - Coming Soon", "searchxng"), # Not fully implemented + ("Brave - Coming Soon", "brave"), # Not implemented + ("Tavily - Coming Soon", "tavily"), # Not implemented + ] + + # Create Dropdown with label-value pairs + # Gradio will display labels but return values + # Disabled options are marked with "Coming Soon" in the label + # The factory will handle "not implemented" cases gracefully + web_search_provider_dropdown = gr.Dropdown( + choices=web_search_provider_options, + value="auto", + label="Web Search Provider", + info="Select web search provider. 'Auto' detects best available.", + ) + + # Multimodal Input Configuration + gr.Markdown("### πŸ“·πŸŽ€ Multimodal Input") + enable_image_input_checkbox = gr.Checkbox( value=settings.enable_image_input, label="Enable Image Input (OCR)", - info="Extract text from uploaded images using OCR", + info="Process uploaded images with OCR", ) - + enable_audio_input_checkbox = gr.Checkbox( value=settings.enable_audio_input, label="Enable Audio Input (STT)", - info="Transcribe audio recordings using speech-to-text", - ) - - # Audio/TTS Configuration Accordion - with gr.Accordion("πŸ”Š Audio Output", open=False): - enable_audio_output_checkbox = gr.Checkbox( - value=settings.enable_audio_output, - label="Enable Audio Output", - info="Generate audio responses using TTS", + info="Process uploaded/recorded audio with speech-to-text", ) - - tts_voice_dropdown = gr.Dropdown( - choices=[ - "af_heart", - "af_bella", - "af_nicole", - "af_aoede", - "af_kore", - "af_sarah", - "af_nova", - "af_sky", - "af_alloy", - "af_jessica", - "af_river", - "am_michael", - "am_fenrir", - "am_puck", - "am_echo", - "am_eric", - "am_liam", - "am_onyx", - "am_santa", - "am_adam", - ], - value=settings.tts_voice, - label="TTS Voice", - info="Select TTS voice (American English voices: af_*, am_*)", + + # Audio Output Configuration - Collapsible + with gr.Accordion("πŸ”Š Audio Output (TTS)", open=False): + gr.Markdown( + "**Generate audio for research responses on-demand.**\n\n" + "Enter Modal keys below or set `MODAL_TOKEN_ID`/`MODAL_TOKEN_SECRET` in `.env` for local development." ) - - tts_speed_slider = gr.Slider( - minimum=0.5, - maximum=2.0, - value=settings.tts_speed, - step=0.1, - label="TTS Speech Speed", - info="Adjust TTS speech speed (0.5x to 2.0x)", + + with gr.Accordion("πŸ”‘ Modal Credentials (Optional)", open=False): + modal_token_id_input = gr.Textbox( + label="Modal Token ID", + placeholder="ak-... (leave empty to use .env)", + type="password", + value="", + ) + + modal_token_secret_input = gr.Textbox( + label="Modal Token Secret", + placeholder="as-... (leave empty to use .env)", + type="password", + value="", + ) + + with gr.Accordion("🎚️ Voice & Quality Settings", open=False): + tts_voice_dropdown = gr.Dropdown( + choices=[ + "af_heart", + "af_bella", + "af_sarah", + "af_sky", + "af_nova", + "af_shimmer", + "af_echo", + "af_fable", + "af_onyx", + "af_angel", + "af_asteria", + "af_jessica", + "af_elli", + "af_domi", + "af_gigi", + "af_freya", + "af_glinda", + "af_cora", + "af_serena", + "af_liv", + "af_naomi", + "af_rachel", + "af_antoni", + "af_thomas", + "af_charlie", + "af_emily", + "af_george", + "af_arnold", + "af_adam", + "af_sam", + "af_paul", + "af_josh", + "af_daniel", + "af_liam", + "af_dave", + "af_fin", + "af_sarah", + "af_glinda", + "af_grace", + "af_dorothy", + "af_michael", + "af_james", + "af_joseph", + "af_jeremy", + "af_ryan", + "af_oliver", + "af_harry", + "af_kyle", + "af_leo", + "af_otto", + "af_owen", + "af_pepper", + "af_phil", + "af_raven", + "af_rocky", + "af_rusty", + "af_serena", + "af_sky", + "af_spark", + "af_stella", + "af_storm", + "af_taylor", + "af_vera", + "af_will", + "af_aria", + "af_ash", + "af_ballad", + "af_bella", + "af_breeze", + "af_cove", + "af_dusk", + "af_ember", + "af_flash", + "af_flow", + "af_glow", + "af_harmony", + "af_journey", + "af_lullaby", + "af_lyra", + "af_melody", + "af_midnight", + "af_moon", + "af_muse", + "af_music", + "af_narrator", + "af_nightingale", + "af_poet", + "af_rain", + "af_redwood", + "af_rewind", + "af_river", + "af_sage", + "af_seashore", + "af_shadow", + "af_silver", + "af_song", + "af_starshine", + "af_story", + "af_summer", + "af_sun", + "af_thunder", + "af_tide", + "af_time", + "af_valentino", + "af_verdant", + "af_verse", + "af_vibrant", + "af_vivid", + "af_warmth", + "af_whisper", + "af_wilderness", + "af_willow", + "af_winter", + "af_wit", + "af_witness", + "af_wren", + "af_writer", + "af_zara", + "af_zeus", + "af_ziggy", + "af_zoom", + "af_river", + "am_michael", + "am_fenrir", + "am_puck", + "am_echo", + "am_eric", + "am_liam", + "am_onyx", + "am_santa", + "am_adam", + ], + value=settings.tts_voice, + label="TTS Voice", + info="Select TTS voice (American English voices: af_*, am_*)", + ) + + tts_speed_slider = gr.Slider( + minimum=0.5, + maximum=2.0, + value=settings.tts_speed, + step=0.1, + label="TTS Speech Speed", + info="Adjust TTS speech speed (0.5x to 2.0x)", + ) + + gr.Dropdown( + choices=["T4", "A10", "A100", "L4", "L40S"], + value=settings.tts_gpu or "T4", + label="TTS GPU Type", + info="Modal GPU type for TTS (T4 is cheapest, A100 is fastest). Note: GPU changes require app restart.", + visible=settings.modal_available, + interactive=False, # GPU type set at function definition time, requires restart + ) + + tts_use_llm_polish_checkbox = gr.Checkbox( + value=settings.tts_use_llm_polish, + label="Use LLM Polish for Audio", + info="Apply LLM-based final polish to remove remaining formatting artifacts (costs API calls)", + ) + + tts_generate_button = gr.Button( + "🎡 Generate Audio for Last Response", + variant="primary", + size="lg", ) - - tts_gpu_dropdown = gr.Dropdown( - choices=["T4", "A10", "A100", "L4", "L40S"], - value=settings.tts_gpu or "T4", - label="TTS GPU Type", - info="Modal GPU type for TTS (T4 is cheapest, A100 is fastest). Note: GPU changes require app restart.", - visible=settings.modal_available, - interactive=False, # GPU type set at function definition time, requires restart + + tts_status_text = gr.Markdown( + "Click the button above to generate audio for the last research response.", + elem_classes="tts-status", ) - - # Audio output component (for TTS response) - moved to sidebar + + # Audio output component (for TTS response) audio_output = gr.Audio( - label="πŸ”Š Audio Response", - visible=settings.enable_audio_output, + label="πŸ”Š Audio Output", + visible=True, ) - - # Update TTS component visibility based on enable_audio_output_checkbox - # This must be after audio_output is defined - def update_tts_visibility(enabled: bool) -> tuple[dict[str, Any], dict[str, Any], dict[str, Any]]: - """Update visibility of TTS components based on enable checkbox.""" - return ( - gr.update(visible=enabled), - gr.update(visible=enabled), - gr.update(visible=enabled), + + # TTS on-demand generation handler + async def handle_tts_generation( + history: list[dict[str, Any]], + modal_token_id: str, + modal_token_secret: str, + voice: str, + speed: float, + use_llm_polish: bool, + ) -> tuple[Any | None, str]: + """Generate audio on-demand for the last response. + + Args: + history: Chat history + modal_token_id: Modal token ID from UI + modal_token_secret: Modal token secret from UI + voice: TTS voice selection + speed: TTS speed + use_llm_polish: Enable LLM polish + + Returns: + Tuple of (audio_output, status_message) + """ + from src.services.tts_modal import generate_audio_on_demand + + # Get last assistant message from history + # History is a list of tuples: [(user_msg, assistant_msg), ...] + if not history: + logger.warning("tts_no_history", history=history) + return None, "❌ No messages in history to generate audio for" + + # Debug: Log history format + logger.info( + "tts_history_debug", + history_type=type(history).__name__, + history_length=len(history) if isinstance(history, list) else 0, + first_entry_type=type(history[0]).__name__ + if isinstance(history, list) and len(history) > 0 + else None, + first_entry_sample=str(history[0])[:200] + if isinstance(history, list) and len(history) > 0 + else None, ) - - enable_audio_output_checkbox.change( - fn=update_tts_visibility, - inputs=[enable_audio_output_checkbox], - outputs=[tts_voice_dropdown, tts_speed_slider, audio_output], - ) + + # Get the last assistant message (second element of last tuple) + last_message = None + if isinstance(history, list) and len(history) > 0: + last_entry = history[-1] + # ChatInterface format: (user_message, assistant_message) + if isinstance(last_entry, (tuple, list)) and len(last_entry) >= 2: + last_message = last_entry[1] + logger.info( + "tts_extracted_from_tuple", message_type=type(last_message).__name__ + ) + # Dict format: {"role": "assistant", "content": "..."} + elif isinstance(last_entry, dict): + if last_entry.get("role") == "assistant": + content = last_entry.get("content", "") + # Content might be a list (multimodal) or string + if isinstance(content, list): + # Extract text from multimodal content list + last_message = " ".join(str(item) for item in content if item) + else: + last_message = content + logger.info( + "tts_extracted_from_dict", + message_type=type(content).__name__, + message_length=len(last_message) + if isinstance(last_message, str) + else 0, + ) + else: + logger.warning( + "tts_unknown_format", + entry_type=type(last_entry).__name__, + entry=str(last_entry)[:200], + ) + + # Also handle if last_message itself is a list + if isinstance(last_message, list): + last_message = " ".join(str(item) for item in last_message if item) + + if not last_message or not isinstance(last_message, str) or not last_message.strip(): + logger.error( + "tts_no_message_found", + last_message_type=type(last_message).__name__ if last_message else None, + last_message_value=str(last_message)[:100] if last_message else None, + ) + return None, "❌ No assistant response found in history" + + # Generate audio + audio_output, status_message = await generate_audio_on_demand( + text=last_message, + modal_token_id=modal_token_id, + modal_token_secret=modal_token_secret, + voice=voice, + speed=speed, + use_llm_polish=use_llm_polish, + ) + + return audio_output, status_message # Chat interface with multimodal support # Examples are provided but will NOT run at startup (cache_examples=False) # Users must log in first before using examples or submitting queries - gr.ChatInterface( + chat_interface = gr.ChatInterface( fn=research_agent, multimodal=True, # Enable multimodal input (text + images + audio) title="πŸ”¬ The DETERMINATOR", description=( - "*Generalist Deep Research Agent β€” stops at nothing until finding precise answers to complex questions*\n\n" - "---\n" - "**The DETERMINATOR** uses iterative search-and-judge loops to comprehensively investigate any research question. " - "It automatically determines if medical knowledge sources (PubMed, ClinicalTrials.gov) are needed and adapts its search strategy accordingly.\n\n" - "**Key Features**:\n" - "- πŸ” Multi-source search (Web, PubMed, ClinicalTrials.gov, Europe PMC, RAG)\n" - "- 🧠 Automatic medical knowledge detection\n" - "- πŸ”„ Iterative refinement until precise answers are found\n" - "- ⏹️ Stops only at configured limits (budget, time, iterations)\n" - "- πŸ“Š Evidence synthesis with citations\n\n" - "**MCP Server Active**: Connect Claude Desktop to `/gradio_api/mcp/`\n\n" - "**πŸ“·πŸŽ€ Multimodal Input Support**:\n" - "- **Images**: Click the πŸ“· image icon in the textbox to upload images (OCR)\n" - "- **Audio**: Click the 🎀 microphone icon in the textbox to record audio (STT)\n" - "- **Files**: Drag & drop or click to upload image/audio files\n" - "- **Text**: Type your research questions directly\n\n" - "πŸ’‘ **Tip**: Look for the πŸ“· and 🎀 icons in the text input box below!\n\n" - "Configure multimodal inputs in the sidebar settings.\n\n" - "**⚠️ Authentication Required**: Please **sign in with HuggingFace** above before using this application." + "*Generalist Deep Research Agent β€” stops at nothing until finding precise answers*\n\n" + "πŸ’‘ **Quick Start**: Type your research question below. Use πŸ“· for images, 🎀 for audio.\n\n" + "⚠️ **Sign in with HuggingFace** (sidebar) before starting." ), examples=[ # When additional_inputs are provided, examples must be lists of lists @@ -997,24 +1236,38 @@ def update_tts_visibility(enabled: bool) -> tuple[dict[str, Any], dict[str, Any] "Analyze the current state of quantum computing architectures: compare different qubit technologies, error correction methods, and scalability challenges across major platforms including IBM, Google, and IonQ.", "deep", "Qwen/Qwen3-Next-80B-A3B-Thinking", - "", + "nebius", "deep", True, ], [ - # Business/Scientific example requiring iterative search - "Investigate the economic and environmental impact of renewable energy transition: analyze cost trends, grid integration challenges, policy frameworks, and market dynamics across solar, wind, and battery storage technologies, in china", + # Historical/Social Science example + "Research and synthesize information about the economic impact of the Industrial Revolution on European social structures, including changes in class dynamics, urbanization patterns, and labor movements from 1750-1900.", + "deep", + "meta-llama/Llama-3.1-70B-Instruct", + "together", + "deep", + True, + ], + [ + # Scientific/Physics example + "Investigate the latest developments in fusion energy research: compare ITER, SPARC, and other major projects, analyze recent breakthroughs in plasma confinement, and assess the timeline to commercial fusion power.", "deep", "Qwen/Qwen3-235B-A22B-Instruct-2507", - "", + "hyperbolic", + "deep", + True, + ], + [ + # Technology/Business example + "Research the competitive landscape of AI chip manufacturers: analyze NVIDIA, AMD, Intel, and emerging players, compare architectures (GPU vs. TPU vs. NPU), and assess market positioning and future trends.", + "deep", + "zai-org/GLM-4.5-Air", + "fireworks", "deep", True, ], ], - cache_examples=False, # CRITICAL: Disable example caching to prevent examples from running at startup - # Examples will only run when user explicitly clicks them (after login) - # Note: additional_inputs_accordion is not a valid parameter in Gradio 6.0 ChatInterface - # Components will be displayed in the order provided additional_inputs=[ mode_radio, hf_model_dropdown, @@ -1023,28 +1276,29 @@ def update_tts_visibility(enabled: bool) -> tuple[dict[str, Any], dict[str, Any] use_graph_checkbox, enable_image_input_checkbox, enable_audio_input_checkbox, + web_search_provider_dropdown, + # Note: gr.OAuthToken and gr.OAuthProfile are automatically passed as function parameters + ], + cache_examples=False, # Don't cache examples - requires authentication + ) + + # Wire up TTS generation button + tts_generate_button.click( + fn=handle_tts_generation, + inputs=[ + chat_interface.chatbot, # Get chat history from ChatInterface + modal_token_id_input, + modal_token_secret_input, tts_voice_dropdown, tts_speed_slider, - # Note: gr.OAuthToken and gr.OAuthProfile are automatically passed as function parameters - # when user is logged in - they should NOT be added to additional_inputs + tts_use_llm_polish_checkbox, ], - additional_outputs=[audio_output], # Add audio output for TTS + outputs=[audio_output, tts_status_text], ) return demo # type: ignore[no-any-return] -def main() -> None: - """Run the Gradio app with MCP server enabled.""" - demo = create_demo() - demo.launch( - # server_name="0.0.0.0", - # server_port=7860, - # share=False, - mcp_server=True, # Enable MCP server for Claude Desktop integration - ssr_mode=False, # Fix for intermittent loading/hydration issues in HF Spaces - ) - - if __name__ == "__main__": - main() + demo = create_demo() + demo.launch(server_name="0.0.0.0", server_port=7860) diff --git a/src/legacy_orchestrator.py b/src/legacy_orchestrator.py index ac1ee46c..b41ba8aa 100644 --- a/src/legacy_orchestrator.py +++ b/src/legacy_orchestrator.py @@ -6,6 +6,11 @@ import structlog +try: + from pydantic_ai import ModelMessage +except ImportError: + ModelMessage = Any # type: ignore[assignment, misc] + from src.utils.config import settings from src.utils.models import ( AgentEvent, @@ -153,7 +158,9 @@ async def _run_analysis_phase( iteration=iteration, ) - async def run(self, query: str) -> AsyncGenerator[AgentEvent, None]: # noqa: PLR0915 + async def run( + self, query: str, message_history: list[ModelMessage] | None = None + ) -> AsyncGenerator[AgentEvent, None]: # noqa: PLR0915 """ Run the agent loop for a query. @@ -161,11 +168,16 @@ async def run(self, query: str) -> AsyncGenerator[AgentEvent, None]: # noqa: PL Args: query: The user's research question + message_history: Optional user conversation history (for compatibility) Yields: AgentEvent objects for each step of the process """ - logger.info("Starting orchestrator", query=query) + logger.info( + "Starting orchestrator", + query=query, + has_history=bool(message_history), + ) yield AgentEvent( type="started", diff --git a/src/mcp_tools.py b/src/mcp_tools.py index fdd15ac0..7b59f9be 100644 --- a/src/mcp_tools.py +++ b/src/mcp_tools.py @@ -242,7 +242,6 @@ async def extract_text_from_image( Extracted text from the image """ from src.services.image_ocr import get_image_ocr_service - from src.utils.config import settings try: @@ -280,7 +279,6 @@ async def transcribe_audio_file( Transcribed text from the audio file """ from src.services.stt_gradio import get_stt_service - from src.utils.config import settings try: @@ -300,4 +298,4 @@ async def transcribe_audio_file( return f"## Audio Transcription\n\n{transcribed_text}" except Exception as e: - return f"Error transcribing audio: {e}" \ No newline at end of file + return f"Error transcribing audio: {e}" diff --git a/src/middleware/state_machine.py b/src/middleware/state_machine.py index d43e131e..8fbdf793 100644 --- a/src/middleware/state_machine.py +++ b/src/middleware/state_machine.py @@ -11,6 +11,11 @@ import structlog from pydantic import BaseModel, Field +try: + from pydantic_ai import ModelMessage +except ImportError: + ModelMessage = Any # type: ignore[assignment, misc] + from src.utils.models import Citation, Conversation, Evidence if TYPE_CHECKING: @@ -28,6 +33,10 @@ class WorkflowState(BaseModel): evidence: list[Evidence] = Field(default_factory=list) conversation: Conversation = Field(default_factory=Conversation) + user_message_history: list[ModelMessage] = Field( + default_factory=list, + description="User conversation history (multi-turn interactions)", + ) # Type as Any to avoid circular imports/runtime resolution issues # The actual object injected will be an EmbeddingService instance embedding_service: Any = Field(default=None) @@ -90,6 +99,31 @@ async def search_related(self, query: str, n_results: int = 5) -> list[Evidence] return evidence_list + def add_user_message(self, message: ModelMessage) -> None: + """Add a user message to conversation history. + + Args: + message: Message to add + """ + self.user_message_history.append(message) + + def get_user_history(self, max_messages: int | None = None) -> list[ModelMessage]: + """Get user conversation history. + + Args: + max_messages: Maximum messages to return (None for all) + + Returns: + List of messages + """ + if max_messages is None: + return self.user_message_history.copy() + return ( + self.user_message_history[-max_messages:] + if len(self.user_message_history) > max_messages + else self.user_message_history.copy() + ) + # The ContextVar holds the WorkflowState for the current execution context _workflow_state_var: ContextVar[WorkflowState | None] = ContextVar("workflow_state", default=None) @@ -97,18 +131,26 @@ async def search_related(self, query: str, n_results: int = 5) -> list[Evidence] def init_workflow_state( embedding_service: "EmbeddingService | None" = None, + message_history: list[ModelMessage] | None = None, ) -> WorkflowState: """Initialize a new state for the current context. Args: embedding_service: Optional embedding service for semantic search. + message_history: Optional user conversation history. Returns: The initialized WorkflowState instance. """ state = WorkflowState(embedding_service=embedding_service) + if message_history: + state.user_message_history = message_history.copy() _workflow_state_var.set(state) - logger.debug("Workflow state initialized", has_embeddings=embedding_service is not None) + logger.debug( + "Workflow state initialized", + has_embeddings=embedding_service is not None, + has_history=bool(message_history), + ) return state @@ -126,14 +168,4 @@ def get_workflow_state() -> WorkflowState: # Auto-initialize if missing (e.g. during tests or simple scripts) logger.debug("Workflow state not found, auto-initializing") return init_workflow_state() - return state - - - - - - - - - - + return state \ No newline at end of file diff --git a/src/orchestrator/graph_orchestrator.py b/src/orchestrator/graph_orchestrator.py index 82650b9f..d71a337d 100644 --- a/src/orchestrator/graph_orchestrator.py +++ b/src/orchestrator/graph_orchestrator.py @@ -10,6 +10,11 @@ import structlog +try: + from pydantic_ai import ModelMessage +except ImportError: + ModelMessage = Any # type: ignore[assignment, misc] + from src.agent_factory.agents import ( create_input_parser_agent, create_knowledge_gap_agent, @@ -44,12 +49,18 @@ class GraphExecutionContext: """Context for managing graph execution state.""" - def __init__(self, state: WorkflowState, budget_tracker: BudgetTracker) -> None: + def __init__( + self, + state: WorkflowState, + budget_tracker: BudgetTracker, + message_history: list[ModelMessage] | None = None, + ) -> None: """Initialize execution context. Args: state: Current workflow state budget_tracker: Budget tracker instance + message_history: Optional user conversation history """ self.current_node: str = "" self.visited_nodes: set[str] = set() @@ -57,6 +68,7 @@ def __init__(self, state: WorkflowState, budget_tracker: BudgetTracker) -> None: self.state = state self.budget_tracker = budget_tracker self.iteration_count = 0 + self.message_history: list[ModelMessage] = message_history or [] def set_node_result(self, node_id: str, result: Any) -> None: """Store result from node execution. @@ -108,6 +120,31 @@ def update_state( """ self.state = updater(self.state, data) + def add_message(self, message: ModelMessage) -> None: + """Add a message to the history. + + Args: + message: Message to add + """ + self.message_history.append(message) + + def get_message_history(self, max_messages: int | None = None) -> list[ModelMessage]: + """Get message history, optionally truncated. + + Args: + max_messages: Maximum messages to return (None for all) + + Returns: + List of messages + """ + if max_messages is None: + return self.message_history.copy() + return ( + self.message_history[-max_messages:] + if len(self.message_history) > max_messages + else self.message_history.copy() + ) + class GraphOrchestrator: """ @@ -174,12 +211,15 @@ def _get_file_service(self) -> ReportFileService | None: return None return self._file_service - async def run(self, query: str) -> AsyncGenerator[AgentEvent, None]: + async def run( + self, query: str, message_history: list[ModelMessage] | None = None + ) -> AsyncGenerator[AgentEvent, None]: """ Run the research workflow. Args: query: The user's research query + message_history: Optional user conversation history Yields: AgentEvent objects for real-time UI updates @@ -189,6 +229,7 @@ async def run(self, query: str) -> AsyncGenerator[AgentEvent, None]: query=query[:100], mode=self.mode, use_graph=self.use_graph, + has_history=bool(message_history), ) yield AgentEvent( @@ -205,10 +246,10 @@ async def run(self, query: str) -> AsyncGenerator[AgentEvent, None]: # Use graph execution if enabled, otherwise fall back to agent chains if self.use_graph: - async for event in self._run_with_graph(query, research_mode): + async for event in self._run_with_graph(query, research_mode, message_history): yield event else: - async for event in self._run_with_chains(query, research_mode): + async for event in self._run_with_chains(query, research_mode, message_history): yield event except Exception as e: @@ -220,13 +261,17 @@ async def run(self, query: str) -> AsyncGenerator[AgentEvent, None]: ) async def _run_with_graph( - self, query: str, research_mode: Literal["iterative", "deep"] + self, + query: str, + research_mode: Literal["iterative", "deep"], + message_history: list[ModelMessage] | None = None, ) -> AsyncGenerator[AgentEvent, None]: """Run workflow using graph execution. Args: query: The research query research_mode: The research mode + message_history: Optional user conversation history Yields: AgentEvent objects @@ -235,7 +280,10 @@ async def _run_with_graph( from src.services.embeddings import get_embedding_service embedding_service = get_embedding_service() - state = init_workflow_state(embedding_service=embedding_service) + state = init_workflow_state( + embedding_service=embedding_service, + message_history=message_history, + ) budget_tracker = BudgetTracker() budget_tracker.create_budget( loop_id="graph_execution", @@ -245,7 +293,11 @@ async def _run_with_graph( ) budget_tracker.start_timer("graph_execution") - context = GraphExecutionContext(state, budget_tracker) + context = GraphExecutionContext( + state, + budget_tracker, + message_history=message_history or [], + ) # Build graph self._graph = await self._build_graph(research_mode) @@ -255,13 +307,17 @@ async def _run_with_graph( yield event async def _run_with_chains( - self, query: str, research_mode: Literal["iterative", "deep"] + self, + query: str, + research_mode: Literal["iterative", "deep"], + message_history: list[ModelMessage] | None = None, ) -> AsyncGenerator[AgentEvent, None]: """Run workflow using agent chains (backward compatibility). Args: query: The research query research_mode: The research mode + message_history: Optional user conversation history Yields: AgentEvent objects @@ -282,7 +338,9 @@ async def _run_with_chains( ) try: - final_report = await self._iterative_flow.run(query) + final_report = await self._iterative_flow.run( + query, message_history=message_history + ) except Exception as e: self.logger.error("Iterative flow failed", error=str(e), exc_info=True) # Yield error event - outer handler will also catch and yield error event @@ -318,7 +376,7 @@ async def _run_with_chains( ) try: - final_report = await self._deep_flow.run(query) + final_report = await self._deep_flow.run(query, message_history=message_history) except Exception as e: self.logger.error("Deep flow failed", error=str(e), exc_info=True) # Yield error event before re-raising so test can capture it @@ -488,73 +546,17 @@ def _emit_completion_event( iteration=iteration, ) - async def _execute_graph( - self, query: str, context: GraphExecutionContext - ) -> AsyncGenerator[AgentEvent, None]: - """Execute the graph from entry node. - - Args: - query: The research query - context: Execution context - - Yields: - AgentEvent objects - """ + def _get_final_result_from_exit_nodes( + self, context: GraphExecutionContext, current_node_id: str | None + ) -> tuple[Any, str | None]: + """Get final result from exit nodes, prioritizing synthesizer/writer.""" if not self._graph: - raise ValueError("Graph not built") + return None, current_node_id - current_node_id = self._graph.entry_node - iteration = 0 - - # Execute nodes until we reach an exit node - while current_node_id: - # Check budget - if not context.budget_tracker.can_continue("graph_execution"): - self.logger.warning("Budget exceeded, exiting graph execution") - break - - # Execute current node - iteration += 1 - context.current_node = current_node_id - node = self._graph.get_node(current_node_id) - - # Emit start event - yield self._emit_start_event(node, current_node_id, iteration, context) - - try: - result = await self._execute_node(current_node_id, query, context) - context.set_node_result(current_node_id, result) - context.mark_visited(current_node_id) - - # Yield completion event - yield self._emit_completion_event(node, current_node_id, result, iteration) - - except Exception as e: - self.logger.error("Node execution failed", node_id=current_node_id, error=str(e)) - yield AgentEvent( - type="error", - message=f"Node {current_node_id} failed: {e!s}", - iteration=iteration, - ) - break - - # Check if current node is an exit node - if so, we're done - if current_node_id in self._graph.exit_nodes: - break - - # Get next node(s) - next_nodes = self._get_next_node(current_node_id, context) - - if not next_nodes: - # No more nodes, we've reached a dead end - self.logger.warning("Reached dead end in graph", node_id=current_node_id) - break - - current_node_id = next_nodes[0] # For now, take first next node (handle parallel later) + final_result = None + result_node_id = current_node_id - # Final event - get result from exit nodes (prioritize synthesizer/writer nodes) # First try to get result from current node (if it's an exit node) - final_result = None if current_node_id and current_node_id in self._graph.exit_nodes: final_result = context.get_node_result(current_node_id) self.logger.debug( @@ -563,7 +565,7 @@ async def _execute_graph( has_result=final_result is not None, result_type=type(final_result).__name__ if final_result else None, ) - + # If no result from current node, check all exit nodes for results # Prioritize synthesizer (deep research) or writer (iterative research) if not final_result: @@ -573,28 +575,28 @@ async def _execute_graph( result = context.get_node_result(exit_node_id) if result: final_result = result - current_node_id = exit_node_id + result_node_id = exit_node_id self.logger.debug( "Final result from priority exit node", node_id=exit_node_id, result_type=type(final_result).__name__, ) break - + # If still no result, check all exit nodes if not final_result: for exit_node_id in self._graph.exit_nodes: result = context.get_node_result(exit_node_id) if result: final_result = result - current_node_id = exit_node_id + result_node_id = exit_node_id self.logger.debug( "Final result from any exit node", node_id=exit_node_id, result_type=type(final_result).__name__, ) break - + # Log warning if no result found if not final_result: self.logger.warning( @@ -604,8 +606,11 @@ async def _execute_graph( all_node_results=list(context.node_results.keys()), ) - # Check if final result contains file information - event_data: dict[str, Any] = {"mode": self.mode, "iterations": iteration} + return final_result, result_node_id + + def _extract_final_message_and_files(self, final_result: Any) -> tuple[str, dict[str, Any]]: + """Extract message and file information from final result.""" + event_data: dict[str, Any] = {"mode": self.mode} message: str = "Research completed" if isinstance(final_result, str): @@ -619,7 +624,7 @@ async def _execute_graph( "Final message extracted from dict 'message' key", length=len(message) if isinstance(message, str) else 0, ) - + # Then check for file paths if "file" in final_result: file_path = final_result["file"] @@ -629,26 +634,89 @@ async def _execute_graph( if "message" not in final_result: message = "Report generated. Download available." self.logger.debug("File path added to event data", file_path=file_path) - elif "files" in final_result: + + # Check for multiple files + if "files" in final_result: files = final_result["files"] if isinstance(files, list): event_data["files"] = files - # Only override message if not already set from "message" key - if "message" not in final_result: - message = "Report generated. Downloads available." - elif isinstance(files, str): - event_data["files"] = [files] - # Only override message if not already set from "message" key - if "message" not in final_result: - message = "Report generated. Download available." - self.logger.debug("File paths added to event data", count=len(event_data.get("files", []))) - else: - # Log warning if result type is unexpected - self.logger.warning( - "Final result has unexpected type", - result_type=type(final_result).__name__ if final_result else None, - result_repr=str(final_result)[:200] if final_result else None, - ) + self.logger.debug("Multiple files added to event data", count=len(files)) + + return message, event_data + + async def _execute_graph( + self, query: str, context: GraphExecutionContext + ) -> AsyncGenerator[AgentEvent, None]: + """Execute the graph from entry node. + + Args: + query: The research query + context: Execution context + + Yields: + AgentEvent objects + """ + if not self._graph: + raise ValueError("Graph not built") + + current_node_id = self._graph.entry_node + iteration = 0 + + # Execute nodes until we reach an exit node + while current_node_id: + # Check budget + if not context.budget_tracker.can_continue("graph_execution"): + self.logger.warning("Budget exceeded, exiting graph execution") + break + + # Execute current node + iteration += 1 + context.current_node = current_node_id + node = self._graph.get_node(current_node_id) + + # Emit start event + yield self._emit_start_event(node, current_node_id, iteration, context) + + try: + result = await self._execute_node(current_node_id, query, context) + context.set_node_result(current_node_id, result) + context.mark_visited(current_node_id) + + # Yield completion event + yield self._emit_completion_event(node, current_node_id, result, iteration) + + except Exception as e: + self.logger.error("Node execution failed", node_id=current_node_id, error=str(e)) + yield AgentEvent( + type="error", + message=f"Node {current_node_id} failed: {e!s}", + iteration=iteration, + ) + break + + # Check if current node is an exit node - if so, we're done + if current_node_id in self._graph.exit_nodes: + break + + # Get next node(s) + next_nodes = self._get_next_node(current_node_id, context) + + if not next_nodes: + # No more nodes, we've reached a dead end + self.logger.warning("Reached dead end in graph", node_id=current_node_id) + break + + current_node_id = next_nodes[0] # For now, take first next node (handle parallel later) + + # Final event - get result from exit nodes (prioritize synthesizer/writer nodes) + final_result, result_node_id = self._get_final_result_from_exit_nodes( + context, current_node_id + ) + + # Check if final result contains file information + event_data: dict[str, Any] = {"mode": self.mode, "iterations": iteration} + message, file_event_data = self._extract_final_message_and_files(final_result) + event_data.update(file_event_data) yield AgentEvent( type="complete", @@ -686,170 +754,121 @@ async def _execute_node(self, node_id: str, query: str, context: GraphExecutionC else: raise ValueError(f"Unknown node type: {type(node)}") - async def _execute_agent_node( - self, node: AgentNode, query: str, context: GraphExecutionContext - ) -> Any: - """Execute an agent node. - - Special handling for deep research nodes: - - "planner": Takes query string, returns ReportPlan - - "synthesizer": Takes query + ReportPlan + section drafts, returns final report - - Args: - node: The agent node - query: The research query - context: Execution context - - Returns: - Agent execution result - """ - # Special handling for synthesizer node (deep research) - if node.node_id == "synthesizer": - # Call LongWriterAgent.write_report() directly instead of using agent.run() - from src.agent_factory.agents import create_long_writer_agent - from src.utils.models import ReportDraft, ReportDraftSection, ReportPlan + async def _execute_synthesizer_node(self, query: str, context: GraphExecutionContext) -> Any: + """Execute synthesizer node for deep research.""" + from src.agent_factory.agents import create_long_writer_agent + from src.utils.models import ReportDraft, ReportDraftSection, ReportPlan - report_plan = context.get_node_result("planner") - section_drafts = context.get_node_result("parallel_loops") or [] + report_plan = context.get_node_result("planner") + section_drafts = context.get_node_result("parallel_loops") or [] - if not isinstance(report_plan, ReportPlan): - raise ValueError("ReportPlan not found for synthesizer") + if not isinstance(report_plan, ReportPlan): + raise ValueError("ReportPlan not found for synthesizer") - if not section_drafts: - raise ValueError("Section drafts not found for synthesizer") + if not section_drafts: + raise ValueError("Section drafts not found for synthesizer") - # Create ReportDraft from section drafts - report_draft = ReportDraft( - sections=[ - ReportDraftSection( - section_title=section.title, - section_content=draft, - ) - for section, draft in zip( - report_plan.report_outline, section_drafts, strict=False - ) - ] - ) - - # Get LongWriterAgent instance and call write_report directly - long_writer_agent = create_long_writer_agent(oauth_token=self.oauth_token) - final_report = await long_writer_agent.write_report( - original_query=query, - report_title=report_plan.report_title, - report_draft=report_draft, - ) + # Create ReportDraft from section drafts + report_draft = ReportDraft( + sections=[ + ReportDraftSection( + section_title=section.title, + section_content=draft, + ) + for section, draft in zip(report_plan.report_outline, section_drafts, strict=False) + ] + ) - # Estimate tokens (rough estimate) - estimated_tokens = len(final_report) // 4 # Rough token estimate - context.budget_tracker.add_tokens("graph_execution", estimated_tokens) + # Get LongWriterAgent instance and call write_report directly + long_writer_agent = create_long_writer_agent(oauth_token=self.oauth_token) + final_report = await long_writer_agent.write_report( + original_query=query, + report_title=report_plan.report_title, + report_draft=report_draft, + ) - # Save report to file if enabled (may generate multiple formats) - file_path: str | None = None - pdf_path: str | None = None - try: - file_service = self._get_file_service() - if file_service: - # Use save_report_multiple_formats to get both MD and PDF if enabled - saved_files = file_service.save_report_multiple_formats( - report_content=final_report, - query=query, - ) - file_path = saved_files.get("md") - pdf_path = saved_files.get("pdf") - self.logger.info( - "Report saved to file", - md_path=file_path, - pdf_path=pdf_path, - ) - except Exception as e: - # Don't fail the entire operation if file saving fails - self.logger.warning("Failed to save report to file", error=str(e)) - file_path = None - pdf_path = None - - # Return dict with file paths if available, otherwise return string (backward compatible) - if file_path: - result: dict[str, Any] = { - "message": final_report, - "file": file_path, - } - # Add PDF path if generated - if pdf_path: - result["files"] = [file_path, pdf_path] - return result - return final_report + # Estimate tokens (rough estimate) + estimated_tokens = len(final_report) // 4 # Rough token estimate + context.budget_tracker.add_tokens("graph_execution", estimated_tokens) - # Special handling for writer node (iterative research) - if node.node_id == "writer": - # Call WriterAgent.write_report() directly instead of using agent.run() - # Collect all findings from workflow state - from src.agent_factory.agents import create_writer_agent + # Save report to file if enabled (may generate multiple formats) + return self._save_report_and_return_result(final_report, query) - # Get all evidence from workflow state and convert to findings string - evidence = context.state.evidence - if evidence: - # Convert evidence to findings format (similar to conversation.get_all_findings()) - findings_parts: list[str] = [] - for ev in evidence: - finding = f"**{ev.title}**\n{ev.content}" - if ev.url: - finding += f"\nSource: {ev.url}" - findings_parts.append(finding) - all_findings = "\n\n".join(findings_parts) - else: - all_findings = "No findings available yet." + def _save_report_and_return_result(self, final_report: str, query: str) -> dict[str, Any] | str: + """Save report to file and return result with file paths if available.""" + file_path: str | None = None + pdf_path: str | None = None + try: + file_service = self._get_file_service() + if file_service: + # Use save_report_multiple_formats to get both MD and PDF if enabled + saved_files = file_service.save_report_multiple_formats( + report_content=final_report, + query=query, + ) + file_path = saved_files.get("md") + pdf_path = saved_files.get("pdf") + self.logger.info( + "Report saved to file", + md_path=file_path, + pdf_path=pdf_path, + ) + except Exception as e: + # Don't fail the entire operation if file saving fails + self.logger.warning("Failed to save report to file", error=str(e)) + file_path = None + pdf_path = None + + # Return dict with file paths if available, otherwise return string (backward compatible) + if file_path: + result: dict[str, Any] = { + "message": final_report, + "file": file_path, + } + # Add PDF path if generated + if pdf_path: + result["files"] = [file_path, pdf_path] + return result + return final_report + + async def _execute_writer_node(self, query: str, context: GraphExecutionContext) -> Any: + """Execute writer node for iterative research.""" + from src.agent_factory.agents import create_writer_agent + + # Get all evidence from workflow state and convert to findings string + evidence = context.state.evidence + if evidence: + # Convert evidence to findings format (similar to conversation.get_all_findings()) + findings_parts: list[str] = [] + for ev in evidence: + finding = f"**{ev.citation.title}**\n{ev.content}" + if ev.citation.url: + finding += f"\nSource: {ev.citation.url}" + findings_parts.append(finding) + all_findings = "\n\n".join(findings_parts) + else: + all_findings = "No findings available yet." + + # Get WriterAgent instance and call write_report directly + writer_agent = create_writer_agent(oauth_token=self.oauth_token) + final_report = await writer_agent.write_report( + query=query, + findings=all_findings, + output_length="", + output_instructions="", + ) - # Get WriterAgent instance and call write_report directly - writer_agent = create_writer_agent(oauth_token=self.oauth_token) - final_report = await writer_agent.write_report( - query=query, - findings=all_findings, - output_length="", - output_instructions="", - ) + # Estimate tokens (rough estimate) + estimated_tokens = len(final_report) // 4 # Rough token estimate + context.budget_tracker.add_tokens("graph_execution", estimated_tokens) - # Estimate tokens (rough estimate) - estimated_tokens = len(final_report) // 4 # Rough token estimate - context.budget_tracker.add_tokens("graph_execution", estimated_tokens) + # Save report to file if enabled (may generate multiple formats) + return self._save_report_and_return_result(final_report, query) - # Save report to file if enabled (may generate multiple formats) - file_path: str | None = None - pdf_path: str | None = None - try: - file_service = self._get_file_service() - if file_service: - # Use save_report_multiple_formats to get both MD and PDF if enabled - saved_files = file_service.save_report_multiple_formats( - report_content=final_report, - query=query, - ) - file_path = saved_files.get("md") - pdf_path = saved_files.get("pdf") - self.logger.info( - "Report saved to file", - md_path=file_path, - pdf_path=pdf_path, - ) - except Exception as e: - # Don't fail the entire operation if file saving fails - self.logger.warning("Failed to save report to file", error=str(e)) - file_path = None - pdf_path = None - - # Return dict with file paths if available, otherwise return string (backward compatible) - if file_path: - result: dict[str, Any] = { - "message": final_report, - "file": file_path, - } - # Add PDF path if generated - if pdf_path: - result["files"] = [file_path, pdf_path] - return result - return final_report - - # Standard agent execution - # Prepare input based on node type + def _prepare_agent_input( + self, node: AgentNode, query: str, context: GraphExecutionContext + ) -> Any: + """Prepare input data for agent execution.""" if node.node_id == "planner": # Planner takes the original query input_data = query @@ -862,94 +881,348 @@ async def _execute_agent_node( if node.input_transformer: input_data = node.input_transformer(input_data) - # Execute agent with error handling + return input_data + + async def _execute_standard_agent( + self, node: AgentNode, input_data: Any, query: str, context: GraphExecutionContext + ) -> Any: + """Execute standard agent with error handling and fallback models.""" + # Get message history from context (limit to most recent 10 messages for token efficiency) + message_history = context.get_message_history(max_messages=10) + + # Try with the original agent first try: - result = await node.agent.run(input_data) + # Pass message_history if available (Pydantic AI agents support this) + if message_history: + result = await node.agent.run(input_data, message_history=message_history) + else: + result = await node.agent.run(input_data) + + # Accumulate new messages from agent result if available + if hasattr(result, "new_messages"): + try: + new_messages = result.new_messages() + for msg in new_messages: + context.add_message(msg) + except Exception as e: + # Don't fail if message accumulation fails + self.logger.debug( + "Failed to accumulate messages from agent result", error=str(e) + ) + return result except Exception as e: - # Handle validation errors and API errors for planner node + # Check if we should retry with fallback models + from src.utils.hf_error_handler import ( + extract_error_details, + should_retry_with_fallback, + ) + + error_details = extract_error_details(e) + should_retry = should_retry_with_fallback(e) + + # Handle validation errors and API errors for planner node (with fallback) if node.node_id == "planner": - self.logger.error( - "Planner agent execution failed, using fallback plan", - error=str(e), - error_type=type(e).__name__, + if should_retry: + self.logger.warning( + "Planner failed, trying fallback models", + original_error=str(e), + status_code=error_details.get("status_code"), + ) + # Try fallback models for planner + fallback_result = await self._try_fallback_models( + node, input_data, message_history, query, context, e + ) + if fallback_result is not None: + return fallback_result + # If fallback failed or not applicable, use fallback plan + return self._create_fallback_plan(query, input_data) + + # For other nodes, try fallback models if applicable + if should_retry: + self.logger.warning( + "Agent node failed, trying fallback models", + node_id=node.node_id, + original_error=str(e), + status_code=error_details.get("status_code"), ) - # Return a minimal fallback ReportPlan - from src.utils.models import ReportPlan, ReportPlanSection - - # Extract query from input_data if possible - fallback_query = query - if isinstance(input_data, str): - # Try to extract query from input string - if "QUERY:" in input_data: - fallback_query = input_data.split("QUERY:")[-1].strip() - - return ReportPlan( - background_context="", - report_outline=[ - ReportPlanSection( - title="Research Findings", - key_question=fallback_query, - ) - ], - report_title=f"Research Report: {fallback_query[:50]}", + fallback_result = await self._try_fallback_models( + node, input_data, message_history, query, context, e ) - # For other nodes, re-raise the exception + if fallback_result is not None: + return fallback_result + + # If fallback didn't work or wasn't applicable, re-raise the exception raise - # Transform output if needed + async def _try_fallback_models( + self, + node: AgentNode, + input_data: Any, + message_history: list[Any], + query: str, + context: GraphExecutionContext, + original_error: Exception, + ) -> Any | None: + """Try executing agent with fallback models. + + Args: + node: The agent node that failed + input_data: Input data for the agent + message_history: Message history for the agent + query: The research query + context: Execution context + original_error: The original error that triggered fallback + + Returns: + Agent result if successful, None if all fallbacks failed + """ + from src.utils.hf_error_handler import extract_error_details, get_fallback_models + + error_details = extract_error_details(original_error) + original_model = error_details.get("model_name") + fallback_models = get_fallback_models(original_model) + + # Also try models from settings fallback list + from src.utils.config import settings + + settings_fallbacks = settings.get_hf_fallback_models_list() + for model in settings_fallbacks: + if model not in fallback_models: + fallback_models.append(model) + + self.logger.info( + "Trying fallback models", + node_id=node.node_id, + original_model=original_model, + fallback_count=len(fallback_models), + ) + + # Try each fallback model + for fallback_model in fallback_models: + try: + # Recreate agent with fallback model + fallback_agent = self._recreate_agent_with_model(node.node_id, fallback_model) + if fallback_agent is None: + continue + + # Try running with fallback agent + if message_history: + result = await fallback_agent.run(input_data, message_history=message_history) + else: + result = await fallback_agent.run(input_data) + + self.logger.info( + "Fallback model succeeded", + node_id=node.node_id, + fallback_model=fallback_model, + ) + + # Accumulate new messages from agent result if available + if hasattr(result, "new_messages"): + try: + new_messages = result.new_messages() + for msg in new_messages: + context.add_message(msg) + except Exception as e: + self.logger.debug( + "Failed to accumulate messages from fallback agent result", error=str(e) + ) + + return result + + except Exception as e: + self.logger.warning( + "Fallback model failed", + node_id=node.node_id, + fallback_model=fallback_model, + error=str(e), + ) + continue + + # All fallback models failed + self.logger.error( + "All fallback models failed", + node_id=node.node_id, + fallback_count=len(fallback_models), + ) + return None + + def _recreate_agent_with_model(self, node_id: str, model_name: str) -> Any | None: + """Recreate an agent with a specific model. + + Args: + node_id: The node ID (e.g., "thinking", "knowledge_gap") + model_name: The model name to use + + Returns: + Agent instance or None if recreation failed + """ + try: + from pydantic_ai.models.huggingface import HuggingFaceModel + from pydantic_ai.providers.huggingface import HuggingFaceProvider + + # Create model with fallback model name + hf_provider = HuggingFaceProvider(api_key=self.oauth_token) + model = HuggingFaceModel(model_name, provider=hf_provider) + + # Recreate agent based on node_id + if node_id == "thinking": + from src.agent_factory.agents import create_thinking_agent + + agent_wrapper = create_thinking_agent(model=model, oauth_token=self.oauth_token) + return agent_wrapper.agent + elif node_id == "knowledge_gap": + from src.agent_factory.agents import create_knowledge_gap_agent + + agent_wrapper = create_knowledge_gap_agent( # type: ignore[assignment] + model=model, oauth_token=self.oauth_token + ) + return agent_wrapper.agent + elif node_id == "tool_selector": + from src.agent_factory.agents import create_tool_selector_agent + + agent_wrapper = create_tool_selector_agent( # type: ignore[assignment] + model=model, oauth_token=self.oauth_token + ) + return agent_wrapper.agent + elif node_id == "planner": + from src.agent_factory.agents import create_planner_agent + + agent_wrapper = create_planner_agent(model=model, oauth_token=self.oauth_token) # type: ignore[assignment] + return agent_wrapper.agent + elif node_id == "writer": + from src.agent_factory.agents import create_writer_agent + + agent_wrapper = create_writer_agent(model=model, oauth_token=self.oauth_token) # type: ignore[assignment] + return agent_wrapper.agent + else: + self.logger.warning("Unknown node_id for agent recreation", node_id=node_id) + return None + + except Exception as e: + self.logger.error( + "Failed to recreate agent with fallback model", + node_id=node_id, + model_name=model_name, + error=str(e), + ) + return None + + def _create_fallback_plan(self, query: str, input_data: Any) -> Any: + """Create fallback ReportPlan when planner fails.""" + from src.utils.models import ReportPlan, ReportPlanSection + + self.logger.error( + "Planner agent execution failed, using fallback plan", + error_type=type(input_data).__name__, + ) + + # Extract query from input_data if possible + fallback_query = query + if isinstance(input_data, str): + # Try to extract query from input string + if "QUERY:" in input_data: + fallback_query = input_data.split("QUERY:")[-1].strip() + + return ReportPlan( + background_context="", + report_outline=[ + ReportPlanSection( + title="Research Findings", + key_question=fallback_query, + ) + ], + report_title=f"Research Report: {fallback_query[:50]}", + ) + + def _extract_agent_output(self, node: AgentNode, result: Any) -> Any: + """Extract and transform output from agent result.""" # Defensively extract output - handle various result formats output = result.output if hasattr(result, "output") else result # Handle case where output might be a tuple (from pydantic-ai validation errors) if isinstance(output, tuple): - # If tuple contains a dict-like structure, try to reconstruct the object - if len(output) == 2 and isinstance(output[0], str) and output[0] == "research_complete": - # This is likely a validation error format: ('research_complete', False) - # Try to get the actual output from result - self.logger.warning( - "Agent result output is a tuple, attempting to extract actual output", + output = self._handle_tuple_output(node, output, result) + return output + + def _handle_tuple_output(self, node: AgentNode, output: tuple[Any, ...], result: Any) -> Any: + """Handle tuple output from agent (validation errors).""" + # If tuple contains a dict-like structure, try to reconstruct the object + if len(output) == 2 and isinstance(output[0], str) and output[0] == "research_complete": + # This is likely a validation error format: ('research_complete', False) + # Try to get the actual output from result + self.logger.warning( + "Agent result output is a tuple, attempting to extract actual output", + node_id=node.node_id, + tuple_value=output, + ) + # Try to get output from result attributes + if hasattr(result, "data"): + return result.data + if hasattr(result, "response"): + return result.response + # Last resort: try to reconstruct from tuple + # This shouldn't happen, but handle gracefully + from src.utils.models import KnowledgeGapOutput + + if node.node_id == "knowledge_gap": + # Reconstruct KnowledgeGapOutput from validation error tuple + reconstructed = KnowledgeGapOutput( + research_complete=output[1] if len(output) > 1 else False, + outstanding_gaps=[], + ) + self.logger.info( + "Reconstructed KnowledgeGapOutput from validation error tuple", node_id=node.node_id, - tuple_value=output, + research_complete=reconstructed.research_complete, ) - # Try to get output from result attributes - if hasattr(result, "data"): - output = result.data - elif hasattr(result, "response"): - output = result.response - else: - # Last resort: try to reconstruct from tuple - # This shouldn't happen, but handle gracefully - from src.utils.models import KnowledgeGapOutput + return reconstructed - if node.node_id == "knowledge_gap": - # Reconstruct KnowledgeGapOutput from validation error tuple - output = KnowledgeGapOutput( - research_complete=output[1] if len(output) > 1 else False, - outstanding_gaps=[], - ) - self.logger.info( - "Reconstructed KnowledgeGapOutput from validation error tuple", - node_id=node.node_id, - research_complete=output.research_complete, - ) - else: - # For other nodes, try to extract meaningful output or use fallback - self.logger.warning( - "Agent node output is tuple format, attempting extraction", - node_id=node.node_id, - tuple_value=output, - ) - # Try to extract first meaningful element - if len(output) > 0: - # If first element is a string or dict, might be the actual output - if isinstance(output[0], (str, dict)): - output = output[0] - else: - # Last resort: use first element - output = output[0] - else: - # Empty tuple - use None and let downstream handle it - output = None + # For other nodes, try to extract meaningful output or use fallback + self.logger.warning( + "Agent node output is tuple format, attempting extraction", + node_id=node.node_id, + tuple_value=output, + ) + # Try to extract first meaningful element + if len(output) > 0: + # If first element is a string or dict, might be the actual output + if isinstance(output[0], str | dict): + return output[0] + # Last resort: use first element + return output[0] + # Empty tuple - use None and let downstream handle it + return None + + async def _execute_agent_node( + self, node: AgentNode, query: str, context: GraphExecutionContext + ) -> Any: + """Execute an agent node. + + Special handling for deep research nodes: + - "planner": Takes query string, returns ReportPlan + - "synthesizer": Takes query + ReportPlan + section drafts, returns final report + + Args: + node: The agent node + query: The research query + context: Execution context + + Returns: + Agent execution result + """ + # Special handling for synthesizer node (deep research) + if node.node_id == "synthesizer": + return await self._execute_synthesizer_node(query, context) + + # Special handling for writer node (iterative research) + if node.node_id == "writer": + return await self._execute_writer_node(query, context) + + # Standard agent execution + input_data = self._prepare_agent_input(node, query, context) + result = await self._execute_standard_agent(node, input_data, query, context) + output = self._extract_agent_output(node, result) if node.output_transformer: output = node.output_transformer(output) @@ -1133,10 +1406,15 @@ async def _execute_decision_node( prev_result = prev_result[0] elif len(prev_result) > 1 and hasattr(prev_result[1], "research_complete"): prev_result = prev_result[1] - elif len(prev_result) == 2 and isinstance(prev_result[0], str) and prev_result[0] == "research_complete": + elif ( + len(prev_result) == 2 + and isinstance(prev_result[0], str) + and prev_result[0] == "research_complete" + ): # Handle validation error format: ('research_complete', False) # Reconstruct KnowledgeGapOutput from tuple from src.utils.models import KnowledgeGapOutput + self.logger.warning( "Decision node received validation error tuple, reconstructing KnowledgeGapOutput", node_id=node.node_id, @@ -1157,6 +1435,7 @@ async def _execute_decision_node( # Try to reconstruct KnowledgeGapOutput if this is from knowledge_gap node if prev_node_id == "knowledge_gap": from src.utils.models import KnowledgeGapOutput + # Try to extract research_complete from tuple research_complete = False for item in prev_result: diff --git a/src/orchestrator/research_flow.py b/src/orchestrator/research_flow.py index 52756654..3ce777bd 100644 --- a/src/orchestrator/research_flow.py +++ b/src/orchestrator/research_flow.py @@ -10,6 +10,11 @@ import structlog +try: + from pydantic_ai import ModelMessage +except ImportError: + ModelMessage = Any # type: ignore[assignment, misc] + from src.agent_factory.agents import ( create_graph_orchestrator, create_knowledge_gap_agent, @@ -137,6 +142,7 @@ async def run( background_context: str = "", output_length: str = "", output_instructions: str = "", + message_history: list[ModelMessage] | None = None, ) -> str: """ Run the iterative research flow. @@ -146,17 +152,18 @@ async def run( background_context: Optional background context output_length: Optional description of desired output length output_instructions: Optional additional instructions + message_history: Optional user conversation history Returns: Final report string """ if self.use_graph: return await self._run_with_graph( - query, background_context, output_length, output_instructions + query, background_context, output_length, output_instructions, message_history ) else: return await self._run_with_chains( - query, background_context, output_length, output_instructions + query, background_context, output_length, output_instructions, message_history ) async def _run_with_chains( @@ -165,6 +172,7 @@ async def _run_with_chains( background_context: str = "", output_length: str = "", output_instructions: str = "", + message_history: list[ModelMessage] | None = None, ) -> str: """ Run the iterative research flow using agent chains. @@ -174,6 +182,7 @@ async def _run_with_chains( background_context: Optional background context output_length: Optional description of desired output length output_instructions: Optional additional instructions + message_history: Optional user conversation history Returns: Final report string @@ -193,10 +202,10 @@ async def _run_with_chains( self.conversation.add_iteration() # 1. Generate observations - await self._generate_observations(query, background_context) + await self._generate_observations(query, background_context, message_history) # 2. Evaluate gaps - evaluation = await self._evaluate_gaps(query, background_context) + evaluation = await self._evaluate_gaps(query, background_context, message_history) # 3. Assess with judge (after tools execute, we'll assess again) # For now, check knowledge gap evaluation @@ -210,7 +219,9 @@ async def _run_with_chains( # 4. Select tools for next gap next_gap = evaluation.outstanding_gaps[0] if evaluation.outstanding_gaps else query - selection_plan = await self._select_agents(next_gap, query, background_context) + selection_plan = await self._select_agents( + next_gap, query, background_context, message_history + ) # 5. Execute tools await self._execute_tools(selection_plan.tasks) @@ -250,6 +261,7 @@ async def _run_with_graph( background_context: str = "", output_length: str = "", output_instructions: str = "", + message_history: list[ModelMessage] | None = None, ) -> str: """ Run the iterative research flow using graph execution. @@ -313,7 +325,12 @@ def _check_constraints(self) -> bool: return True - async def _generate_observations(self, query: str, background_context: str = "") -> str: + async def _generate_observations( + self, + query: str, + background_context: str = "", + message_history: list[ModelMessage] | None = None, + ) -> str: """Generate observations from current research state.""" # Build input prompt for token estimation conversation_history = self.conversation.compile_conversation_history() @@ -335,6 +352,7 @@ async def _generate_observations(self, query: str, background_context: str = "") query=query, background_context=background_context, conversation_history=conversation_history, + message_history=message_history, iteration=self.iteration, ) @@ -350,7 +368,12 @@ async def _generate_observations(self, query: str, background_context: str = "") self.conversation.set_latest_thought(observations) return observations - async def _evaluate_gaps(self, query: str, background_context: str = "") -> KnowledgeGapOutput: + async def _evaluate_gaps( + self, + query: str, + background_context: str = "", + message_history: list[ModelMessage] | None = None, + ) -> KnowledgeGapOutput: """Evaluate knowledge gaps in current research.""" if self.start_time: elapsed_minutes = (time.time() - self.start_time) / 60 @@ -377,6 +400,7 @@ async def _evaluate_gaps(self, query: str, background_context: str = "") -> Know query=query, background_context=background_context, conversation_history=conversation_history, + message_history=message_history, iteration=self.iteration, time_elapsed_minutes=elapsed_minutes, max_time_minutes=self.max_time_minutes, @@ -437,7 +461,11 @@ async def _assess_with_judge(self, query: str) -> JudgeAssessment: return assessment async def _select_agents( - self, gap: str, query: str, background_context: str = "" + self, + gap: str, + query: str, + background_context: str = "", + message_history: list[ModelMessage] | None = None, ) -> AgentSelectionPlan: """Select tools to address knowledge gap.""" # Build input prompt for token estimation @@ -461,6 +489,7 @@ async def _select_agents( query=query, background_context=background_context, conversation_history=conversation_history, + message_history=message_history, ) # Track tokens for this iteration @@ -775,27 +804,31 @@ def _get_file_service(self) -> ReportFileService | None: return None return self._file_service - async def run(self, query: str) -> str: + async def run(self, query: str, message_history: list[ModelMessage] | None = None) -> str: """ Run the deep research flow. Args: query: The research query + message_history: Optional user conversation history Returns: Final report string """ if self.use_graph: - return await self._run_with_graph(query) + return await self._run_with_graph(query, message_history) else: - return await self._run_with_chains(query) + return await self._run_with_chains(query, message_history) - async def _run_with_chains(self, query: str) -> str: + async def _run_with_chains( + self, query: str, message_history: list[ModelMessage] | None = None + ) -> str: """ Run the deep research flow using agent chains. Args: query: The research query + message_history: Optional user conversation history Returns: Final report string @@ -812,11 +845,11 @@ async def _run_with_chains(self, query: str) -> str: embedding_service = None self.logger.debug("Embedding service unavailable, initializing state without it") - init_workflow_state(embedding_service=embedding_service) + init_workflow_state(embedding_service=embedding_service, message_history=message_history) self.logger.debug("Workflow state initialized for deep research") # 1. Build report plan - report_plan = await self._build_report_plan(query) + report_plan = await self._build_report_plan(query, message_history) self.logger.info( "Report plan created", sections=len(report_plan.report_outline), @@ -824,7 +857,7 @@ async def _run_with_chains(self, query: str) -> str: ) # 2. Run parallel research loops with state synchronization - section_drafts = await self._run_research_loops(report_plan) + section_drafts = await self._run_research_loops(report_plan, message_history) # Verify state synchronization - log evidence count state = get_workflow_state() @@ -845,12 +878,15 @@ async def _run_with_chains(self, query: str) -> str: return final_report - async def _run_with_graph(self, query: str) -> str: + async def _run_with_graph( + self, query: str, message_history: list[ModelMessage] | None = None + ) -> str: """ Run the deep research flow using graph execution. Args: query: The research query + message_history: Optional user conversation history Returns: Final report string @@ -868,7 +904,7 @@ async def _run_with_graph(self, query: str) -> str: # Run orchestrator and collect events final_report = "" - async for event in self._graph_orchestrator.run(query): + async for event in self._graph_orchestrator.run(query, message_history=message_history): if event.type == "complete": final_report = event.message break @@ -884,13 +920,17 @@ async def _run_with_graph(self, query: str) -> str: return final_report - async def _build_report_plan(self, query: str) -> ReportPlan: + async def _build_report_plan( + self, query: str, message_history: list[ModelMessage] | None = None + ) -> ReportPlan: """Build the initial report plan.""" self.logger.info("Building report plan") # Build input prompt for token estimation input_prompt = f"QUERY: {query}" + # Planner agent may not support message_history yet, so we'll pass it if available + # For now, just use the standard run() call report_plan = await self.planner_agent.run(query) # Track tokens for planner agent @@ -913,7 +953,9 @@ async def _build_report_plan(self, query: str) -> ReportPlan: return report_plan - async def _run_research_loops(self, report_plan: ReportPlan) -> list[str]: + async def _run_research_loops( + self, report_plan: ReportPlan, message_history: list[ModelMessage] | None = None + ) -> list[str]: """Run parallel iterative research loops for each section.""" self.logger.info("Running research loops", sections=len(report_plan.report_outline)) @@ -950,10 +992,11 @@ async def run_research_for_section(config: dict[str, Any]) -> str: judge_handler=self.judge_handler if not self.use_graph else None, ) - # Run research + # Run research with message_history result = await flow.run( query=query, background_context=background_context, + message_history=message_history, ) # Sync evidence from flow to loop diff --git a/src/orchestrator_hierarchical.py b/src/orchestrator_hierarchical.py index bf3848ad..a7bfb85a 100644 --- a/src/orchestrator_hierarchical.py +++ b/src/orchestrator_hierarchical.py @@ -2,9 +2,15 @@ import asyncio from collections.abc import AsyncGenerator +from typing import Any import structlog +try: + from pydantic_ai import ModelMessage +except ImportError: + ModelMessage = Any # type: ignore[assignment, misc] + from src.agents.judge_agent_llm import LLMSubIterationJudge from src.agents.magentic_agents import create_search_agent from src.middleware.sub_iteration import SubIterationMiddleware, SubIterationTeam @@ -38,8 +44,14 @@ def __init__(self) -> None: self.judge = LLMSubIterationJudge() self.middleware = SubIterationMiddleware(self.team, self.judge, max_iterations=5) - async def run(self, query: str) -> AsyncGenerator[AgentEvent, None]: - logger.info("Starting hierarchical orchestrator", query=query) + async def run( + self, query: str, message_history: list[ModelMessage] | None = None + ) -> AsyncGenerator[AgentEvent, None]: + logger.info( + "Starting hierarchical orchestrator", + query=query, + has_history=bool(message_history), + ) try: service = get_embedding_service() @@ -58,6 +70,8 @@ async def run(self, query: str) -> AsyncGenerator[AgentEvent, None]: async def event_callback(event: AgentEvent) -> None: await queue.put(event) + # Note: middleware.run() may not support message_history yet + # Pass query for now, message_history can be added to middleware later if needed task_future = asyncio.create_task(self.middleware.run(query, event_callback)) while not task_future.done(): diff --git a/src/orchestrator_magentic.py b/src/orchestrator_magentic.py index fd9d4f72..416d8968 100644 --- a/src/orchestrator_magentic.py +++ b/src/orchestrator_magentic.py @@ -4,6 +4,11 @@ from typing import TYPE_CHECKING, Any import structlog + +try: + from pydantic_ai import ModelMessage +except ImportError: + ModelMessage = Any # type: ignore[assignment, misc] from agent_framework import ( MagenticAgentDeltaEvent, MagenticAgentMessageEvent, @@ -98,17 +103,24 @@ def _build_workflow(self) -> Any: .build() ) - async def run(self, query: str) -> AsyncGenerator[AgentEvent, None]: + async def run( + self, query: str, message_history: list[ModelMessage] | None = None + ) -> AsyncGenerator[AgentEvent, None]: """ Run the Magentic workflow. Args: query: User's research question + message_history: Optional user conversation history (for compatibility) Yields: AgentEvent objects for real-time UI updates """ - logger.info("Starting Magentic orchestrator", query=query) + logger.info( + "Starting Magentic orchestrator", + query=query, + has_history=bool(message_history), + ) yield AgentEvent( type="started", @@ -122,7 +134,17 @@ async def run(self, query: str) -> AsyncGenerator[AgentEvent, None]: workflow = self._build_workflow() - task = f"""Research query: {query} + # Include conversation history context if provided + history_context = "" + if message_history: + # Convert message history to string context for task + from src.utils.message_history import message_history_to_string + + history_str = message_history_to_string(message_history, max_messages=5) + if history_str: + history_context = f"\n\nPrevious conversation context:\n{history_str}" + + task = f"""Research query: {query}{history_context} Workflow: 1. SearchAgent: Find evidence from available sources (automatically selects: web search, PubMed, ClinicalTrials.gov, Europe PMC, or RAG based on query) diff --git a/src/services/audio_processing.py b/src/services/audio_processing.py index 076ba571..588497f9 100644 --- a/src/services/audio_processing.py +++ b/src/services/audio_processing.py @@ -6,9 +6,9 @@ import numpy as np import structlog +from src.agents.audio_refiner import audio_refiner from src.services.stt_gradio import STTService, get_stt_service from src.utils.config import settings -from src.utils.exceptions import ConfigurationError logger = structlog.get_logger(__name__) @@ -53,7 +53,7 @@ def __init__( async def process_audio_input( self, - audio_input: tuple[int, np.ndarray] | None, + audio_input: tuple[int, np.ndarray[Any, Any]] | None, # type: ignore[type-arg] hf_token: str | None = None, ) -> str | None: """Process audio input and return transcribed text. @@ -82,11 +82,11 @@ async def generate_audio_output( text: str, voice: str | None = None, speed: float | None = None, - ) -> tuple[int, np.ndarray] | None: + ) -> tuple[int, np.ndarray[Any, Any]] | None: # type: ignore[type-arg] """Generate audio output from text. Args: - text: Text to synthesize + text: Text to synthesize (markdown will be cleaned for audio) voice: Voice ID (default: settings.tts_voice) speed: Speech speed (default: settings.tts_speed) @@ -102,11 +102,23 @@ async def generate_audio_output( return None try: + # Refine text for audio (remove markdown, citations, etc.) + # Use LLM polish if enabled in settings + refined_text = await audio_refiner.refine_for_audio( + text, use_llm_polish=settings.tts_use_llm_polish + ) + logger.info( + "text_refined_for_audio", + original_length=len(text), + refined_length=len(refined_text), + llm_polish_enabled=settings.tts_use_llm_polish, + ) + # Use provided voice/speed or fallback to settings defaults voice = voice if voice else settings.tts_voice speed = speed if speed is not None else settings.tts_speed - audio_output = await self.tts.synthesize_async(text, voice, speed) # type: ignore[misc] + audio_output = await self.tts.synthesize_async(refined_text, voice, speed) # type: ignore[misc] if audio_output: logger.info( @@ -115,7 +127,7 @@ async def generate_audio_output( sample_rate=audio_output[0], ) - return audio_output + return audio_output # type: ignore[no-any-return] except Exception as e: logger.error("audio_output_generation_failed", error=str(e)) @@ -131,4 +143,3 @@ def get_audio_service() -> AudioService: AudioService instance """ return AudioService() - diff --git a/src/services/image_ocr.py b/src/services/image_ocr.py index b885c6d4..ed888193 100644 --- a/src/services/image_ocr.py +++ b/src/services/image_ocr.py @@ -31,7 +31,10 @@ def __init__(self, api_url: str | None = None, hf_token: str | None = None) -> N ConfigurationError: If API URL not configured """ # Defensively access ocr_api_url - may not exist in older config versions - default_url = getattr(settings, "ocr_api_url", None) or "https://prithivmlmods-multimodal-ocr3.hf.space" + default_url = ( + getattr(settings, "ocr_api_url", None) + or "https://prithivmlmods-multimodal-ocr3.hf.space" + ) self.api_url = api_url or default_url if not self.api_url: raise ConfigurationError("OCR API URL not configured") @@ -49,11 +52,11 @@ async def _get_client(self, hf_token: str | None = None) -> Client: """ # Use provided token or instance token token = hf_token or self.hf_token - + # If client exists but token changed, recreate it if self.client is not None and token != self.hf_token: self.client = None - + if self.client is None: loop = asyncio.get_running_loop() # Pass token to Client for authenticated Spaces @@ -129,7 +132,7 @@ async def extract_text( async def extract_text_from_image( self, - image_data: np.ndarray | Image.Image | str, + image_data: np.ndarray[Any, Any] | Image.Image | str, # type: ignore[type-arg] hf_token: str | None = None, ) -> str: """Extract text from image data (numpy array, PIL Image, or file path). @@ -240,10 +243,3 @@ def get_image_ocr_service() -> ImageOCRService: ImageOCRService instance """ return ImageOCRService() - - - - - - - diff --git a/src/services/llamaindex_rag.py b/src/services/llamaindex_rag.py index 9107c240..5de92f55 100644 --- a/src/services/llamaindex_rag.py +++ b/src/services/llamaindex_rag.py @@ -86,13 +86,15 @@ def __init__( self._initialize_chromadb() def _import_dependencies(self) -> dict[str, Any]: - """Import LlamaIndex dependencies and return as dict.""" + """Import LlamaIndex dependencies and return as dict. + + OpenAI dependencies are imported lazily (only when needed) to avoid + tiktoken circular import issues on Windows when using local embeddings. + """ try: import chromadb from llama_index.core import Document, Settings, StorageContext, VectorStoreIndex from llama_index.core.retrievers import VectorIndexRetriever - from llama_index.embeddings.openai import OpenAIEmbedding - from llama_index.llms.openai import OpenAI from llama_index.vector_stores.chroma import ChromaVectorStore # Try to import Hugging Face embeddings (may not be available in all versions) @@ -120,10 +122,22 @@ def _import_dependencies(self) -> dict[str, Any]: HuggingFaceLLM as _HuggingFaceLLM, # type: ignore[import-untyped] ) - huggingface_llm = _HuggingFaceLLM + huggingface_llm = _HuggingFaceLLM # type: ignore[assignment] except ImportError: huggingface_llm = None # type: ignore[assignment] + # OpenAI imports are optional - only import when actually needed + # This avoids tiktoken circular import issues on Windows + try: + from llama_index.embeddings.openai import OpenAIEmbedding + except ImportError: + OpenAIEmbedding = None # type: ignore[assignment, misc] # noqa: N806 + + try: + from llama_index.llms.openai import OpenAI + except ImportError: + OpenAI = None # type: ignore[assignment, misc] # noqa: N806 + return { "chromadb": chromadb, "Document": Document, @@ -151,6 +165,10 @@ def _configure_embeddings( ) -> None: """Configure embedding model.""" if use_openai_embeddings: + if openai_embedding is None: + raise ConfigurationError( + "OpenAI embeddings not available. Install with: uv sync --extra modal" + ) if not settings.openai_api_key: raise ConfigurationError("OPENAI_API_KEY required for OpenAI embeddings") self.embedding_model = embedding_model or settings.openai_embedding_model @@ -167,8 +185,33 @@ def _configure_embeddings( self._Settings.embed_model = self._create_sentence_transformer_embedding(model_name) def _create_sentence_transformer_embedding(self, model_name: str) -> Any: - """Create sentence-transformer embedding wrapper.""" - from sentence_transformers import SentenceTransformer + """Create sentence-transformer embedding wrapper. + + Note: sentence-transformers is a required dependency (in pyproject.toml). + If this fails, it's likely a Windows-specific regex package issue. + + Raises: + ConfigurationError: If sentence_transformers cannot be imported + (e.g., due to circular import issues on Windows with regex package) + """ + try: + from sentence_transformers import SentenceTransformer + except ImportError as e: + # Handle Windows-specific circular import issues with regex package + # This is a known bug: https://github.com/mrabarnett/mrab-regex/issues/417 + error_msg = str(e) + if "regex" in error_msg.lower() or "_regex" in error_msg: + raise ConfigurationError( + "sentence_transformers cannot be imported due to circular import issue " + "with regex package (Windows-specific bug). " + "sentence-transformers is installed but regex has a circular import. " + "Try: uv pip install --upgrade --force-reinstall regex " + "Or use HuggingFace embeddings via llama-index-embeddings-huggingface instead." + ) from e + raise ConfigurationError( + f"sentence_transformers not available: {e}. " + "This is a required dependency - check your uv sync installation." + ) from e try: from llama_index.embeddings.base import ( @@ -205,11 +248,7 @@ async def _aget_text_embedding(self, text: str) -> list[float]: def _configure_llm(self, huggingface_llm: Any, openai_llm: Any) -> None: """Configure LLM for query synthesis.""" # Priority: oauth_token > env vars - effective_token = ( - self.oauth_token - or settings.hf_token - or settings.huggingface_api_key - ) + effective_token = self.oauth_token or settings.hf_token or settings.huggingface_api_key if huggingface_llm is not None and effective_token: model_name = settings.huggingface_model or "meta-llama/Llama-3.1-8B-Instruct" token = effective_token @@ -245,7 +284,7 @@ def _configure_llm(self, huggingface_llm: Any, openai_llm: Any) -> None: tokenizer_name=model_name, ) logger.info("Using HuggingFace LLM for query synthesis", model=model_name) - elif settings.openai_api_key: + elif settings.openai_api_key and openai_llm is not None: self._Settings.llm = openai_llm( model=settings.openai_model, api_key=settings.openai_api_key, @@ -461,6 +500,4 @@ def get_rag_service( # Default to local embeddings if not explicitly set if "use_openai_embeddings" not in kwargs: kwargs["use_openai_embeddings"] = False - return LlamaIndexRAGService( - collection_name=collection_name, oauth_token=oauth_token, **kwargs - ) + return LlamaIndexRAGService(collection_name=collection_name, oauth_token=oauth_token, **kwargs) diff --git a/src/services/multimodal_processing.py b/src/services/multimodal_processing.py index cefb9d7b..9199f194 100644 --- a/src/services/multimodal_processing.py +++ b/src/services/multimodal_processing.py @@ -83,7 +83,9 @@ async def process_multimodal_input( # For now, log a warning logger.warning("audio_file_upload_not_supported", file_path=file_path) except Exception as e: - logger.warning("audio_file_processing_failed", file_path=file_path, error=str(e)) + logger.warning( + "audio_file_processing_failed", file_path=file_path, error=str(e) + ) # Add original text if present if text and text.strip(): @@ -142,7 +144,3 @@ def get_multimodal_service() -> MultimodalService: MultimodalService instance """ return MultimodalService() - - - - diff --git a/src/services/neo4j_service.py b/src/services/neo4j_service.py index 26a2340b..6a9c6f7e 100644 --- a/src/services/neo4j_service.py +++ b/src/services/neo4j_service.py @@ -1,25 +1,28 @@ """Neo4j Knowledge Graph Service for Drug Repurposing""" -from neo4j import GraphDatabase -from typing import List, Dict, Optional, Any + +import logging import os +from typing import Any + from dotenv import load_dotenv -import logging +from neo4j import GraphDatabase load_dotenv() logger = logging.getLogger(__name__) + class Neo4jService: - def __init__(self): + def __init__(self) -> None: self.uri = os.getenv("NEO4J_URI", "bolt://localhost:7687") self.user = os.getenv("NEO4J_USER", "neo4j") self.password = os.getenv("NEO4J_PASSWORD") self.database = os.getenv("NEO4J_DATABASE", "neo4j") - + if not self.password: logger.warning("⚠️ NEO4J_PASSWORD not set") self.driver = None return - + try: self.driver = GraphDatabase.driver(self.uri, auth=(self.user, self.password)) self.driver.verify_connectivity() @@ -27,80 +30,96 @@ def __init__(self): except Exception as e: logger.error(f"❌ Neo4j connection failed: {e}") self.driver = None - + def is_connected(self) -> bool: return self.driver is not None - - def close(self): + + def close(self) -> None: if self.driver: self.driver.close() - - def ingest_search_results(self, disease_name: str, papers: List[Dict[str, Any]], - drugs_mentioned: List[str] = None) -> Dict[str, int]: + + def ingest_search_results( + self, + disease_name: str, + papers: list[dict[str, Any]], + drugs_mentioned: list[str] | None = None, + ) -> dict[str, int]: if not self.driver: - return {"error": "Neo4j not connected"} - + return {"error": "Neo4j not connected"} # type: ignore[dict-item] + stats = {"papers": 0, "drugs": 0, "relationships": 0, "errors": 0} - + try: with self.driver.session(database=self.database) as session: session.run("MERGE (d:Disease {name: $name})", name=disease_name) - + for paper in papers: try: - paper_id = paper.get('id') or paper.get('url', '') + paper_id = paper.get("id") or paper.get("url", "") if not paper_id: continue - - session.run(""" + + session.run( + """ MERGE (p:Paper {paper_id: $id}) SET p.title = $title, p.abstract = $abstract, p.url = $url, p.source = $source, p.updated_at = datetime() - """, - id=paper_id, - title=str(paper.get('title', ''))[:500], - abstract=str(paper.get('abstract', ''))[:2000], - url=str(paper.get('url', ''))[:500], - source=str(paper.get('source', ''))[:100]) - - session.run(""" + """, + id=paper_id, + title=str(paper.get("title", ""))[:500], + abstract=str(paper.get("abstract", ""))[:2000], + url=str(paper.get("url", ""))[:500], + source=str(paper.get("source", ""))[:100], + ) + + session.run( + """ MATCH (p:Paper {paper_id: $id}) MATCH (d:Disease {name: $disease}) MERGE (p)-[r:ABOUT]->(d) - """, id=paper_id, disease=disease_name) - - stats['papers'] += 1 - stats['relationships'] += 1 - except Exception as e: - stats['errors'] += 1 - + """, + id=paper_id, + disease=disease_name, + ) + + stats["papers"] += 1 + stats["relationships"] += 1 + except Exception: + stats["errors"] += 1 + if drugs_mentioned: for drug in drugs_mentioned: try: session.run("MERGE (d:Drug {name: $name})", name=drug) - session.run(""" + session.run( + """ MATCH (drug:Drug {name: $drug}) MATCH (disease:Disease {name: $disease}) MERGE (drug)-[r:POTENTIAL_TREATMENT]->(disease) - """, drug=drug, disease=disease_name) - stats['drugs'] += 1 - stats['relationships'] += 1 - except Exception as e: - stats['errors'] += 1 - + """, + drug=drug, + disease=disease_name, + ) + stats["drugs"] += 1 + stats["relationships"] += 1 + except Exception: + stats["errors"] += 1 + logger.info(f"οΏ½οΏ½ Neo4j ingestion: {stats['papers']} papers, {stats['drugs']} drugs") except Exception as e: logger.error(f"Neo4j ingestion error: {e}") - stats['errors'] += 1 - + stats["errors"] += 1 + return stats + _neo4j_service = None -def get_neo4j_service() -> Optional[Neo4jService]: + +def get_neo4j_service() -> Neo4jService | None: global _neo4j_service if _neo4j_service is None: _neo4j_service = Neo4jService() diff --git a/src/services/report_file_service.py b/src/services/report_file_service.py index 9b9f82fe..14496640 100644 --- a/src/services/report_file_service.py +++ b/src/services/report_file_service.py @@ -329,4 +329,3 @@ def _get_service() -> ReportFileService: return ReportFileService() return _get_service() - diff --git a/src/services/stt_gradio.py b/src/services/stt_gradio.py index 28062205..44b993b4 100644 --- a/src/services/stt_gradio.py +++ b/src/services/stt_gradio.py @@ -46,11 +46,11 @@ async def _get_client(self, hf_token: str | None = None) -> Client: """ # Use provided token or instance token token = hf_token or self.hf_token - + # If client exists but token changed, recreate it if self.client is not None and token != self.hf_token: self.client = None - + if self.client is None: loop = asyncio.get_running_loop() # Pass token to Client for authenticated Spaces @@ -130,7 +130,7 @@ async def transcribe_file( async def transcribe_audio( self, - audio_data: tuple[int, np.ndarray], + audio_data: tuple[int, np.ndarray[Any, Any]], # type: ignore[type-arg] hf_token: str | None = None, ) -> str: """Transcribe audio numpy array to text. @@ -163,7 +163,7 @@ async def transcribe_audio( except Exception as e: logger.warning("failed_to_cleanup_temp_file", path=temp_path, error=str(e)) - def _extract_transcription(self, api_result: tuple) -> str: + def _extract_transcription(self, api_result: tuple[Any, ...]) -> str: """Extract transcription text from API result. Args: @@ -210,7 +210,7 @@ def _extract_transcription(self, api_result: tuple) -> str: def _save_audio_temp( self, - audio_data: tuple[int, np.ndarray], + audio_data: tuple[int, np.ndarray[Any, Any]], # type: ignore[type-arg] ) -> str: """Save audio numpy array to temporary WAV file. @@ -269,4 +269,3 @@ def get_stt_service() -> STTService: STTService instance """ return STTService() - diff --git a/src/services/tts_modal.py b/src/services/tts_modal.py index e0e30afb..ce55c49a 100644 --- a/src/services/tts_modal.py +++ b/src/services/tts_modal.py @@ -1,12 +1,22 @@ """Text-to-Speech service using Kokoro 82M via Modal GPU.""" import asyncio +import os +from collections.abc import Iterator +from contextlib import contextmanager from functools import lru_cache -from typing import Any +from typing import Any, cast import numpy as np +from numpy.typing import NDArray import structlog +# Load .env file BEFORE importing Modal SDK +# Modal SDK reads MODAL_TOKEN_ID and MODAL_TOKEN_SECRET from environment on import +from dotenv import load_dotenv + +load_dotenv() + from src.utils.config import settings from src.utils.exceptions import ConfigurationError @@ -24,39 +34,107 @@ # Modal app and function definitions (module-level for Modal) _modal_app: Any | None = None _tts_function: Any | None = None +_tts_image: Any | None = None + + +@contextmanager +def modal_credentials_override(token_id: str | None, token_secret: str | None) -> Iterator[None]: + """Context manager to temporarily override Modal credentials. + + Args: + token_id: Modal token ID (overrides env if provided) + token_secret: Modal token secret (overrides env if provided) + + Yields: + None + + Note: + Resets global Modal state to force re-initialization with new credentials. + """ + global _modal_app, _tts_function + + # Save original credentials + original_token_id = os.environ.get("MODAL_TOKEN_ID") + original_token_secret = os.environ.get("MODAL_TOKEN_SECRET") + + # Save original Modal state + original_app = _modal_app + original_function = _tts_function + + try: + # Override environment variables if provided + if token_id: + os.environ["MODAL_TOKEN_ID"] = token_id + if token_secret: + os.environ["MODAL_TOKEN_SECRET"] = token_secret + + # Reset Modal state to force re-initialization + _modal_app = None + _tts_function = None + + yield + + finally: + # Restore original credentials + if original_token_id is not None: + os.environ["MODAL_TOKEN_ID"] = original_token_id + elif "MODAL_TOKEN_ID" in os.environ: + del os.environ["MODAL_TOKEN_ID"] + + if original_token_secret is not None: + os.environ["MODAL_TOKEN_SECRET"] = original_token_secret + elif "MODAL_TOKEN_SECRET" in os.environ: + del os.environ["MODAL_TOKEN_SECRET"] + + # Restore original Modal state + _modal_app = original_app + _tts_function = original_function def _get_modal_app() -> Any: - """Get or create Modal app instance.""" + """Get or create Modal app instance. + + Retrieves Modal credentials directly from environment variables (.env file) + instead of relying on settings configuration. + """ global _modal_app if _modal_app is None: try: import modal - # Validate Modal credentials before attempting lookup - if not settings.modal_available: + # Get credentials directly from environment variables + token_id = os.getenv("MODAL_TOKEN_ID") + token_secret = os.getenv("MODAL_TOKEN_SECRET") + + # Validate Modal credentials + if not token_id or not token_secret: raise ConfigurationError( - "Modal credentials not configured. Set MODAL_TOKEN_ID and MODAL_TOKEN_SECRET environment variables." + "Modal credentials not found in environment. " + "Set MODAL_TOKEN_ID and MODAL_TOKEN_SECRET in .env file." ) # Validate token ID format (Modal token IDs are typically UUIDs or specific formats) - token_id = settings.modal_token_id - if token_id: - # Basic validation: token ID should not be empty and should be a reasonable length - if len(token_id.strip()) < 10: - raise ConfigurationError( - f"Modal token ID appears malformed (too short: {len(token_id)} chars). " - "Token ID should be a valid Modal token identifier." - ) + if len(token_id.strip()) < 10: + raise ConfigurationError( + f"Modal token ID appears malformed (too short: {len(token_id)} chars). " + "Token ID should be a valid Modal token identifier." + ) + + logger.info( + "modal_credentials_loaded", + token_id_prefix=token_id[:8] + "...", # Log prefix for debugging + has_secret=bool(token_secret), + ) try: + # Use lookup with create_if_missing for inline function fallback _modal_app = modal.App.lookup("deepcritical-tts", create_if_missing=True) except Exception as e: error_msg = str(e).lower() if "token" in error_msg or "malformed" in error_msg or "invalid" in error_msg: raise ConfigurationError( f"Modal token validation failed: {e}. " - "Please check that MODAL_TOKEN_ID and MODAL_TOKEN_SECRET are correctly set." + "Please check that MODAL_TOKEN_ID and MODAL_TOKEN_SECRET in .env are correctly set." ) from e raise except ImportError as e: @@ -69,103 +147,137 @@ def _get_modal_app() -> Any: # Define Modal image with Kokoro dependencies (module-level) def _get_tts_image() -> Any: """Get Modal image with Kokoro dependencies.""" + global _tts_image + if _tts_image is not None: + return _tts_image + try: import modal - return ( + _tts_image = ( modal.Image.debian_slim(python_version="3.11") .pip_install(*KOKORO_DEPENDENCIES) .pip_install("git+https://github.com/hexgrad/kokoro.git") ) + return _tts_image except ImportError: return None -def _setup_modal_function() -> None: - """Setup Modal GPU function for TTS (called once, lazy initialization). +# Modal TTS function - Using serialized=True to allow dynamic creation +# This will be initialized lazily when _setup_modal_function() is called +def _create_tts_function() -> Any: + """Create the Modal TTS function using serialized=True. - Note: GPU type is set at function definition time. Changes to settings.tts_gpu - require app restart to take effect. + The serialized=True parameter allows the function to be defined outside + of global scope, which is necessary for dynamic initialization. """ - global _tts_function, _modal_app + app = _get_modal_app() + tts_image = _get_tts_image() + + if tts_image is None: + raise ConfigurationError("Modal image setup failed") + + # Get GPU and timeout from settings (with defaults) + gpu_type = getattr(settings, "tts_gpu", None) or "T4" + timeout_seconds = getattr(settings, "tts_timeout", None) or 120 # 2 minutes for cold starts + + @app.function( + image=tts_image, + gpu=gpu_type, + timeout=timeout_seconds, + serialized=True, # Allow function to be defined outside global scope + ) + def kokoro_tts_function(text: str, voice: str, speed: float) -> tuple[int, NDArray[np.float32]]: + """Modal GPU function for Kokoro TTS. + + This function runs on Modal's GPU infrastructure. + Based on: https://huggingface.co/spaces/hexgrad/Kokoro-TTS + Reference: https://huggingface.co/spaces/hexgrad/Kokoro-TTS/raw/main/app.py + """ + import numpy as np - if _tts_function is not None: - return # Already set up + # Import Kokoro inside function (lazy load) + try: + import torch + from kokoro import KModel, KPipeline - try: - app = _get_modal_app() - tts_image = _get_tts_image() - - if tts_image is None: - raise ConfigurationError("Modal image setup failed") - - # Get GPU and timeout from settings (with defaults) - # Note: These are evaluated at function definition time, not at call time - # Changes to settings require app restart - gpu_type = getattr(settings, "tts_gpu", None) or "T4" - timeout_seconds = getattr(settings, "tts_timeout", None) or 60 - - # Define GPU function at module level (required by Modal) - # Modal functions are immutable once defined, so GPU changes require restart - @app.function( - image=tts_image, - gpu=gpu_type, - timeout=timeout_seconds, - ) - def kokoro_tts_function(text: str, voice: str, speed: float) -> tuple[int, np.ndarray]: - """Modal GPU function for Kokoro TTS. + # Initialize model (cached on GPU) + model = KModel().to("cuda").eval() + pipeline = KPipeline(lang_code=voice[0]) + pack = pipeline.load_voice(voice) - This function runs on Modal's GPU infrastructure. - Based on: https://huggingface.co/spaces/hexgrad/Kokoro-TTS - Reference: https://huggingface.co/spaces/hexgrad/Kokoro-TTS/raw/main/app.py - """ - import numpy as np + # Generate audio + for _, ps, _ in pipeline(text, voice, speed): + ref_s = pack[len(ps) - 1] + audio = model(ps, ref_s, speed) + return (24000, audio.numpy()) - # Import Kokoro inside function (lazy load) - try: - import torch - from kokoro import KModel, KPipeline + # If no audio generated, return empty + return (24000, np.zeros(1, dtype=np.float32)) - # Initialize model (cached on GPU) - model = KModel().to("cuda").eval() - pipeline = KPipeline(lang_code=voice[0]) - pack = pipeline.load_voice(voice) + except ImportError as e: + raise ConfigurationError( + "Kokoro not installed. Install with: pip install git+https://github.com/hexgrad/kokoro.git" + ) from e + except Exception as e: + raise ConfigurationError(f"TTS synthesis failed: {e}") from e - # Generate audio - for _, ps, _ in pipeline(text, voice, speed): - ref_s = pack[len(ps) - 1] - audio = model(ps, ref_s, speed) - return (24000, audio.numpy()) + return kokoro_tts_function - # If no audio generated, return empty - return (24000, np.zeros(1, dtype=np.float32)) - except ImportError as e: - raise ConfigurationError( - "Kokoro not installed. Install with: pip install git+https://github.com/hexgrad/kokoro.git" - ) from e - except Exception as e: - raise ConfigurationError(f"TTS synthesis failed: {e}") from e +def _setup_modal_function() -> None: + """Setup Modal GPU function for TTS (called once, lazy initialization). - # Store function reference for remote calls - _tts_function = kokoro_tts_function + Hybrid approach: + 1. Try to lookup pre-deployed function (fast path for advanced users) + 2. If lookup fails, create function inline (fallback for casual users) - # Verify function is properly attached to app - if not hasattr(app, kokoro_tts_function.__name__): - logger.warning( - "modal_function_not_attached", function_name=kokoro_tts_function.__name__ + This allows both workflows: + - Advanced: Deploy with `modal deploy deployments/modal_tts.py` for best performance + - Casual: Just add Modal keys and it auto-creates function on first use + """ + global _tts_function + + if _tts_function is not None: + return # Already set up + + try: + import modal + + # Try path 1: Lookup pre-deployed function (fast path) + try: + _tts_function = modal.Function.from_name("deepcritical-tts", "kokoro_tts_function") + logger.info( + "modal_tts_function_lookup_success", + app_name="deepcritical-tts", + function_name="kokoro_tts_function", + method="lookup", + ) + return + except Exception as lookup_error: + logger.info( + "modal_tts_function_lookup_failed", + error=str(lookup_error), + fallback="Creating function inline", ) + # Try path 2: Create function inline (fallback for casual users) + logger.info("modal_tts_creating_inline_function") + _tts_function = _create_tts_function() logger.info( "modal_tts_function_setup_complete", - gpu=gpu_type, - timeout=timeout_seconds, - function_name=kokoro_tts_function.__name__, + app_name="deepcritical-tts", + function_name="kokoro_tts_function", + method="inline", ) except Exception as e: logger.error("modal_tts_function_setup_failed", error=str(e)) - raise ConfigurationError(f"Failed to setup Modal TTS function: {e}") from e + raise ConfigurationError( + f"Failed to setup Modal TTS function: {e}. " + "Ensure Modal credentials (MODAL_TOKEN_ID, MODAL_TOKEN_SECRET) are valid." + ) from e class ModalTTSExecutor: @@ -179,13 +291,17 @@ def __init__(self) -> None: """Initialize Modal TTS executor. Note: - Logs a warning if Modal credentials are not configured. - Execution will fail at runtime without valid credentials. + Logs a warning if Modal credentials are not configured in environment. + Execution will fail at runtime without valid credentials in .env file. """ - # Check for Modal credentials - if not settings.modal_available: + # Check for Modal credentials directly from environment + token_id = os.getenv("MODAL_TOKEN_ID") + token_secret = os.getenv("MODAL_TOKEN_SECRET") + + if not token_id or not token_secret: logger.warning( - "Modal credentials not found. TTS will not be available unless modal setup is run." + "Modal credentials not found in environment. " + "TTS will not be available. Set MODAL_TOKEN_ID and MODAL_TOKEN_SECRET in .env file." ) def synthesize( @@ -193,8 +309,8 @@ def synthesize( text: str, voice: str = "af_heart", speed: float = 1.0, - timeout: int = 60, - ) -> tuple[int, np.ndarray]: + timeout: int = 120, + ) -> tuple[int, NDArray[np.float32]]: """Synthesize text to speech using Kokoro on Modal GPU. Args: @@ -219,7 +335,7 @@ def synthesize( try: # Call the GPU function remotely - result = _tts_function.remote(text, voice, speed) + result = cast(tuple[int, NDArray[np.float32]], _tts_function.remote(text, voice, speed)) logger.info( "tts_synthesis_complete", sample_rate=result[0], audio_shape=result[1].shape @@ -236,9 +352,19 @@ class TTSService: """TTS service wrapper for async usage.""" def __init__(self) -> None: - """Initialize TTS service.""" - if not settings.modal_available: - raise ConfigurationError("Modal credentials required for TTS") + """Initialize TTS service. + + Validates Modal credentials from environment variables (.env file). + """ + # Check credentials directly from environment + token_id = os.getenv("MODAL_TOKEN_ID") + token_secret = os.getenv("MODAL_TOKEN_SECRET") + + if not token_id or not token_secret: + raise ConfigurationError( + "Modal credentials required for TTS. " + "Set MODAL_TOKEN_ID and MODAL_TOKEN_SECRET in .env file." + ) self.executor = ModalTTSExecutor() async def synthesize_async( @@ -246,7 +372,7 @@ async def synthesize_async( text: str, voice: str = "af_heart", speed: float = 1.0, - ) -> tuple[int, np.ndarray] | None: + ) -> tuple[int, NDArray[np.float32]] | None: """Async wrapper for TTS synthesis. Args: @@ -284,3 +410,73 @@ def get_tts_service() -> TTSService: ConfigurationError: If Modal credentials not configured """ return TTSService() + + +async def generate_audio_on_demand( + text: str, + modal_token_id: str | None = None, + modal_token_secret: str | None = None, + voice: str = "af_heart", + speed: float = 1.0, + use_llm_polish: bool = False, +) -> tuple[tuple[int, NDArray[np.float32]] | None, str]: + """Generate audio on-demand with optional runtime credentials. + + Args: + text: Text to synthesize + modal_token_id: Modal token ID (UI input, overrides .env) + modal_token_secret: Modal token secret (UI input, overrides .env) + voice: Voice ID (default: af_heart) + speed: Speech speed (default: 1.0) + use_llm_polish: Apply LLM polish to text (default: False) + + Returns: + Tuple of (audio_output, status_message) + - audio_output: (sample_rate, audio_array) or None if failed + - status_message: Status/error message for user + + Priority: UI credentials > .env credentials + """ + # Priority: UI keys > .env keys + token_id = (modal_token_id or "").strip() or os.getenv("MODAL_TOKEN_ID") + token_secret = (modal_token_secret or "").strip() or os.getenv("MODAL_TOKEN_SECRET") + + if not token_id or not token_secret: + return ( + None, + "❌ Modal credentials required. Enter keys above or set MODAL_TOKEN_ID and MODAL_TOKEN_SECRET in .env", + ) + + try: + # Use credentials override context + with modal_credentials_override(token_id, token_secret): + # Import audio_processing here to avoid circular import + from src.services.audio_processing import AudioService + + # Temporarily override LLM polish setting + original_llm_polish = settings.tts_use_llm_polish + try: + settings.tts_use_llm_polish = use_llm_polish + + # Create fresh AudioService instance (bypass cache to pick up new credentials) + audio_service = AudioService() + audio_output = await audio_service.generate_audio_output( + text=text, + voice=voice, + speed=speed, + ) + + if audio_output: + return audio_output, "βœ… Audio generated successfully" + else: + return None, "⚠️ Audio generation returned no output" + + finally: + settings.tts_use_llm_polish = original_llm_polish + + except ConfigurationError as e: + logger.error("audio_generation_config_error", error=str(e)) + return None, f"❌ Configuration error: {e}" + except Exception as e: + logger.error("audio_generation_failed", error=str(e), exc_info=True) + return None, f"❌ Audio generation failed: {e}" diff --git a/src/tools/crawl_adapter.py b/src/tools/crawl_adapter.py index 53394234..332569c5 100644 --- a/src/tools/crawl_adapter.py +++ b/src/tools/crawl_adapter.py @@ -56,8 +56,3 @@ async def crawl_website(starting_url: str) -> str: except Exception as e: logger.error("Crawl failed", error=str(e), url=starting_url) return f"Error crawling website: {e!s}" - - - - - diff --git a/src/tools/neo4j_search.py b/src/tools/neo4j_search.py index 93b49ace..01983889 100644 --- a/src/tools/neo4j_search.py +++ b/src/tools/neo4j_search.py @@ -1,16 +1,19 @@ """Neo4j knowledge graph search tool.""" + import structlog -from src.utils.models import Citation, Evidence + from src.services.neo4j_service import get_neo4j_service +from src.utils.models import Citation, Evidence logger = structlog.get_logger() + class Neo4jSearchTool: """Search Neo4j knowledge graph for papers.""" - - def __init__(self): + + def __init__(self) -> None: self.name = "neo4j" # βœ… Definir explΓ­citamente - + async def search(self, query: str, max_results: int = 10) -> list[Evidence]: """Search Neo4j for papers about diseases in the query.""" try: @@ -18,25 +21,32 @@ async def search(self, query: str, max_results: int = 10) -> list[Evidence]: if not service: logger.warning("Neo4j service not available") return [] - + # Extract disease name from query disease = query if "for" in query.lower(): disease = query.split("for")[-1].strip().rstrip("?") - + # Query Neo4j + if not service.driver: + logger.warning("Neo4j driver not available") + return [] with service.driver.session(database=service.database) as session: - result = session.run(""" + result = session.run( + """ MATCH (p:Paper)-[:ABOUT]->(d:Disease) WHERE d.name CONTAINS $disease RETURN p.title as title, p.abstract as abstract, p.url as url, p.source as source ORDER BY p.updated_at DESC LIMIT $max_results - """, disease=disease, max_results=max_results) - + """, + disease=disease, + max_results=max_results, + ) + records = list(result) - + results = [] for record in records: citation = Citation( @@ -44,17 +54,14 @@ async def search(self, query: str, max_results: int = 10) -> list[Evidence]: title=record["title"] or "Untitled", url=record["url"] or "", date="", - authors=[] + authors=[], ) - + evidence = Evidence( content=record["abstract"] or record["title"] or "", citation=citation, relevance=1.0, - metadata={ - "from_kb": True, - "original_source": record["source"] - } + metadata={"from_kb": True, "original_source": record["source"]}, ) results.append(evidence) diff --git a/src/tools/search_handler.py b/src/tools/search_handler.py index 77edba21..8814a8b7 100644 --- a/src/tools/search_handler.py +++ b/src/tools/search_handler.py @@ -5,11 +5,11 @@ import structlog +from src.services.neo4j_service import get_neo4j_service from src.tools.base import SearchTool from src.tools.rag_tool import create_rag_tool from src.utils.exceptions import ConfigurationError, SearchError from src.utils.models import Evidence, SearchResult, SourceName -from src.services.neo4j_service import get_neo4j_service if TYPE_CHECKING: from src.services.llamaindex_rag import LlamaIndexRAGService @@ -113,6 +113,8 @@ async def execute(self, query: str, max_results_per_tool: int = 10) -> SearchRes # Some tools have internal names that differ from SourceName literals tool_name_to_source: dict[str, SourceName] = { "duckduckgo": "web", + "serper": "web", # Serper uses Google search but maps to "web" source + "searchxng": "web", # SearchXNG also maps to "web" source "pubmed": "pubmed", "clinicaltrials": "clinicaltrials", "europepmc": "europepmc", @@ -131,7 +133,15 @@ async def execute(self, query: str, max_results_per_tool: int = 10) -> SearchRes # Map tool.name to SourceName (handle tool names that don't match SourceName literals) tool_name = tool_name_to_source.get(tool.name, cast(SourceName, tool.name)) - if tool_name not in ["pubmed", "clinicaltrials", "biorxiv", "europepmc", "preprint", "rag", "web"]: + if tool_name not in [ + "pubmed", + "clinicaltrials", + "biorxiv", + "europepmc", + "preprint", + "rag", + "web", + ]: logger.warning( "Tool name not in SourceName literals, defaulting to 'web'", tool_name=tool.name, @@ -173,18 +183,20 @@ async def execute(self, query: str, max_results_per_tool: int = 10) -> SearchRes disease = query if "for" in query.lower(): disease = query.split("for")[-1].strip().rstrip("?") - + # Convert Evidence objects to dicts for Neo4j papers = [] for ev in all_evidence: - papers.append({ - 'id': ev.citation.url or '', - 'title': ev.citation.title or '', - 'abstract': ev.content, - 'url': ev.citation.url or '', - 'source': ev.citation.source, - }) - + papers.append( + { + "id": ev.citation.url or "", + "title": ev.citation.title or "", + "abstract": ev.content, + "url": ev.citation.url or "", + "source": ev.citation.source, + } + ) + stats = neo4j_service.ingest_search_results(disease, papers) logger.info("πŸ’Ύ Saved to Neo4j", stats=stats) except Exception as e: diff --git a/src/tools/searchxng_web_search.py b/src/tools/searchxng_web_search.py index 80cf8be5..e120b61f 100644 --- a/src/tools/searchxng_web_search.py +++ b/src/tools/searchxng_web_search.py @@ -85,12 +85,17 @@ async def search(self, query: str, max_results: int = 10) -> list[Evidence]: # Convert ScrapeResult to Evidence objects evidence = [] for result in scraped: + # Truncate title to max 500 characters to match Citation model validation + title = result.title + if len(title) > 500: + title = title[:497] + "..." + ev = Evidence( content=result.text, citation=Citation( - title=result.title, + title=title, url=result.url, - source="searchxng", + source="web", # Use "web" to match SourceName literal, not "searchxng" date="Unknown", authors=[], ), @@ -113,13 +118,3 @@ async def search(self, query: str, max_results: int = 10) -> list[Evidence]: except Exception as e: logger.error("Unexpected error in SearchXNG search", error=str(e), query=final_query) raise SearchError(f"SearchXNG search failed: {e}") from e - - - - - - - - - - diff --git a/src/tools/serper_web_search.py b/src/tools/serper_web_search.py index 79e9449e..77fc6eb0 100644 --- a/src/tools/serper_web_search.py +++ b/src/tools/serper_web_search.py @@ -85,12 +85,17 @@ async def search(self, query: str, max_results: int = 10) -> list[Evidence]: # Convert ScrapeResult to Evidence objects evidence = [] for result in scraped: + # Truncate title to max 500 characters to match Citation model validation + title = result.title + if len(title) > 500: + title = title[:497] + "..." + ev = Evidence( content=result.text, citation=Citation( - title=result.title, + title=title, url=result.url, - source="serper", + source="web", # Use "web" to match SourceName literal, not "serper" date="Unknown", authors=[], ), @@ -113,13 +118,3 @@ async def search(self, query: str, max_results: int = 10) -> list[Evidence]: except Exception as e: logger.error("Unexpected error in Serper search", error=str(e), query=final_query) raise SearchError(f"Serper search failed: {e}") from e - - - - - - - - - - diff --git a/src/tools/tool_executor.py b/src/tools/tool_executor.py index 04099226..e4999607 100644 --- a/src/tools/tool_executor.py +++ b/src/tools/tool_executor.py @@ -182,9 +182,9 @@ async def execute_tool_tasks( results[f"{task.agent}_{i}"] = ToolAgentOutput(output=f"Error: {result!s}", sources=[]) else: # Type narrowing: result is ToolAgentOutput after Exception check - assert isinstance( - result, ToolAgentOutput - ), "Expected ToolAgentOutput after Exception check" + assert isinstance(result, ToolAgentOutput), ( + "Expected ToolAgentOutput after Exception check" + ) key = f"{task.agent}_{task.gap or i}" if task.gap else f"{task.agent}_{i}" results[key] = result diff --git a/src/tools/vendored/__init__.py b/src/tools/vendored/__init__.py index 6db9f15a..e96e8252 100644 --- a/src/tools/vendored/__init__.py +++ b/src/tools/vendored/__init__.py @@ -16,12 +16,12 @@ __all__ = [ "CONTENT_LENGTH_LIMIT", "ScrapeResult", - "WebpageSnippet", - "SerperClient", "SearchXNGClient", - "scrape_urls", + "SerperClient", + "WebpageSnippet", + "crawl_website", "fetch_and_process_url", "html_to_text", "is_valid_url", - "crawl_website", + "scrape_urls", ] diff --git a/src/tools/vendored/crawl_website.py b/src/tools/vendored/crawl_website.py index 32fc2d94..cd75bb50 100644 --- a/src/tools/vendored/crawl_website.py +++ b/src/tools/vendored/crawl_website.py @@ -20,6 +20,63 @@ logger = structlog.get_logger() +async def _extract_links( + html: str, current_url: str, base_domain: str +) -> tuple[list[str], list[str]]: + """Extract prioritized links from HTML content.""" + soup = BeautifulSoup(html, "html.parser") + nav_links = set() + body_links = set() + + # Find navigation/header links + for nav_element in soup.find_all(["nav", "header"]): + for a in nav_element.find_all("a", href=True): + href = str(a["href"]) + link = urljoin(current_url, href) + if urlparse(link).netloc == base_domain: + nav_links.add(link) + + # Find remaining body links + for a in soup.find_all("a", href=True): + href = str(a["href"]) + link = urljoin(current_url, href) + if urlparse(link).netloc == base_domain and link not in nav_links: + body_links.add(link) + + return list(nav_links), list(body_links) + + +async def _fetch_page(url: str) -> str: + """Fetch HTML content from a URL.""" + connector = aiohttp.TCPConnector(ssl=ssl_context) + async with aiohttp.ClientSession(connector=connector) as session: + try: + timeout = aiohttp.ClientTimeout(total=30) + async with session.get(url, timeout=timeout) as response: + if response.status == 200: + return await response.text() + return "" + except Exception as e: + logger.warning("Error fetching URL", url=url, error=str(e)) + return "" + + +def _add_links_to_queue( + links: list[str], + queue: list[str], + all_pages_to_scrape: set[str], + remaining_slots: int, +) -> int: + """Add normalized links to queue if not already visited.""" + for link in links: + normalized_link = link.rstrip("/") + if normalized_link not in all_pages_to_scrape and remaining_slots > 0: + queue.append(normalized_link) + all_pages_to_scrape.add(normalized_link) + remaining_slots -= 1 + return remaining_slots + + async def crawl_website(starting_url: str) -> list[ScrapeResult] | str: """Crawl the pages of a website starting with the starting_url and then descending into the pages linked from there. @@ -45,41 +102,6 @@ async def crawl_website(starting_url: str) -> list[ScrapeResult] | str: max_pages = 10 base_domain = urlparse(starting_url).netloc - async def extract_links(html: str, current_url: str) -> tuple[list[str], list[str]]: - """Extract prioritized links from HTML content""" - soup = BeautifulSoup(html, "html.parser") - nav_links = set() - body_links = set() - - # Find navigation/header links - for nav_element in soup.find_all(["nav", "header"]): - for a in nav_element.find_all("a", href=True): - link = urljoin(current_url, a["href"]) - if urlparse(link).netloc == base_domain: - nav_links.add(link) - - # Find remaining body links - for a in soup.find_all("a", href=True): - link = urljoin(current_url, a["href"]) - if urlparse(link).netloc == base_domain and link not in nav_links: - body_links.add(link) - - return list(nav_links), list(body_links) - - async def fetch_page(url: str) -> str: - """Fetch HTML content from a URL""" - connector = aiohttp.TCPConnector(ssl=ssl_context) - async with aiohttp.ClientSession(connector=connector) as session: - try: - timeout = aiohttp.ClientTimeout(total=30) - async with session.get(url, timeout=timeout) as response: - if response.status == 200: - return await response.text() - return "" - except Exception as e: - logger.warning("Error fetching URL", url=url, error=str(e)) - return "" - # Initialize with starting URL queue: list[str] = [starting_url] next_level_queue: list[str] = [] @@ -90,26 +112,20 @@ async def fetch_page(url: str) -> str: current_url = queue.pop(0) # Fetch and process the page - html_content = await fetch_page(current_url) + html_content = await _fetch_page(current_url) if html_content: - nav_links, body_links = await extract_links(html_content, current_url) + nav_links, body_links = await _extract_links(html_content, current_url, base_domain) # Add unvisited nav links to current queue (higher priority) remaining_slots = max_pages - len(all_pages_to_scrape) - for link in nav_links: - link = link.rstrip("/") - if link not in all_pages_to_scrape and remaining_slots > 0: - queue.append(link) - all_pages_to_scrape.add(link) - remaining_slots -= 1 + remaining_slots = _add_links_to_queue( + nav_links, queue, all_pages_to_scrape, remaining_slots + ) # Add unvisited body links to next level queue (lower priority) - for link in body_links: - link = link.rstrip("/") - if link not in all_pages_to_scrape and remaining_slots > 0: - next_level_queue.append(link) - all_pages_to_scrape.add(link) - remaining_slots -= 1 + remaining_slots = _add_links_to_queue( + body_links, next_level_queue, all_pages_to_scrape, remaining_slots + ) # If current queue is empty, add next level links if not queue: diff --git a/src/tools/vendored/web_search_core.py b/src/tools/vendored/web_search_core.py index 391b5c2e..a465f398 100644 --- a/src/tools/vendored/web_search_core.py +++ b/src/tools/vendored/web_search_core.py @@ -199,13 +199,3 @@ def is_valid_url(url: str) -> bool: if any(ext in url for ext in restricted_extensions): return False return True - - - - - - - - - - diff --git a/src/tools/web_search.py b/src/tools/web_search.py index 318b43d3..f6350e82 100644 --- a/src/tools/web_search.py +++ b/src/tools/web_search.py @@ -3,6 +3,7 @@ import asyncio import structlog + try: from ddgs import DDGS # New package name except ImportError: @@ -55,10 +56,15 @@ def _do_search() -> list[dict[str, str]]: evidence = [] for r in raw_results: + # Truncate title to max 500 characters to match Citation model validation + title = r.get("title", "No Title") + if len(title) > 500: + title = title[:497] + "..." + ev = Evidence( content=r.get("body", ""), citation=Citation( - title=r.get("title", "No Title"), + title=title, url=r.get("href", ""), source="web", date="Unknown", diff --git a/src/tools/web_search_adapter.py b/src/tools/web_search_adapter.py index 4b4bfab0..6ee833f2 100644 --- a/src/tools/web_search_adapter.py +++ b/src/tools/web_search_adapter.py @@ -53,11 +53,3 @@ async def web_search(query: str) -> str: except Exception as e: logger.error("Web search failed", error=str(e), query=query) return f"Error performing web search: {e!s}" - - - - - - - - diff --git a/src/tools/web_search_factory.py b/src/tools/web_search_factory.py index fae4c5ff..213de5de 100644 --- a/src/tools/web_search_factory.py +++ b/src/tools/web_search_factory.py @@ -12,19 +12,66 @@ logger = structlog.get_logger() -def create_web_search_tool() -> SearchTool | None: +def create_web_search_tool(provider: str | None = None) -> SearchTool | None: """Create a web search tool based on configuration. + Args: + provider: Override provider selection. If None, uses settings.web_search_provider. + Returns: SearchTool instance, or None if not available/configured - The tool is selected based on settings.web_search_provider: + The tool is selected based on provider (or settings.web_search_provider if None): - "serper": SerperWebSearchTool (requires SERPER_API_KEY) - "searchxng": SearchXNGWebSearchTool (requires SEARCHXNG_HOST) - "duckduckgo": WebSearchTool (always available, no API key) - "brave" or "tavily": Not yet implemented, returns None + - "auto": Auto-detect best available provider (prefers Serper > SearchXNG > DuckDuckGo) + + Auto-detection logic (when provider is "auto" or not explicitly set): + 1. Try Serper if SERPER_API_KEY is available (best quality - Google search + full content scraping) + 2. Try SearchXNG if SEARCHXNG_HOST is available + 3. Fall back to DuckDuckGo (always available, but lower quality - snippets only) """ - provider = settings.web_search_provider + provider = provider or settings.web_search_provider + + # Auto-detect best available provider if "auto" or if provider is duckduckgo but better options exist + if provider == "auto" or (provider == "duckduckgo" and settings.serper_api_key): + # Prefer Serper if API key is available (better quality) + if settings.serper_api_key: + try: + logger.info( + "Auto-detected Serper web search (SERPER_API_KEY found)", + provider="serper", + ) + return SerperWebSearchTool() + except Exception as e: + logger.warning( + "Failed to initialize Serper, falling back", + error=str(e), + ) + + # Try SearchXNG as second choice + if settings.searchxng_host: + try: + logger.info( + "Auto-detected SearchXNG web search (SEARCHXNG_HOST found)", + provider="searchxng", + ) + return SearchXNGWebSearchTool() + except Exception as e: + logger.warning( + "Failed to initialize SearchXNG, falling back", + error=str(e), + ) + + # Fall back to DuckDuckGo + if provider == "auto": + logger.info( + "Auto-detected DuckDuckGo web search (no API keys found)", + provider="duckduckgo", + ) + return WebSearchTool() try: if provider == "serper": @@ -66,13 +113,3 @@ def create_web_search_tool() -> SearchTool | None: except Exception as e: logger.error("Unexpected error creating web search tool", error=str(e), provider=provider) return None - - - - - - - - - - diff --git a/src/utils/config.py b/src/utils/config.py index b6b2a0ba..9d4356d7 100644 --- a/src/utils/config.py +++ b/src/utils/config.py @@ -61,6 +61,15 @@ class Settings(BaseSettings): default="meta-llama/Llama-3.1-8B-Instruct", description="Default HuggingFace model ID for inference", ) + hf_fallback_models: str = Field( + default="Qwen/Qwen3-Next-80B-A3B-Thinking,Qwen/Qwen3-Next-80B-A3B-Instruct,meta-llama/Llama-3.3-70B-Instruct,meta-llama/Llama-3.1-8B-Instruct,HuggingFaceH4/zephyr-7b-beta,Qwen/Qwen2-7B-Instruct", + alias="HF_FALLBACK_MODELS", + description=( + "Comma-separated list of fallback models for provider discovery and error recovery. " + "Reads from HF_FALLBACK_MODELS environment variable. " + "Default value is used only if the environment variable is not set." + ), + ) # PubMed Configuration ncbi_api_key: str | None = Field( @@ -68,9 +77,11 @@ class Settings(BaseSettings): ) # Web Search Configuration - web_search_provider: Literal["serper", "searchxng", "brave", "tavily", "duckduckgo"] = Field( - default="duckduckgo", - description="Web search provider to use", + web_search_provider: Literal["serper", "searchxng", "brave", "tavily", "duckduckgo", "auto"] = ( + Field( + default="auto", + description="Web search provider to use. 'auto' will auto-detect best available (prefers Serper > SearchXNG > DuckDuckGo)", + ) ) serper_api_key: str | None = Field(default=None, description="Serper API key for Google search") searchxng_host: str | None = Field(default=None, description="SearchXNG host URL") @@ -163,6 +174,10 @@ class Settings(BaseSettings): le=2.0, description="TTS speech speed multiplier (0.5x to 2.0x)", ) + tts_use_llm_polish: bool = Field( + default=False, + description="Use LLM for final text polish before TTS (optional, costs API calls)", + ) tts_gpu: str | None = Field( default=None, description="Modal GPU type for TTS (T4, A10, A100, L4, L40S). None uses default T4.", @@ -269,6 +284,19 @@ def web_search_available(self) -> bool: return bool(self.tavily_api_key) return False + def get_hf_fallback_models_list(self) -> list[str]: + """Get the list of fallback models as a list. + + Parses the comma-separated HF_FALLBACK_MODELS string into a list, + stripping whitespace from each model ID. + + Returns: + List of model IDs + """ + if not self.hf_fallback_models: + return [] + return [model.strip() for model in self.hf_fallback_models.split(",") if model.strip()] + def get_settings() -> Settings: """Factory function to get settings (allows mocking in tests).""" diff --git a/src/utils/hf_error_handler.py b/src/utils/hf_error_handler.py new file mode 100644 index 00000000..becc6862 --- /dev/null +++ b/src/utils/hf_error_handler.py @@ -0,0 +1,199 @@ +"""Utility functions for handling HuggingFace API errors and token validation.""" + +import re +from typing import Any + +import structlog + +logger = structlog.get_logger() + + +def extract_error_details(error: Exception) -> dict[str, Any]: + """Extract error details from HuggingFace API errors. + + Pydantic AI and HuggingFace Inference API errors often contain + information in the error message string like: + "status_code: 403, model_name: Qwen/Qwen3-Next-80B-A3B-Thinking, body: Forbidden" + + Args: + error: The exception object + + Returns: + Dictionary with extracted error details: + - status_code: HTTP status code (if found) + - model_name: Model name (if found) + - body: Error body/message (if found) + - error_type: Type of error (403, 422, etc.) + - is_auth_error: Whether this is an authentication/authorization error + - is_model_error: Whether this is a model-specific error + """ + error_str = str(error) + details: dict[str, Any] = { + "status_code": None, + "model_name": None, + "body": None, + "error_type": "unknown", + "is_auth_error": False, + "is_model_error": False, + } + + # Try to extract status_code + status_match = re.search(r"status_code:\s*(\d+)", error_str) + if status_match: + details["status_code"] = int(status_match.group(1)) + details["error_type"] = f"http_{details['status_code']}" + + # Determine error category + if details["status_code"] == 403: + details["is_auth_error"] = True + elif details["status_code"] == 422: + details["is_model_error"] = True + + # Try to extract model_name + model_match = re.search(r"model_name:\s*([^\s,]+)", error_str) + if model_match: + details["model_name"] = model_match.group(1) + + # Try to extract body + body_match = re.search(r"body:\s*(.+)", error_str) + if body_match: + details["body"] = body_match.group(1).strip() + + return details + + +def get_user_friendly_error_message(error: Exception, model_name: str | None = None) -> str: + """Generate a user-friendly error message from an exception. + + Args: + error: The exception object + model_name: Optional model name for context + + Returns: + User-friendly error message + """ + details = extract_error_details(error) + + if details["is_auth_error"]: + return ( + "πŸ” **Authentication Error**\n\n" + "Your HuggingFace token doesn't have permission to access this model or API.\n\n" + "**Possible solutions:**\n" + "1. **Re-authenticate**: Log out and log back in to ensure your token has the `inference-api` scope\n" + "2. **Check model access**: Visit the model page on HuggingFace and request access if it's gated\n" + "3. **Use alternative model**: Try a different model that's publicly available\n\n" + f"**Model attempted**: {details['model_name'] or model_name or 'Unknown'}\n" + f"**Error**: {details['body'] or str(error)}" + ) + + if details["is_model_error"]: + return ( + "⚠️ **Model Compatibility Error**\n\n" + "The selected model is not compatible with the current provider or has specific requirements.\n\n" + "**Possible solutions:**\n" + "1. **Try a different model**: Use a model that's compatible with the current provider\n" + "2. **Check provider status**: The provider may be in staging mode or unavailable\n" + "3. **Wait and retry**: If the model is in staging, it may become available later\n\n" + f"**Model attempted**: {details['model_name'] or model_name or 'Unknown'}\n" + f"**Error**: {details['body'] or str(error)}" + ) + + # Generic error + return ( + "❌ **API Error**\n\n" + f"An error occurred while calling the HuggingFace API:\n\n" + f"**Error**: {error!s}\n\n" + "Please try again or contact support if the issue persists." + ) + + +def validate_hf_token(token: str | None) -> tuple[bool, str | None]: + """Validate HuggingFace token format. + + Args: + token: The token to validate + + Returns: + Tuple of (is_valid, error_message) + - is_valid: True if token appears valid + - error_message: Error message if invalid, None if valid + """ + if not token: + return False, "Token is None or empty" + + if not isinstance(token, str): + return False, f"Token is not a string (type: {type(token).__name__})" + + if len(token) < 10: + return False, "Token appears too short (minimum 10 characters expected)" + + # HuggingFace tokens typically start with "hf_" for user tokens + # OAuth tokens may have different formats, so we're lenient + # Just check it's not obviously invalid + + return True, None + + +def log_token_info(token: str | None, context: str = "") -> None: + """Log token information for debugging (without exposing the actual token). + + Args: + token: The token to log info about + context: Additional context for the log message + """ + if token: + is_valid, error_msg = validate_hf_token(token) + logger.debug( + "Token validation", + context=context, + has_token=True, + is_valid=is_valid, + token_length=len(token), + token_prefix=token[:4] + "..." if len(token) > 4 else "***", + validation_error=error_msg, + ) + else: + logger.debug("Token validation", context=context, has_token=False) + + +def should_retry_with_fallback(error: Exception) -> bool: + """Determine if an error should trigger a fallback to alternative models. + + Args: + error: The exception object + + Returns: + True if the error suggests we should try a fallback model + """ + details = extract_error_details(error) + + # Retry with fallback for: + # - 403 errors (authentication/permission issues - might work with different model) + # - 422 errors (model/provider compatibility - definitely try different model) + # - Model-specific errors + return ( + details["is_auth_error"] or details["is_model_error"] or details["model_name"] is not None + ) + + +def get_fallback_models(original_model: str | None = None) -> list[str]: + """Get a list of fallback models to try. + + Args: + original_model: The original model that failed + + Returns: + List of fallback model names to try in order + """ + # Publicly available models that should work with most tokens + fallbacks = [ + "meta-llama/Llama-3.1-8B-Instruct", # Common, often available + "mistralai/Mistral-7B-Instruct-v0.3", # Alternative + "HuggingFaceH4/zephyr-7b-beta", # Ungated fallback + ] + + # If original model is in the list, remove it + if original_model and original_model in fallbacks: + fallbacks.remove(original_model) + + return fallbacks diff --git a/src/utils/hf_model_validator.py b/src/utils/hf_model_validator.py new file mode 100644 index 00000000..e7dd75ee --- /dev/null +++ b/src/utils/hf_model_validator.py @@ -0,0 +1,477 @@ +"""Validator for querying available HuggingFace models and providers using OAuth token. + +This module provides functions to: +1. Query available models from HuggingFace Hub +2. Query available inference providers (with dynamic discovery) +3. Validate model/provider combinations +4. Return formatted lists for Gradio dropdowns + +Uses Hugging Face Hub API to discover providers dynamically by querying model +information. Falls back to known providers list if discovery fails. +""" + +import asyncio +from time import time +from typing import Any + +import structlog +from huggingface_hub import HfApi + +from src.utils.config import settings + +logger = structlog.get_logger() + + +def extract_oauth_token(oauth_token: Any) -> str | None: + """Extract OAuth token value from Gradio OAuthToken object. + + Handles both gr.OAuthToken objects (with .token attribute) and plain strings. + This is a convenience function for Gradio apps that use OAuth authentication. + + Args: + oauth_token: Gradio OAuthToken object or string token + + Returns: + Token string if available, None otherwise + """ + if oauth_token is None: + return None + + if hasattr(oauth_token, "token"): + return oauth_token.token # type: ignore[no-any-return] + elif isinstance(oauth_token, str): + return oauth_token + + logger.warning( + "Could not extract token from OAuthToken object", + oauth_token_type=type(oauth_token).__name__, + ) + return None + + +# Known providers as fallback (updated from Hugging Face documentation) +# These are used when dynamic discovery fails or times out +KNOWN_PROVIDERS = [ + "auto", # Auto-select (always available) + "hf-inference", # HuggingFace's own Inference API + "nebius", + "together", + "scaleway", + "hyperbolic", + "novita", + "nscale", + "sambanova", + "ovh", + "fireworks-ai", # Note: API uses "fireworks-ai", not "fireworks" + "cerebras", + "fal-ai", + "cohere", +] + + +def get_provider_discovery_models() -> list[str]: + """Get list of models to use for provider discovery. + + Reads from HF_FALLBACK_MODELS environment variable via settings. + The environment variable should be a comma-separated list of model IDs. + + Returns: + List of model IDs to query for provider discovery + """ + # Get models from HF_FALLBACK_MODELS environment variable + # This is automatically read by Pydantic Settings from the env var + fallback_models = settings.get_hf_fallback_models_list() + + logger.debug( + "Using HF_FALLBACK_MODELS for provider discovery", + count=len(fallback_models), + models=fallback_models, + ) + + return fallback_models + + +# Simple in-memory cache for provider lists (TTL: 1 hour) +_provider_cache: dict[str, tuple[list[str], float]] = {} +PROVIDER_CACHE_TTL = 3600 # 1 hour in seconds + + +async def get_available_providers(token: str | None = None) -> list[str]: + """Get list of available inference providers. + + Discovers providers dynamically by querying model information from HuggingFace Hub. + Uses caching to avoid repeated API calls. Falls back to known providers if discovery fails. + + Strategy: + 1. Check cache (if valid, return cached list) + 2. Query popular models to extract unique providers from their inferenceProviderMapping + 3. Fall back to known providers list if discovery fails + 4. Cache results for future use + + Args: + token: Optional HuggingFace API token for authenticated requests + Can be extracted from gr.OAuthToken.token in Gradio apps + + Returns: + List of provider names sorted alphabetically, with "auto" first + (e.g., ["auto", "fireworks-ai", "hf-inference", "nebius", ...]) + """ + # Check cache first + cache_key = "providers" + (f"_{token[:8]}" if token else "_no_token") + if cache_key in _provider_cache: + cached_providers, cache_time = _provider_cache[cache_key] + if time() - cache_time < PROVIDER_CACHE_TTL: + logger.debug("Returning cached providers", count=len(cached_providers)) + return cached_providers + + try: + providers = set(["auto"]) # Always include "auto" + + # Try dynamic discovery by querying popular models + loop = asyncio.get_running_loop() + api = HfApi(token=token) + + # Get models to query from HF_FALLBACK_MODELS environment variable via settings + discovery_models = get_provider_discovery_models() + + # Query a sample of popular models to discover providers + # This is more efficient than querying all models + discovery_count = 0 + for model_id in discovery_models: + try: + + def _get_model_info(m: str) -> Any: + """Get model info synchronously.""" + return api.model_info(m, expand=["inferenceProviderMapping"]) # type: ignore[arg-type] + + info = await loop.run_in_executor(None, _get_model_info, model_id) + + # Extract providers from inference_provider_mapping + if hasattr(info, "inference_provider_mapping") and info.inference_provider_mapping: + mapping = info.inference_provider_mapping + # mapping is a dict like {'hf-inference': InferenceProviderMapping(...), ...} + providers.update(mapping.keys()) + discovery_count += 1 + logger.debug( + "Discovered providers from model", + model=model_id, + providers=list(mapping.keys()), + ) + except Exception as e: + logger.debug( + "Could not get provider info for model", + model=model_id, + error=str(e), + ) + continue + + # If we discovered providers, use them; otherwise fall back to known providers + if len(providers) > 1: # More than just "auto" + provider_list = sorted(list(providers)) + logger.info( + "Discovered providers dynamically", + count=len(provider_list), + models_queried=discovery_count, + has_token=bool(token), + ) + else: + # Fallback to known providers + provider_list = KNOWN_PROVIDERS.copy() + logger.info( + "Using known providers list (discovery failed or incomplete)", + count=len(provider_list), + models_queried=discovery_count, + ) + + # Cache the results + _provider_cache[cache_key] = (provider_list, time()) + + return provider_list + + except Exception as e: + logger.warning("Failed to get providers", error=str(e)) + # Return known providers as fallback + return KNOWN_PROVIDERS.copy() + + +async def get_available_models( + token: str | None = None, + task: str = "text-generation", + limit: int = 100, + inference_provider: str | None = None, +) -> list[str]: + """Get list of available models for text generation. + + Queries HuggingFace Hub API to get models that support text generation. + Optionally filters by inference provider to show only models available via that provider. + + Args: + token: Optional HuggingFace API token for authenticated requests + Can be extracted from gr.OAuthToken.token in Gradio apps + task: Task type to filter models (default: "text-generation") + limit: Maximum number of models to return + inference_provider: Optional provider name to filter models (e.g., "fireworks-ai", "nebius") + If None, returns all models for the task + + Returns: + List of model IDs (e.g., ["meta-llama/Llama-3.1-8B-Instruct", ...]) + """ + try: + loop = asyncio.get_running_loop() + + def _fetch_models() -> list[str]: + """Fetch models synchronously in executor.""" + api = HfApi(token=token) + + # Build query parameters + query_params: dict[str, Any] = { + "task": task, + "sort": "downloads", + "direction": -1, + "limit": limit, + } + + # Filter by inference provider if specified + if inference_provider and inference_provider != "auto": + query_params["inference_provider"] = inference_provider + + # Search for models + models = api.list_models(**query_params) + + # Extract model IDs + model_ids = [model.id for model in models] + return model_ids + + model_ids = await loop.run_in_executor(None, _fetch_models) + + logger.info( + "Fetched available models", + count=len(model_ids), + task=task, + provider=inference_provider or "all", + has_token=bool(token), + ) + + return model_ids + + except Exception as e: + logger.warning("Failed to get models from Hub API", error=str(e)) + # Return popular fallback models + return [ + "meta-llama/Llama-3.1-8B-Instruct", + "mistralai/Mistral-7B-Instruct-v0.3", + "HuggingFaceH4/zephyr-7b-beta", + "google/gemma-2-9b-it", + ] + + +async def validate_model_provider_combination( + model_id: str, + provider: str | None, + token: str | None = None, +) -> tuple[bool, str | None]: + """Validate that a model is available with a specific provider. + + Uses HuggingFace Hub API to check if the provider is listed in the model's + inferenceProviderMapping. This is faster and more reliable than making test API calls. + + Args: + model_id: HuggingFace model ID + provider: Provider name (or None/empty for auto) + token: Optional HuggingFace API token (from gr.OAuthToken.token) + + Returns: + Tuple of (is_valid, error_message) + - is_valid: True if combination is valid or provider is "auto" + - error_message: Error message if invalid, None if valid + """ + # "auto" is always valid - let HuggingFace select the provider + if not provider or provider == "auto": + return True, None + + try: + loop = asyncio.get_running_loop() + api = HfApi(token=token) + + def _get_model_info() -> Any: + """Get model info with provider mapping synchronously.""" + return api.model_info(model_id, expand=["inferenceProviderMapping"]) # type: ignore[arg-type] + + info = await loop.run_in_executor(None, _get_model_info) + + # Check if provider is in the model's inference provider mapping + if hasattr(info, "inference_provider_mapping") and info.inference_provider_mapping: + mapping = info.inference_provider_mapping + available_providers = set(mapping.keys()) + + # Normalize provider name (some APIs use "fireworks-ai", others use "fireworks") + normalized_provider = provider.lower() + provider_variants = {normalized_provider} + + # Handle common provider name variations + if normalized_provider == "fireworks": + provider_variants.add("fireworks-ai") + elif normalized_provider == "fireworks-ai": + provider_variants.add("fireworks") + + # Check if any variant matches + if any(p in available_providers for p in provider_variants): + logger.debug( + "Model/provider combination validated via API", + model=model_id, + provider=provider, + available_providers=list(available_providers), + ) + return True, None + else: + error_msg = ( + f"Model {model_id} is not available with provider '{provider}'. " + f"Available providers: {', '.join(sorted(available_providers))}" + ) + logger.debug( + "Model/provider combination invalid", + model=model_id, + provider=provider, + available_providers=list(available_providers), + ) + return False, error_msg + else: + # Model doesn't have provider mapping - assume valid and let actual usage determine + logger.debug( + "Model has no provider mapping, assuming valid", + model=model_id, + provider=provider, + ) + return True, None + + except Exception as e: + logger.warning( + "Model/provider validation failed", + model=model_id, + provider=provider, + error=str(e), + ) + # Don't fail validation on error - let the actual request fail + # This is more user-friendly than blocking on validation errors + return True, None + + +async def get_models_for_provider( + provider: str, + token: str | None = None, + limit: int = 50, +) -> list[str]: + """Get models available for a specific provider. + + This is a convenience wrapper around get_available_models() with provider filtering. + + Args: + provider: Provider name (e.g., "nebius", "together", "fireworks-ai") + Note: Use "fireworks-ai" not "fireworks" for the API + token: Optional HuggingFace API token (from gr.OAuthToken.token) + limit: Maximum number of models to return + + Returns: + List of model IDs available for the provider + """ + # Normalize provider name for API + normalized_provider = provider + if provider.lower() == "fireworks": + normalized_provider = "fireworks-ai" + logger.debug("Normalized provider name", original=provider, normalized=normalized_provider) + + return await get_available_models( + token=token, + task="text-generation", + limit=limit, + inference_provider=normalized_provider, + ) + + +async def validate_oauth_token(token: str | None) -> dict[str, Any]: + """Validate OAuth token and return available resources. + + Args: + token: OAuth token to validate + + Returns: + Dictionary with: + - is_valid: Whether token is valid + - has_inference_api_scope: Whether token has inference-api scope + - available_models: List of available model IDs + - available_providers: List of available provider names + - username: HuggingFace username (if available) + - error: Error message if validation failed + """ + result: dict[str, Any] = { + "is_valid": False, + "has_inference_api_scope": False, + "available_models": [], + "available_providers": [], + "username": None, + "error": None, + } + + if not token: + result["error"] = "No token provided" + return result + + try: + # Validate token format + from src.utils.hf_error_handler import validate_hf_token + + is_valid_format, format_error = validate_hf_token(token) + if not is_valid_format: + result["error"] = f"Invalid token format: {format_error}" + return result + + # Try to get user info to validate token + loop = asyncio.get_running_loop() + + def _get_user_info() -> dict[str, Any] | None: + """Get user info from HuggingFace API.""" + try: + api = HfApi(token=token) + user_info = api.whoami() + return user_info + except Exception: + return None + + user_info = await loop.run_in_executor(None, _get_user_info) + + if user_info: + result["is_valid"] = True + result["username"] = user_info.get("name") or user_info.get("fullname") + logger.info("Token validated", username=result["username"]) + else: + result["error"] = "Token validation failed - could not authenticate" + return result + + # Try to query models to check inference-api scope + try: + models = await get_available_models(token=token, limit=10) + if models: + result["has_inference_api_scope"] = True + result["available_models"] = models + logger.info("Inference API scope confirmed", model_count=len(models)) + except Exception as e: + logger.warning("Could not verify inference-api scope", error=str(e)) + # Token might be valid but without inference-api scope + result["has_inference_api_scope"] = False + result["error"] = f"Token may not have inference-api scope: {e}" + + # Get available providers + try: + providers = await get_available_providers(token=token) + result["available_providers"] = providers + except Exception as e: + logger.warning("Could not get providers", error=str(e)) + # Use fallback providers + result["available_providers"] = ["auto"] + + return result + + except Exception as e: + logger.error("Token validation failed", error=str(e)) + result["error"] = str(e) + return result diff --git a/src/utils/llm_factory.py b/src/utils/llm_factory.py index bcd958bd..20095682 100644 --- a/src/utils/llm_factory.py +++ b/src/utils/llm_factory.py @@ -16,9 +16,13 @@ from typing import TYPE_CHECKING, Any +import structlog + from src.utils.config import settings from src.utils.exceptions import ConfigurationError +logger = structlog.get_logger() + if TYPE_CHECKING: from agent_framework.openai import OpenAIChatClient @@ -98,7 +102,7 @@ def get_chat_client_for_agent(oauth_token: str | None = None) -> Any: """ # Check if we have OAuth token or env vars has_hf_key = bool(oauth_token or settings.has_huggingface_key) - + # Prefer HuggingFace if available (free tier) if has_hf_key: return get_huggingface_chat_client(oauth_token=oauth_token) @@ -147,6 +151,19 @@ def get_pydantic_ai_model(oauth_token: str | None = None) -> Any: "3. Set huggingface_api_key in settings" ) + # Validate and log token information + from src.utils.hf_error_handler import log_token_info, validate_hf_token + + log_token_info(effective_hf_token, context="get_pydantic_ai_model") + is_valid, error_msg = validate_hf_token(effective_hf_token) + if not is_valid: + logger.warning( + "Token validation failed in get_pydantic_ai_model", + error=error_msg, + has_oauth=bool(oauth_token), + ) + # Continue anyway - let the API call fail with a clear error + # Always use HuggingFace with available token model_name = settings.huggingface_model or "meta-llama/Llama-3.1-8B-Instruct" hf_provider = HuggingFaceProvider(api_key=effective_hf_token) diff --git a/src/utils/markdown.css b/src/utils/markdown.css index b083c296..1854a4bb 100644 --- a/src/utils/markdown.css +++ b/src/utils/markdown.css @@ -5,12 +5,3 @@ body { color: #000; } - - - - - - - - - diff --git a/src/utils/md_to_pdf.py b/src/utils/md_to_pdf.py index 940d707d..08e7f71d 100644 --- a/src/utils/md_to_pdf.py +++ b/src/utils/md_to_pdf.py @@ -1,6 +1,5 @@ """Utility for converting markdown to PDF.""" -import os from pathlib import Path from typing import TYPE_CHECKING @@ -43,9 +42,7 @@ def md_to_pdf(md_text: str, pdf_file_path: str) -> None: OSError: If PDF file cannot be written """ if not _MD2PDF_AVAILABLE: - raise ImportError( - "md2pdf is not installed. Install it with: pip install md2pdf" - ) + raise ImportError("md2pdf is not installed. Install it with: pip install md2pdf") if not md_text or not md_text.strip(): raise ValueError("Markdown text cannot be empty") @@ -64,13 +61,3 @@ def md_to_pdf(md_text: str, pdf_file_path: str) -> None: md2pdf(pdf_file_path, md_text, css_file_path=str(css_path)) logger.debug("PDF generated successfully", pdf_path=pdf_file_path) - - - - - - - - - - diff --git a/src/utils/message_history.py b/src/utils/message_history.py new file mode 100644 index 00000000..5ddbbe30 --- /dev/null +++ b/src/utils/message_history.py @@ -0,0 +1,169 @@ +"""Message history utilities for Pydantic AI integration.""" + +from typing import Any + +import structlog + +try: + from pydantic_ai import ModelMessage, ModelRequest, ModelResponse + from pydantic_ai.messages import TextPart, UserPromptPart + + _PYDANTIC_AI_AVAILABLE = True +except ImportError: + # Fallback for older pydantic-ai versions + ModelMessage = Any # type: ignore[assignment, misc] + ModelRequest = Any # type: ignore[assignment, misc] + ModelResponse = Any # type: ignore[assignment, misc] + TextPart = Any # type: ignore[assignment, misc] + UserPromptPart = Any # type: ignore[assignment, misc] + _PYDANTIC_AI_AVAILABLE = False + +logger = structlog.get_logger() + + +def convert_gradio_to_message_history( + history: list[dict[str, Any]], + max_messages: int = 20, +) -> list[ModelMessage]: + """ + Convert Gradio chat history to Pydantic AI message history. + + Args: + history: Gradio chat history format [{"role": "user", "content": "..."}, ...] + max_messages: Maximum messages to include (most recent) + + Returns: + List of ModelMessage objects for Pydantic AI + """ + if not history: + return [] + + if not _PYDANTIC_AI_AVAILABLE: + logger.warning( + "Pydantic AI message history not available, returning empty list", + ) + return [] + + messages: list[ModelMessage] = [] + + # Take most recent messages + recent = history[-max_messages:] if len(history) > max_messages else history + + for msg in recent: + role = msg.get("role", "") + content = msg.get("content", "") + + if not content or role not in ("user", "assistant"): + continue + + # Convert content to string if needed + content_str = str(content) + + if role == "user": + messages.append( + ModelRequest(parts=[UserPromptPart(content=content_str)]), + ) + elif role == "assistant": + messages.append( + ModelResponse(parts=[TextPart(content=content_str)]), + ) + + logger.debug( + "Converted Gradio history to message history", + input_turns=len(history), + output_messages=len(messages), + ) + + return messages + + +def message_history_to_string( + messages: list[ModelMessage], + max_messages: int = 5, + include_metadata: bool = False, +) -> str: + """ + Convert message history to string format for backward compatibility. + + Used during transition period when some agents still expect strings. + + Args: + messages: List of ModelMessage objects + max_messages: Maximum messages to include + include_metadata: Whether to include metadata + + Returns: + Formatted string representation + """ + if not messages: + return "" + + recent = messages[-max_messages:] if len(messages) > max_messages else messages + + parts = ["PREVIOUS CONVERSATION:", "---"] + turn_num = 1 + + for msg in recent: + # Extract text content + text = "" + if isinstance(msg, ModelRequest): + for part in msg.parts: + if hasattr(part, "content"): + text += str(part.content) + parts.append(f"[Turn {turn_num}]") + parts.append(f"User: {text}") + turn_num += 1 + elif isinstance(msg, ModelResponse): + for part in msg.parts: # type: ignore[assignment] + if hasattr(part, "content"): + text += str(part.content) + parts.append(f"Assistant: {text}") + + parts.append("---") + return "\n".join(parts) + + +def create_truncation_processor(max_messages: int = 10) -> Any: + """Create a history processor that keeps only the most recent N messages. + + Args: + max_messages: Maximum number of messages to keep + + Returns: + Processor function that takes a list of messages and returns truncated list + """ + + def processor(messages: list[ModelMessage]) -> list[ModelMessage]: + return messages[-max_messages:] if len(messages) > max_messages else messages + + return processor + + +def create_relevance_processor(min_length: int = 10) -> Any: + """Create a history processor that filters out very short messages. + + Args: + min_length: Minimum message length to keep + + Returns: + Processor function that filters messages by length + """ + + def processor(messages: list[ModelMessage]) -> list[ModelMessage]: + filtered = [] + for msg in messages: + text = "" + if isinstance(msg, ModelRequest): + for part in msg.parts: + if hasattr(part, "content"): + text += str(part.content) + elif isinstance(msg, ModelResponse): + for part in msg.parts: # type: ignore[assignment] + if hasattr(part, "content"): + text += str(part.content) + + if len(text.strip()) >= min_length: + filtered.append(msg) + return filtered + + return processor diff --git a/src/utils/models.py b/src/utils/models.py index 0582aab4..e43efac0 100644 --- a/src/utils/models.py +++ b/src/utils/models.py @@ -6,7 +6,9 @@ from pydantic import BaseModel, Field # Centralized source type - add new sources here (e.g., "biorxiv" in Phase 11) -SourceName = Literal["pubmed", "clinicaltrials", "biorxiv", "europepmc", "preprint", "rag", "web", "neo4j"] +SourceName = Literal[ + "pubmed", "clinicaltrials", "biorxiv", "europepmc", "preprint", "rag", "web", "neo4j" +] class Citation(BaseModel): diff --git a/src/utils/report_generator.py b/src/utils/report_generator.py index ad283050..330f2ad2 100644 --- a/src/utils/report_generator.py +++ b/src/utils/report_generator.py @@ -5,11 +5,101 @@ import structlog if TYPE_CHECKING: - from src.utils.models import Evidence + from src.utils.models import Citation, Evidence logger = structlog.get_logger() +def _format_authors(citation: "Citation") -> str: + """Format authors string from citation.""" + authors = ", ".join(citation.authors[:3]) + if len(citation.authors) > 3: + authors += " et al." + elif not authors: + authors = "Unknown" + return authors + + +def _add_evidence_section(report_parts: list[str], evidence: list["Evidence"]) -> None: + """Add evidence summary section to report.""" + from src.utils.models import SourceName + + report_parts.append("## Evidence Summary\n") + report_parts.append(f"**Total Sources Found:** {len(evidence)}\n\n") + + # Group evidence by source + by_source: dict[SourceName, list[Evidence]] = {} + for ev in evidence: + source = ev.citation.source + if source not in by_source: + by_source[source] = [] + by_source[source].append(ev) + + # Organize by source + for source in sorted(by_source.keys()): # type: ignore[assignment] + source_evidence = by_source[source] + report_parts.append(f"### {source.upper()} Sources ({len(source_evidence)})\n\n") + + for i, ev in enumerate(source_evidence, 1): + authors = _format_authors(ev.citation) + report_parts.append(f"#### {i}. {ev.citation.title}\n") + if authors and authors != "Unknown": + report_parts.append(f"**Authors:** {authors} \n") + report_parts.append(f"**Date:** {ev.citation.date} \n") + report_parts.append(f"**Source:** {ev.citation.source.upper()} \n") + report_parts.append(f"**URL:** {ev.citation.url} \n\n") + + # Content (truncated if too long) + content = ev.content + if len(content) > 500: + content = content[:500] + "... [truncated]" + report_parts.append(f"{content}\n\n") + + +def _add_key_findings(report_parts: list[str], evidence: list["Evidence"]) -> None: + """Add key findings section to report.""" + report_parts.append("## Key Findings\n\n") + report_parts.append( + "Based on the evidence collected, the following key points were identified:\n\n" + ) + + # Extract key points from evidence (first sentence or summary) + key_points: list[str] = [] + for ev in evidence[:10]: # Limit to top 10 + # Try to extract first meaningful sentence + content = ev.content.strip() + if content: + # Find first sentence + first_period = content.find(".") + if first_period > 0 and first_period < 200: + key_point = content[: first_period + 1].strip() + else: + # Fallback: first 150 chars + key_point = content[:150].strip() + if len(content) > 150: + key_point += "..." + key_points.append(f"- {key_point} [[{len(key_points) + 1}]](#references)") + + if key_points: + report_parts.append("\n".join(key_points)) + report_parts.append("\n\n") + else: + report_parts.append("*No specific key findings could be extracted from the evidence.*\n\n") + + +def _add_references(report_parts: list[str], evidence: list["Evidence"]) -> None: + """Add references section to report.""" + report_parts.append("## References\n\n") + for i, ev in enumerate(evidence, 1): + authors = _format_authors(ev.citation) + report_parts.append( + f"[{i}] {authors} ({ev.citation.date}). " + f"*{ev.citation.title}*. " + f"{ev.citation.source.upper()}. " + f"Available at: {ev.citation.url}\n\n" + ) + + def generate_report_from_evidence( query: str, evidence: list["Evidence"] | None = None, @@ -36,9 +126,7 @@ def generate_report_from_evidence( # Introduction report_parts.append("## Introduction\n") - report_parts.append( - f"This report addresses the following research query: **{query}**\n" - ) + report_parts.append(f"This report addresses the following research query: **{query}**\n") report_parts.append( "*Note: This report was generated from collected evidence. " "LLM-based synthesis was unavailable due to API limitations.*\n\n" @@ -46,73 +134,8 @@ def generate_report_from_evidence( # Evidence Summary if evidence and len(evidence) > 0: - report_parts.append("## Evidence Summary\n") - report_parts.append( - f"**Total Sources Found:** {len(evidence)}\n\n" - ) - - # Group evidence by source - by_source: dict[str, list["Evidence"]] = {} - for ev in evidence: - source = ev.citation.source - if source not in by_source: - by_source[source] = [] - by_source[source].append(ev) - - # Organize by source - for source in sorted(by_source.keys()): - source_evidence = by_source[source] - report_parts.append(f"### {source.upper()} Sources ({len(source_evidence)})\n\n") - - for i, ev in enumerate(source_evidence, 1): - # Format citation - authors = ", ".join(ev.citation.authors[:3]) - if len(ev.citation.authors) > 3: - authors += " et al." - - report_parts.append(f"#### {i}. {ev.citation.title}\n") - if authors: - report_parts.append(f"**Authors:** {authors} \n") - report_parts.append(f"**Date:** {ev.citation.date} \n") - report_parts.append(f"**Source:** {ev.citation.source.upper()} \n") - report_parts.append(f"**URL:** {ev.citation.url} \n\n") - - # Content (truncated if too long) - content = ev.content - if len(content) > 500: - content = content[:500] + "... [truncated]" - report_parts.append(f"{content}\n\n") - - # Key Findings Section - report_parts.append("## Key Findings\n\n") - report_parts.append( - "Based on the evidence collected, the following key points were identified:\n\n" - ) - - # Extract key points from evidence (first sentence or summary) - key_points: list[str] = [] - for ev in evidence[:10]: # Limit to top 10 - # Try to extract first meaningful sentence - content = ev.content.strip() - if content: - # Find first sentence - first_period = content.find(".") - if first_period > 0 and first_period < 200: - key_point = content[: first_period + 1].strip() - else: - # Fallback: first 150 chars - key_point = content[:150].strip() - if len(content) > 150: - key_point += "..." - key_points.append(f"- {key_point} [[{len(key_points) + 1}]](#references)") - - if key_points: - report_parts.append("\n".join(key_points)) - report_parts.append("\n\n") - else: - report_parts.append( - "*No specific key findings could be extracted from the evidence.*\n\n" - ) + _add_evidence_section(report_parts, evidence) + _add_key_findings(report_parts, evidence) elif findings: # Fallback: use findings string if evidence not available @@ -129,20 +152,7 @@ def generate_report_from_evidence( # References Section if evidence and len(evidence) > 0: - report_parts.append("## References\n\n") - for i, ev in enumerate(evidence, 1): - authors = ", ".join(ev.citation.authors[:3]) - if len(ev.citation.authors) > 3: - authors += " et al." - elif not authors: - authors = "Unknown" - - report_parts.append( - f"[{i}] {authors} ({ev.citation.date}). " - f"*{ev.citation.title}*. " - f"{ev.citation.source.upper()}. " - f"Available at: {ev.citation.url}\n\n" - ) + _add_references(report_parts, evidence) # Conclusion report_parts.append("## Conclusion\n\n") @@ -167,13 +177,3 @@ def generate_report_from_evidence( ) return "".join(report_parts) - - - - - - - - - - diff --git a/tests/conftest.py b/tests/conftest.py index 6d942f2b..0d44787e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,7 @@ """Shared pytest fixtures for all tests.""" import os -from unittest.mock import AsyncMock +from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -78,4 +78,63 @@ def default_to_huggingface(monkeypatch): # Set a dummy HF_TOKEN if not set (prevents errors, but tests should mock actual API calls) if "HF_TOKEN" not in os.environ: - monkeypatch.setenv("HF_TOKEN", "dummy_token_for_testing") \ No newline at end of file + monkeypatch.setenv("HF_TOKEN", "dummy_token_for_testing") + + # Unset OpenAI/Anthropic keys to prevent fallback (unless explicitly set for specific tests) + # This ensures get_model() uses HuggingFace + if "OPENAI_API_KEY" in os.environ and os.environ.get("OPENAI_API_KEY"): + # Only unset if it's not explicitly needed (tests can set it if needed) + monkeypatch.delenv("OPENAI_API_KEY", raising=False) + if "ANTHROPIC_API_KEY" in os.environ and os.environ.get("ANTHROPIC_API_KEY"): + monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False) + + +@pytest.fixture +def mock_hf_model(): + """Create a real HuggingFace model instance for testing. + + This fixture provides a real HuggingFaceModel instance that uses + the HF API/client. The InferenceClient is mocked to prevent real API calls. + """ + from pydantic_ai.models.huggingface import HuggingFaceModel + from pydantic_ai.providers.huggingface import HuggingFaceProvider + + # Create a real HuggingFace model with dummy token + # The InferenceClient will be mocked by auto_mock_hf_inference_client + provider = HuggingFaceProvider(api_key="dummy_token_for_testing") + model = HuggingFaceModel("meta-llama/Llama-3.1-8B-Instruct", provider=provider) + return model + + +@pytest.fixture(autouse=True) +def auto_mock_hf_inference_client(request): + """Automatically mock HuggingFace InferenceClient to prevent real API calls. + + This fixture runs automatically for all tests (except OpenAI tests) and + mocks the InferenceClient so tests use HuggingFace models but don't make + real API calls. This allows tests to use the actual HF model/client setup + without requiring API keys or making network requests. + + Tests marked with @pytest.mark.openai will skip this fixture. + Tests can override by explicitly patching InferenceClient themselves. + """ + # Skip auto-mocking for OpenAI tests + if "openai" in request.keywords: + return + + # Mock InferenceClient to prevent real API calls + # This allows get_model() to create real HuggingFaceModel instances + # but prevents actual network requests + mock_client = MagicMock() + mock_client.text_generation = AsyncMock(return_value="Mocked response") + mock_client.chat_completion = AsyncMock(return_value={"choices": [{"message": {"content": "Mocked response"}}]}) + + # Patch InferenceClient at its source (huggingface_hub) + # This will affect all imports of InferenceClient, including in pydantic_ai + inference_client_patch = patch("huggingface_hub.InferenceClient", return_value=mock_client) + inference_client_patch.start() + + yield + + # Stop patch + inference_client_patch.stop() \ No newline at end of file diff --git a/tests/integration/test_rag_integration.py b/tests/integration/test_rag_integration.py index 38d3f6ec..56f9339c 100644 --- a/tests/integration/test_rag_integration.py +++ b/tests/integration/test_rag_integration.py @@ -8,6 +8,31 @@ import pytest +# Skip if sentence_transformers cannot be imported +# Note: sentence-transformers is a required dependency, but may fail due to: +# - Windows regex circular import bug +# - PyTorch C extensions not loading properly +try: + pytest.importorskip("sentence_transformers", exc_type=ImportError) +except (ImportError, OSError) as e: + # Handle various import issues + error_msg = str(e).lower() + if "regex" in error_msg or "_regex" in error_msg: + pytest.skip( + "sentence_transformers import failed due to Windows regex circular import bug. " + "This is a known issue with the regex package on Windows. " + "Try: uv pip install --upgrade --force-reinstall regex", + allow_module_level=True, + ) + elif "pytorch" in error_msg or "torch" in error_msg: + pytest.skip( + "sentence_transformers import failed due to PyTorch C extensions issue. " + "Try: uv pip install --upgrade --force-reinstall torch", + allow_module_level=True, + ) + # Re-raise other import errors + raise + from src.services.llamaindex_rag import get_rag_service from src.tools.rag_tool import create_rag_tool from src.tools.search_handler import SearchHandler diff --git a/tests/integration/test_rag_integration_hf.py b/tests/integration/test_rag_integration_hf.py index cee61ba7..9c58fb1d 100644 --- a/tests/integration/test_rag_integration_hf.py +++ b/tests/integration/test_rag_integration_hf.py @@ -6,6 +6,31 @@ import pytest +# Skip if sentence_transformers cannot be imported +# Note: sentence-transformers is a required dependency, but may fail due to: +# - Windows regex circular import bug +# - PyTorch C extensions not loading properly +try: + pytest.importorskip("sentence_transformers", exc_type=ImportError) +except (ImportError, OSError) as e: + # Handle various import issues + error_msg = str(e).lower() + if "regex" in error_msg or "_regex" in error_msg: + pytest.skip( + "sentence_transformers import failed due to Windows regex circular import bug. " + "This is a known issue with the regex package on Windows. " + "Try: uv pip install --upgrade --force-reinstall regex", + allow_module_level=True, + ) + elif "pytorch" in error_msg or "torch" in error_msg: + pytest.skip( + "sentence_transformers import failed due to PyTorch C extensions issue. " + "Try: uv pip install --upgrade --force-reinstall torch", + allow_module_level=True, + ) + # Re-raise other import errors + raise + from src.services.llamaindex_rag import get_rag_service from src.tools.rag_tool import create_rag_tool from src.tools.search_handler import SearchHandler diff --git a/tests/unit/agent_factory/test_judges.py b/tests/unit/agent_factory/test_judges.py index c2075cda..90551d99 100644 --- a/tests/unit/agent_factory/test_judges.py +++ b/tests/unit/agent_factory/test_judges.py @@ -142,6 +142,135 @@ async def test_assess_handles_llm_failure(self): assert result.recommendation == "continue" assert "failed" in result.reasoning.lower() + @pytest.mark.asyncio + async def test_assess_handles_403_error(self): + """JudgeHandler should handle 403 Forbidden errors with error extraction.""" + error_msg = "status_code: 403, model_name: Qwen/Qwen3-Next-80B-A3B-Thinking, body: Forbidden" + + with ( + patch("src.agent_factory.judges.get_model") as mock_get_model, + patch("src.agent_factory.judges.Agent") as mock_agent_class, + patch("src.agent_factory.judges.logger") as mock_logger, + ): + mock_get_model.return_value = MagicMock() + mock_agent = AsyncMock() + mock_agent.run = AsyncMock(side_effect=Exception(error_msg)) + mock_agent_class.return_value = mock_agent + + handler = JudgeHandler() + handler.agent = mock_agent + + evidence = [ + Evidence( + content="Some content", + citation=Citation( + source="pubmed", + title="Title", + url="url", + date="2024", + ), + ) + ] + + result = await handler.assess("test question", evidence) + + # Should return fallback + assert result.sufficient is False + assert result.recommendation == "continue" + + # Should log error details + error_calls = [call for call in mock_logger.error.call_args_list if "Assessment failed" in str(call)] + assert len(error_calls) > 0 + + @pytest.mark.asyncio + async def test_assess_handles_422_error(self): + """JudgeHandler should handle 422 Unprocessable Entity errors.""" + error_msg = "status_code: 422, model_name: meta-llama/Llama-3.1-70B-Instruct, body: Unprocessable Entity" + + with ( + patch("src.agent_factory.judges.get_model") as mock_get_model, + patch("src.agent_factory.judges.Agent") as mock_agent_class, + patch("src.agent_factory.judges.logger") as mock_logger, + ): + mock_get_model.return_value = MagicMock() + mock_agent = AsyncMock() + mock_agent.run = AsyncMock(side_effect=Exception(error_msg)) + mock_agent_class.return_value = mock_agent + + handler = JudgeHandler() + handler.agent = mock_agent + + evidence = [ + Evidence( + content="Some content", + citation=Citation( + source="pubmed", + title="Title", + url="url", + date="2024", + ), + ) + ] + + result = await handler.assess("test question", evidence) + + # Should return fallback + assert result.sufficient is False + + # Should log warning with user-friendly message + warning_calls = [call for call in mock_logger.warning.call_args_list if "API error details" in str(call)] + assert len(warning_calls) > 0 + + +class TestGetModel: + """Tests for get_model function with token validation.""" + + @patch("src.agent_factory.judges.settings") + @patch("src.utils.hf_error_handler.log_token_info") + @patch("src.utils.hf_error_handler.validate_hf_token") + def test_get_model_validates_oauth_token(self, mock_validate, mock_log, mock_settings): + """Should validate and log OAuth token when provided.""" + mock_settings.hf_token = None + mock_settings.huggingface_api_key = None + mock_settings.huggingface_model = "test-model" + mock_validate.return_value = (True, None) + + with patch("src.agent_factory.judges.HuggingFaceProvider"), \ + patch("src.agent_factory.judges.HuggingFaceModel") as mock_model_class: + mock_model_class.return_value = MagicMock() + + from src.agent_factory.judges import get_model + + get_model(oauth_token="hf_test_token") + + # Should log token info + mock_log.assert_called_once_with("hf_test_token", context="get_model") + # Should validate token + mock_validate.assert_called_once_with("hf_test_token") + + @patch("src.agent_factory.judges.settings") + @patch("src.utils.hf_error_handler.log_token_info") + @patch("src.utils.hf_error_handler.validate_hf_token") + @patch("src.agent_factory.judges.logger") + def test_get_model_warns_on_invalid_token(self, mock_logger, mock_validate, mock_log, mock_settings): + """Should warn when token validation fails.""" + mock_settings.hf_token = None + mock_settings.huggingface_api_key = None + mock_settings.huggingface_model = "test-model" + mock_validate.return_value = (False, "Token too short") + + with patch("src.agent_factory.judges.HuggingFaceProvider"), \ + patch("src.agent_factory.judges.HuggingFaceModel") as mock_model_class: + mock_model_class.return_value = MagicMock() + + from src.agent_factory.judges import get_model + + get_model(oauth_token="short") + + # Should warn about invalid token + warning_calls = [call for call in mock_logger.warning.call_args_list if "Token validation failed" in str(call)] + assert len(warning_calls) > 0 + class TestMockJudgeHandler: """Tests for MockJudgeHandler.""" diff --git a/tests/unit/agent_factory/test_judges_factory.py b/tests/unit/agent_factory/test_judges_factory.py index 3cc7e331..037ef87e 100644 --- a/tests/unit/agent_factory/test_judges_factory.py +++ b/tests/unit/agent_factory/test_judges_factory.py @@ -21,8 +21,13 @@ def mock_settings(): yield mock_settings +@pytest.mark.openai def test_get_model_openai(mock_settings): - """Test that OpenAI model is returned when provider is openai.""" + """Test that OpenAI model is returned when provider is openai. + + This test instantiates an OpenAI model and requires OpenAI API key. + Marked with @pytest.mark.openai to exclude from pre-commit and CI. + """ mock_settings.llm_provider = "openai" mock_settings.openai_api_key = "sk-test" mock_settings.openai_model = "gpt-5.1" @@ -37,6 +42,11 @@ def test_get_model_anthropic(mock_settings): mock_settings.llm_provider = "anthropic" mock_settings.anthropic_api_key = "sk-ant-test" mock_settings.anthropic_model = "claude-sonnet-4-5-20250929" + # Ensure no HF token is set, otherwise get_model() will prefer HuggingFace + mock_settings.hf_token = None + mock_settings.huggingface_api_key = None + mock_settings.has_openai_key = False + mock_settings.has_anthropic_key = True model = get_model() assert isinstance(model, AnthropicModel) diff --git a/tests/unit/agents/test_audio_refiner.py b/tests/unit/agents/test_audio_refiner.py new file mode 100644 index 00000000..162241ab --- /dev/null +++ b/tests/unit/agents/test_audio_refiner.py @@ -0,0 +1,306 @@ +"""Unit tests for AudioRefiner agent.""" + +import pytest +from unittest.mock import AsyncMock, Mock, patch + +from src.agents.audio_refiner import AudioRefiner, refine_text_for_audio + + +class TestAudioRefiner: + """Test suite for AudioRefiner functionality.""" + + @pytest.fixture + def refiner(self): + """Create AudioRefiner instance.""" + return AudioRefiner() + + def test_remove_markdown_headers(self, refiner): + """Test removal of markdown headers.""" + text = """# Main Title +## Subtitle +### Section +Content here""" + result = refiner._remove_markdown_syntax(text) + assert "#" not in result + assert "Main Title" in result + assert "Subtitle" in result + + def test_remove_bold_italic(self, refiner): + """Test removal of bold and italic formatting.""" + text = "**Bold text** and *italic text* and __another bold__" + result = refiner._remove_markdown_syntax(text) + assert "**" not in result + assert "*" not in result + assert "__" not in result + assert "Bold text" in result + assert "italic text" in result + + def test_remove_links(self, refiner): + """Test removal of markdown links.""" + text = "Check [this link](https://example.com) for details" + result = refiner._remove_markdown_syntax(text) + assert "[" not in result + assert "]" not in result + assert "https://" not in result + assert "this link" in result + + def test_remove_citations_numbered(self, refiner): + """Test removal of numbered citations.""" + text = "Research shows [1] that metformin [2,3] works [4-6]." + result = refiner._remove_citations(text) + assert "[1]" not in result + assert "[2,3]" not in result + assert "[4-6]" not in result + assert "Research shows" in result + + def test_remove_citations_author_year(self, refiner): + """Test removal of author-year citations.""" + text = "Studies (Smith et al., 2023) and (Jones, 2022) confirm this." + result = refiner._remove_citations(text) + assert "(Smith et al., 2023)" not in result + assert "(Jones, 2022)" not in result + assert "Studies" in result + assert "confirm this" in result + + def test_remove_first_references_section(self, refiner): + """Test that References sections are removed while preserving other content.""" + text = """Main content here. + +# References +[1] First reference +[2] Second reference + +# More Content +This should remain. + +## References +This second References should also be removed.""" + + result = refiner._remove_references_sections(text) + assert "Main content here" in result + assert "References" not in result + assert "First reference" not in result + assert "More Content" in result # Content after References should be preserved + assert "This should remain" in result + assert "second References should also be removed" not in result # Second References section removed + + def test_roman_to_int_conversion(self, refiner): + """Test roman numeral to integer conversion.""" + assert refiner._roman_to_int("I") == 1 + assert refiner._roman_to_int("II") == 2 + assert refiner._roman_to_int("III") == 3 + assert refiner._roman_to_int("IV") == 4 + assert refiner._roman_to_int("V") == 5 + assert refiner._roman_to_int("IX") == 9 + assert refiner._roman_to_int("X") == 10 + assert refiner._roman_to_int("XII") == 12 + assert refiner._roman_to_int("XX") == 20 + + def test_int_to_word_conversion(self, refiner): + """Test integer to word conversion.""" + assert refiner._int_to_word(1) == "One" + assert refiner._int_to_word(2) == "Two" + assert refiner._int_to_word(3) == "Three" + assert refiner._int_to_word(10) == "Ten" + assert refiner._int_to_word(20) == "Twenty" + assert refiner._int_to_word(25) == "25" # Falls back to digit + + def test_convert_roman_numerals_with_context(self, refiner): + """Test roman numeral conversion with context words.""" + test_cases = [ + ("Phase I trial", "Phase One trial"), + ("Phase II study", "Phase Two study"), + ("Phase III data", "Phase Three data"), + ("Type I diabetes", "Type One diabetes"), + ("Type II error", "Type Two error"), + ("Stage IV cancer", "Stage Four cancer"), + ("Trial I results", "Trial One results"), + ] + + for input_text, expected in test_cases: + result = refiner._convert_roman_numerals(input_text) + assert expected in result, f"Failed for: {input_text}" + + def test_convert_standalone_roman_numerals(self, refiner): + """Test standalone roman numeral conversion.""" + text = "Results for I, II, and III are positive." + result = refiner._convert_roman_numerals(text) + # Standalone roman numerals should be converted + assert "One" in result or "Two" in result or "Three" in result + + def test_dont_convert_roman_in_words(self, refiner): + """Test that roman numerals inside words aren't converted.""" + text = "INVALID data fromIXIN compound" + result = refiner._convert_roman_numerals(text) + # Should not break words containing I, V, X, etc. + assert "INVALID" in result or "Invalid" in result # May be case-normalized + + def test_clean_special_characters(self, refiner): + """Test special character cleanup.""" + # Using unicode escapes to avoid syntax issues + text = "Text with \u2014 em-dash and \u201csmart quotes\u201d and \u2018apostrophes\u2019." + result = refiner._clean_special_characters(text) + assert "\u2014" not in result # em-dash + assert "\u201c" not in result # smart quote open + assert "\u2018" not in result # smart apostrophe + assert "-" in result + + def test_normalize_whitespace(self, refiner): + """Test whitespace normalization.""" + text = "Text with multiple spaces\n\n\n\nand many newlines" + result = refiner._normalize_whitespace(text) + assert " " not in result # No double spaces + assert "\n\n\n" not in result # Max two newlines + + async def test_full_refine_workflow(self, refiner): + """Test complete refinement workflow.""" + markdown_text = """# Summary + +**Metformin** shows promise for *long COVID* treatment [1]. + +## Phase I Trials + +Research (Smith et al., 2023) indicates [2,3]: +- 50% improvement +- Low side effects + +Check [this study](https://example.com) for details. + +# References +[1] Smith, J. et al. (2023) +[2] Jones, K. (2022) +""" + + result = await refiner.refine_for_audio(markdown_text) + + # Check markdown removed + assert "#" not in result + assert "**" not in result + assert "*" not in result + + # Check citations removed + assert "[1]" not in result + assert "(Smith et al., 2023)" not in result + + # Check roman numerals converted + assert "Phase One" in result + + # Check references section removed + assert "References" not in result + assert "Smith, J. et al." not in result + + # Check content preserved + assert "Metformin" in result + assert "long COVID" in result + + async def test_convenience_function(self): + """Test convenience function.""" + text = "**Bold** text with [link](url)" + result = await refine_text_for_audio(text) + assert "**" not in result + assert "[link]" not in result + assert "Bold" in result + + async def test_empty_text(self, refiner): + """Test handling of empty text.""" + assert await refiner.refine_for_audio("") == "" + assert await refiner.refine_for_audio(" ") == "" + + async def test_no_references_section(self, refiner): + """Test text without References section.""" + text = "Main content without references." + result = await refiner.refine_for_audio(text) + assert "Main content without references" in result + + def test_multiple_reference_formats(self, refiner): + """Test different References section formats.""" + formats = [ + ("# References\nContent", True), # Markdown header - will be removed + ("## References\nContent", True), # Markdown header - will be removed + ("**References**\nContent", True), # Bold heading - will be removed + ("References:\nContent", False), # Standalone without markers - NOT removed (edge case) + ] + + for format_text, should_remove in formats: + text = f"Main content\n{format_text}" + result = refiner._remove_references_sections(text) + assert "Main content" in result + if should_remove: + assert "References" not in result or result.count("References") == 0 + # Standalone "References:" without markers is an edge case we don't handle + + def test_preserve_paragraph_structure(self, refiner): + """Test that paragraph structure is preserved.""" + text = "First paragraph.\n\nSecond paragraph.\n\nThird paragraph." + + result = refiner._normalize_whitespace(text) + # Should have paragraph breaks (double newlines) + assert "\n\n" in result + # But not excessive newlines + assert "\n\n\n" not in result + + @patch('src.agents.audio_refiner.get_pydantic_ai_model') + async def test_llm_polish_disabled_by_default(self, mock_get_model, refiner): + """Test that LLM polish is not called by default.""" + text = "Test text" + result = await refiner.refine_for_audio(text, use_llm_polish=False) + + # LLM should not be called when disabled + mock_get_model.assert_not_called() + assert "Test text" in result + + @patch('src.agents.audio_refiner.Agent') + @patch('src.agents.audio_refiner.get_pydantic_ai_model') + async def test_llm_polish_enabled(self, mock_get_model, mock_agent_class, refiner): + """Test that LLM polish is called when enabled.""" + # Setup mock + mock_model = Mock() + mock_get_model.return_value = mock_model + + mock_agent_instance = Mock() + mock_result = Mock() + mock_result.output = "Polished text" + mock_agent_instance.run = AsyncMock(return_value=mock_result) + mock_agent_class.return_value = mock_agent_instance + + # Test with LLM polish enabled + text = "**Test** text" + result = await refiner.refine_for_audio(text, use_llm_polish=True) + + # Verify LLM was called + mock_get_model.assert_called_once() + mock_agent_class.assert_called_once() + mock_agent_instance.run.assert_called_once() + + assert result == "Polished text" + + @patch('src.agents.audio_refiner.Agent') + @patch('src.agents.audio_refiner.get_pydantic_ai_model') + async def test_llm_polish_graceful_fallback(self, mock_get_model, mock_agent_class, refiner): + """Test graceful fallback when LLM polish fails.""" + # Setup mock to raise exception + mock_get_model.return_value = Mock() + mock_agent_instance = Mock() + mock_agent_instance.run = AsyncMock(side_effect=Exception("API Error")) + mock_agent_class.return_value = mock_agent_instance + + # Test with LLM polish enabled but failing + text = "Test text" + result = await refiner.refine_for_audio(text, use_llm_polish=True) + + # Should fall back to rule-based output + assert "Test text" in result + assert result != "" # Should not be empty + + async def test_convenience_function_with_llm_polish(self): + """Test convenience function with LLM polish parameter.""" + with patch.object(AudioRefiner, 'refine_for_audio') as mock_refine: + mock_refine.return_value = AsyncMock(return_value="Refined text")() + + # Test without LLM polish + result = await refine_text_for_audio("Test", use_llm_polish=False) + mock_refine.assert_called_with("Test", use_llm_polish=False) + + # Test with LLM polish + result = await refine_text_for_audio("Test", use_llm_polish=True) + mock_refine.assert_called_with("Test", use_llm_polish=True) diff --git a/tests/unit/agents/test_input_parser.py b/tests/unit/agents/test_input_parser.py index 95324026..f0c1075b 100644 --- a/tests/unit/agents/test_input_parser.py +++ b/tests/unit/agents/test_input_parser.py @@ -11,11 +11,13 @@ @pytest.fixture -def mock_model() -> MagicMock: - """Create a mock Pydantic AI model.""" - model = MagicMock() - model.name = "test-model" - return model +def mock_model(mock_hf_model): + """Create a HuggingFace model for testing. + + Uses the mock_hf_model from conftest which is a real HuggingFaceModel + instance with mocked InferenceClient to prevent real API calls. + """ + return mock_hf_model @pytest.fixture diff --git a/tests/unit/middleware/test_budget_tracker_phase7.py b/tests/unit/middleware/test_budget_tracker_phase7.py index 903addc1..3338ad03 100644 --- a/tests/unit/middleware/test_budget_tracker_phase7.py +++ b/tests/unit/middleware/test_budget_tracker_phase7.py @@ -159,4 +159,3 @@ def test_iteration_tokens_separate_per_loop(self) -> None: assert budget2.iteration_tokens[1] == 200 - diff --git a/tests/unit/middleware/test_state_machine.py b/tests/unit/middleware/test_state_machine.py index 90fc3a4d..73a6008d 100644 --- a/tests/unit/middleware/test_state_machine.py +++ b/tests/unit/middleware/test_state_machine.py @@ -12,6 +12,14 @@ ) from src.utils.models import Citation, Conversation, Evidence, IterationData +try: + from pydantic_ai import ModelRequest, ModelResponse + from pydantic_ai.messages import TextPart, UserPromptPart + + _PYDANTIC_AI_AVAILABLE = True +except ImportError: + _PYDANTIC_AI_AVAILABLE = False + @pytest.mark.unit class TestWorkflowState: @@ -143,6 +151,49 @@ async def test_search_related_handles_empty_authors(self) -> None: assert len(results) == 1 assert results[0].citation.authors == [] + @pytest.mark.skipif(not _PYDANTIC_AI_AVAILABLE, reason="pydantic_ai not available") + def test_user_message_history_initialization(self) -> None: + """WorkflowState should initialize with empty user_message_history.""" + state = WorkflowState() + assert state.user_message_history == [] + + @pytest.mark.skipif(not _PYDANTIC_AI_AVAILABLE, reason="pydantic_ai not available") + def test_add_user_message(self) -> None: + """add_user_message should add messages to history.""" + state = WorkflowState() + message = ModelRequest(parts=[UserPromptPart(content="Test message")]) + state.add_user_message(message) + assert len(state.user_message_history) == 1 + assert state.user_message_history[0] == message + + @pytest.mark.skipif(not _PYDANTIC_AI_AVAILABLE, reason="pydantic_ai not available") + def test_get_user_history(self) -> None: + """get_user_history should return message history.""" + state = WorkflowState() + for i in range(5): + message = ModelRequest(parts=[UserPromptPart(content=f"Message {i}")]) + state.add_user_message(message) + + # Get all history + all_history = state.get_user_history() + assert len(all_history) == 5 + + # Get limited history + limited = state.get_user_history(max_messages=3) + assert len(limited) == 3 + # Should be most recent messages + assert limited[0].parts[0].content == "Message 2" + + @pytest.mark.skipif(not _PYDANTIC_AI_AVAILABLE, reason="pydantic_ai not available") + def test_init_workflow_state_with_message_history(self) -> None: + """init_workflow_state should accept message_history parameter.""" + messages = [ + ModelRequest(parts=[UserPromptPart(content="Question")]), + ModelResponse(parts=[TextPart(content="Answer")]), + ] + state = init_workflow_state(message_history=messages) + assert len(state.user_message_history) == 2 + @pytest.mark.unit class TestConversation: @@ -354,7 +405,3 @@ def context2(): assert len(state2.evidence) == 1 assert state1.evidence[0].citation.url == "https://example.com/1" assert state2.evidence[0].citation.url == "https://example.com/2" - - - - diff --git a/tests/unit/middleware/test_workflow_manager.py b/tests/unit/middleware/test_workflow_manager.py index 3703390c..42d67658 100644 --- a/tests/unit/middleware/test_workflow_manager.py +++ b/tests/unit/middleware/test_workflow_manager.py @@ -285,4 +285,3 @@ async def test_get_shared_evidence(self, monkeypatch) -> None: assert len(shared) == 1 assert shared[0].content == "Shared" - diff --git a/tests/unit/orchestrator/test_graph_orchestrator.py b/tests/unit/orchestrator/test_graph_orchestrator.py index 4136663f..44732bf7 100644 --- a/tests/unit/orchestrator/test_graph_orchestrator.py +++ b/tests/unit/orchestrator/test_graph_orchestrator.py @@ -48,7 +48,86 @@ def test_visited_nodes_tracking(self): context = GraphExecutionContext(WorkflowState(), BudgetTracker()) assert not context.has_visited("node1") context.mark_visited("node1") - assert context.has_visited("node1") + + def test_message_history_initialization(self): + """Test message history initialization in context.""" + from src.middleware.budget_tracker import BudgetTracker + from src.middleware.state_machine import WorkflowState + + context = GraphExecutionContext(WorkflowState(), BudgetTracker()) + assert context.message_history == [] + + def test_message_history_with_initial_history(self): + """Test context with initial message history.""" + from src.middleware.budget_tracker import BudgetTracker + from src.middleware.state_machine import WorkflowState + + try: + from pydantic_ai import ModelRequest + from pydantic_ai.messages import UserPromptPart + + messages = [ + ModelRequest(parts=[UserPromptPart(content="Test message")]) + ] + context = GraphExecutionContext( + WorkflowState(), BudgetTracker(), message_history=messages + ) + assert len(context.message_history) == 1 + except ImportError: + pytest.skip("pydantic_ai not available") + + def test_add_message(self): + """Test adding messages to context.""" + from src.middleware.budget_tracker import BudgetTracker + from src.middleware.state_machine import WorkflowState + + try: + from pydantic_ai import ModelRequest, ModelResponse + from pydantic_ai.messages import TextPart, UserPromptPart + + context = GraphExecutionContext(WorkflowState(), BudgetTracker()) + message1 = ModelRequest(parts=[UserPromptPart(content="Question")]) + message2 = ModelResponse(parts=[TextPart(content="Answer")]) + + context.add_message(message1) + context.add_message(message2) + + assert len(context.message_history) == 2 + except ImportError: + pytest.skip("pydantic_ai not available") + + def test_get_message_history(self): + """Test getting message history with limits.""" + from src.middleware.budget_tracker import BudgetTracker + from src.middleware.state_machine import WorkflowState + + try: + from pydantic_ai import ModelRequest + from pydantic_ai.messages import UserPromptPart + + messages = [ + ModelRequest(parts=[UserPromptPart(content=f"Message {i}")]) + for i in range(10) + ] + context = GraphExecutionContext( + WorkflowState(), BudgetTracker(), message_history=messages + ) + + # Get all + all_messages = context.get_message_history() + assert len(all_messages) == 10 + + # Get limited + limited = context.get_message_history(max_messages=5) + assert len(limited) == 5 + # Should be most recent + assert limited[0].parts[0].content == "Message 5" + + # Visit a node to test has_visited + context.visited_nodes.add("node1") + assert context.has_visited("node1") + except ImportError: + pytest.skip("pydantic_ai not available") class TestGraphOrchestrator: @@ -177,7 +256,7 @@ async def mock_build_graph(mode: str): orchestrator._build_graph = mock_build_graph # Mock the graph execution - async def mock_run_with_graph(query: str, mode: str): + async def mock_run_with_graph(query: str, research_mode: str, message_history: list | None = None): yield AgentEvent(type="started", message="Starting", iteration=0) yield AgentEvent(type="looping", message="Processing", iteration=1) yield AgentEvent(type="complete", message="# Final Report\n\nContent", iteration=1) diff --git a/tests/unit/services/test_embeddings.py b/tests/unit/services/test_embeddings.py index f7e59e9b..aaa0c99e 100644 --- a/tests/unit/services/test_embeddings.py +++ b/tests/unit/services/test_embeddings.py @@ -6,15 +6,16 @@ import pytest # Skip if embeddings dependencies are not installed -# Handle Windows-specific scipy import issues +# Handle Windows-specific scipy import issues and PyTorch C extensions issues try: pytest.importorskip("chromadb") - pytest.importorskip("sentence_transformers") -except OSError: + pytest.importorskip("sentence_transformers", exc_type=ImportError) +except (OSError, ImportError): # On Windows, scipy import can fail with OSError during collection + # PyTorch C extensions can also fail to load # Skip the entire test module in this case pytest.skip( - "Embeddings dependencies not available (scipy import issue)", allow_module_level=True + "Embeddings dependencies not available (scipy/PyTorch import issue)", allow_module_level=True ) from src.services.embeddings import EmbeddingService diff --git a/tests/unit/test_app_oauth.py b/tests/unit/test_app_oauth.py new file mode 100644 index 00000000..2a906a92 --- /dev/null +++ b/tests/unit/test_app_oauth.py @@ -0,0 +1,260 @@ +"""Unit tests for OAuth-related functions in app.py.""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +# Mock gradio and its dependencies before importing app functions +import sys +from unittest.mock import MagicMock as Mock + +# Create a comprehensive gradio mock +mock_gradio = Mock() +mock_gradio.OAuthToken = Mock +mock_gradio.OAuthProfile = Mock +mock_gradio.Request = Mock +mock_gradio.update = Mock(return_value={"choices": [], "value": ""}) +mock_gradio.Blocks = Mock +mock_gradio.Markdown = Mock +mock_gradio.LoginButton = Mock +mock_gradio.Dropdown = Mock +mock_gradio.Textbox = Mock +mock_gradio.State = Mock + +# Mock gradio.data_classes +mock_gradio.data_classes = Mock() +mock_gradio.data_classes.FileData = Mock() + +sys.modules["gradio"] = mock_gradio +sys.modules["gradio.data_classes"] = mock_gradio.data_classes + +# Mock other dependencies that might be imported +sys.modules["neo4j"] = Mock() +sys.modules["neo4j"].GraphDatabase = Mock() + +from src.app import extract_oauth_info, update_model_provider_dropdowns + + +class TestExtractOAuthInfo: + """Tests for extract_oauth_info function.""" + + def test_extract_from_oauth_token_attribute(self) -> None: + """Should extract token from request.oauth_token.token.""" + mock_request = MagicMock() + mock_oauth_token = MagicMock() + mock_oauth_token.token = "hf_test_token_123" + mock_request.oauth_token = mock_oauth_token + mock_request.username = "testuser" + + token, username = extract_oauth_info(mock_request) + + assert token == "hf_test_token_123" + assert username == "testuser" + + def test_extract_from_string_oauth_token(self) -> None: + """Should extract token when oauth_token is a string.""" + mock_request = MagicMock() + mock_request.oauth_token = "hf_test_token_123" + mock_request.username = "testuser" + + token, username = extract_oauth_info(mock_request) + + assert token == "hf_test_token_123" + assert username == "testuser" + + def test_extract_from_authorization_header(self) -> None: + """Should extract token from Authorization header.""" + mock_request = MagicMock() + mock_request.oauth_token = None + mock_request.headers = {"Authorization": "Bearer hf_test_token_123"} + mock_request.username = "testuser" + + token, username = extract_oauth_info(mock_request) + + assert token == "hf_test_token_123" + assert username == "testuser" + + def test_extract_username_from_oauth_profile(self) -> None: + """Should extract username from oauth_profile.""" + mock_request = MagicMock() + mock_request.oauth_token = None + mock_request.username = None + mock_oauth_profile = MagicMock() + mock_oauth_profile.username = "testuser" + mock_request.oauth_profile = mock_oauth_profile + + token, username = extract_oauth_info(mock_request) + + assert username == "testuser" + + def test_extract_name_from_oauth_profile(self) -> None: + """Should extract name from oauth_profile when username not available.""" + mock_request = MagicMock() + mock_request.oauth_token = None + # Ensure username attribute doesn't exist or is explicitly None + # Use delattr to remove it, then set oauth_profile + if hasattr(mock_request, "username"): + delattr(mock_request, "username") + mock_oauth_profile = MagicMock() + mock_oauth_profile.username = None + mock_oauth_profile.name = "Test User" + mock_request.oauth_profile = mock_oauth_profile + + token, username = extract_oauth_info(mock_request) + + assert username == "Test User" + + def test_extract_none_request(self) -> None: + """Should return None for both when request is None.""" + token, username = extract_oauth_info(None) + + assert token is None + assert username is None + + def test_extract_no_oauth_info(self) -> None: + """Should return None when no OAuth info is available.""" + mock_request = MagicMock() + mock_request.oauth_token = None + mock_request.headers = {} + mock_request.username = None + mock_request.oauth_profile = None + + token, username = extract_oauth_info(mock_request) + + assert token is None + assert username is None + + +class TestUpdateModelProviderDropdowns: + """Tests for update_model_provider_dropdowns function.""" + + @pytest.mark.asyncio + async def test_update_with_valid_token(self) -> None: + """Should update dropdowns with available models and providers.""" + mock_oauth_token = MagicMock() + mock_oauth_token.token = "hf_test_token_123" + mock_oauth_profile = MagicMock() + + mock_validation_result = { + "is_valid": True, + "has_inference_api_scope": True, + "available_models": ["model1", "model2"], + "available_providers": ["auto", "nebius"], + "username": "testuser", + } + + with patch("src.utils.hf_model_validator.validate_oauth_token", return_value=mock_validation_result) as mock_validate, \ + patch("src.utils.hf_model_validator.get_available_models", new_callable=AsyncMock) as mock_get_models, \ + patch("src.utils.hf_model_validator.get_available_providers", new_callable=AsyncMock) as mock_get_providers, \ + patch("src.app.gr") as mock_gr, \ + patch("src.app.logger"): + mock_get_models.return_value = ["model1", "model2"] + mock_get_providers.return_value = ["auto", "nebius"] + mock_gr.update.return_value = {"choices": [], "value": ""} + + result = await update_model_provider_dropdowns(mock_oauth_token, mock_oauth_profile) + + assert len(result) == 3 # model_update, provider_update, status_msg + assert "testuser" in result[2] # Status message should contain username + mock_validate.assert_called_once_with("hf_test_token_123") + + @pytest.mark.asyncio + async def test_update_with_no_token(self) -> None: + """Should return defaults when no token provided.""" + with patch("src.app.gr") as mock_gr: + mock_gr.update.return_value = {"choices": [], "value": ""} + + result = await update_model_provider_dropdowns(None, None) + + assert len(result) == 3 + assert "Not authenticated" in result[2] # Status message + + @pytest.mark.asyncio + async def test_update_with_invalid_token(self) -> None: + """Should return error message for invalid token.""" + mock_oauth_token = MagicMock() + mock_oauth_token.token = "invalid_token" + + mock_validation_result = { + "is_valid": False, + "error": "Invalid token format", + } + + with patch("src.utils.hf_model_validator.validate_oauth_token", return_value=mock_validation_result), \ + patch("src.app.gr") as mock_gr: + mock_gr.update.return_value = {"choices": [], "value": ""} + + result = await update_model_provider_dropdowns(mock_oauth_token, None) + + assert len(result) == 3 + assert "Token validation failed" in result[2] + + @pytest.mark.asyncio + async def test_update_without_inference_scope(self) -> None: + """Should warn when token lacks inference-api scope.""" + mock_oauth_token = MagicMock() + mock_oauth_token.token = "hf_token_without_scope" + + mock_validation_result = { + "is_valid": True, + "has_inference_api_scope": False, + "available_models": [], + "available_providers": ["auto"], + "username": "testuser", + } + + with patch("src.utils.hf_model_validator.validate_oauth_token", return_value=mock_validation_result), \ + patch("src.utils.hf_model_validator.get_available_models", new_callable=AsyncMock) as mock_get_models, \ + patch("src.utils.hf_model_validator.get_available_providers", new_callable=AsyncMock) as mock_get_providers, \ + patch("src.app.gr") as mock_gr, \ + patch("src.app.logger"): + mock_get_models.return_value = [] + mock_get_providers.return_value = ["auto"] + mock_gr.update.return_value = {"choices": [], "value": ""} + + result = await update_model_provider_dropdowns(mock_oauth_token, None) + + assert len(result) == 3 + assert "inference-api" in result[2] and "scope" in result[2] + + @pytest.mark.asyncio + async def test_update_handles_exception(self) -> None: + """Should handle exceptions gracefully.""" + mock_oauth_token = MagicMock() + mock_oauth_token.token = "hf_test_token" + + with patch("src.utils.hf_model_validator.validate_oauth_token", side_effect=Exception("API error")), \ + patch("src.app.gr") as mock_gr, \ + patch("src.app.logger"): + mock_gr.update.return_value = {"choices": [], "value": ""} + + result = await update_model_provider_dropdowns(mock_oauth_token, None) + + assert len(result) == 3 + assert "Failed to load models" in result[2] + + @pytest.mark.asyncio + async def test_update_with_string_token(self) -> None: + """Should handle string token (edge case).""" + # Edge case: oauth_token is already a string + with patch("src.utils.hf_model_validator.validate_oauth_token") as mock_validate, \ + patch("src.utils.hf_model_validator.get_available_models", new_callable=AsyncMock), \ + patch("src.utils.hf_model_validator.get_available_providers", new_callable=AsyncMock), \ + patch("src.app.gr") as mock_gr, \ + patch("src.app.logger"): + mock_validation_result = { + "is_valid": True, + "has_inference_api_scope": True, + "available_models": ["model1"], + "available_providers": ["auto"], + "username": "testuser", + } + mock_validate.return_value = mock_validation_result + mock_gr.update.return_value = {"choices": [], "value": ""} + + # Pass string directly (shouldn't happen but defensive) + result = await update_model_provider_dropdowns("hf_string_token", None) + + # Should still work (extracts as string) + assert len(result) == 3 + diff --git a/tests/unit/tools/test_web_search.py b/tests/unit/tools/test_web_search.py new file mode 100644 index 00000000..d0a49ef8 --- /dev/null +++ b/tests/unit/tools/test_web_search.py @@ -0,0 +1,195 @@ +"""Unit tests for WebSearchTool.""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +# Mock dependencies before importing to avoid import errors +import sys +sys.modules["neo4j"] = MagicMock() +sys.modules["neo4j"].GraphDatabase = MagicMock() + +# Mock ddgs/duckduckgo_search +# Create a proper mock structure to avoid "ddgs.ddgs" import errors +mock_ddgs_module = MagicMock() +mock_ddgs_submodule = MagicMock() +# Create a mock DDGS class that can be instantiated +class MockDDGS: + def __init__(self, *args, **kwargs): + pass + def text(self, *args, **kwargs): + return [] + +mock_ddgs_submodule.DDGS = MockDDGS +mock_ddgs_module.ddgs = mock_ddgs_submodule +mock_ddgs_module.DDGS = MockDDGS +sys.modules["ddgs"] = mock_ddgs_module +sys.modules["ddgs.ddgs"] = mock_ddgs_submodule +sys.modules["duckduckgo_search"] = MagicMock() +sys.modules["duckduckgo_search"].DDGS = MockDDGS + +from src.tools.web_search import WebSearchTool +from src.utils.exceptions import SearchError +from src.utils.models import Citation, Evidence + + +class TestWebSearchTool: + """Tests for WebSearchTool class.""" + + @pytest.fixture + def web_search_tool(self) -> WebSearchTool: + """Create a WebSearchTool instance.""" + return WebSearchTool() + + def test_name_property(self, web_search_tool: WebSearchTool) -> None: + """Should return correct tool name.""" + assert web_search_tool.name == "duckduckgo" + + @pytest.mark.asyncio + async def test_search_success(self, web_search_tool: WebSearchTool) -> None: + """Should return evidence from successful search.""" + mock_results = [ + { + "title": "Test Title 1", + "body": "Test content 1", + "href": "https://example.com/1", + }, + { + "title": "Test Title 2", + "body": "Test content 2", + "href": "https://example.com/2", + }, + ] + + with patch.object(web_search_tool._ddgs, "text", return_value=iter(mock_results)): + with patch("src.tools.web_search.preprocess_query", return_value="clean query"): + evidence = await web_search_tool.search("test query", max_results=10) + + assert len(evidence) == 2 + assert isinstance(evidence[0], Evidence) + assert evidence[0].citation.title == "Test Title 1" + assert evidence[0].citation.url == "https://example.com/1" + assert evidence[0].citation.source == "web" + + @pytest.mark.asyncio + async def test_search_title_truncation_500_chars(self, web_search_tool: WebSearchTool) -> None: + """Should truncate titles longer than 500 characters.""" + long_title = "A" * 600 # 600 characters + mock_results = [ + { + "title": long_title, + "body": "Test content", + "href": "https://example.com/1", + }, + ] + + with patch.object(web_search_tool._ddgs, "text", return_value=iter(mock_results)): + with patch("src.tools.web_search.preprocess_query", return_value="clean query"): + evidence = await web_search_tool.search("test query", max_results=10) + + assert len(evidence) == 1 + truncated_title = evidence[0].citation.title + assert len(truncated_title) == 500 # Should be exactly 500 chars + assert truncated_title.endswith("...") # Should end with ellipsis + assert truncated_title.startswith("A") # Should start with original content + + @pytest.mark.asyncio + async def test_search_title_truncation_exactly_500_chars(self, web_search_tool: WebSearchTool) -> None: + """Should not truncate titles that are exactly 500 characters.""" + exact_title = "A" * 500 # Exactly 500 characters + mock_results = [ + { + "title": exact_title, + "body": "Test content", + "href": "https://example.com/1", + }, + ] + + with patch.object(web_search_tool._ddgs, "text", return_value=iter(mock_results)): + with patch("src.tools.web_search.preprocess_query", return_value="clean query"): + evidence = await web_search_tool.search("test query", max_results=10) + + assert len(evidence) == 1 + title = evidence[0].citation.title + assert len(title) == 500 + assert title == exact_title # Should be unchanged + + @pytest.mark.asyncio + async def test_search_title_truncation_501_chars(self, web_search_tool: WebSearchTool) -> None: + """Should truncate titles that are 501 characters.""" + long_title = "A" * 501 # 501 characters + mock_results = [ + { + "title": long_title, + "body": "Test content", + "href": "https://example.com/1", + }, + ] + + with patch.object(web_search_tool._ddgs, "text", return_value=iter(mock_results)): + with patch("src.tools.web_search.preprocess_query", return_value="clean query"): + evidence = await web_search_tool.search("test query", max_results=10) + + assert len(evidence) == 1 + truncated_title = evidence[0].citation.title + assert len(truncated_title) == 500 + assert truncated_title.endswith("...") + + @pytest.mark.asyncio + async def test_search_missing_title(self, web_search_tool: WebSearchTool) -> None: + """Should handle missing title gracefully.""" + mock_results = [ + { + "body": "Test content", + "href": "https://example.com/1", + }, + ] + + with patch.object(web_search_tool._ddgs, "text", return_value=iter(mock_results)): + with patch("src.tools.web_search.preprocess_query", return_value="clean query"): + evidence = await web_search_tool.search("test query", max_results=10) + + assert len(evidence) == 1 + assert evidence[0].citation.title == "No Title" + + @pytest.mark.asyncio + async def test_search_empty_results(self, web_search_tool: WebSearchTool) -> None: + """Should return empty list for no results.""" + with patch.object(web_search_tool._ddgs, "text", return_value=iter([])): + with patch("src.tools.web_search.preprocess_query", return_value="clean query"): + evidence = await web_search_tool.search("test query", max_results=10) + + assert evidence == [] + + @pytest.mark.asyncio + async def test_search_raises_search_error(self, web_search_tool: WebSearchTool) -> None: + """Should raise SearchError on exception.""" + with patch.object(web_search_tool._ddgs, "text", side_effect=Exception("API error")): + with patch("src.tools.web_search.preprocess_query", return_value="clean query"): + with pytest.raises(SearchError, match="DuckDuckGo search failed"): + await web_search_tool.search("test query", max_results=10) + + @pytest.mark.asyncio + async def test_search_uses_preprocessed_query(self, web_search_tool: WebSearchTool) -> None: + """Should use preprocessed query for search.""" + mock_results = [{"title": "Test", "body": "Content", "href": "https://example.com"}] + + with patch.object(web_search_tool._ddgs, "text", return_value=iter(mock_results)) as mock_text: + with patch("src.tools.web_search.preprocess_query", return_value="preprocessed query"): + await web_search_tool.search("original query", max_results=5) + + # Should call text with preprocessed query + mock_text.assert_called_once_with("preprocessed query", max_results=5) + + @pytest.mark.asyncio + async def test_search_falls_back_to_original_query(self, web_search_tool: WebSearchTool) -> None: + """Should fall back to original query if preprocessing returns empty.""" + mock_results = [{"title": "Test", "body": "Content", "href": "https://example.com"}] + + with patch.object(web_search_tool._ddgs, "text", return_value=iter(mock_results)) as mock_text: + with patch("src.tools.web_search.preprocess_query", return_value=""): + await web_search_tool.search("original query", max_results=5) + + # Should call text with original query + mock_text.assert_called_once_with("original query", max_results=5) + diff --git a/tests/unit/utils/test_hf_error_handler.py b/tests/unit/utils/test_hf_error_handler.py new file mode 100644 index 00000000..cd7df2c2 --- /dev/null +++ b/tests/unit/utils/test_hf_error_handler.py @@ -0,0 +1,239 @@ +"""Unit tests for HuggingFace error handling utilities.""" + +from unittest.mock import patch + +import pytest + +from src.utils.hf_error_handler import ( + extract_error_details, + get_fallback_models, + get_user_friendly_error_message, + log_token_info, + should_retry_with_fallback, + validate_hf_token, +) + + +class TestExtractErrorDetails: + """Tests for extract_error_details function.""" + + def test_extract_403_error(self) -> None: + """Should extract 403 error details correctly.""" + error = Exception("status_code: 403, model_name: Qwen/Qwen3-Next-80B-A3B-Thinking, body: Forbidden") + details = extract_error_details(error) + + assert details["status_code"] == 403 + assert details["model_name"] == "Qwen/Qwen3-Next-80B-A3B-Thinking" + assert details["body"] == "Forbidden" + assert details["error_type"] == "http_403" + assert details["is_auth_error"] is True + assert details["is_model_error"] is False + + def test_extract_422_error(self) -> None: + """Should extract 422 error details correctly.""" + error = Exception("status_code: 422, model_name: meta-llama/Llama-3.1-70B-Instruct, body: Unprocessable Entity") + details = extract_error_details(error) + + assert details["status_code"] == 422 + assert details["model_name"] == "meta-llama/Llama-3.1-70B-Instruct" + assert details["body"] == "Unprocessable Entity" + assert details["error_type"] == "http_422" + assert details["is_auth_error"] is False + assert details["is_model_error"] is True + + def test_extract_partial_error(self) -> None: + """Should handle partial error information.""" + error = Exception("status_code: 500") + details = extract_error_details(error) + + assert details["status_code"] == 500 + assert details["model_name"] is None + assert details["body"] is None + assert details["error_type"] == "http_500" + assert details["is_auth_error"] is False + assert details["is_model_error"] is False + + def test_extract_generic_error(self) -> None: + """Should handle generic errors without status codes.""" + error = Exception("Something went wrong") + details = extract_error_details(error) + + assert details["status_code"] is None + assert details["model_name"] is None + assert details["body"] is None + assert details["error_type"] == "unknown" + assert details["is_auth_error"] is False + assert details["is_model_error"] is False + + +class TestGetUserFriendlyErrorMessage: + """Tests for get_user_friendly_error_message function.""" + + def test_403_error_message(self) -> None: + """Should generate user-friendly 403 error message.""" + error = Exception("status_code: 403, model_name: Qwen/Qwen3-Next-80B-A3B-Thinking, body: Forbidden") + message = get_user_friendly_error_message(error) + + assert "Authentication Error" in message + assert "inference-api" in message + assert "Qwen/Qwen3-Next-80B-A3B-Thinking" in message + assert "Forbidden" in message + + def test_422_error_message(self) -> None: + """Should generate user-friendly 422 error message.""" + error = Exception("status_code: 422, model_name: meta-llama/Llama-3.1-70B-Instruct, body: Unprocessable Entity") + message = get_user_friendly_error_message(error) + + assert "Model Compatibility Error" in message + assert "meta-llama/Llama-3.1-70B-Instruct" in message + assert "Unprocessable Entity" in message + + def test_generic_error_message(self) -> None: + """Should generate generic error message for unknown errors.""" + error = Exception("Something went wrong") + message = get_user_friendly_error_message(error) + + assert "API Error" in message + assert "Something went wrong" in message + + def test_error_message_with_model_name_param(self) -> None: + """Should use provided model_name parameter when not in error.""" + error = Exception("status_code: 403, body: Forbidden") + message = get_user_friendly_error_message(error, model_name="test-model") + + assert "test-model" in message + + +class TestValidateHfToken: + """Tests for validate_hf_token function.""" + + def test_valid_token(self) -> None: + """Should validate a valid token.""" + token = "hf_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx" + is_valid, error_msg = validate_hf_token(token) + + assert is_valid is True + assert error_msg is None + + def test_none_token(self) -> None: + """Should reject None token.""" + is_valid, error_msg = validate_hf_token(None) + + assert is_valid is False + assert "None or empty" in error_msg + + def test_empty_token(self) -> None: + """Should reject empty token.""" + is_valid, error_msg = validate_hf_token("") + + assert is_valid is False + assert "None or empty" in error_msg + + def test_non_string_token(self) -> None: + """Should reject non-string token.""" + is_valid, error_msg = validate_hf_token(123) # type: ignore[arg-type] + + assert is_valid is False + assert "not a string" in error_msg + + def test_short_token(self) -> None: + """Should reject token that's too short.""" + is_valid, error_msg = validate_hf_token("hf_123") + + assert is_valid is False + assert "too short" in error_msg + + def test_oauth_token_format(self) -> None: + """Should accept OAuth tokens (may not start with hf_).""" + # OAuth tokens may have different formats + token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ" + is_valid, error_msg = validate_hf_token(token) + + assert is_valid is True + assert error_msg is None + + +class TestLogTokenInfo: + """Tests for log_token_info function.""" + + @patch("src.utils.hf_error_handler.logger") + def test_log_valid_token(self, mock_logger) -> None: + """Should log token info for valid token.""" + token = "hf_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx" + log_token_info(token, context="test") + + mock_logger.debug.assert_called_once() + call_args = mock_logger.debug.call_args + assert call_args[0][0] == "Token validation" + assert call_args[1]["context"] == "test" + assert call_args[1]["has_token"] is True + assert call_args[1]["is_valid"] is True + assert call_args[1]["token_length"] == len(token) + assert "token_prefix" in call_args[1] + + @patch("src.utils.hf_error_handler.logger") + def test_log_none_token(self, mock_logger) -> None: + """Should log None token info.""" + log_token_info(None, context="test") + + mock_logger.debug.assert_called_once() + call_args = mock_logger.debug.call_args + assert call_args[0][0] == "Token validation" + assert call_args[1]["context"] == "test" + assert call_args[1]["has_token"] is False + + +class TestShouldRetryWithFallback: + """Tests for should_retry_with_fallback function.""" + + def test_403_error_should_retry(self) -> None: + """Should retry for 403 errors.""" + error = Exception("status_code: 403, model_name: test-model, body: Forbidden") + assert should_retry_with_fallback(error) is True + + def test_422_error_should_retry(self) -> None: + """Should retry for 422 errors.""" + error = Exception("status_code: 422, model_name: test-model, body: Unprocessable Entity") + assert should_retry_with_fallback(error) is True + + def test_model_specific_error_should_retry(self) -> None: + """Should retry for model-specific errors.""" + error = Exception("status_code: 500, model_name: test-model, body: Error") + assert should_retry_with_fallback(error) is True + + def test_generic_error_should_not_retry(self) -> None: + """Should not retry for generic errors without model info.""" + error = Exception("Something went wrong") + assert should_retry_with_fallback(error) is False + + +class TestGetFallbackModels: + """Tests for get_fallback_models function.""" + + def test_get_fallback_models_default(self) -> None: + """Should return default fallback models.""" + fallbacks = get_fallback_models() + + assert len(fallbacks) > 0 + assert "meta-llama/Llama-3.1-8B-Instruct" in fallbacks + assert isinstance(fallbacks, list) + + def test_get_fallback_models_excludes_original(self) -> None: + """Should exclude original model from fallbacks.""" + original = "meta-llama/Llama-3.1-8B-Instruct" + fallbacks = get_fallback_models(original_model=original) + + assert original not in fallbacks + assert len(fallbacks) > 0 + + def test_get_fallback_models_with_unknown_original(self) -> None: + """Should return all fallbacks if original is not in list.""" + original = "unknown/model" + fallbacks = get_fallback_models(original_model=original) + + # Should still have all fallbacks since original is not in the list + assert len(fallbacks) >= 3 # At least 3 fallback models + + + + diff --git a/tests/unit/utils/test_hf_model_validator.py b/tests/unit/utils/test_hf_model_validator.py new file mode 100644 index 00000000..2027ff72 --- /dev/null +++ b/tests/unit/utils/test_hf_model_validator.py @@ -0,0 +1,416 @@ +"""Unit tests for HuggingFace model and provider validator.""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from src.utils.hf_model_validator import ( + extract_oauth_token, + get_available_models, + get_available_providers, + get_models_for_provider, + validate_model_provider_combination, + validate_oauth_token, +) + + +class TestExtractOAuthToken: + """Tests for extract_oauth_token function.""" + + def test_extract_from_oauth_token_object(self) -> None: + """Should extract token from OAuthToken object with .token attribute.""" + mock_oauth_token = MagicMock() + mock_oauth_token.token = "hf_test_token_123" + + result = extract_oauth_token(mock_oauth_token) + + assert result == "hf_test_token_123" + + def test_extract_from_string(self) -> None: + """Should return string token as-is.""" + token = "hf_test_token_123" + result = extract_oauth_token(token) + + assert result == token + + def test_extract_none(self) -> None: + """Should return None for None input.""" + result = extract_oauth_token(None) + + assert result is None + + def test_extract_invalid_object(self) -> None: + """Should return None for object without .token attribute.""" + invalid_object = MagicMock() + del invalid_object.token # Remove token attribute + + with patch("src.utils.hf_model_validator.logger") as mock_logger: + result = extract_oauth_token(invalid_object) + + assert result is None + mock_logger.warning.assert_called_once() + + +class TestGetAvailableProviders: + """Tests for get_available_providers function.""" + + @pytest.mark.asyncio + async def test_get_providers_with_cache(self) -> None: + """Should return cached providers if available.""" + # First call - should query API + with patch("src.utils.hf_model_validator.HfApi") as mock_api_class: + mock_api = MagicMock() + mock_api_class.return_value = mock_api + + # Mock model_info to return provider mapping + mock_model_info = MagicMock() + mock_model_info.inference_provider_mapping = { + "hf-inference": MagicMock(), + "nebius": MagicMock(), + } + mock_api.model_info.return_value = mock_model_info + + # Mock settings + with patch("src.utils.hf_model_validator.settings") as mock_settings: + mock_settings.get_hf_fallback_models_list.return_value = [ + "meta-llama/Llama-3.1-8B-Instruct" + ] + + providers = await get_available_providers(token="test_token") + + assert "auto" in providers + assert len(providers) > 1 + + @pytest.mark.asyncio + async def test_get_providers_fallback_to_known(self) -> None: + """Should fall back to known providers if discovery fails.""" + with patch("src.utils.hf_model_validator.HfApi") as mock_api_class: + mock_api = MagicMock() + mock_api_class.return_value = mock_api + mock_api.model_info.side_effect = Exception("API error") + + with patch("src.utils.hf_model_validator.settings") as mock_settings: + mock_settings.get_hf_fallback_models_list.return_value = [ + "meta-llama/Llama-3.1-8B-Instruct" + ] + + providers = await get_available_providers(token="test_token") + + # Should return known providers as fallback + assert "auto" in providers + assert len(providers) > 0 + + @pytest.mark.asyncio + async def test_get_providers_no_token(self) -> None: + """Should work without token.""" + with patch("src.utils.hf_model_validator.HfApi") as mock_api_class: + mock_api = MagicMock() + mock_api_class.return_value = mock_api + mock_api.model_info.side_effect = Exception("API error") + + with patch("src.utils.hf_model_validator.settings") as mock_settings: + mock_settings.get_hf_fallback_models_list.return_value = [ + "meta-llama/Llama-3.1-8B-Instruct" + ] + + providers = await get_available_providers(token=None) + + # Should return known providers as fallback + assert "auto" in providers + + +class TestGetAvailableModels: + """Tests for get_available_models function.""" + + @pytest.mark.asyncio + async def test_get_models_with_token(self) -> None: + """Should fetch models with token.""" + with patch("src.utils.hf_model_validator.HfApi") as mock_api_class: + mock_api = MagicMock() + mock_api_class.return_value = mock_api + + # Mock list_models to return model objects + mock_model1 = MagicMock() + mock_model1.id = "model1" + mock_model2 = MagicMock() + mock_model2.id = "model2" + mock_api.list_models.return_value = [mock_model1, mock_model2] + + models = await get_available_models(token="test_token", limit=10) + + assert len(models) == 2 + assert "model1" in models + assert "model2" in models + mock_api.list_models.assert_called_once() + + @pytest.mark.asyncio + async def test_get_models_with_provider_filter(self) -> None: + """Should filter models by provider.""" + with patch("src.utils.hf_model_validator.HfApi") as mock_api_class: + mock_api = MagicMock() + mock_api_class.return_value = mock_api + + mock_model = MagicMock() + mock_model.id = "model1" + mock_api.list_models.return_value = [mock_model] + + models = await get_available_models( + token="test_token", + inference_provider="nebius", + limit=10, + ) + + # Check that inference_provider was passed to list_models + call_kwargs = mock_api.list_models.call_args[1] + assert call_kwargs.get("inference_provider") == "nebius" + + @pytest.mark.asyncio + async def test_get_models_fallback_on_error(self) -> None: + """Should return fallback models on error.""" + with patch("src.utils.hf_model_validator.HfApi") as mock_api_class: + mock_api = MagicMock() + mock_api_class.return_value = mock_api + mock_api.list_models.side_effect = Exception("API error") + + models = await get_available_models(token="test_token", limit=10) + + # Should return fallback models + assert len(models) > 0 + assert "meta-llama/Llama-3.1-8B-Instruct" in models + + +class TestValidateModelProviderCombination: + """Tests for validate_model_provider_combination function.""" + + @pytest.mark.asyncio + async def test_validate_auto_provider(self) -> None: + """Should always validate 'auto' provider.""" + is_valid, error_msg = await validate_model_provider_combination( + model_id="test-model", + provider="auto", + token="test_token", + ) + + assert is_valid is True + assert error_msg is None + + @pytest.mark.asyncio + async def test_validate_none_provider(self) -> None: + """Should validate None provider as auto.""" + is_valid, error_msg = await validate_model_provider_combination( + model_id="test-model", + provider=None, + token="test_token", + ) + + assert is_valid is True + assert error_msg is None + + @pytest.mark.asyncio + async def test_validate_valid_combination(self) -> None: + """Should validate valid model/provider combination.""" + with patch("src.utils.hf_model_validator.HfApi") as mock_api_class: + mock_api = MagicMock() + mock_api_class.return_value = mock_api + + # Mock model_info with provider mapping + mock_model_info = MagicMock() + mock_model_info.inference_provider_mapping = { + "nebius": MagicMock(), + "hf-inference": MagicMock(), + } + mock_api.model_info.return_value = mock_model_info + + is_valid, error_msg = await validate_model_provider_combination( + model_id="test-model", + provider="nebius", + token="test_token", + ) + + assert is_valid is True + assert error_msg is None + + @pytest.mark.asyncio + async def test_validate_invalid_combination(self) -> None: + """Should reject invalid model/provider combination.""" + with patch("src.utils.hf_model_validator.HfApi") as mock_api_class: + mock_api = MagicMock() + mock_api_class.return_value = mock_api + + # Mock model_info with provider mapping (without requested provider) + mock_model_info = MagicMock() + mock_model_info.inference_provider_mapping = { + "hf-inference": MagicMock(), + } + mock_api.model_info.return_value = mock_model_info + + is_valid, error_msg = await validate_model_provider_combination( + model_id="test-model", + provider="nebius", + token="test_token", + ) + + assert is_valid is False + assert error_msg is not None + assert "nebius" in error_msg + + @pytest.mark.asyncio + async def test_validate_fireworks_variants(self) -> None: + """Should handle fireworks/fireworks-ai name variants.""" + with patch("src.utils.hf_model_validator.HfApi") as mock_api_class: + mock_api = MagicMock() + mock_api_class.return_value = mock_api + + # Mock model_info with fireworks-ai in mapping + mock_model_info = MagicMock() + mock_model_info.inference_provider_mapping = { + "fireworks-ai": MagicMock(), + } + mock_api.model_info.return_value = mock_model_info + + # Should accept "fireworks" when mapping has "fireworks-ai" + is_valid, error_msg = await validate_model_provider_combination( + model_id="test-model", + provider="fireworks", + token="test_token", + ) + + assert is_valid is True + assert error_msg is None + + @pytest.mark.asyncio + async def test_validate_graceful_on_error(self) -> None: + """Should return valid on error (graceful degradation).""" + with patch("src.utils.hf_model_validator.HfApi") as mock_api_class: + mock_api = MagicMock() + mock_api_class.return_value = mock_api + mock_api.model_info.side_effect = Exception("API error") + + is_valid, error_msg = await validate_model_provider_combination( + model_id="test-model", + provider="nebius", + token="test_token", + ) + + # Should return True to allow actual request to determine validity + assert is_valid is True + + +class TestGetModelsForProvider: + """Tests for get_models_for_provider function.""" + + @pytest.mark.asyncio + async def test_get_models_for_provider(self) -> None: + """Should get models for specific provider.""" + with patch("src.utils.hf_model_validator.get_available_models") as mock_get_models: + mock_get_models.return_value = ["model1", "model2"] + + models = await get_models_for_provider( + provider="nebius", + token="test_token", + limit=10, + ) + + assert len(models) == 2 + mock_get_models.assert_called_once_with( + token="test_token", + task="text-generation", + limit=10, + inference_provider="nebius", + ) + + @pytest.mark.asyncio + async def test_get_models_normalize_fireworks(self) -> None: + """Should normalize 'fireworks' to 'fireworks-ai'.""" + with patch("src.utils.hf_model_validator.get_available_models") as mock_get_models: + mock_get_models.return_value = ["model1"] + + models = await get_models_for_provider( + provider="fireworks", + token="test_token", + ) + + # Should call with "fireworks-ai" not "fireworks" + call_kwargs = mock_get_models.call_args[1] + assert call_kwargs["inference_provider"] == "fireworks-ai" + + +class TestValidateOAuthToken: + """Tests for validate_oauth_token function.""" + + @pytest.mark.asyncio + async def test_validate_none_token(self) -> None: + """Should return invalid for None token.""" + result = await validate_oauth_token(None) + + assert result["is_valid"] is False + assert result["error"] == "No token provided" + + @pytest.mark.asyncio + async def test_validate_invalid_format(self) -> None: + """Should return invalid for malformed token.""" + result = await validate_oauth_token("short") + + assert result["is_valid"] is False + assert "Invalid token format" in result["error"] + + @pytest.mark.asyncio + async def test_validate_valid_token(self) -> None: + """Should validate valid token and return resources.""" + with patch("src.utils.hf_model_validator.HfApi") as mock_api_class: + mock_api = MagicMock() + mock_api_class.return_value = mock_api + + # Mock whoami to return user info + mock_api.whoami.return_value = {"name": "testuser", "fullname": "Test User"} + + # Mock get_available_models and get_available_providers + with patch("src.utils.hf_model_validator.get_available_models") as mock_get_models, \ + patch("src.utils.hf_model_validator.get_available_providers") as mock_get_providers: + mock_get_models.return_value = ["model1", "model2"] + mock_get_providers.return_value = ["auto", "nebius"] + + result = await validate_oauth_token("hf_valid_token_123") + + assert result["is_valid"] is True + assert result["username"] == "testuser" + assert result["has_inference_api_scope"] is True + assert len(result["available_models"]) == 2 + assert len(result["available_providers"]) == 2 + + @pytest.mark.asyncio + async def test_validate_token_without_scope(self) -> None: + """Should detect missing inference-api scope.""" + with patch("src.utils.hf_model_validator.HfApi") as mock_api_class: + mock_api = MagicMock() + mock_api_class.return_value = mock_api + mock_api.whoami.return_value = {"name": "testuser"} + + # Mock get_available_models to fail (no scope) + with patch("src.utils.hf_model_validator.get_available_models") as mock_get_models, \ + patch("src.utils.hf_model_validator.get_available_providers") as mock_get_providers: + mock_get_models.side_effect = Exception("403 Forbidden") + mock_get_providers.return_value = ["auto"] + + result = await validate_oauth_token("hf_token_without_scope") + + assert result["is_valid"] is True # Token is valid + assert result["has_inference_api_scope"] is False # But no scope + assert "inference-api scope" in result["error"] + + @pytest.mark.asyncio + async def test_validate_invalid_token(self) -> None: + """Should return invalid for token that fails authentication.""" + with patch("src.utils.hf_model_validator.HfApi") as mock_api_class: + mock_api = MagicMock() + mock_api_class.return_value = mock_api + mock_api.whoami.side_effect = Exception("401 Unauthorized") + + result = await validate_oauth_token("hf_invalid_token") + + assert result["is_valid"] is False + assert "could not authenticate" in result["error"] + + + + diff --git a/tests/unit/utils/test_llm_factory_token_validation.py b/tests/unit/utils/test_llm_factory_token_validation.py new file mode 100644 index 00000000..e663958f --- /dev/null +++ b/tests/unit/utils/test_llm_factory_token_validation.py @@ -0,0 +1,142 @@ +"""Unit tests for token validation in llm_factory.py.""" + +from unittest.mock import MagicMock, patch + +import pytest + +from src.utils.exceptions import ConfigurationError + + +class TestGetPydanticAiModelTokenValidation: + """Tests for get_pydantic_ai_model function with token validation.""" + + @patch("src.utils.llm_factory.settings") + @patch("src.utils.hf_error_handler.log_token_info") + @patch("src.utils.hf_error_handler.validate_hf_token") + @patch("pydantic_ai.providers.huggingface.HuggingFaceProvider") + @patch("pydantic_ai.models.huggingface.HuggingFaceModel") + def test_validates_oauth_token( + self, + mock_model_class, + mock_provider_class, + mock_validate, + mock_log, + mock_settings, + ) -> None: + """Should validate and log OAuth token when provided.""" + mock_settings.hf_token = None + mock_settings.huggingface_api_key = None + mock_settings.huggingface_model = "test-model" + mock_validate.return_value = (True, None) + mock_model_class.return_value = MagicMock() + + from src.utils.llm_factory import get_pydantic_ai_model + + get_pydantic_ai_model(oauth_token="hf_test_token") + + # Should log token info + mock_log.assert_called_once_with("hf_test_token", context="get_pydantic_ai_model") + # Should validate token + mock_validate.assert_called_once_with("hf_test_token") + + @patch("src.utils.llm_factory.settings") + @patch("src.utils.hf_error_handler.log_token_info") + @patch("src.utils.hf_error_handler.validate_hf_token") + @patch("src.utils.llm_factory.logger") + @patch("pydantic_ai.providers.huggingface.HuggingFaceProvider") + @patch("pydantic_ai.models.huggingface.HuggingFaceModel") + def test_warns_on_invalid_token( + self, + mock_model_class, + mock_provider_class, + mock_logger, + mock_validate, + mock_log, + mock_settings, + ) -> None: + """Should warn when token validation fails.""" + mock_settings.hf_token = None + mock_settings.huggingface_api_key = None + mock_settings.huggingface_model = "test-model" + mock_validate.return_value = (False, "Token too short") + mock_model_class.return_value = MagicMock() + + from src.utils.llm_factory import get_pydantic_ai_model + + get_pydantic_ai_model(oauth_token="short") + + # Should warn about invalid token + warning_calls = [ + call + for call in mock_logger.warning.call_args_list + if "Token validation failed" in str(call) + ] + assert len(warning_calls) > 0 + + @patch("src.utils.llm_factory.settings") + @patch("src.utils.hf_error_handler.log_token_info") + @patch("src.utils.hf_error_handler.validate_hf_token") + @patch("pydantic_ai.providers.huggingface.HuggingFaceProvider") + @patch("pydantic_ai.models.huggingface.HuggingFaceModel") + def test_uses_env_token_when_oauth_not_provided( + self, + mock_model_class, + mock_provider_class, + mock_validate, + mock_log, + mock_settings, + ) -> None: + """Should use environment token when OAuth token not provided.""" + mock_settings.hf_token = "hf_env_token" + mock_settings.huggingface_api_key = None + mock_settings.huggingface_model = "test-model" + mock_validate.return_value = (True, None) + mock_model_class.return_value = MagicMock() + + from src.utils.llm_factory import get_pydantic_ai_model + + get_pydantic_ai_model(oauth_token=None) + + # Should log and validate env token + mock_log.assert_called_once_with("hf_env_token", context="get_pydantic_ai_model") + mock_validate.assert_called_once_with("hf_env_token") + + @patch("src.utils.llm_factory.settings") + def test_raises_when_no_token_available(self, mock_settings) -> None: + """Should raise ConfigurationError when no token is available.""" + mock_settings.hf_token = None + mock_settings.huggingface_api_key = None + + from src.utils.llm_factory import get_pydantic_ai_model + + with pytest.raises(ConfigurationError, match="HuggingFace token required"): + get_pydantic_ai_model(oauth_token=None) + + @patch("src.utils.llm_factory.settings") + @patch("src.utils.hf_error_handler.log_token_info") + @patch("src.utils.hf_error_handler.validate_hf_token") + @patch("pydantic_ai.providers.huggingface.HuggingFaceProvider") + @patch("pydantic_ai.models.huggingface.HuggingFaceModel") + def test_oauth_token_priority_over_env( + self, + mock_model_class, + mock_provider_class, + mock_validate, + mock_log, + mock_settings, + ) -> None: + """Should prefer OAuth token over environment token.""" + mock_settings.hf_token = "hf_env_token" + mock_settings.huggingface_api_key = None + mock_settings.huggingface_model = "test-model" + mock_validate.return_value = (True, None) + mock_model_class.return_value = MagicMock() + + from src.utils.llm_factory import get_pydantic_ai_model + + get_pydantic_ai_model(oauth_token="hf_oauth_token") + + # Should use OAuth token, not env token + mock_log.assert_called_once_with("hf_oauth_token", context="get_pydantic_ai_model") + mock_validate.assert_called_once_with("hf_oauth_token") + diff --git a/tests/unit/utils/test_message_history.py b/tests/unit/utils/test_message_history.py new file mode 100644 index 00000000..19344e58 --- /dev/null +++ b/tests/unit/utils/test_message_history.py @@ -0,0 +1,151 @@ +"""Unit tests for message history utilities.""" + +import pytest + +pytestmark = pytest.mark.unit + +from src.utils.message_history import ( + convert_gradio_to_message_history, + create_relevance_processor, + create_truncation_processor, + message_history_to_string, +) + + +def test_convert_gradio_to_message_history_empty(): + """Test conversion with empty history.""" + result = convert_gradio_to_message_history([]) + assert result == [] + + +def test_convert_gradio_to_message_history_single_turn(): + """Test conversion with a single turn.""" + gradio_history = [ + {"role": "user", "content": "What is AI?"}, + {"role": "assistant", "content": "AI is artificial intelligence."}, + ] + result = convert_gradio_to_message_history(gradio_history) + assert len(result) == 2 + + +def test_convert_gradio_to_message_history_multiple_turns(): + """Test conversion with multiple turns.""" + gradio_history = [ + {"role": "user", "content": "What is AI?"}, + {"role": "assistant", "content": "AI is artificial intelligence."}, + {"role": "user", "content": "Tell me more"}, + {"role": "assistant", "content": "AI includes machine learning..."}, + ] + result = convert_gradio_to_message_history(gradio_history) + assert len(result) == 4 + + +def test_convert_gradio_to_message_history_max_messages(): + """Test conversion with max_messages limit.""" + gradio_history = [] + for i in range(15): # Create 15 turns + gradio_history.append({"role": "user", "content": f"Message {i}"}) + gradio_history.append({"role": "assistant", "content": f"Response {i}"}) + + result = convert_gradio_to_message_history(gradio_history, max_messages=10) + # Should only include most recent 10 messages + assert len(result) <= 10 + + +def test_convert_gradio_to_message_history_filters_invalid(): + """Test that invalid entries are filtered out.""" + gradio_history = [ + {"role": "user", "content": "Valid message"}, + {"role": "system", "content": "Should be filtered"}, + {"role": "assistant", "content": ""}, # Empty content should be filtered + {"role": "assistant", "content": "Valid response"}, + ] + result = convert_gradio_to_message_history(gradio_history) + # Should only have 2 valid messages (user + assistant) + assert len(result) == 2 + + +def test_message_history_to_string_empty(): + """Test string conversion with empty history.""" + result = message_history_to_string([]) + assert result == "" + + +def test_message_history_to_string_format(): + """Test string conversion format.""" + # Create mock message history + try: + from pydantic_ai import ModelRequest, ModelResponse + from pydantic_ai.messages import TextPart, UserPromptPart + + messages = [ + ModelRequest(parts=[UserPromptPart(content="Question 1")]), + ModelResponse(parts=[TextPart(content="Answer 1")]), + ] + result = message_history_to_string(messages) + assert "PREVIOUS CONVERSATION" in result + assert "User:" in result + assert "Assistant:" in result + except ImportError: + # Skip if pydantic_ai not available + pytest.skip("pydantic_ai not available") + + +def test_message_history_to_string_max_messages(): + """Test string conversion with max_messages limit.""" + try: + from pydantic_ai import ModelRequest, ModelResponse + from pydantic_ai.messages import TextPart, UserPromptPart + + messages = [] + for i in range(10): # Create 10 turns + messages.append(ModelRequest(parts=[UserPromptPart(content=f"Question {i}")])) + messages.append(ModelResponse(parts=[TextPart(content=f"Answer {i}")])) + + result = message_history_to_string(messages, max_messages=3) + # Should only include most recent 3 messages (1.5 turns) + assert result != "" + except ImportError: + pytest.skip("pydantic_ai not available") + + +def test_create_truncation_processor(): + """Test truncation processor factory.""" + processor = create_truncation_processor(max_messages=5) + assert callable(processor) + + try: + from pydantic_ai import ModelRequest + from pydantic_ai.messages import UserPromptPart + + messages = [ + ModelRequest(parts=[UserPromptPart(content=f"Message {i}")]) + for i in range(10) + ] + result = processor(messages) + assert len(result) == 5 + except ImportError: + pytest.skip("pydantic_ai not available") + + +def test_create_relevance_processor(): + """Test relevance processor factory.""" + processor = create_relevance_processor(min_length=10) + assert callable(processor) + + try: + from pydantic_ai import ModelRequest, ModelResponse + from pydantic_ai.messages import TextPart, UserPromptPart + + messages = [ + ModelRequest(parts=[UserPromptPart(content="Short")]), # Too short + ModelRequest(parts=[UserPromptPart(content="This is a longer message")]), # Valid + ModelResponse(parts=[TextPart(content="OK")]), # Too short + ModelResponse(parts=[TextPart(content="This is a valid response")]), # Valid + ] + result = processor(messages) + # Should only keep messages with length >= 10 + assert len(result) == 2 + except ImportError: + pytest.skip("pydantic_ai not available") +