diff --git a/.github/README.md b/.github/README.md
index 35cca027..8f3727f7 100644
--- a/.github/README.md
+++ b/.github/README.md
@@ -3,8 +3,7 @@
> **You are reading the Github README!**
>
> - π **Documentation**: See our [technical documentation](https://deepcritical.github.io/GradioDemo/) for detailed information
-> - π **Demo README**: Check out the [Demo README](..README.md) for for more information about our MCP Hackathon submission
-> - π **Hackathon Submission**: Keep reading below for more information about our MCP Hackathon submission
+> - π **Demo README**: Check out the [Demo README](..README.md) for more information > - π **Demo**: Kindly consider using our [Free Demo](https://hf.co/DataQuests/GradioDemo)
diff --git a/.github/scripts/deploy_to_hf_space.py b/.github/scripts/deploy_to_hf_space.py
index 839a38dd..ba4f81ac 100644
--- a/.github/scripts/deploy_to_hf_space.py
+++ b/.github/scripts/deploy_to_hf_space.py
@@ -3,13 +3,22 @@
import os
import shutil
import subprocess
+<<<<<<< HEAD
+import tempfile
+from pathlib import Path
+=======
from pathlib import Path
from typing import Set
+>>>>>>> origin/dev
from huggingface_hub import HfApi
+<<<<<<< HEAD
+def get_excluded_dirs() -> set[str]:
+=======
def get_excluded_dirs() -> Set[str]:
+>>>>>>> origin/dev
"""Get set of directory names to exclude from deployment."""
return {
"docs",
@@ -38,17 +47,28 @@ def get_excluded_dirs() -> Set[str]:
"dist",
".eggs",
"htmlcov",
+<<<<<<< HEAD
+ "hf_space", # Exclude the cloned HF Space directory itself
+ }
+
+
+def get_excluded_files() -> set[str]:
+=======
}
def get_excluded_files() -> Set[str]:
+>>>>>>> origin/dev
"""Get set of file names to exclude from deployment."""
return {
".pre-commit-config.yaml",
"mkdocs.yml",
"uv.lock",
"AGENTS.txt",
+<<<<<<< HEAD
+=======
"CONTRIBUTING.md",
+>>>>>>> origin/dev
".env",
".env.local",
"*.local",
@@ -60,17 +80,29 @@ def get_excluded_files() -> Set[str]:
}
+<<<<<<< HEAD
+def should_exclude(path: Path, excluded_dirs: set[str], excluded_files: set[str]) -> bool:
+=======
def should_exclude(path: Path, excluded_dirs: Set[str], excluded_files: Set[str]) -> bool:
+>>>>>>> origin/dev
"""Check if a path should be excluded from deployment."""
# Check if any parent directory is excluded
for parent in path.parents:
if parent.name in excluded_dirs:
return True
+<<<<<<< HEAD
+
+ # Check if the path itself is a directory that should be excluded
+ if path.is_dir() and path.name in excluded_dirs:
+ return True
+
+=======
# Check if the path itself is a directory that should be excluded
if path.is_dir() and path.name in excluded_dirs:
return True
+>>>>>>> origin/dev
# Check if the file name matches excluded patterns
if path.is_file():
# Check exact match
@@ -83,39 +115,82 @@ def should_exclude(path: Path, excluded_dirs: Set[str], excluded_files: Set[str]
suffix = pattern.replace("*", "")
if path.name.endswith(suffix):
return True
+<<<<<<< HEAD
+
+=======
+>>>>>>> origin/dev
return False
def deploy_to_hf_space() -> None:
"""Deploy repository to Hugging Face Space.
+<<<<<<< HEAD
+
+ Supports both user and organization Spaces:
+ - User Space: username/space-name
+ - Organization Space: organization-name/space-name
+
+=======
Supports both user and organization Spaces:
- User Space: username/space-name
- Organization Space: organization-name/space-name
+>>>>>>> origin/dev
Works with both classic tokens and fine-grained tokens.
"""
# Get configuration from environment variables
hf_token = os.getenv("HF_TOKEN")
hf_username = os.getenv("HF_USERNAME") # Can be username or organization name
space_name = os.getenv("HF_SPACE_NAME")
+<<<<<<< HEAD
+
+ # Check which variables are missing and provide helpful error message
+ missing = []
+ if not hf_token:
+ missing.append("HF_TOKEN (should be in repository secrets)")
+ if not hf_username:
+ missing.append("HF_USERNAME (should be in repository variables)")
+ if not space_name:
+ missing.append("HF_SPACE_NAME (should be in repository variables)")
+
+ if missing:
+ raise ValueError(
+ f"Missing required environment variables: {', '.join(missing)}\n"
+ f"Please configure:\n"
+ f" - HF_TOKEN in Settings > Secrets and variables > Actions > Secrets\n"
+ f" - HF_USERNAME in Settings > Secrets and variables > Actions > Variables\n"
+ f" - HF_SPACE_NAME in Settings > Secrets and variables > Actions > Variables"
+ )
+
+=======
if not all([hf_token, hf_username, space_name]):
raise ValueError(
"Missing required environment variables: HF_TOKEN, HF_USERNAME, HF_SPACE_NAME"
)
+>>>>>>> origin/dev
# HF_USERNAME can be either a username or organization name
# Format: {username|organization}/{space_name}
repo_id = f"{hf_username}/{space_name}"
local_dir = "hf_space"
+<<<<<<< HEAD
+
+ print(f"π Deploying to Hugging Face Space: {repo_id}")
+
+ # Initialize HF API
+ api = HfApi(token=hf_token)
+
+=======
print(f"π Deploying to Hugging Face Space: {repo_id}")
# Initialize HF API
api = HfApi(token=hf_token)
+>>>>>>> origin/dev
# Create Space if it doesn't exist
try:
api.repo_info(repo_id=repo_id, repo_type="space", token=hf_token)
@@ -133,6 +208,47 @@ def deploy_to_hf_space() -> None:
exist_ok=True,
)
print(f"β
Created new Space: {repo_id}")
+<<<<<<< HEAD
+
+ # Configure Git credential helper for authentication
+ # This is needed for Git LFS to work properly with fine-grained tokens
+ print("π Configuring Git credentials...")
+
+ # Use Git credential store to store the token
+ # This allows Git LFS to authenticate properly
+ temp_dir = Path(tempfile.gettempdir())
+ credential_store = temp_dir / ".git-credentials-hf"
+
+ # Write credentials in the format: https://username:token@huggingface.co
+ credential_store.write_text(
+ f"https://{hf_username}:{hf_token}@huggingface.co\n", encoding="utf-8"
+ )
+ try:
+ credential_store.chmod(0o600) # Secure permissions (Unix only)
+ except OSError:
+ # Windows doesn't support chmod, skip
+ pass
+
+ # Configure Git to use the credential store
+ subprocess.run(
+ ["git", "config", "--global", "credential.helper", f"store --file={credential_store}"],
+ check=True,
+ capture_output=True,
+ )
+
+ # Also set environment variable for Git LFS
+ os.environ["GIT_CREDENTIAL_HELPER"] = f"store --file={credential_store}"
+
+ # Clone repository using git
+ # Use the token in the URL for initial clone, but LFS will use credential store
+ space_url = f"https://{hf_username}:{hf_token}@huggingface.co/spaces/{repo_id}"
+
+ if Path(local_dir).exists():
+ print(f"π§Ή Removing existing {local_dir} directory...")
+ shutil.rmtree(local_dir)
+
+ print("π₯ Cloning Space repository...")
+=======
# Clone repository using git
space_url = f"https://{hf_token}@huggingface.co/spaces/{repo_id}"
@@ -142,6 +258,7 @@ def deploy_to_hf_space() -> None:
shutil.rmtree(local_dir)
print(f"π₯ Cloning Space repository...")
+>>>>>>> origin/dev
try:
result = subprocess.run(
["git", "clone", space_url, local_dir],
@@ -149,6 +266,66 @@ def deploy_to_hf_space() -> None:
capture_output=True,
text=True,
)
+<<<<<<< HEAD
+ print("β
Cloned Space repository")
+
+ # After clone, configure the remote to use credential helper
+ # This ensures future operations (like push) use the credential store
+ os.chdir(local_dir)
+ subprocess.run(
+ ["git", "remote", "set-url", "origin", f"https://huggingface.co/spaces/{repo_id}"],
+ check=True,
+ capture_output=True,
+ )
+ os.chdir("..")
+
+ except subprocess.CalledProcessError as e:
+ error_msg = e.stderr if e.stderr else e.stdout if e.stdout else "Unknown error"
+ print(f"β Failed to clone Space repository: {error_msg}")
+
+ # Try alternative: clone with LFS skip, then fetch LFS files separately
+ print("π Trying alternative clone method (skip LFS during clone)...")
+ try:
+ env = os.environ.copy()
+ env["GIT_LFS_SKIP_SMUDGE"] = "1" # Skip LFS during clone
+
+ subprocess.run(
+ ["git", "clone", space_url, local_dir],
+ check=True,
+ capture_output=True,
+ text=True,
+ env=env,
+ )
+ print("β
Cloned Space repository (LFS skipped)")
+
+ # Configure remote
+ os.chdir(local_dir)
+ subprocess.run(
+ ["git", "remote", "set-url", "origin", f"https://huggingface.co/spaces/{repo_id}"],
+ check=True,
+ capture_output=True,
+ )
+
+ # Try to fetch LFS files with proper authentication
+ print("π₯ Fetching LFS files...")
+ subprocess.run(
+ ["git", "lfs", "pull"],
+ check=False, # Don't fail if LFS pull fails - we'll continue without LFS files
+ capture_output=True,
+ text=True,
+ )
+ os.chdir("..")
+ print("β
Repository cloned (LFS files may be incomplete, but deployment can continue)")
+ except subprocess.CalledProcessError as e2:
+ error_msg2 = e2.stderr if e2.stderr else e2.stdout if e2.stdout else "Unknown error"
+ print(f"β Alternative clone method also failed: {error_msg2}")
+ raise RuntimeError(f"Git clone failed: {error_msg}") from e
+
+ # Get exclusion sets
+ excluded_dirs = get_excluded_dirs()
+ excluded_files = get_excluded_files()
+
+=======
print(f"β
Cloned Space repository")
except subprocess.CalledProcessError as e:
error_msg = e.stderr if e.stderr else e.stdout if e.stdout else "Unknown error"
@@ -159,6 +336,7 @@ def deploy_to_hf_space() -> None:
excluded_dirs = get_excluded_dirs()
excluded_files = get_excluded_files()
+>>>>>>> origin/dev
# Remove all existing files in HF Space (except .git)
print("π§Ή Cleaning existing files...")
for item in Path(local_dir).iterdir():
@@ -168,28 +346,61 @@ def deploy_to_hf_space() -> None:
shutil.rmtree(item)
else:
item.unlink()
+<<<<<<< HEAD
+
+=======
+>>>>>>> origin/dev
# Copy files from repository root
print("π¦ Copying files...")
repo_root = Path(".")
files_copied = 0
dirs_copied = 0
+<<<<<<< HEAD
+
+=======
+>>>>>>> origin/dev
for item in repo_root.rglob("*"):
# Skip if in .git directory
if ".git" in item.parts:
continue
+<<<<<<< HEAD
+
+ # Skip if in hf_space directory (the cloned Space directory)
+ if "hf_space" in item.parts:
+ continue
+
+ # Skip if should be excluded
+ if should_exclude(item, excluded_dirs, excluded_files):
+ continue
+
+=======
# Skip if should be excluded
if should_exclude(item, excluded_dirs, excluded_files):
continue
+>>>>>>> origin/dev
# Calculate relative path
try:
rel_path = item.relative_to(repo_root)
except ValueError:
# Item is outside repo root, skip
continue
+<<<<<<< HEAD
+
+ # Skip if in excluded directory
+ if any(part in excluded_dirs for part in rel_path.parts):
+ continue
+
+ # Destination path
+ dest_path = Path(local_dir) / rel_path
+
+ # Create parent directories
+ dest_path.parent.mkdir(parents=True, exist_ok=True)
+
+=======
# Skip if in excluded directory
if any(part in excluded_dirs for part in rel_path.parts):
@@ -201,6 +412,7 @@ def deploy_to_hf_space() -> None:
# Create parent directories
dest_path.parent.mkdir(parents=True, exist_ok=True)
+>>>>>>> origin/dev
# Copy file or directory
if item.is_file():
shutil.copy2(item, dest_path)
@@ -208,6 +420,18 @@ def deploy_to_hf_space() -> None:
elif item.is_dir():
# Directory will be created by parent mkdir, but we track it
dirs_copied += 1
+<<<<<<< HEAD
+
+ print(f"β
Copied {files_copied} files and {dirs_copied} directories")
+
+ # Commit and push changes using git
+ print("πΎ Committing changes...")
+
+ # Change to the Space directory
+ original_cwd = os.getcwd()
+ os.chdir(local_dir)
+
+=======
print(f"β
Copied {files_copied} files and {dirs_copied} directories")
@@ -218,6 +442,7 @@ def deploy_to_hf_space() -> None:
original_cwd = os.getcwd()
os.chdir(local_dir)
+>>>>>>> origin/dev
try:
# Configure git user (required for commit)
subprocess.run(
@@ -230,13 +455,28 @@ def deploy_to_hf_space() -> None:
check=True,
capture_output=True,
)
+<<<<<<< HEAD
+
+=======
+>>>>>>> origin/dev
# Add all files
subprocess.run(
["git", "add", "."],
check=True,
capture_output=True,
)
+<<<<<<< HEAD
+
+ # Check if there are changes to commit
+ result = subprocess.run(
+ ["git", "status", "--porcelain"],
+ check=False,
+ capture_output=True,
+ text=True,
+ )
+
+=======
# Check if there are changes to commit
result = subprocess.run(
@@ -245,6 +485,7 @@ def deploy_to_hf_space() -> None:
text=True,
)
+>>>>>>> origin/dev
if result.stdout.strip():
# There are changes, commit and push
subprocess.run(
@@ -253,6 +494,15 @@ def deploy_to_hf_space() -> None:
capture_output=True,
)
print("π€ Pushing to Hugging Face Space...")
+<<<<<<< HEAD
+ # Ensure remote URL uses credential helper (not token in URL)
+ subprocess.run(
+ ["git", "remote", "set-url", "origin", f"https://huggingface.co/spaces/{repo_id}"],
+ check=True,
+ capture_output=True,
+ )
+=======
+>>>>>>> origin/dev
subprocess.run(
["git", "push"],
check=True,
@@ -273,10 +523,25 @@ def deploy_to_hf_space() -> None:
finally:
# Return to original directory
os.chdir(original_cwd)
+<<<<<<< HEAD
+
+ # Clean up credential store for security
+ try:
+ if credential_store.exists():
+ credential_store.unlink()
+ except Exception:
+ # Ignore cleanup errors
+ pass
+
+=======
+>>>>>>> origin/dev
print(f"π Successfully deployed to: https://huggingface.co/spaces/{repo_id}")
if __name__ == "__main__":
deploy_to_hf_space()
+<<<<<<< HEAD
+=======
+>>>>>>> origin/dev
diff --git a/.github/workflows/deploy-hf-space.yml b/.github/workflows/deploy-hf-space.yml
index 5e788686..2561b279 100644
--- a/.github/workflows/deploy-hf-space.yml
+++ b/.github/workflows/deploy-hf-space.yml
@@ -30,9 +30,18 @@ jobs:
- name: Deploy to Hugging Face Space
env:
+<<<<<<< HEAD
+ # Token from secrets (sensitive data)
+ HF_TOKEN: ${{ secrets.HF_TOKEN }}
+ # Username/Organization from repository variables (non-sensitive)
+ HF_USERNAME: ${{ vars.HF_USERNAME }}
+ # Space name from repository variables (non-sensitive)
+ HF_SPACE_NAME: ${{ vars.HF_SPACE_NAME }}
+=======
HF_TOKEN: ${{ secrets.HF_TOKEN }}
HF_USERNAME: ${{ secrets.HF_USERNAME }}
HF_SPACE_NAME: ${{ secrets.HF_SPACE_NAME }}
+>>>>>>> origin/dev
run: |
python .github/scripts/deploy_to_hf_space.py
@@ -40,5 +49,9 @@ jobs:
if: success()
run: |
echo "β
Deployment completed successfully!"
+<<<<<<< HEAD
+ echo "Space URL: https://huggingface.co/spaces/${{ vars.HF_USERNAME }}/${{ vars.HF_SPACE_NAME }}"
+=======
echo "Space URL: https://huggingface.co/spaces/${{ secrets.HF_USERNAME }}/${{ secrets.HF_SPACE_NAME }}"
+>>>>>>> origin/dev
diff --git a/.gitignore b/.gitignore
index 84c2e770..9a982ecb 100644
--- a/.gitignore
+++ b/.gitignore
@@ -57,6 +57,9 @@ reference_repos/DeepCritical/
# Keep the README in reference_repos
!reference_repos/README.md
+# Development directory
+dev/
+
# OS
.DS_Store
Thumbs.db
@@ -70,6 +73,7 @@ logs/
.mypy_cache/
.coverage
htmlcov/
+test_output*.txt
# Database files
chroma_db/
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 12a77673..8b1184f5 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -14,7 +14,7 @@ repos:
hooks:
- id: mypy
files: ^src/
- exclude: ^folder
+ exclude: ^folder|^src/app.py
additional_dependencies:
- pydantic>=2.7
- pydantic-settings>=2.2
diff --git a/README.md b/README.md
index aedb9d90..a4c0e90e 100644
--- a/README.md
+++ b/README.md
@@ -10,7 +10,14 @@ app_file: src/app.py
hf_oauth: true
hf_oauth_expiration_minutes: 480
hf_oauth_scopes:
- - inference-api
+ # Required for HuggingFace Inference API (includes all third-party providers)
+ # This scope grants access to:
+ # - HuggingFace's own Inference API
+ # - Third-party inference providers (nebius, together, scaleway, hyperbolic, novita, nscale, sambanova, ovh, fireworks, etc.)
+ # - All models available through the Inference Providers API
+ - inference-api
+ # Optional: Uncomment if you need to access user's billing information
+ # - read-billing
pinned: true
license: mit
tags:
diff --git a/deployments/README.md b/deployments/README.md
new file mode 100644
index 00000000..3a4f4a4d
--- /dev/null
+++ b/deployments/README.md
@@ -0,0 +1,46 @@
+# Deployments
+
+This directory contains infrastructure deployment scripts for DeepCritical services.
+
+## Modal Deployments
+
+### TTS Service (`modal_tts.py`)
+
+Deploys the Kokoro TTS (Text-to-Speech) function to Modal's GPU infrastructure.
+
+**Deploy:**
+```bash
+modal deploy deployments/modal_tts.py
+```
+
+**Features:**
+- Kokoro 82M TTS model
+- GPU-accelerated (T4)
+- Voice options: af_heart, af_bella, am_michael, etc.
+- Configurable speech speed
+
+**Requirements:**
+- Modal account and credentials (`MODAL_TOKEN_ID`, `MODAL_TOKEN_SECRET` in `.env`)
+- GPU quota on Modal
+
+**After Deployment:**
+The function will be available at:
+- App: `deepcritical-tts`
+- Function: `kokoro_tts_function`
+
+The main application (`src/services/tts_modal.py`) will call this deployed function.
+
+---
+
+## Adding New Deployments
+
+When adding new deployment scripts:
+
+1. Create a new file: `deployments/.py`
+2. Use Modal's app pattern:
+ ```python
+ import modal
+ app = modal.App("deepcritical-")
+ ```
+3. Document in this README
+4. Test deployment: `modal deploy deployments/.py`
diff --git a/deployments/modal_tts.py b/deployments/modal_tts.py
new file mode 100644
index 00000000..9987a339
--- /dev/null
+++ b/deployments/modal_tts.py
@@ -0,0 +1,97 @@
+"""Deploy Kokoro TTS function to Modal.
+
+This script deploys the TTS function to Modal so it can be called
+from the main DeepCritical application.
+
+Usage:
+ modal deploy deploy_modal_tts.py
+
+After deployment, the function will be available at:
+ App: deepcritical-tts
+ Function: kokoro_tts_function
+"""
+
+import modal
+import numpy as np
+
+# Create Modal app
+app = modal.App("deepcritical-tts")
+
+# Define Kokoro TTS dependencies
+KOKORO_DEPENDENCIES = [
+ "torch>=2.0.0",
+ "transformers>=4.30.0",
+ "numpy<2.0",
+]
+
+# Create Modal image with Kokoro
+tts_image = (
+ modal.Image.debian_slim(python_version="3.11")
+ .apt_install("git") # Install git first for pip install from github
+ .pip_install(*KOKORO_DEPENDENCIES)
+ .pip_install("git+https://github.com/hexgrad/kokoro.git")
+)
+
+
+@app.function(
+ image=tts_image,
+ gpu="T4",
+ timeout=60,
+)
+def kokoro_tts_function(text: str, voice: str, speed: float) -> tuple[int, np.ndarray]:
+ """Modal GPU function for Kokoro TTS.
+
+ This function runs on Modal's GPU infrastructure.
+ Based on: https://huggingface.co/spaces/hexgrad/Kokoro-TTS
+
+ Args:
+ text: Text to synthesize
+ voice: Voice ID (e.g., af_heart, af_bella, am_michael)
+ speed: Speech speed multiplier (0.5-2.0)
+
+ Returns:
+ Tuple of (sample_rate, audio_array)
+ """
+ import numpy as np
+
+ try:
+ import torch
+ from kokoro import KModel, KPipeline
+
+ # Initialize model (cached on GPU)
+ model = KModel().to("cuda").eval()
+ pipeline = KPipeline(lang_code=voice[0])
+ pack = pipeline.load_voice(voice)
+
+ # Generate audio - accumulate all chunks
+ audio_chunks = []
+ for _, ps, _ in pipeline(text, voice, speed):
+ ref_s = pack[len(ps) - 1]
+ audio = model(ps, ref_s, speed)
+ audio_chunks.append(audio.numpy())
+
+ # Concatenate all audio chunks
+ if audio_chunks:
+ full_audio = np.concatenate(audio_chunks)
+ return (24000, full_audio)
+
+ # If no audio generated, return empty
+ return (24000, np.zeros(1, dtype=np.float32))
+
+ except ImportError as e:
+ raise RuntimeError(
+ f"Kokoro not installed: {e}. "
+ "Install with: pip install git+https://github.com/hexgrad/kokoro.git"
+ ) from e
+ except Exception as e:
+ raise RuntimeError(f"TTS synthesis failed: {e}") from e
+
+
+# Optional: Add a test entrypoint
+@app.local_entrypoint()
+def test():
+ """Test the TTS function."""
+ print("Testing Modal TTS function...")
+ sample_rate, audio = kokoro_tts_function.remote("Hello, this is a test.", "af_heart", 1.0)
+ print(f"Generated audio: {sample_rate}Hz, shape={audio.shape}")
+ print("β TTS function works!")
diff --git a/docs/LICENSE.md b/docs/LICENSE.md
index 7a0c1fad..f44ce878 100644
--- a/docs/LICENSE.md
+++ b/docs/LICENSE.md
@@ -25,3 +25,14 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
+<<<<<<< HEAD
+
+
+
+
+
+
+
+
+=======
+>>>>>>> origin/dev
diff --git a/docs/api/orchestrators.md b/docs/api/orchestrators.md
index 9e3d22df..1f157036 100644
--- a/docs/api/orchestrators.md
+++ b/docs/api/orchestrators.md
@@ -23,9 +23,18 @@ Runs iterative research flow.
- `background_context`: Background context (default: "")
- `output_length`: Optional description of desired output length (default: "")
- `output_instructions`: Optional additional instructions for report generation (default: "")
+<<<<<<< HEAD
+- `message_history`: Optional user conversation history in Pydantic AI `ModelMessage` format (default: None)
**Returns**: Final report string.
+**Note**: The `message_history` parameter enables multi-turn conversations by providing context from previous interactions.
+
+=======
+
+**Returns**: Final report string.
+
+>>>>>>> origin/dev
**Note**: `max_iterations`, `max_time_minutes`, and `token_budget` are constructor parameters, not `run()` parameters.
## DeepResearchFlow
@@ -46,9 +55,18 @@ Runs deep research flow.
**Parameters**:
- `query`: Research query string
+<<<<<<< HEAD
+- `message_history`: Optional user conversation history in Pydantic AI `ModelMessage` format (default: None)
**Returns**: Final report string.
+**Note**: The `message_history` parameter enables multi-turn conversations by providing context from previous interactions.
+
+=======
+
+**Returns**: Final report string.
+
+>>>>>>> origin/dev
**Note**: `max_iterations_per_section`, `max_time_minutes`, and `token_budget` are constructor parameters, not `run()` parameters.
## GraphOrchestrator
@@ -69,10 +87,20 @@ Runs graph-based research orchestration.
**Parameters**:
- `query`: Research query string
+<<<<<<< HEAD
+- `message_history`: Optional user conversation history in Pydantic AI `ModelMessage` format (default: None)
+
+**Yields**: `AgentEvent` objects during graph execution.
+
+**Note**:
+- `research_mode` and `use_graph` are constructor parameters, not `run()` parameters.
+- The `message_history` parameter enables multi-turn conversations by providing context from previous interactions. Message history is stored in `GraphExecutionContext` and passed to agents during execution.
+=======
**Yields**: `AgentEvent` objects during graph execution.
**Note**: `research_mode` and `use_graph` are constructor parameters, not `run()` parameters.
+>>>>>>> origin/dev
## Orchestrator Factory
diff --git a/docs/architecture/graph_orchestration.md b/docs/architecture/graph_orchestration.md
index cf8c4dbd..5cb48dd9 100644
--- a/docs/architecture/graph_orchestration.md
+++ b/docs/architecture/graph_orchestration.md
@@ -4,6 +4,47 @@
DeepCritical implements a graph-based orchestration system for research workflows using Pydantic AI agents as nodes. This enables better parallel execution, conditional routing, and state management compared to simple agent chains.
+<<<<<<< HEAD
+## Conversation History
+
+DeepCritical supports multi-turn conversations through Pydantic AI's native message history format. The system maintains two types of history:
+
+1. **User Conversation History**: Multi-turn user interactions (from Gradio chat interface) stored as `list[ModelMessage]`
+2. **Research Iteration History**: Internal research process state (existing `Conversation` model)
+
+### Message History Flow
+
+```
+Gradio Chat History β convert_gradio_to_message_history() β GraphOrchestrator.run(message_history)
+ β
+GraphExecutionContext (stores message_history)
+ β
+Agent Nodes (receive message_history via agent.run())
+ β
+WorkflowState (persists user_message_history)
+```
+
+### Usage
+
+Message history is automatically converted from Gradio format and passed through the orchestrator:
+
+```python
+# In app.py - automatic conversion
+message_history = convert_gradio_to_message_history(history) if history else None
+async for event in orchestrator.run(query, message_history=message_history):
+ yield event
+```
+
+Agents receive message history through their `run()` methods:
+
+```python
+# In agent execution
+if message_history:
+ result = await agent.run(input_data, message_history=message_history)
+```
+
+=======
+>>>>>>> origin/dev
## Graph Patterns
### Iterative Research Graph
diff --git a/pyproject.toml b/pyproject.toml
index 6bef4f10..6f59cac5 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -127,6 +127,7 @@ ignore = [
"PLR0913", # Too many arguments (agents need many params)
"PLR0912", # Too many branches (complex orchestrator logic)
"PLR0911", # Too many return statements (complex agent logic)
+ "PLR0915", # Too many statements (Gradio UI setup functions)
"PLR2004", # Magic values (statistical constants like p-values)
"PLW0603", # Global statement (singleton pattern for Modal)
"PLC0415", # Lazy imports for optional dependencies
@@ -152,6 +153,7 @@ exclude = [
"^reference_repos/",
"^examples/",
"^folder/",
+ "^src/app.py",
]
# ============== PYTEST CONFIG ==============
diff --git a/src/agent_factory/agents.py b/src/agent_factory/agents.py
index 69d01380..c676140e 100644
--- a/src/agent_factory/agents.py
+++ b/src/agent_factory/agents.py
@@ -27,7 +27,9 @@
logger = structlog.get_logger()
-def create_input_parser_agent(model: Any | None = None, oauth_token: str | None = None) -> "InputParserAgent":
+def create_input_parser_agent(
+ model: Any | None = None, oauth_token: str | None = None
+) -> "InputParserAgent":
"""
Create input parser agent for query analysis and research mode detection.
@@ -51,7 +53,9 @@ def create_input_parser_agent(model: Any | None = None, oauth_token: str | None
raise ConfigurationError(f"Failed to create input parser agent: {e}") from e
-def create_planner_agent(model: Any | None = None, oauth_token: str | None = None) -> "PlannerAgent":
+def create_planner_agent(
+ model: Any | None = None, oauth_token: str | None = None
+) -> "PlannerAgent":
"""
Create planner agent with web search and crawl tools.
@@ -76,7 +80,9 @@ def create_planner_agent(model: Any | None = None, oauth_token: str | None = Non
raise ConfigurationError(f"Failed to create planner agent: {e}") from e
-def create_knowledge_gap_agent(model: Any | None = None, oauth_token: str | None = None) -> "KnowledgeGapAgent":
+def create_knowledge_gap_agent(
+ model: Any | None = None, oauth_token: str | None = None
+) -> "KnowledgeGapAgent":
"""
Create knowledge gap agent for evaluating research completeness.
@@ -100,7 +106,9 @@ def create_knowledge_gap_agent(model: Any | None = None, oauth_token: str | None
raise ConfigurationError(f"Failed to create knowledge gap agent: {e}") from e
-def create_tool_selector_agent(model: Any | None = None, oauth_token: str | None = None) -> "ToolSelectorAgent":
+def create_tool_selector_agent(
+ model: Any | None = None, oauth_token: str | None = None
+) -> "ToolSelectorAgent":
"""
Create tool selector agent for choosing tools to address gaps.
@@ -124,7 +132,9 @@ def create_tool_selector_agent(model: Any | None = None, oauth_token: str | None
raise ConfigurationError(f"Failed to create tool selector agent: {e}") from e
-def create_thinking_agent(model: Any | None = None, oauth_token: str | None = None) -> "ThinkingAgent":
+def create_thinking_agent(
+ model: Any | None = None, oauth_token: str | None = None
+) -> "ThinkingAgent":
"""
Create thinking agent for generating observations.
@@ -172,7 +182,9 @@ def create_writer_agent(model: Any | None = None, oauth_token: str | None = None
raise ConfigurationError(f"Failed to create writer agent: {e}") from e
-def create_long_writer_agent(model: Any | None = None, oauth_token: str | None = None) -> "LongWriterAgent":
+def create_long_writer_agent(
+ model: Any | None = None, oauth_token: str | None = None
+) -> "LongWriterAgent":
"""
Create long writer agent for iteratively writing report sections.
@@ -196,7 +208,9 @@ def create_long_writer_agent(model: Any | None = None, oauth_token: str | None =
raise ConfigurationError(f"Failed to create long writer agent: {e}") from e
-def create_proofreader_agent(model: Any | None = None, oauth_token: str | None = None) -> "ProofreaderAgent":
+def create_proofreader_agent(
+ model: Any | None = None, oauth_token: str | None = None
+) -> "ProofreaderAgent":
"""
Create proofreader agent for finalizing report drafts.
diff --git a/src/agent_factory/graph_builder.py b/src/agent_factory/graph_builder.py
index 0c03b89c..c06c1b0a 100644
--- a/src/agent_factory/graph_builder.py
+++ b/src/agent_factory/graph_builder.py
@@ -487,12 +487,13 @@ def create_iterative_graph(
# Add nodes
builder.add_agent_node("thinking", thinking_agent, "Generate observations")
builder.add_agent_node("knowledge_gap", knowledge_gap_agent, "Evaluate knowledge gaps")
+
def _decision_function(result: Any) -> str:
"""Decision function for continue_decision node.
-
+
Args:
result: Result from knowledge_gap node (KnowledgeGapOutput or tuple)
-
+
Returns:
Next node ID: "writer" if research complete, "tool_selector" otherwise
"""
@@ -510,11 +511,11 @@ def _decision_function(result: Any) -> str:
return "writer" if item["research_complete"] else "tool_selector"
# Default to continuing research if we can't determine
return "tool_selector"
-
+
# Normal case: result is KnowledgeGapOutput object
research_complete = getattr(result, "research_complete", False)
return "writer" if research_complete else "tool_selector"
-
+
builder.add_decision_node(
"continue_decision",
decision_function=_decision_function,
diff --git a/src/agent_factory/judges.py b/src/agent_factory/judges.py
index 92fe3b1e..59ccdc08 100644
--- a/src/agent_factory/judges.py
+++ b/src/agent_factory/judges.py
@@ -37,7 +37,7 @@ def get_model(oauth_token: str | None = None) -> Any:
1. HuggingFace (if OAuth token or API key available - preferred for free tier)
2. OpenAI (if API key available)
3. Anthropic (if API key available)
-
+
If OAuth token is available, prefer HuggingFace (even if provider is set to OpenAI).
This ensures users logged in via HuggingFace Spaces get the free tier.
@@ -50,9 +50,23 @@ def get_model(oauth_token: str | None = None) -> Any:
Raises:
ConfigurationError: If no LLM provider is available
"""
+ from src.utils.hf_error_handler import log_token_info, validate_hf_token
+
# Priority: oauth_token > settings.hf_token > settings.huggingface_api_key
effective_hf_token = oauth_token or settings.hf_token or settings.huggingface_api_key
+ # Validate and log token information
+ if effective_hf_token:
+ log_token_info(effective_hf_token, context="get_model")
+ is_valid, error_msg = validate_hf_token(effective_hf_token)
+ if not is_valid:
+ logger.warning(
+ "Token validation failed",
+ error=error_msg,
+ has_oauth=bool(oauth_token),
+ )
+ # Continue anyway - let the API call fail with a clear error
+
# Try HuggingFace first (preferred for free tier)
if effective_hf_token:
model_name = settings.huggingface_model or "meta-llama/Llama-3.1-8B-Instruct"
@@ -157,7 +171,27 @@ async def assess(
return assessment
except Exception as e:
- logger.error("Assessment failed", error=str(e))
+ # Extract error details for better logging and handling
+ from src.utils.hf_error_handler import (
+ extract_error_details,
+ get_user_friendly_error_message,
+ )
+
+ error_details = extract_error_details(e)
+ logger.error(
+ "Assessment failed",
+ error=str(e),
+ status_code=error_details.get("status_code"),
+ model_name=error_details.get("model_name"),
+ is_auth_error=error_details.get("is_auth_error"),
+ is_model_error=error_details.get("is_model_error"),
+ )
+
+ # Log user-friendly message for debugging
+ if error_details.get("is_auth_error") or error_details.get("is_model_error"):
+ user_msg = get_user_friendly_error_message(e, error_details.get("model_name"))
+ logger.warning("API error details", user_message=user_msg[:200])
+
# Return a safe default assessment on failure
return self._create_fallback_assessment(question, str(e))
@@ -209,9 +243,7 @@ class HFInferenceJudgeHandler:
"HuggingFaceH4/zephyr-7b-beta", # Fallback (Ungated)
]
- def __init__(
- self, model_id: str | None = None, api_key: str | None = None
- ) -> None:
+ def __init__(self, model_id: str | None = None, api_key: str | None = None) -> None:
"""
Initialize with HF Inference client.
diff --git a/src/agents/audio_refiner.py b/src/agents/audio_refiner.py
new file mode 100644
index 00000000..257c6f5e
--- /dev/null
+++ b/src/agents/audio_refiner.py
@@ -0,0 +1,402 @@
+"""Audio Refiner Agent - Cleans markdown reports for TTS audio clarity.
+
+This agent transforms markdown-formatted research reports into clean,
+audio-friendly plain text suitable for text-to-speech synthesis.
+"""
+
+import re
+
+import structlog
+from pydantic_ai import Agent
+
+from src.utils.llm_factory import get_pydantic_ai_model
+
+logger = structlog.get_logger(__name__)
+
+
+class AudioRefiner:
+ """Refines markdown reports for optimal TTS audio output.
+
+ Handles common formatting issues that make text difficult to listen to:
+ - Markdown syntax (headers, bold, italic, links)
+ - Citations and reference markers
+ - Roman numerals in medical contexts
+ - Multiple References sections
+ - Special characters and formatting artifacts
+ """
+
+ # Roman numeral to integer mapping
+ ROMAN_VALUES = {"I": 1, "V": 5, "X": 10, "L": 50, "C": 100, "D": 500, "M": 1000}
+
+ # Number to word mapping (1-20, common in medical literature)
+ NUMBER_TO_WORD = {
+ 1: "One",
+ 2: "Two",
+ 3: "Three",
+ 4: "Four",
+ 5: "Five",
+ 6: "Six",
+ 7: "Seven",
+ 8: "Eight",
+ 9: "Nine",
+ 10: "Ten",
+ 11: "Eleven",
+ 12: "Twelve",
+ 13: "Thirteen",
+ 14: "Fourteen",
+ 15: "Fifteen",
+ 16: "Sixteen",
+ 17: "Seventeen",
+ 18: "Eighteen",
+ 19: "Nineteen",
+ 20: "Twenty",
+ }
+
+ async def refine_for_audio(self, markdown_text: str, use_llm_polish: bool = False) -> str:
+ """Transform markdown report into audio-friendly plain text.
+
+ Args:
+ markdown_text: Markdown-formatted research report
+ use_llm_polish: If True, apply LLM-based final polish (optional)
+
+ Returns:
+ Clean plain text optimized for TTS audio
+ """
+ logger.info("Refining report for audio output", use_llm_polish=use_llm_polish)
+
+ text = markdown_text
+
+ # Step 1: Remove References sections first (before other processing)
+ text = self._remove_references_sections(text)
+
+ # Step 2: Remove markdown formatting
+ text = self._remove_markdown_syntax(text)
+
+ # Step 3: Convert roman numerals to words
+ text = self._convert_roman_numerals(text)
+
+ # Step 4: Remove citations
+ text = self._remove_citations(text)
+
+ # Step 5: Clean up special characters and artifacts
+ text = self._clean_special_characters(text)
+
+ # Step 6: Normalize whitespace
+ text = self._normalize_whitespace(text)
+
+ # Step 7 (Optional): LLM polish for edge cases
+ if use_llm_polish:
+ text = await self._llm_polish(text)
+
+ logger.info(
+ "Audio refinement complete",
+ original_length=len(markdown_text),
+ refined_length=len(text),
+ llm_polish_applied=use_llm_polish,
+ )
+
+ return text.strip()
+
+ def _remove_references_sections(self, text: str) -> str:
+ """Remove References sections while preserving other content.
+
+ Removes the References section and its content until the next section
+ heading or end of document. Handles multiple References sections.
+
+ Matches various References heading formats:
+ - # References
+ - ## References
+ - **References:**
+ - **Additional References:**
+ - References: (plain text)
+ """
+ # Pattern to match References section heading (case-insensitive)
+ # Matches: markdown headers (# References), bold (**References:**), or plain text (References:)
+ references_pattern = r"\n(?:#+\s*References?:?\s*\n|\*\*\s*(?:Additional\s+)?References?:?\s*\*\*\s*\n|References?:?\s*\n)"
+
+ # Find all References sections
+ while True:
+ match = re.search(references_pattern, text, re.IGNORECASE)
+ if not match:
+ break
+
+ # Find the start of the References section
+ section_start = match.start()
+
+ # Find the next section (markdown header or bold heading) or end of document
+ # Match: "# Header", "## Header", or "**Header**"
+ next_section_patterns = [
+ r"\n#+\s+\w+", # Markdown headers (# Section, ## Section)
+ r"\n\*\*[A-Z][^*]+\*\*", # Bold headings (**Section Name**)
+ ]
+
+ remaining_text = text[match.end() :]
+ next_section_match = None
+
+ # Try all patterns and find the earliest match
+ earliest_match = None
+ for pattern in next_section_patterns:
+ m = re.search(pattern, remaining_text)
+ if m and (earliest_match is None or m.start() < earliest_match.start()):
+ earliest_match = m
+
+ next_section_match = earliest_match
+
+ if next_section_match:
+ # Remove from References heading to next section
+ section_end = match.end() + next_section_match.start()
+ else:
+ # No next section - remove to end of document
+ section_end = len(text)
+
+ # Remove the References section
+ text = text[:section_start] + text[section_end:]
+ logger.debug("Removed References section", removed_chars=section_end - section_start)
+
+ return text
+
+ def _remove_markdown_syntax(self, text: str) -> str:
+ """Remove markdown formatting syntax."""
+
+ # Headers (# ## ###)
+ text = re.sub(r"^\s*#+\s+", "", text, flags=re.MULTILINE)
+
+ # Bold (**text** or __text__)
+ text = re.sub(r"\*\*([^*]+)\*\*", r"\1", text)
+ text = re.sub(r"__([^_]+)__", r"\1", text)
+
+ # Italic (*text* or _text_)
+ text = re.sub(r"\*([^*]+)\*", r"\1", text)
+ text = re.sub(r"_([^_]+)_", r"\1", text)
+
+ # Links [text](url) β text
+ text = re.sub(r"\[([^\]]+)\]\([^)]+\)", r"\1", text)
+
+ # Inline code `code` β code
+ text = re.sub(r"`([^`]+)`", r"\1", text)
+
+ # Strikethrough ~~text~~
+ text = re.sub(r"~~([^~]+)~~", r"\1", text)
+
+ # Blockquotes (> text)
+ text = re.sub(r"^\s*>\s+", "", text, flags=re.MULTILINE)
+
+ # Horizontal rules (---, ***, ___)
+ text = re.sub(r"^\s*[-*_]{3,}\s*$", "", text, flags=re.MULTILINE)
+
+ # List markers (-, *, 1., 2.)
+ text = re.sub(r"^\s*[-*]\s+", "", text, flags=re.MULTILINE)
+ text = re.sub(r"^\s*\d+\.\s+", "", text, flags=re.MULTILINE)
+
+ return text
+
+ def _roman_to_int(self, roman: str) -> int | None:
+ """Convert roman numeral string to integer.
+
+ Args:
+ roman: Roman numeral string (e.g., 'IV', 'XII')
+
+ Returns:
+ Integer value, or None if invalid roman numeral
+ """
+ roman = roman.upper()
+ result = 0
+ prev_value = 0
+
+ for char in reversed(roman):
+ if char not in self.ROMAN_VALUES:
+ return None
+
+ value = self.ROMAN_VALUES[char]
+
+ # Subtractive notation (IV = 4, IX = 9)
+ if value < prev_value:
+ result -= value
+ else:
+ result += value
+
+ prev_value = value
+
+ return result
+
+ def _int_to_word(self, num: int) -> str:
+ """Convert integer to word representation.
+
+ Args:
+ num: Integer to convert (1-20 supported)
+
+ Returns:
+ Word representation (e.g., 'One', 'Twelve')
+ """
+ if num in self.NUMBER_TO_WORD:
+ return self.NUMBER_TO_WORD[num]
+ else:
+ # For numbers > 20, just return the digit
+ return str(num)
+
+ def _convert_roman_numerals(self, text: str) -> str:
+ """Convert roman numerals to words for better TTS pronunciation.
+
+ Handles patterns like:
+ - Phase I, Phase II, Phase III
+ - Trial I, Trial II
+ - Type I, Type II
+ - Stage I, Stage II
+ - Standalone I, II, III (with word boundaries)
+ """
+
+ def replace_roman(match: re.Match[str]) -> str:
+ """Callback to replace matched roman numeral."""
+ prefix = match.group(1) # Word before roman numeral (if any)
+ roman = match.group(2) # The roman numeral
+
+ # Convert to integer
+ num = self._roman_to_int(roman)
+ if num is None:
+ return match.group(0) # Return original if invalid
+
+ # Convert to word
+ word = self._int_to_word(num)
+
+ # Return with prefix if present
+ if prefix:
+ return f"{prefix} {word}"
+ else:
+ return word
+
+ # Pattern: Optional word + space + roman numeral
+ # Matches: "Phase I", "Trial II", standalone "I", "II"
+ # Uses word boundaries to avoid matching "I" in "INVALID"
+ pattern = r"\b(Phase|Trial|Type|Stage|Class|Group|Arm|Cohort)?\s*([IVXLCDM]+)\b"
+
+ text = re.sub(pattern, replace_roman, text)
+
+ return text
+
+ def _remove_citations(self, text: str) -> str:
+ """Remove citation markers and references."""
+
+ # Numbered citations [1], [2], [1,2], [1-3]
+ text = re.sub(r"\[\d+(?:[-,]\d+)*\]", "", text)
+
+ # Author citations (Smith et al., 2023) or (Smith et al. 2023)
+ text = re.sub(r"\([A-Z][a-z]+\s+et\s+al\.?,?\s+\d{4}\)", "", text)
+
+ # Simple year citations (2023)
+ text = re.sub(r"\(\d{4}\)", "", text)
+
+ # Author-year (Smith, 2023)
+ text = re.sub(r"\([A-Z][a-z]+,?\s+\d{4}\)", "", text)
+
+ # Footnote markers (ΒΉ, Β², Β³)
+ text = re.sub(r"[ΒΉΒ²Β³β΄β΅βΆβ·βΈβΉβ°]+", "", text)
+
+ return text
+
+ def _clean_special_characters(self, text: str) -> str:
+ """Clean up special characters and formatting artifacts."""
+
+ # Replace em dashes with regular dashes
+ text = text.replace("\u2014", "-") # em dash
+ text = text.replace("\u2013", "-") # en dash
+
+ # Replace smart quotes with regular quotes
+ text = text.replace("\u201c", '"') # left double quote
+ text = text.replace("\u201d", '"') # right double quote
+ text = text.replace("\u2018", "'") # left single quote
+ text = text.replace("\u2019", "'") # right single quote
+
+ # Remove excessive punctuation (!!!, ???)
+ text = re.sub(r"([!?]){2,}", r"\1", text)
+
+ # Remove asterisks used for footnotes
+ text = re.sub(r"\*+", "", text)
+
+ # Remove hash symbols (from headers)
+ text = text.replace("#", "")
+
+ # Remove excessive dots (...)
+ text = re.sub(r"\.{4,}", "...", text)
+
+ return text
+
+ def _normalize_whitespace(self, text: str) -> str:
+ """Normalize whitespace for clean audio output."""
+
+ # Replace multiple spaces with single space
+ text = re.sub(r" {2,}", " ", text)
+
+ # Replace multiple newlines with double newline (paragraph break)
+ text = re.sub(r"\n{3,}", "\n\n", text)
+
+ # Remove trailing/leading whitespace from lines
+ text = "\n".join(line.strip() for line in text.split("\n"))
+
+ # Remove empty lines at start/end
+ text = text.strip()
+
+ return text
+
+ async def _llm_polish(self, text: str) -> str:
+ """Apply LLM-based final polish to catch edge cases.
+
+ This is a lightweight pass that removes any remaining formatting
+ artifacts the rule-based methods might have missed.
+
+ Args:
+ text: Pre-cleaned text from rule-based methods
+
+ Returns:
+ Final polished text ready for TTS
+ """
+ try:
+ # Create a simple agent for text cleanup
+ model = get_pydantic_ai_model()
+ polish_agent = Agent(
+ model=model,
+ system_prompt=(
+ "You are a text cleanup assistant. Your ONLY job is to remove "
+ "any remaining formatting artifacts (markdown, citations, special "
+ "characters) that make text unsuitable for text-to-speech audio. "
+ "DO NOT rewrite, improve, or change the content. "
+ "DO NOT add explanations. "
+ "ONLY output the cleaned text."
+ ),
+ )
+
+ # Run asynchronously
+ result = await polish_agent.run(
+ f"Clean this text for audio (remove any formatting artifacts):\n\n{text}"
+ )
+
+ polished_text = result.output.strip()
+
+ logger.info(
+ "llm_polish_applied", original_length=len(text), polished_length=len(polished_text)
+ )
+
+ return polished_text
+
+ except Exception as e:
+ logger.warning(
+ "llm_polish_failed", error=str(e), message="Falling back to rule-based output"
+ )
+ # Graceful fallback: return original text if LLM fails
+ return text
+
+
+# Singleton instance for easy import
+audio_refiner = AudioRefiner()
+
+
+async def refine_text_for_audio(markdown_text: str, use_llm_polish: bool = False) -> str:
+ """Convenience function to refine markdown text for audio.
+
+ Args:
+ markdown_text: Markdown-formatted text
+ use_llm_polish: If True, apply LLM-based final polish (optional)
+
+ Returns:
+ Audio-friendly plain text
+ """
+ return await audio_refiner.refine_for_audio(markdown_text, use_llm_polish=use_llm_polish)
diff --git a/src/agents/knowledge_gap.py b/src/agents/knowledge_gap.py
index 92114eaf..4ceab6cc 100644
--- a/src/agents/knowledge_gap.py
+++ b/src/agents/knowledge_gap.py
@@ -9,6 +9,11 @@
import structlog
from pydantic_ai import Agent
+try:
+ from pydantic_ai import ModelMessage
+except ImportError:
+ ModelMessage = Any # type: ignore[assignment, misc]
+
from src.agent_factory.judges import get_model
from src.utils.exceptions import ConfigurationError
from src.utils.models import KnowledgeGapOutput
@@ -68,6 +73,7 @@ async def evaluate(
query: str,
background_context: str = "",
conversation_history: str = "",
+ message_history: list[ModelMessage] | None = None,
iteration: int = 0,
time_elapsed_minutes: float = 0.0,
max_time_minutes: int = 10,
@@ -78,7 +84,8 @@ async def evaluate(
Args:
query: The original research query
background_context: Optional background context
- conversation_history: History of actions, findings, and thoughts
+ conversation_history: History of actions, findings, and thoughts (backward compat)
+ message_history: Optional user conversation history (Pydantic AI format)
iteration: Current iteration number
time_elapsed_minutes: Time elapsed so far
max_time_minutes: Maximum time allowed
@@ -111,8 +118,11 @@ async def evaluate(
"""
try:
- # Run the agent
- result = await self.agent.run(user_message)
+ # Run the agent with message_history if provided
+ if message_history:
+ result = await self.agent.run(user_message, message_history=message_history)
+ else:
+ result = await self.agent.run(user_message)
evaluation = result.output
self.logger.info(
@@ -132,7 +142,9 @@ async def evaluate(
)
-def create_knowledge_gap_agent(model: Any | None = None, oauth_token: str | None = None) -> KnowledgeGapAgent:
+def create_knowledge_gap_agent(
+ model: Any | None = None, oauth_token: str | None = None
+) -> KnowledgeGapAgent:
"""
Factory function to create a knowledge gap agent.
diff --git a/src/agents/long_writer.py b/src/agents/long_writer.py
index 0e5f0804..5b0553aa 100644
--- a/src/agents/long_writer.py
+++ b/src/agents/long_writer.py
@@ -225,25 +225,27 @@ async def write_next_section(
"Section writing failed after all attempts",
error=str(last_exception) if last_exception else "Unknown error",
)
-
+
# Try to enhance fallback with evidence if available
try:
from src.middleware.state_machine import get_workflow_state
-
+
state = get_workflow_state()
if state and state.evidence:
# Include evidence citations in fallback
evidence_refs: list[str] = []
for i, ev in enumerate(state.evidence[:10], 1): # Limit to 10
- authors = ", ".join(ev.citation.authors[:2]) if ev.citation.authors else "Unknown"
+ authors = (
+ ", ".join(ev.citation.authors[:2]) if ev.citation.authors else "Unknown"
+ )
evidence_refs.append(
f"[{i}] {authors}. *{ev.citation.title}*. {ev.citation.url}"
)
-
+
enhanced_draft = f"## {next_section_title}\n\n{next_section_draft}"
if evidence_refs:
enhanced_draft += "\n\n### Sources\n\n" + "\n".join(evidence_refs)
-
+
return LongWriterOutput(
next_section_markdown=enhanced_draft,
references=evidence_refs,
@@ -253,7 +255,7 @@ async def write_next_section(
"Failed to enhance fallback with evidence",
error=str(e),
)
-
+
# Basic fallback
return LongWriterOutput(
next_section_markdown=f"## {next_section_title}\n\n{next_section_draft}",
@@ -437,7 +439,9 @@ def adjust_heading_level(match: re.Match[str]) -> str:
return re.sub(r"^(#+)\s(.+)$", adjust_heading_level, section_markdown, flags=re.MULTILINE)
-def create_long_writer_agent(model: Any | None = None, oauth_token: str | None = None) -> LongWriterAgent:
+def create_long_writer_agent(
+ model: Any | None = None, oauth_token: str | None = None
+) -> LongWriterAgent:
"""
Factory function to create a long writer agent.
diff --git a/src/agents/proofreader.py b/src/agents/proofreader.py
index 5fdf15e3..7a209c36 100644
--- a/src/agents/proofreader.py
+++ b/src/agents/proofreader.py
@@ -181,7 +181,9 @@ async def proofread(
return f"# Research Report\n\n## Query\n{query}\n\n" + "\n\n".join(sections)
-def create_proofreader_agent(model: Any | None = None, oauth_token: str | None = None) -> ProofreaderAgent:
+def create_proofreader_agent(
+ model: Any | None = None, oauth_token: str | None = None
+) -> ProofreaderAgent:
"""
Factory function to create a proofreader agent.
diff --git a/src/agents/thinking.py b/src/agents/thinking.py
index 225543b7..eff08e51 100644
--- a/src/agents/thinking.py
+++ b/src/agents/thinking.py
@@ -9,6 +9,11 @@
import structlog
from pydantic_ai import Agent
+try:
+ from pydantic_ai import ModelMessage
+except ImportError:
+ ModelMessage = Any # type: ignore[assignment, misc]
+
from src.agent_factory.judges import get_model
from src.utils.exceptions import ConfigurationError
@@ -72,6 +77,7 @@ async def generate_observations(
query: str,
background_context: str = "",
conversation_history: str = "",
+ message_history: list[ModelMessage] | None = None,
iteration: int = 1,
) -> str:
"""
@@ -80,7 +86,8 @@ async def generate_observations(
Args:
query: The original research query
background_context: Optional background context
- conversation_history: History of actions, findings, and thoughts
+ conversation_history: History of actions, findings, and thoughts (backward compat)
+ message_history: Optional user conversation history (Pydantic AI format)
iteration: Current iteration number
Returns:
@@ -110,8 +117,11 @@ async def generate_observations(
"""
try:
- # Run the agent
- result = await self.agent.run(user_message)
+ # Run the agent with message_history if provided
+ if message_history:
+ result = await self.agent.run(user_message, message_history=message_history)
+ else:
+ result = await self.agent.run(user_message)
observations = result.output
self.logger.info("Observations generated", length=len(observations))
@@ -124,7 +134,9 @@ async def generate_observations(
return f"Starting iteration {iteration}. Need to gather information about: {query}"
-def create_thinking_agent(model: Any | None = None, oauth_token: str | None = None) -> ThinkingAgent:
+def create_thinking_agent(
+ model: Any | None = None, oauth_token: str | None = None
+) -> ThinkingAgent:
"""
Factory function to create a thinking agent.
diff --git a/src/agents/tool_selector.py b/src/agents/tool_selector.py
index 3f06fe92..0da35f43 100644
--- a/src/agents/tool_selector.py
+++ b/src/agents/tool_selector.py
@@ -9,6 +9,11 @@
import structlog
from pydantic_ai import Agent
+try:
+ from pydantic_ai import ModelMessage
+except ImportError:
+ ModelMessage = Any # type: ignore[assignment, misc]
+
from src.agent_factory.judges import get_model
from src.utils.exceptions import ConfigurationError
from src.utils.models import AgentSelectionPlan
@@ -81,6 +86,7 @@ async def select_tools(
query: str,
background_context: str = "",
conversation_history: str = "",
+ message_history: list[ModelMessage] | None = None,
) -> AgentSelectionPlan:
"""
Select tools to address a knowledge gap.
@@ -89,7 +95,8 @@ async def select_tools(
gap: The knowledge gap to address
query: The original research query
background_context: Optional background context
- conversation_history: History of actions, findings, and thoughts
+ conversation_history: History of actions, findings, and thoughts (backward compat)
+ message_history: Optional user conversation history (Pydantic AI format)
Returns:
AgentSelectionPlan with tasks for selected agents
@@ -115,8 +122,11 @@ async def select_tools(
"""
try:
- # Run the agent
- result = await self.agent.run(user_message)
+ # Run the agent with message_history if provided
+ if message_history:
+ result = await self.agent.run(user_message, message_history=message_history)
+ else:
+ result = await self.agent.run(user_message)
selection_plan = result.output
self.logger.info(
@@ -144,7 +154,9 @@ async def select_tools(
)
-def create_tool_selector_agent(model: Any | None = None, oauth_token: str | None = None) -> ToolSelectorAgent:
+def create_tool_selector_agent(
+ model: Any | None = None, oauth_token: str | None = None
+) -> ToolSelectorAgent:
"""
Factory function to create a tool selector agent.
diff --git a/src/agents/writer.py b/src/agents/writer.py
index 3bfe7ee0..9fc2edff 100644
--- a/src/agents/writer.py
+++ b/src/agents/writer.py
@@ -175,12 +175,12 @@ async def write_report(
"Report writing failed after all attempts",
error=str(last_exception) if last_exception else "Unknown error",
)
-
+
# Try to use evidence-based report generator for better fallback
try:
from src.middleware.state_machine import get_workflow_state
from src.utils.report_generator import generate_report_from_evidence
-
+
state = get_workflow_state()
if state and state.evidence:
self.logger.info(
@@ -197,7 +197,7 @@ async def write_report(
"Failed to use evidence-based report generator",
error=str(e),
)
-
+
# Fallback to simple report if evidence generator fails
# Truncate findings in fallback if too long
fallback_findings = findings[:500] + "..." if len(findings) > 500 else findings
diff --git a/src/app.py b/src/app.py
index ecb13412..e700eaec 100644
--- a/src/app.py
+++ b/src/app.py
@@ -1,4 +1,12 @@
-"""Gradio UI for The DETERMINATOR agent with MCP server support."""
+"""Main Gradio application for DeepCritical research agent.
+
+This module provides the Gradio interface with:
+- OAuth authentication via HuggingFace
+- Multimodal input support (text, images, audio)
+- Research agent orchestration
+- Real-time event streaming
+- MCP server integration
+"""
import os
from collections.abc import AsyncGenerator
@@ -6,38 +14,43 @@
import gradio as gr
import numpy as np
-from gradio.components.multimodal_textbox import MultimodalPostprocess
-
-# Try to import HuggingFace support (may not be available in all pydantic-ai versions)
-# According to https://ai.pydantic.dev/models/huggingface/, HuggingFace support requires
-# pydantic-ai with huggingface extra or pydantic-ai-slim[huggingface]
-# There are two ways to use HuggingFace:
-# 1. Inference API: HuggingFaceModel with HuggingFaceProvider (uses AsyncInferenceClient internally)
-# 2. Local models: Would use transformers directly (not via pydantic-ai)
+import structlog
+
+from src.agent_factory.judges import HFInferenceJudgeHandler, JudgeHandler, MockJudgeHandler
+from src.orchestrator_factory import create_orchestrator
+from src.services.multimodal_processing import get_multimodal_service
+from src.utils.config import settings
+from src.utils.models import AgentEvent, OrchestratorConfig
+
+# Import ModelMessage from pydantic_ai with fallback
+try:
+ from pydantic_ai import ModelMessage
+except ImportError:
+ from typing import Any
+
+ ModelMessage = Any # type: ignore[assignment, misc]
+
+# Type alias for Gradio multimodal input
+MultimodalPostprocess = dict[str, Any] | str
+
+# Import HuggingFace components with graceful fallback
try:
- from huggingface_hub import AsyncInferenceClient
from pydantic_ai.models.huggingface import HuggingFaceModel
from pydantic_ai.providers.huggingface import HuggingFaceProvider
_HUGGINGFACE_AVAILABLE = True
except ImportError:
+ _HUGGINGFACE_AVAILABLE = False
HuggingFaceModel = None # type: ignore[assignment, misc]
HuggingFaceProvider = None # type: ignore[assignment, misc]
- AsyncInferenceClient = None # type: ignore[assignment, misc]
- _HUGGINGFACE_AVAILABLE = False
-from src.agent_factory.judges import HFInferenceJudgeHandler, JudgeHandler, MockJudgeHandler
-from src.orchestrator_factory import create_orchestrator
-from src.services.audio_processing import get_audio_service
-from src.services.multimodal_processing import get_multimodal_service
-import structlog
-from src.tools.clinicaltrials import ClinicalTrialsTool
-from src.tools.europepmc import EuropePMCTool
-from src.tools.pubmed import PubMedTool
-from src.tools.search_handler import SearchHandler
-from src.tools.neo4j_search import Neo4jSearchTool
-from src.utils.config import settings
-from src.utils.models import AgentEvent, OrchestratorConfig
+try:
+ from huggingface_hub import AsyncInferenceClient
+
+ _ASYNC_INFERENCE_AVAILABLE = True
+except ImportError:
+ _ASYNC_INFERENCE_AVAILABLE = False
+ AsyncInferenceClient = None # type: ignore[assignment, misc]
logger = structlog.get_logger()
@@ -50,40 +63,39 @@ def configure_orchestrator(
hf_provider: str | None = None,
graph_mode: str | None = None,
use_graph: bool = True,
+ web_search_provider: str | None = None,
) -> tuple[Any, str]:
"""
- Create an orchestrator instance.
+ Configure and create the research orchestrator.
Args:
- use_mock: If True, use MockJudgeHandler (no API key needed)
- mode: Orchestrator mode ("simple", "advanced", "iterative", "deep", or "auto")
- oauth_token: Optional OAuth token from HuggingFace login
- hf_model: Selected HuggingFace model ID
- hf_provider: Selected inference provider
- graph_mode: Graph research mode ("iterative", "deep", or "auto") - used when mode is graph-based
- use_graph: Whether to use graph execution (True) or agent chains (False)
+ use_mock: Force mock judge handler (for testing)
+ mode: Orchestrator mode ("simple", "iterative", "deep", "auto", "advanced")
+ oauth_token: Optional OAuth token from HuggingFace login (takes priority over env vars)
+ hf_model: Optional HuggingFace model ID (overrides settings)
+ hf_provider: Optional inference provider (currently not used by HuggingFaceProvider)
+ graph_mode: Optional graph execution mode
+ use_graph: Whether to use graph execution
+ web_search_provider: Optional web search provider ("auto", "serper", "duckduckgo")
Returns:
- Tuple of (Orchestrator instance, backend_name)
+ Tuple of (orchestrator, backend_info_string)
"""
- # Create orchestrator config
- config = OrchestratorConfig(
- max_iterations=10,
- max_results_per_tool=10,
- )
-
- # Create search tools with RAG enabled
- # Pass OAuth token to SearchHandler so it can be used by RAG service
- tools = [Neo4jSearchTool(),PubMedTool(), ClinicalTrialsTool(), EuropePMCTool()]
-
- # Add web search tool if available
+ from src.tools.search_handler import SearchHandler
from src.tools.web_search_factory import create_web_search_tool
- web_search_tool = create_web_search_tool()
- if web_search_tool is not None:
+ # Create search handler with tools
+ tools = []
+
+ # Add web search tool
+ web_search_tool = create_web_search_tool(provider=web_search_provider or "auto")
+ if web_search_tool:
tools.append(web_search_tool)
logger.info("Web search tool added to search handler", provider=web_search_tool.name)
+ # Create config if not provided
+ config = OrchestratorConfig()
+
search_handler = SearchHandler(
tools=tools,
timeout=config.search_timeout,
@@ -104,11 +116,37 @@ def configure_orchestrator(
# 2. API Key (OAuth or Env) - HuggingFace only (OAuth provides HF token)
# Priority: oauth_token > env vars
# On HuggingFace Spaces, OAuth token is available via request.oauth_token
+ #
+ # OAuth Scope Requirements:
+ # - 'inference-api': Required for HuggingFace Inference API access
+ # This scope grants access to:
+ # * HuggingFace's own Inference API
+ # * All third-party inference providers (nebius, together, scaleway, hyperbolic, novita, nscale, sambanova, ovh, fireworks, etc.)
+ # * All models available through the Inference Providers API
+ # See: https://huggingface.co/docs/hub/oauth#currently-supported-scopes
+ #
+ # Note: The hf_provider parameter is accepted but not used here because HuggingFaceProvider
+ # from pydantic-ai doesn't support provider selection. Provider selection happens at the
+ # InferenceClient level (used in HuggingFaceChatClient for advanced mode).
effective_api_key = oauth_token or os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_API_KEY")
+ # Log which authentication source is being used
+ if effective_api_key:
+ auth_source = (
+ "OAuth token"
+ if oauth_token
+ else ("HF_TOKEN env var" if os.getenv("HF_TOKEN") else "HUGGINGFACE_API_KEY env var")
+ )
+ logger.info(
+ "Using HuggingFace authentication",
+ source=auth_source,
+ has_token=bool(effective_api_key),
+ )
+
if effective_api_key:
# We have an API key (OAuth or env) - use pydantic-ai with JudgeHandler
- # This uses HuggingFace's own inference API, not third-party providers
+ # This uses HuggingFace Inference API, which includes access to all third-party providers
+ # via the Inference Providers API (router.huggingface.co)
model: Any | None = None
# Use selected model or fall back to env var/settings
model_name = (
@@ -126,6 +164,7 @@ def configure_orchestrator(
# Per https://ai.pydantic.dev/models/huggingface/#configure-the-provider
# HuggingFaceProvider accepts api_key parameter directly
# This is consistent with usage in src/utils/llm_factory.py and src/agent_factory/judges.py
+ # The OAuth token with 'inference-api' scope provides access to all inference providers
provider = HuggingFaceProvider(api_key=effective_api_key) # type: ignore[misc]
model = HuggingFaceModel(model_name, provider=provider) # type: ignore[misc]
backend_info = "API (HuggingFace OAuth)" if oauth_token else "API (Env Config)"
@@ -167,203 +206,44 @@ def configure_orchestrator(
def _is_file_path(text: str) -> bool:
"""Check if text appears to be a file path.
-
+
Args:
text: Text to check
-
+
Returns:
True if text looks like a file path
"""
- import os
- # Check for common file extensions
- file_extensions = ['.md', '.pdf', '.txt', '.json', '.csv', '.xlsx', '.docx', '.html']
- text_lower = text.lower().strip()
-
- # Check if it ends with a file extension
- if any(text_lower.endswith(ext) for ext in file_extensions):
- # Check if it's a valid path (absolute or relative)
- if os.path.sep in text or '/' in text or '\\' in text:
- return True
- # Or if it's just a filename with extension
- if '.' in text and len(text.split('.')) == 2:
- return True
-
- # Check if it's an absolute path
- if os.path.isabs(text):
- return True
-
- return False
-
-
-def _get_file_name(file_path: str) -> str:
- """Extract filename from file path.
-
- Args:
- file_path: Full file path
-
- Returns:
- Filename with extension
- """
- import os
- return os.path.basename(file_path)
+ return ("/" in text or "\\" in text) and (
+ "." in text.split("/")[-1] or "." in text.split("\\")[-1]
+ )
def event_to_chat_message(event: AgentEvent) -> dict[str, Any]:
- """
- Convert AgentEvent to gr.ChatMessage with metadata for accordion display.
+ """Convert AgentEvent to Gradio chat message format.
Args:
- event: The AgentEvent to convert
+ event: AgentEvent to convert
Returns:
- ChatMessage with metadata for collapsible accordion
+ Dictionary with 'role' and 'content' keys for Gradio Chatbot
"""
- # Map event types to accordion titles and determine if pending
- event_configs: dict[str, dict[str, Any]] = {
- "started": {"title": "π Starting Research", "status": "done", "icon": "π"},
- "searching": {"title": "π Searching Literature", "status": "pending", "icon": "π"},
- "search_complete": {"title": "π Search Results", "status": "done", "icon": "π"},
- "judging": {"title": "π§ Evaluating Evidence", "status": "pending", "icon": "π§ "},
- "judge_complete": {"title": "β
Evidence Assessment", "status": "done", "icon": "β
"},
- "looping": {"title": "π Research Iteration", "status": "pending", "icon": "π"},
- "synthesizing": {"title": "π Synthesizing Report", "status": "pending", "icon": "π"},
- "hypothesizing": {"title": "π¬ Generating Hypothesis", "status": "pending", "icon": "π¬"},
- "analyzing": {"title": "π Statistical Analysis", "status": "pending", "icon": "π"},
- "analysis_complete": {"title": "π Analysis Results", "status": "done", "icon": "π"},
- "streaming": {"title": "π‘ Processing", "status": "pending", "icon": "π‘"},
- "complete": {"title": None, "status": "done", "icon": "π"}, # Main response, no accordion
- "error": {"title": "β Error", "status": "done", "icon": "β"},
- }
-
- config = event_configs.get(
- event.type, {"title": f"β’ {event.type}", "status": "done", "icon": "β’"}
- )
-
- # For complete events, return main response without accordion
- if event.type == "complete":
- # Check if event contains file information
- content = event.message
- files: list[str] | None = None
-
- # Check event.data for file paths
- if event.data and isinstance(event.data, dict):
- # Support both "files" (list) and "file" (single path) keys
- if "files" in event.data:
- files = event.data["files"]
- if isinstance(files, str):
- files = [files]
- elif not isinstance(files, list):
- files = None
- else:
- # Filter to only valid file paths
- files = [f for f in files if isinstance(f, str) and _is_file_path(f)]
- elif "file" in event.data:
- file_path = event.data["file"]
- if isinstance(file_path, str) and _is_file_path(file_path):
- files = [file_path]
-
- # Also check if message itself is a file path (less common, but possible)
- if not files and isinstance(event.message, str) and _is_file_path(event.message):
- files = [event.message]
- # Keep message as text description
- content = "Report generated. Download available below."
-
- # Return as dict format for Gradio Chatbot compatibility
- result: dict[str, Any] = {
- "role": "assistant",
- "content": content,
- }
-
- # Add files if present
- # Gradio Chatbot supports file paths in content as markdown links
- # The links will be clickable and downloadable
- if files:
- # Validate files exist before including them
- import os
- valid_files = [f for f in files if os.path.exists(f)]
-
- if valid_files:
- # Format files for Gradio: include as markdown download links
- # Gradio ChatInterface automatically renders file links as downloadable files
- import os
- file_links = []
- for f in valid_files:
- file_name = _get_file_name(f)
- try:
- file_size = os.path.getsize(f)
- # Format file size (bytes to KB/MB)
- if file_size < 1024:
- size_str = f"{file_size} B"
- elif file_size < 1024 * 1024:
- size_str = f"{file_size / 1024:.1f} KB"
- else:
- size_str = f"{file_size / (1024 * 1024):.1f} MB"
- file_links.append(f"π [Download: {file_name} ({size_str})]({f})")
- except OSError:
- # If we can't get file size, just show the name
- file_links.append(f"π [Download: {file_name}]({f})")
-
- result["content"] = f"{content}\n\n" + "\n\n".join(file_links)
-
- # Also store in metadata for potential future use
- if "metadata" not in result:
- result["metadata"] = {}
- result["metadata"]["files"] = valid_files
-
- return result
-
- # Build metadata for accordion according to Gradio ChatMessage spec
- # Metadata keys: title (str), status ("pending"|"done"), log (str), duration (float)
- # See: https://www.gradio.app/guides/agents-and-tool-usage
- metadata: dict[str, Any] = {}
-
- # Title is required for accordion display - must be string
- if config["title"]:
- metadata["title"] = str(config["title"])
-
- # Set status (pending shows spinner, done is collapsed)
- # Must be exactly "pending" or "done" per Gradio spec
- if config["status"] == "pending":
- metadata["status"] = "pending"
- elif config["status"] == "done":
- metadata["status"] = "done"
-
- # Add duration if available in data (must be float)
- if event.data and isinstance(event.data, dict) and "duration" in event.data:
- duration = event.data["duration"]
- if isinstance(duration, int | float):
- metadata["duration"] = float(duration)
-
- # Add log info (iteration number, etc.) - must be string
- log_parts: list[str] = []
- if event.iteration > 0:
- log_parts.append(f"Iteration {event.iteration}")
- if event.data and isinstance(event.data, dict):
- if "tool" in event.data:
- log_parts.append(f"Tool: {event.data['tool']}")
- if "results_count" in event.data:
- log_parts.append(f"Results: {event.data['results_count']}")
- if log_parts:
- metadata["log"] = " | ".join(log_parts)
-
- # Return as dict format for Gradio Chatbot compatibility
- # According to Gradio docs: https://www.gradio.app/guides/agents-and-tool-usage
- # ChatMessage format: {"role": "assistant", "content": "...", "metadata": {...}}
- # Metadata must have "title" key for accordion display
- # Valid metadata keys: title (str), status ("pending"|"done"), log (str), duration (float)
result: dict[str, Any] = {
"role": "assistant",
- "content": event.message,
+ "content": event.to_markdown(),
}
- # Only add metadata if it has a title (required for accordion display)
- # Ensure metadata values match Gradio's expected types
- if metadata and metadata.get("title"):
- # Ensure status is valid if present
- if "status" in metadata:
- status = metadata["status"]
- if status not in ("pending", "done"):
- metadata["status"] = "done" # Default to "done" if invalid
- result["metadata"] = metadata
+
+ # Add metadata if available
+ if event.data:
+ metadata: dict[str, Any] = {}
+
+ # Extract file path if present
+ if isinstance(event.data, dict):
+ file_path = event.data.get("file_path")
+ if file_path:
+ metadata["file_path"] = file_path
+
+ if metadata:
+ result["metadata"] = metadata
return result
@@ -402,9 +282,9 @@ def extract_oauth_info(request: gr.Request | None) -> tuple[str | None, str | No
oauth_username = request.username
# Also try accessing via oauth_profile if available
elif hasattr(request, "oauth_profile") and request.oauth_profile is not None:
- if hasattr(request.oauth_profile, "username"):
+ if hasattr(request.oauth_profile, "username") and request.oauth_profile.username:
oauth_username = request.oauth_profile.username
- elif hasattr(request.oauth_profile, "name"):
+ elif hasattr(request.oauth_profile, "name") and request.oauth_profile.name:
oauth_username = request.oauth_profile.name
return oauth_token, oauth_username
@@ -417,134 +297,141 @@ async def yield_auth_messages(
mode: str,
) -> AsyncGenerator[dict[str, Any], None]:
"""
- Yield authentication and mode status messages.
+ Yield authentication status messages.
Args:
oauth_username: OAuth username if available
oauth_token: OAuth token if available
- has_huggingface: Whether HuggingFace credentials are available
- mode: Orchestrator mode
+ has_huggingface: Whether HuggingFace authentication is available
+ mode: Research mode
Yields:
- ChatMessage objects with authentication status
+ Chat message dictionaries
"""
- # Show user greeting if logged in via OAuth
if oauth_username:
yield {
"role": "assistant",
- "content": f"π **Welcome, {oauth_username}!** Using your HuggingFace account.\n\n",
+ "content": f"π **Welcome, {oauth_username}!**\n\nAuthenticated via HuggingFace OAuth.",
}
- # Advanced mode is not currently supported with HuggingFace inference
- # For now, we only support simple mode with HuggingFace
- if mode == "advanced":
+ if oauth_token:
yield {
"role": "assistant",
"content": (
- "β οΈ **Note**: Advanced mode is not available with HuggingFace inference providers. "
- "Falling back to simple mode.\n\n"
+ "π **Authentication Status**: β
Authenticated\n\n"
+ "Your OAuth token has been validated. You can now use all AI models and research tools."
),
}
-
- # Inform user about authentication status
- if oauth_token:
+ elif has_huggingface:
yield {
"role": "assistant",
"content": (
- "π **Using HuggingFace OAuth token** - "
- "Authenticated via your HuggingFace account.\n\n"
+ "π **Authentication Status**: β
Using environment token\n\n"
+ "Using HF_TOKEN from environment variables."
),
}
- elif not has_huggingface:
- # No keys at all - will use FREE HuggingFace Inference (public models)
+ else:
yield {
"role": "assistant",
"content": (
- "π€ **Free Tier**: Using HuggingFace Inference (Llama 3.1 / Mistral) for AI analysis.\n"
- "For premium models or higher rate limits, sign in with HuggingFace above.\n\n"
+ "β οΈ **Authentication Status**: β No authentication\n\n"
+ "Please sign in with HuggingFace or set HF_TOKEN environment variable."
),
}
+ yield {
+ "role": "assistant",
+ "content": f"π **Mode**: {mode.upper()}\n\nStarting research agent...",
+ }
-async def handle_orchestrator_events(
- orchestrator: Any,
- message: str,
-) -> AsyncGenerator[dict[str, Any], None]:
- """
- Handle orchestrator events and yield ChatMessages.
- Args:
- orchestrator: The orchestrator instance
- message: The research question
+def _extract_oauth_token(oauth_token: gr.OAuthToken | None) -> str | None:
+ """Extract token value from OAuth token object."""
+ if oauth_token is None:
+ return None
- Yields:
- ChatMessage objects from orchestrator events
- """
- # Track pending accordions for real-time updates
- pending_accordions: dict[str, str] = {} # title -> accumulated content
-
- async for event in orchestrator.run(message):
- # Convert event to ChatMessage with metadata
- chat_msg = event_to_chat_message(event)
-
- # Handle complete events (main response)
- if event.type == "complete":
- # Close any pending accordions first
- if pending_accordions:
- for title, content in pending_accordions.items():
- yield {
- "role": "assistant",
- "content": content.strip(),
- "metadata": {"title": title, "status": "done"},
- }
- pending_accordions.clear()
-
- # Yield final response (no accordion for main response)
- # chat_msg is already a dict from event_to_chat_message
- yield chat_msg
- continue
+ if hasattr(oauth_token, "token"):
+ token_value: str | None = getattr(oauth_token, "token", None) # type: ignore[assignment]
+ if token_value is None:
+ return None
+ logger.debug("OAuth token extracted from oauth_token.token attribute")
- # Handle events with metadata (accordions)
- # chat_msg is always a dict from event_to_chat_message
- metadata: dict[str, Any] = chat_msg.get("metadata", {})
- if metadata:
- msg_title: str | None = metadata.get("title")
- msg_status: str | None = metadata.get("status")
-
- if msg_title:
- # For pending operations, accumulate content and show spinner
- if msg_status == "pending":
- if msg_title not in pending_accordions:
- pending_accordions[msg_title] = ""
- # chat_msg is always a dict, so access content via key
- content = chat_msg.get("content", "")
- pending_accordions[msg_title] += content + "\n"
- # Yield updated accordion with accumulated content
- yield {
- "role": "assistant",
- "content": pending_accordions[msg_title].strip(),
- "metadata": chat_msg.get("metadata", {}),
- }
- elif msg_title in pending_accordions:
- # Combine pending content with final content
- # chat_msg is always a dict, so access content via key
- content = chat_msg.get("content", "")
- final_content = pending_accordions[msg_title] + content
- del pending_accordions[msg_title]
- yield {
- "role": "assistant",
- "content": final_content.strip(),
- "metadata": {"title": msg_title, "status": "done"},
- }
- else:
- # New done accordion (no pending state)
- yield chat_msg
- else:
- # No title, yield as-is
- yield chat_msg
- else:
- # No metadata, yield as plain message
- yield chat_msg
+ # Validate token format
+ from src.utils.hf_error_handler import log_token_info, validate_hf_token
+
+ log_token_info(token_value, context="research_agent")
+ is_valid, error_msg = validate_hf_token(token_value)
+ if not is_valid:
+ logger.warning(
+ "OAuth token validation failed",
+ error=error_msg,
+ oauth_token_type=type(oauth_token).__name__,
+ )
+ return token_value
+
+ if isinstance(oauth_token, str):
+ logger.debug("OAuth token extracted as string")
+
+ # Validate token format
+ from src.utils.hf_error_handler import log_token_info, validate_hf_token
+
+ log_token_info(oauth_token, context="research_agent")
+ return oauth_token
+
+ logger.warning(
+ "OAuth token object present but token extraction failed",
+ oauth_token_type=type(oauth_token).__name__,
+ )
+ return None
+
+
+def _extract_username(oauth_profile: gr.OAuthProfile | None) -> str | None:
+ """Extract username from OAuth profile."""
+ if oauth_profile is None:
+ return None
+
+ username: str | None = None
+ if hasattr(oauth_profile, "username") and oauth_profile.username:
+ username = str(oauth_profile.username)
+ elif hasattr(oauth_profile, "name") and oauth_profile.name:
+ username = str(oauth_profile.name)
+
+ if username:
+ logger.info("OAuth user authenticated", username=username)
+ return username
+
+
+async def _process_multimodal_input(
+ message: str | MultimodalPostprocess,
+ enable_image_input: bool,
+ enable_audio_input: bool,
+ token_value: str | None,
+) -> tuple[str, tuple[int, np.ndarray[Any, Any]] | None]: # type: ignore[type-arg]
+ """Process multimodal input and return processed text and audio data."""
+ processed_text = ""
+ audio_input_data: tuple[int, np.ndarray[Any, Any]] | None = None # type: ignore[type-arg]
+
+ if isinstance(message, dict):
+ processed_text = message.get("text", "") or ""
+ files = message.get("files", []) or []
+ audio_input_data = message.get("audio") or None
+
+ if (files and enable_image_input) or (audio_input_data is not None and enable_audio_input):
+ try:
+ multimodal_service = get_multimodal_service()
+ processed_text = await multimodal_service.process_multimodal_input(
+ processed_text,
+ files=files if enable_image_input else [],
+ audio_input=audio_input_data if enable_audio_input else None,
+ hf_token=token_value,
+ prepend_multimodal=True,
+ )
+ except Exception as e:
+ logger.warning("multimodal_processing_failed", error=str(e))
+ else:
+ processed_text = str(message) if message else ""
+
+ return processed_text, audio_input_data
async def research_agent(
@@ -557,57 +444,33 @@ async def research_agent(
use_graph: bool = True,
enable_image_input: bool = True,
enable_audio_input: bool = True,
- tts_voice: str = "af_heart",
- tts_speed: float = 1.0,
+ web_search_provider: str = "auto",
oauth_token: gr.OAuthToken | None = None,
oauth_profile: gr.OAuthProfile | None = None,
-) -> AsyncGenerator[dict[str, Any] | tuple[dict[str, Any], tuple[int, np.ndarray] | None], None]:
+) -> AsyncGenerator[dict[str, Any], None]:
"""
- Gradio chat function that runs the research agent.
+ Main research agent function that processes queries and streams results.
Args:
- message: User's research question (str or MultimodalPostprocess with text/files)
- history: Chat history (Gradio format)
- mode: Orchestrator mode ("simple" or "advanced")
- hf_model: Selected HuggingFace model ID (from dropdown)
- hf_provider: Selected inference provider (from dropdown)
+ message: User message (text, image, or audio)
+ history: Conversation history
+ mode: Orchestrator mode
+ hf_model: Optional HuggingFace model ID
+ hf_provider: Optional inference provider
+ graph_mode: Graph execution mode
+ use_graph: Whether to use graph execution
+ enable_image_input: Whether to process image inputs
+ enable_audio_input: Whether to process audio inputs
+ web_search_provider: Web search provider selection
oauth_token: Gradio OAuth token (None if user not logged in)
oauth_profile: Gradio OAuth profile (None if user not logged in)
Yields:
- ChatMessage objects with metadata for accordion display, optionally with audio output
+ Chat message dictionaries
"""
- import structlog
-
- logger = structlog.get_logger()
-
- # REQUIRE LOGIN BEFORE USE
- # Extract OAuth token and username using Gradio's OAuth types
- # According to Gradio docs: OAuthToken and OAuthProfile are None if user not logged in
- token_value: str | None = None
- username: str | None = None
-
- if oauth_token is not None:
- # OAuthToken has a .token attribute containing the access token
- if hasattr(oauth_token, "token"):
- token_value = oauth_token.token
- elif isinstance(oauth_token, str):
- # Handle case where oauth_token is already a string (shouldn't happen but defensive)
- token_value = oauth_token
- else:
- token_value = None
-
- if oauth_profile is not None:
- # OAuthProfile has .username, .name, .profile_image attributes
- username = (
- oauth_profile.username
- if hasattr(oauth_profile, "username") and oauth_profile.username
- else (
- oauth_profile.name
- if hasattr(oauth_profile, "name") and oauth_profile.name
- else None
- )
- )
+ # Extract OAuth token and username
+ token_value = _extract_oauth_token(oauth_token)
+ username = _extract_username(oauth_profile)
# Check if user is logged in (OAuth token or env var)
# Fallback to env vars for local development or Spaces with HF_TOKEN secret
@@ -624,47 +487,19 @@ async def research_agent(
"before using this application.\n\n"
"The login button is required to access the AI models and research tools."
),
- }, None
+ }
return
- # Process multimodal input (text + images + audio)
- processed_text = ""
- audio_input_data: tuple[int, np.ndarray] | None = None
-
- if isinstance(message, dict):
- # MultimodalPostprocess format: {"text": str, "files": list[FileData], "audio": tuple | None}
- processed_text = message.get("text", "") or ""
- files = message.get("files", [])
- # Check for audio input in message (Gradio may include it as a separate field)
- audio_input_data = message.get("audio") or None
-
- # Process multimodal input (images, audio files, audio input)
- # Process if we have files (and image input enabled) or audio input (and audio input enabled)
- # Use UI settings from function parameters
- if (files and enable_image_input) or (audio_input_data is not None and enable_audio_input):
- try:
- multimodal_service = get_multimodal_service()
- # Prepend audio/image text to original text (prepend_multimodal=True)
- # Filter files and audio based on UI settings
- processed_text = await multimodal_service.process_multimodal_input(
- processed_text,
- files=files if enable_image_input else [],
- audio_input=audio_input_data if enable_audio_input else None,
- hf_token=token_value,
- prepend_multimodal=True, # Prepend audio/image text to text input
- )
- except Exception as e:
- logger.warning("multimodal_processing_failed", error=str(e))
- # Continue with text-only input
- else:
- # Plain string message
- processed_text = str(message) if message else ""
+ # Process multimodal input
+ processed_text, audio_input_data = await _process_multimodal_input(
+ message, enable_image_input, enable_audio_input, token_value
+ )
if not processed_text.strip():
yield {
"role": "assistant",
"content": "Please enter a research question or provide an image/audio input.",
- }, None
+ }
return
# Check available keys (use token_value instead of oauth_token)
@@ -687,56 +522,64 @@ async def research_agent(
model_id = hf_model if hf_model and hf_model.strip() else None
provider_name = hf_provider if hf_provider and hf_provider.strip() else None
+ # Log authentication source for debugging
+ auth_source = (
+ "OAuth"
+ if token_value
+ else (
+ "Env (HF_TOKEN)"
+ if os.getenv("HF_TOKEN")
+ else ("Env (HUGGINGFACE_API_KEY)" if os.getenv("HUGGINGFACE_API_KEY") else "None")
+ )
+ )
+ logger.info(
+ "Configuring orchestrator",
+ mode=effective_mode,
+ auth_source=auth_source,
+ has_oauth_token=bool(token_value),
+ model=model_id or "default",
+ provider=provider_name or "auto",
+ )
+
+ # Convert empty string to None for web_search_provider
+ web_search_provider_value = (
+ web_search_provider if web_search_provider and web_search_provider.strip() else None
+ )
+
orchestrator, backend_name = configure_orchestrator(
use_mock=False, # Never use mock in production - HF Inference is the free fallback
mode=effective_mode,
- oauth_token=token_value, # Use extracted token value
+ oauth_token=token_value, # Use extracted token value - passed to all agents and services
hf_model=model_id, # None will use defaults in configure_orchestrator
hf_provider=provider_name, # None will use defaults in configure_orchestrator
graph_mode=graph_mode if graph_mode else None,
use_graph=use_graph,
+ web_search_provider=web_search_provider_value, # None will use settings default
)
yield {
"role": "assistant",
- "content": f"π§ **Backend**: {backend_name}\n\n",
+ "content": f"π§ **Backend**: {backend_name}\n\nProcessing your query...",
}
- # Handle orchestrator events and generate audio output
- audio_output_data: tuple[int, np.ndarray] | None = None
- final_message = ""
-
- async for msg in handle_orchestrator_events(orchestrator, processed_text):
- # Track final message for TTS
- if isinstance(msg, dict) and msg.get("role") == "assistant":
+ # Convert history to ModelMessage format if needed
+ message_history: list[ModelMessage] = []
+ if history:
+ for msg in history:
+ role = msg.get("role", "user")
content = msg.get("content", "")
- metadata = msg.get("metadata", {})
- # This is the main response (not an accordion) if no title in metadata
- if content and not metadata.get("title"):
- final_message = content
-
- # Yield without audio for intermediate messages
- yield msg, None
-
- # Generate audio output for final response
- if final_message and settings.enable_audio_output:
- try:
- audio_service = get_audio_service()
- # Use UI-configured voice and speed, fallback to settings defaults
- audio_output_data = await audio_service.generate_audio_output(
- final_message,
- voice=tts_voice or settings.tts_voice,
- speed=tts_speed if tts_speed else settings.tts_speed,
- )
- except Exception as e:
- logger.warning("audio_synthesis_failed", error=str(e))
- # Continue without audio output
+ if isinstance(content, str) and content.strip():
+ message_history.append(ModelMessage(role=role, content=content)) # type: ignore[operator]
+
+ # Run orchestrator and stream events
+ async for event in orchestrator.run(
+ processed_text, message_history=message_history if message_history else None
+ ):
+ chat_msg = event_to_chat_message(event)
+ yield chat_msg
- # If we have audio output, we need to yield it with the final message
- # Note: The final message was already yielded above, so we yield None, audio_output_data
- # This will update the audio output component
- if audio_output_data is not None:
- yield None, audio_output_data
+ # Note: Audio output is now handled via on-demand TTS button
+ # Users click "Generate Audio" button to create TTS for the last response
except Exception as e:
# Return error message without metadata to avoid issues during example caching
@@ -747,7 +590,110 @@ async def research_agent(
yield {
"role": "assistant",
"content": f"Error: {error_msg}. Please check your configuration and try again.",
- }, None
+ }
+
+
+async def update_model_provider_dropdowns(
+ oauth_token: gr.OAuthToken | None = None,
+ oauth_profile: gr.OAuthProfile | None = None,
+) -> tuple[dict[str, Any], dict[str, Any], str]:
+ """Update model and provider dropdowns based on OAuth token.
+
+ This function is called when OAuth token/profile changes (user logs in/out).
+ It queries HuggingFace API to get available models and providers.
+
+ Args:
+ oauth_token: Gradio OAuth token
+ oauth_profile: Gradio OAuth profile
+
+ Returns:
+ Tuple of (model_dropdown_update, provider_dropdown_update, status_message)
+ """
+ from src.utils.hf_model_validator import (
+ get_available_models,
+ get_available_providers,
+ validate_oauth_token,
+ )
+
+ # Extract token value
+ token_value: str | None = None
+ if oauth_token is not None:
+ if hasattr(oauth_token, "token"):
+ token_value = oauth_token.token
+ elif isinstance(oauth_token, str):
+ token_value = oauth_token
+
+ # Default values (empty = use default)
+ default_models = [""]
+ default_providers = [""]
+ status_msg = "β οΈ Not authenticated - using default models"
+
+ if not token_value:
+ # No token - return defaults
+ return (
+ gr.update(choices=default_models, value=""),
+ gr.update(choices=default_providers, value=""),
+ status_msg,
+ )
+
+ try:
+ # Validate token and get available resources
+ validation_result = await validate_oauth_token(token_value)
+
+ if not validation_result["is_valid"]:
+ status_msg = (
+ f"β Token validation failed: {validation_result.get('error', 'Unknown error')}"
+ )
+ return (
+ gr.update(choices=default_models, value=""),
+ gr.update(choices=default_providers, value=""),
+ status_msg,
+ )
+
+ # Get available models and providers
+ models = await get_available_models(token=token_value, limit=50)
+ providers = await get_available_providers(token=token_value)
+
+ # Combine with defaults
+ model_choices = ["", *models[:49]] # Keep first 49 + empty option
+ provider_choices = providers # Already includes "auto"
+
+ username = validation_result.get("username", "User")
+
+ # Build status message with warning if scope is missing
+ scope_warning = ""
+ if not validation_result["has_inference_api_scope"]:
+ scope_warning = (
+ "β οΈ Token may not have 'inference-api' scope - some models may not work\n\n"
+ )
+
+ status_msg = (
+ f"{scope_warning}β
Authenticated as {username}\n\n"
+ f"π Found {len(models)} available models\n"
+ f"π§ Found {len(providers)} available providers"
+ )
+
+ logger.info(
+ "Updated model/provider dropdowns",
+ model_count=len(model_choices),
+ provider_count=len(provider_choices),
+ username=username,
+ )
+
+ return (
+ gr.update(choices=model_choices, value=""),
+ gr.update(choices=provider_choices, value=""),
+ status_msg,
+ )
+
+ except Exception as e:
+ logger.error("Failed to update dropdowns", error=str(e))
+ status_msg = f"β οΈ Failed to load models: {e!s}"
+ return (
+ gr.update(choices=default_models, value=""),
+ gr.update(choices=default_providers, value=""),
+ status_msg,
+ )
def create_demo() -> gr.Blocks:
@@ -768,24 +714,31 @@ def create_demo() -> gr.Blocks:
)
gr.LoginButton("Sign in with Hugging Face")
gr.Markdown("---")
- gr.Markdown("### βΉοΈ About") # noqa: RUF001
- gr.Markdown(
- "**The DETERMINATOR** - Generalist Deep Research Agent\n\n"
- "A powerful research agent that stops at nothing until finding precise answers to complex questions.\n\n"
- "**Available Sources**:\n"
- "- Web Search (general knowledge)\n"
- "- PubMed (biomedical literature)\n"
- "- ClinicalTrials.gov (clinical trials)\n"
- "- Europe PMC (preprints & papers)\n"
- "- RAG (semantic search)\n\n"
- "**Automatic Detection**: Automatically determines if medical knowledge sources are needed for your query.\n\n"
- "β οΈ **Research tool only** - Synthesizes evidence but cannot provide medical advice."
- )
+
+ # About Section - Collapsible with details
+ with gr.Accordion("βΉοΈ About", open=False):
+ gr.Markdown(
+ "**The DETERMINATOR** - Generalist Deep Research Agent\n\n"
+ "Stops at nothing until finding precise answers to complex questions.\n\n"
+ "**How It Works**:\n"
+ "- π Multi-source search (Web, PubMed, ClinicalTrials.gov, Europe PMC, RAG)\n"
+ "- π§ Automatic medical knowledge detection\n"
+ "- π Iterative refinement with search-judge loops\n"
+ "- βΉοΈ Continues until budget/time/iteration limits\n"
+ "- π Evidence synthesis with citations\n\n"
+ "**Multimodal Input**:\n"
+ "- π· **Images**: Click image icon in textbox (OCR)\n"
+ "- π€ **Audio**: Click microphone icon (speech-to-text)\n"
+ "- π **Files**: Drag & drop or click to upload\n\n"
+ "**MCP Server**: Connect Claude Desktop to `/gradio_api/mcp/`\n\n"
+ "β οΈ **Research tool only** - Synthesizes evidence but cannot provide medical advice."
+ )
+
gr.Markdown("---")
-
+
# Settings Section - Organized in Accordions
gr.Markdown("## βοΈ Settings")
-
+
# Research Configuration Accordion
with gr.Accordion("π¬ Research Configuration", open=True):
mode_radio = gr.Radio(
@@ -800,24 +753,30 @@ def create_demo() -> gr.Blocks:
"Auto: Smart routing"
),
)
-
+
graph_mode_radio = gr.Radio(
choices=["iterative", "deep", "auto"],
value="auto",
label="Graph Research Mode",
info="Iterative: Single loop | Deep: Parallel sections | Auto: Detect from query",
)
-
+
use_graph_checkbox = gr.Checkbox(
value=True,
label="Use Graph Execution",
info="Enable graph-based workflow execution",
)
-
+
# Model and Provider selection
gr.Markdown("### π€ Model & Provider")
-
- # Popular models list
+
+ # Status message for model/provider loading
+ model_provider_status = gr.Markdown(
+ value="β οΈ Sign in to see available models and providers",
+ visible=True,
+ )
+
+ # Popular models list (will be updated by validator)
popular_models = [
"", # Empty = use default
"Qwen/Qwen3-Next-80B-A3B-Thinking",
@@ -828,16 +787,16 @@ def create_demo() -> gr.Blocks:
"mistralai/Mistral-7B-Instruct-v0.2",
"google/gemma-2-9b-it",
]
-
+
hf_model_dropdown = gr.Dropdown(
choices=popular_models,
value="", # Empty string - will be converted to None in research_agent
label="Reasoning Model",
- info="Select a HuggingFace model (leave empty for default)",
+ info="Select a HuggingFace model (leave empty for default). Sign in to see all available models.",
allow_custom_value=True, # Allow users to type custom model IDs
)
- # Provider list from README
+ # Provider list from README (will be updated by validator)
providers = [
"", # Empty string = auto-select
"nebius",
@@ -850,131 +809,411 @@ def create_demo() -> gr.Blocks:
"ovh",
"fireworks",
]
-
+
hf_provider_dropdown = gr.Dropdown(
choices=providers,
value="", # Empty string - will be converted to None in research_agent
label="Inference Provider",
- info="Select inference provider (leave empty for auto-select)",
+ info="Select inference provider (leave empty for auto-select). Sign in to see all available providers.",
+ )
+
+ # Refresh button for updating models/providers after login
+ def refresh_models_and_providers(
+ request: gr.Request,
+ ) -> tuple[dict[str, Any], dict[str, Any], str]:
+ """Handle refresh button click and update dropdowns."""
+ import asyncio
+
+ # Extract OAuth token and profile from request
+ oauth_token: gr.OAuthToken | None = None
+ oauth_profile: gr.OAuthProfile | None = None
+
+ if request is not None:
+ # Try to get OAuth token from request
+ if hasattr(request, "oauth_token"):
+ oauth_token = request.oauth_token
+ if hasattr(request, "oauth_profile"):
+ oauth_profile = request.oauth_profile
+
+ # Run async function in sync context
+ loop = asyncio.new_event_loop()
+ asyncio.set_event_loop(loop)
+ try:
+ result = loop.run_until_complete(
+ update_model_provider_dropdowns(oauth_token, oauth_profile)
+ )
+ return result
+ finally:
+ loop.close()
+
+ refresh_models_btn = gr.Button(
+ value="π Refresh Available Models",
+ visible=True,
+ size="sm",
)
-
- # Multimodal Input Configuration Accordion
- with gr.Accordion("π· Multimodal Input", open=False):
+
+ # Pass request to get OAuth token from Gradio context
+ refresh_models_btn.click(
+ fn=refresh_models_and_providers,
+ inputs=[], # Request is automatically available in Gradio context
+ outputs=[hf_model_dropdown, hf_provider_dropdown, model_provider_status],
+ )
+
+ # Web Search Provider selection
+ gr.Markdown("### π Web Search Provider")
+
+ # Available providers with labels indicating availability
+ # Format: (display_label, value) - Gradio Dropdown supports tuples
+ web_search_provider_options = [
+ ("Auto-detect (Recommended)", "auto"),
+ ("Serper (Google Search + Full Content)", "serper"),
+ ("DuckDuckGo (Free, Snippets Only)", "duckduckgo"),
+ ("SearchXNG (Self-hosted) - Coming Soon", "searchxng"), # Not fully implemented
+ ("Brave - Coming Soon", "brave"), # Not implemented
+ ("Tavily - Coming Soon", "tavily"), # Not implemented
+ ]
+
+ # Create Dropdown with label-value pairs
+ # Gradio will display labels but return values
+ # Disabled options are marked with "Coming Soon" in the label
+ # The factory will handle "not implemented" cases gracefully
+ web_search_provider_dropdown = gr.Dropdown(
+ choices=web_search_provider_options,
+ value="auto",
+ label="Web Search Provider",
+ info="Select web search provider. 'Auto' detects best available.",
+ )
+
+ # Multimodal Input Configuration
+ gr.Markdown("### π·π€ Multimodal Input")
+
enable_image_input_checkbox = gr.Checkbox(
value=settings.enable_image_input,
label="Enable Image Input (OCR)",
- info="Extract text from uploaded images using OCR",
+ info="Process uploaded images with OCR",
)
-
+
enable_audio_input_checkbox = gr.Checkbox(
value=settings.enable_audio_input,
label="Enable Audio Input (STT)",
- info="Transcribe audio recordings using speech-to-text",
- )
-
- # Audio/TTS Configuration Accordion
- with gr.Accordion("π Audio Output", open=False):
- enable_audio_output_checkbox = gr.Checkbox(
- value=settings.enable_audio_output,
- label="Enable Audio Output",
- info="Generate audio responses using TTS",
+ info="Process uploaded/recorded audio with speech-to-text",
)
-
- tts_voice_dropdown = gr.Dropdown(
- choices=[
- "af_heart",
- "af_bella",
- "af_nicole",
- "af_aoede",
- "af_kore",
- "af_sarah",
- "af_nova",
- "af_sky",
- "af_alloy",
- "af_jessica",
- "af_river",
- "am_michael",
- "am_fenrir",
- "am_puck",
- "am_echo",
- "am_eric",
- "am_liam",
- "am_onyx",
- "am_santa",
- "am_adam",
- ],
- value=settings.tts_voice,
- label="TTS Voice",
- info="Select TTS voice (American English voices: af_*, am_*)",
+
+ # Audio Output Configuration - Collapsible
+ with gr.Accordion("π Audio Output (TTS)", open=False):
+ gr.Markdown(
+ "**Generate audio for research responses on-demand.**\n\n"
+ "Enter Modal keys below or set `MODAL_TOKEN_ID`/`MODAL_TOKEN_SECRET` in `.env` for local development."
)
-
- tts_speed_slider = gr.Slider(
- minimum=0.5,
- maximum=2.0,
- value=settings.tts_speed,
- step=0.1,
- label="TTS Speech Speed",
- info="Adjust TTS speech speed (0.5x to 2.0x)",
+
+ with gr.Accordion("π Modal Credentials (Optional)", open=False):
+ modal_token_id_input = gr.Textbox(
+ label="Modal Token ID",
+ placeholder="ak-... (leave empty to use .env)",
+ type="password",
+ value="",
+ )
+
+ modal_token_secret_input = gr.Textbox(
+ label="Modal Token Secret",
+ placeholder="as-... (leave empty to use .env)",
+ type="password",
+ value="",
+ )
+
+ with gr.Accordion("ποΈ Voice & Quality Settings", open=False):
+ tts_voice_dropdown = gr.Dropdown(
+ choices=[
+ "af_heart",
+ "af_bella",
+ "af_sarah",
+ "af_sky",
+ "af_nova",
+ "af_shimmer",
+ "af_echo",
+ "af_fable",
+ "af_onyx",
+ "af_angel",
+ "af_asteria",
+ "af_jessica",
+ "af_elli",
+ "af_domi",
+ "af_gigi",
+ "af_freya",
+ "af_glinda",
+ "af_cora",
+ "af_serena",
+ "af_liv",
+ "af_naomi",
+ "af_rachel",
+ "af_antoni",
+ "af_thomas",
+ "af_charlie",
+ "af_emily",
+ "af_george",
+ "af_arnold",
+ "af_adam",
+ "af_sam",
+ "af_paul",
+ "af_josh",
+ "af_daniel",
+ "af_liam",
+ "af_dave",
+ "af_fin",
+ "af_sarah",
+ "af_glinda",
+ "af_grace",
+ "af_dorothy",
+ "af_michael",
+ "af_james",
+ "af_joseph",
+ "af_jeremy",
+ "af_ryan",
+ "af_oliver",
+ "af_harry",
+ "af_kyle",
+ "af_leo",
+ "af_otto",
+ "af_owen",
+ "af_pepper",
+ "af_phil",
+ "af_raven",
+ "af_rocky",
+ "af_rusty",
+ "af_serena",
+ "af_sky",
+ "af_spark",
+ "af_stella",
+ "af_storm",
+ "af_taylor",
+ "af_vera",
+ "af_will",
+ "af_aria",
+ "af_ash",
+ "af_ballad",
+ "af_bella",
+ "af_breeze",
+ "af_cove",
+ "af_dusk",
+ "af_ember",
+ "af_flash",
+ "af_flow",
+ "af_glow",
+ "af_harmony",
+ "af_journey",
+ "af_lullaby",
+ "af_lyra",
+ "af_melody",
+ "af_midnight",
+ "af_moon",
+ "af_muse",
+ "af_music",
+ "af_narrator",
+ "af_nightingale",
+ "af_poet",
+ "af_rain",
+ "af_redwood",
+ "af_rewind",
+ "af_river",
+ "af_sage",
+ "af_seashore",
+ "af_shadow",
+ "af_silver",
+ "af_song",
+ "af_starshine",
+ "af_story",
+ "af_summer",
+ "af_sun",
+ "af_thunder",
+ "af_tide",
+ "af_time",
+ "af_valentino",
+ "af_verdant",
+ "af_verse",
+ "af_vibrant",
+ "af_vivid",
+ "af_warmth",
+ "af_whisper",
+ "af_wilderness",
+ "af_willow",
+ "af_winter",
+ "af_wit",
+ "af_witness",
+ "af_wren",
+ "af_writer",
+ "af_zara",
+ "af_zeus",
+ "af_ziggy",
+ "af_zoom",
+ "af_river",
+ "am_michael",
+ "am_fenrir",
+ "am_puck",
+ "am_echo",
+ "am_eric",
+ "am_liam",
+ "am_onyx",
+ "am_santa",
+ "am_adam",
+ ],
+ value=settings.tts_voice,
+ label="TTS Voice",
+ info="Select TTS voice (American English voices: af_*, am_*)",
+ )
+
+ tts_speed_slider = gr.Slider(
+ minimum=0.5,
+ maximum=2.0,
+ value=settings.tts_speed,
+ step=0.1,
+ label="TTS Speech Speed",
+ info="Adjust TTS speech speed (0.5x to 2.0x)",
+ )
+
+ gr.Dropdown(
+ choices=["T4", "A10", "A100", "L4", "L40S"],
+ value=settings.tts_gpu or "T4",
+ label="TTS GPU Type",
+ info="Modal GPU type for TTS (T4 is cheapest, A100 is fastest). Note: GPU changes require app restart.",
+ visible=settings.modal_available,
+ interactive=False, # GPU type set at function definition time, requires restart
+ )
+
+ tts_use_llm_polish_checkbox = gr.Checkbox(
+ value=settings.tts_use_llm_polish,
+ label="Use LLM Polish for Audio",
+ info="Apply LLM-based final polish to remove remaining formatting artifacts (costs API calls)",
+ )
+
+ tts_generate_button = gr.Button(
+ "π΅ Generate Audio for Last Response",
+ variant="primary",
+ size="lg",
)
-
- tts_gpu_dropdown = gr.Dropdown(
- choices=["T4", "A10", "A100", "L4", "L40S"],
- value=settings.tts_gpu or "T4",
- label="TTS GPU Type",
- info="Modal GPU type for TTS (T4 is cheapest, A100 is fastest). Note: GPU changes require app restart.",
- visible=settings.modal_available,
- interactive=False, # GPU type set at function definition time, requires restart
+
+ tts_status_text = gr.Markdown(
+ "Click the button above to generate audio for the last research response.",
+ elem_classes="tts-status",
)
-
- # Audio output component (for TTS response) - moved to sidebar
+
+ # Audio output component (for TTS response)
audio_output = gr.Audio(
- label="π Audio Response",
- visible=settings.enable_audio_output,
+ label="π Audio Output",
+ visible=True,
)
-
- # Update TTS component visibility based on enable_audio_output_checkbox
- # This must be after audio_output is defined
- def update_tts_visibility(enabled: bool) -> tuple[dict[str, Any], dict[str, Any], dict[str, Any]]:
- """Update visibility of TTS components based on enable checkbox."""
- return (
- gr.update(visible=enabled),
- gr.update(visible=enabled),
- gr.update(visible=enabled),
+
+ # TTS on-demand generation handler
+ async def handle_tts_generation(
+ history: list[dict[str, Any]],
+ modal_token_id: str,
+ modal_token_secret: str,
+ voice: str,
+ speed: float,
+ use_llm_polish: bool,
+ ) -> tuple[Any | None, str]:
+ """Generate audio on-demand for the last response.
+
+ Args:
+ history: Chat history
+ modal_token_id: Modal token ID from UI
+ modal_token_secret: Modal token secret from UI
+ voice: TTS voice selection
+ speed: TTS speed
+ use_llm_polish: Enable LLM polish
+
+ Returns:
+ Tuple of (audio_output, status_message)
+ """
+ from src.services.tts_modal import generate_audio_on_demand
+
+ # Get last assistant message from history
+ # History is a list of tuples: [(user_msg, assistant_msg), ...]
+ if not history:
+ logger.warning("tts_no_history", history=history)
+ return None, "β No messages in history to generate audio for"
+
+ # Debug: Log history format
+ logger.info(
+ "tts_history_debug",
+ history_type=type(history).__name__,
+ history_length=len(history) if isinstance(history, list) else 0,
+ first_entry_type=type(history[0]).__name__
+ if isinstance(history, list) and len(history) > 0
+ else None,
+ first_entry_sample=str(history[0])[:200]
+ if isinstance(history, list) and len(history) > 0
+ else None,
)
-
- enable_audio_output_checkbox.change(
- fn=update_tts_visibility,
- inputs=[enable_audio_output_checkbox],
- outputs=[tts_voice_dropdown, tts_speed_slider, audio_output],
- )
+
+ # Get the last assistant message (second element of last tuple)
+ last_message = None
+ if isinstance(history, list) and len(history) > 0:
+ last_entry = history[-1]
+ # ChatInterface format: (user_message, assistant_message)
+ if isinstance(last_entry, (tuple, list)) and len(last_entry) >= 2:
+ last_message = last_entry[1]
+ logger.info(
+ "tts_extracted_from_tuple", message_type=type(last_message).__name__
+ )
+ # Dict format: {"role": "assistant", "content": "..."}
+ elif isinstance(last_entry, dict):
+ if last_entry.get("role") == "assistant":
+ content = last_entry.get("content", "")
+ # Content might be a list (multimodal) or string
+ if isinstance(content, list):
+ # Extract text from multimodal content list
+ last_message = " ".join(str(item) for item in content if item)
+ else:
+ last_message = content
+ logger.info(
+ "tts_extracted_from_dict",
+ message_type=type(content).__name__,
+ message_length=len(last_message)
+ if isinstance(last_message, str)
+ else 0,
+ )
+ else:
+ logger.warning(
+ "tts_unknown_format",
+ entry_type=type(last_entry).__name__,
+ entry=str(last_entry)[:200],
+ )
+
+ # Also handle if last_message itself is a list
+ if isinstance(last_message, list):
+ last_message = " ".join(str(item) for item in last_message if item)
+
+ if not last_message or not isinstance(last_message, str) or not last_message.strip():
+ logger.error(
+ "tts_no_message_found",
+ last_message_type=type(last_message).__name__ if last_message else None,
+ last_message_value=str(last_message)[:100] if last_message else None,
+ )
+ return None, "β No assistant response found in history"
+
+ # Generate audio
+ audio_output, status_message = await generate_audio_on_demand(
+ text=last_message,
+ modal_token_id=modal_token_id,
+ modal_token_secret=modal_token_secret,
+ voice=voice,
+ speed=speed,
+ use_llm_polish=use_llm_polish,
+ )
+
+ return audio_output, status_message
# Chat interface with multimodal support
# Examples are provided but will NOT run at startup (cache_examples=False)
# Users must log in first before using examples or submitting queries
- gr.ChatInterface(
+ chat_interface = gr.ChatInterface(
fn=research_agent,
multimodal=True, # Enable multimodal input (text + images + audio)
title="π¬ The DETERMINATOR",
description=(
- "*Generalist Deep Research Agent β stops at nothing until finding precise answers to complex questions*\n\n"
- "---\n"
- "**The DETERMINATOR** uses iterative search-and-judge loops to comprehensively investigate any research question. "
- "It automatically determines if medical knowledge sources (PubMed, ClinicalTrials.gov) are needed and adapts its search strategy accordingly.\n\n"
- "**Key Features**:\n"
- "- π Multi-source search (Web, PubMed, ClinicalTrials.gov, Europe PMC, RAG)\n"
- "- π§ Automatic medical knowledge detection\n"
- "- π Iterative refinement until precise answers are found\n"
- "- βΉοΈ Stops only at configured limits (budget, time, iterations)\n"
- "- π Evidence synthesis with citations\n\n"
- "**MCP Server Active**: Connect Claude Desktop to `/gradio_api/mcp/`\n\n"
- "**π·π€ Multimodal Input Support**:\n"
- "- **Images**: Click the π· image icon in the textbox to upload images (OCR)\n"
- "- **Audio**: Click the π€ microphone icon in the textbox to record audio (STT)\n"
- "- **Files**: Drag & drop or click to upload image/audio files\n"
- "- **Text**: Type your research questions directly\n\n"
- "π‘ **Tip**: Look for the π· and π€ icons in the text input box below!\n\n"
- "Configure multimodal inputs in the sidebar settings.\n\n"
- "**β οΈ Authentication Required**: Please **sign in with HuggingFace** above before using this application."
+ "*Generalist Deep Research Agent β stops at nothing until finding precise answers*\n\n"
+ "π‘ **Quick Start**: Type your research question below. Use π· for images, π€ for audio.\n\n"
+ "β οΈ **Sign in with HuggingFace** (sidebar) before starting."
),
examples=[
# When additional_inputs are provided, examples must be lists of lists
@@ -997,24 +1236,38 @@ def update_tts_visibility(enabled: bool) -> tuple[dict[str, Any], dict[str, Any]
"Analyze the current state of quantum computing architectures: compare different qubit technologies, error correction methods, and scalability challenges across major platforms including IBM, Google, and IonQ.",
"deep",
"Qwen/Qwen3-Next-80B-A3B-Thinking",
- "",
+ "nebius",
"deep",
True,
],
[
- # Business/Scientific example requiring iterative search
- "Investigate the economic and environmental impact of renewable energy transition: analyze cost trends, grid integration challenges, policy frameworks, and market dynamics across solar, wind, and battery storage technologies, in china",
+ # Historical/Social Science example
+ "Research and synthesize information about the economic impact of the Industrial Revolution on European social structures, including changes in class dynamics, urbanization patterns, and labor movements from 1750-1900.",
+ "deep",
+ "meta-llama/Llama-3.1-70B-Instruct",
+ "together",
+ "deep",
+ True,
+ ],
+ [
+ # Scientific/Physics example
+ "Investigate the latest developments in fusion energy research: compare ITER, SPARC, and other major projects, analyze recent breakthroughs in plasma confinement, and assess the timeline to commercial fusion power.",
"deep",
"Qwen/Qwen3-235B-A22B-Instruct-2507",
- "",
+ "hyperbolic",
+ "deep",
+ True,
+ ],
+ [
+ # Technology/Business example
+ "Research the competitive landscape of AI chip manufacturers: analyze NVIDIA, AMD, Intel, and emerging players, compare architectures (GPU vs. TPU vs. NPU), and assess market positioning and future trends.",
+ "deep",
+ "zai-org/GLM-4.5-Air",
+ "fireworks",
"deep",
True,
],
],
- cache_examples=False, # CRITICAL: Disable example caching to prevent examples from running at startup
- # Examples will only run when user explicitly clicks them (after login)
- # Note: additional_inputs_accordion is not a valid parameter in Gradio 6.0 ChatInterface
- # Components will be displayed in the order provided
additional_inputs=[
mode_radio,
hf_model_dropdown,
@@ -1023,28 +1276,29 @@ def update_tts_visibility(enabled: bool) -> tuple[dict[str, Any], dict[str, Any]
use_graph_checkbox,
enable_image_input_checkbox,
enable_audio_input_checkbox,
+ web_search_provider_dropdown,
+ # Note: gr.OAuthToken and gr.OAuthProfile are automatically passed as function parameters
+ ],
+ cache_examples=False, # Don't cache examples - requires authentication
+ )
+
+ # Wire up TTS generation button
+ tts_generate_button.click(
+ fn=handle_tts_generation,
+ inputs=[
+ chat_interface.chatbot, # Get chat history from ChatInterface
+ modal_token_id_input,
+ modal_token_secret_input,
tts_voice_dropdown,
tts_speed_slider,
- # Note: gr.OAuthToken and gr.OAuthProfile are automatically passed as function parameters
- # when user is logged in - they should NOT be added to additional_inputs
+ tts_use_llm_polish_checkbox,
],
- additional_outputs=[audio_output], # Add audio output for TTS
+ outputs=[audio_output, tts_status_text],
)
return demo # type: ignore[no-any-return]
-def main() -> None:
- """Run the Gradio app with MCP server enabled."""
- demo = create_demo()
- demo.launch(
- # server_name="0.0.0.0",
- # server_port=7860,
- # share=False,
- mcp_server=True, # Enable MCP server for Claude Desktop integration
- ssr_mode=False, # Fix for intermittent loading/hydration issues in HF Spaces
- )
-
-
if __name__ == "__main__":
- main()
+ demo = create_demo()
+ demo.launch(server_name="0.0.0.0", server_port=7860)
diff --git a/src/legacy_orchestrator.py b/src/legacy_orchestrator.py
index ac1ee46c..b41ba8aa 100644
--- a/src/legacy_orchestrator.py
+++ b/src/legacy_orchestrator.py
@@ -6,6 +6,11 @@
import structlog
+try:
+ from pydantic_ai import ModelMessage
+except ImportError:
+ ModelMessage = Any # type: ignore[assignment, misc]
+
from src.utils.config import settings
from src.utils.models import (
AgentEvent,
@@ -153,7 +158,9 @@ async def _run_analysis_phase(
iteration=iteration,
)
- async def run(self, query: str) -> AsyncGenerator[AgentEvent, None]: # noqa: PLR0915
+ async def run(
+ self, query: str, message_history: list[ModelMessage] | None = None
+ ) -> AsyncGenerator[AgentEvent, None]: # noqa: PLR0915
"""
Run the agent loop for a query.
@@ -161,11 +168,16 @@ async def run(self, query: str) -> AsyncGenerator[AgentEvent, None]: # noqa: PL
Args:
query: The user's research question
+ message_history: Optional user conversation history (for compatibility)
Yields:
AgentEvent objects for each step of the process
"""
- logger.info("Starting orchestrator", query=query)
+ logger.info(
+ "Starting orchestrator",
+ query=query,
+ has_history=bool(message_history),
+ )
yield AgentEvent(
type="started",
diff --git a/src/mcp_tools.py b/src/mcp_tools.py
index fdd15ac0..7b59f9be 100644
--- a/src/mcp_tools.py
+++ b/src/mcp_tools.py
@@ -242,7 +242,6 @@ async def extract_text_from_image(
Extracted text from the image
"""
from src.services.image_ocr import get_image_ocr_service
-
from src.utils.config import settings
try:
@@ -280,7 +279,6 @@ async def transcribe_audio_file(
Transcribed text from the audio file
"""
from src.services.stt_gradio import get_stt_service
-
from src.utils.config import settings
try:
@@ -300,4 +298,4 @@ async def transcribe_audio_file(
return f"## Audio Transcription\n\n{transcribed_text}"
except Exception as e:
- return f"Error transcribing audio: {e}"
\ No newline at end of file
+ return f"Error transcribing audio: {e}"
diff --git a/src/middleware/state_machine.py b/src/middleware/state_machine.py
index d43e131e..8fbdf793 100644
--- a/src/middleware/state_machine.py
+++ b/src/middleware/state_machine.py
@@ -11,6 +11,11 @@
import structlog
from pydantic import BaseModel, Field
+try:
+ from pydantic_ai import ModelMessage
+except ImportError:
+ ModelMessage = Any # type: ignore[assignment, misc]
+
from src.utils.models import Citation, Conversation, Evidence
if TYPE_CHECKING:
@@ -28,6 +33,10 @@ class WorkflowState(BaseModel):
evidence: list[Evidence] = Field(default_factory=list)
conversation: Conversation = Field(default_factory=Conversation)
+ user_message_history: list[ModelMessage] = Field(
+ default_factory=list,
+ description="User conversation history (multi-turn interactions)",
+ )
# Type as Any to avoid circular imports/runtime resolution issues
# The actual object injected will be an EmbeddingService instance
embedding_service: Any = Field(default=None)
@@ -90,6 +99,31 @@ async def search_related(self, query: str, n_results: int = 5) -> list[Evidence]
return evidence_list
+ def add_user_message(self, message: ModelMessage) -> None:
+ """Add a user message to conversation history.
+
+ Args:
+ message: Message to add
+ """
+ self.user_message_history.append(message)
+
+ def get_user_history(self, max_messages: int | None = None) -> list[ModelMessage]:
+ """Get user conversation history.
+
+ Args:
+ max_messages: Maximum messages to return (None for all)
+
+ Returns:
+ List of messages
+ """
+ if max_messages is None:
+ return self.user_message_history.copy()
+ return (
+ self.user_message_history[-max_messages:]
+ if len(self.user_message_history) > max_messages
+ else self.user_message_history.copy()
+ )
+
# The ContextVar holds the WorkflowState for the current execution context
_workflow_state_var: ContextVar[WorkflowState | None] = ContextVar("workflow_state", default=None)
@@ -97,18 +131,26 @@ async def search_related(self, query: str, n_results: int = 5) -> list[Evidence]
def init_workflow_state(
embedding_service: "EmbeddingService | None" = None,
+ message_history: list[ModelMessage] | None = None,
) -> WorkflowState:
"""Initialize a new state for the current context.
Args:
embedding_service: Optional embedding service for semantic search.
+ message_history: Optional user conversation history.
Returns:
The initialized WorkflowState instance.
"""
state = WorkflowState(embedding_service=embedding_service)
+ if message_history:
+ state.user_message_history = message_history.copy()
_workflow_state_var.set(state)
- logger.debug("Workflow state initialized", has_embeddings=embedding_service is not None)
+ logger.debug(
+ "Workflow state initialized",
+ has_embeddings=embedding_service is not None,
+ has_history=bool(message_history),
+ )
return state
@@ -126,14 +168,4 @@ def get_workflow_state() -> WorkflowState:
# Auto-initialize if missing (e.g. during tests or simple scripts)
logger.debug("Workflow state not found, auto-initializing")
return init_workflow_state()
- return state
-
-
-
-
-
-
-
-
-
-
+ return state
\ No newline at end of file
diff --git a/src/orchestrator/graph_orchestrator.py b/src/orchestrator/graph_orchestrator.py
index 82650b9f..d71a337d 100644
--- a/src/orchestrator/graph_orchestrator.py
+++ b/src/orchestrator/graph_orchestrator.py
@@ -10,6 +10,11 @@
import structlog
+try:
+ from pydantic_ai import ModelMessage
+except ImportError:
+ ModelMessage = Any # type: ignore[assignment, misc]
+
from src.agent_factory.agents import (
create_input_parser_agent,
create_knowledge_gap_agent,
@@ -44,12 +49,18 @@
class GraphExecutionContext:
"""Context for managing graph execution state."""
- def __init__(self, state: WorkflowState, budget_tracker: BudgetTracker) -> None:
+ def __init__(
+ self,
+ state: WorkflowState,
+ budget_tracker: BudgetTracker,
+ message_history: list[ModelMessage] | None = None,
+ ) -> None:
"""Initialize execution context.
Args:
state: Current workflow state
budget_tracker: Budget tracker instance
+ message_history: Optional user conversation history
"""
self.current_node: str = ""
self.visited_nodes: set[str] = set()
@@ -57,6 +68,7 @@ def __init__(self, state: WorkflowState, budget_tracker: BudgetTracker) -> None:
self.state = state
self.budget_tracker = budget_tracker
self.iteration_count = 0
+ self.message_history: list[ModelMessage] = message_history or []
def set_node_result(self, node_id: str, result: Any) -> None:
"""Store result from node execution.
@@ -108,6 +120,31 @@ def update_state(
"""
self.state = updater(self.state, data)
+ def add_message(self, message: ModelMessage) -> None:
+ """Add a message to the history.
+
+ Args:
+ message: Message to add
+ """
+ self.message_history.append(message)
+
+ def get_message_history(self, max_messages: int | None = None) -> list[ModelMessage]:
+ """Get message history, optionally truncated.
+
+ Args:
+ max_messages: Maximum messages to return (None for all)
+
+ Returns:
+ List of messages
+ """
+ if max_messages is None:
+ return self.message_history.copy()
+ return (
+ self.message_history[-max_messages:]
+ if len(self.message_history) > max_messages
+ else self.message_history.copy()
+ )
+
class GraphOrchestrator:
"""
@@ -174,12 +211,15 @@ def _get_file_service(self) -> ReportFileService | None:
return None
return self._file_service
- async def run(self, query: str) -> AsyncGenerator[AgentEvent, None]:
+ async def run(
+ self, query: str, message_history: list[ModelMessage] | None = None
+ ) -> AsyncGenerator[AgentEvent, None]:
"""
Run the research workflow.
Args:
query: The user's research query
+ message_history: Optional user conversation history
Yields:
AgentEvent objects for real-time UI updates
@@ -189,6 +229,7 @@ async def run(self, query: str) -> AsyncGenerator[AgentEvent, None]:
query=query[:100],
mode=self.mode,
use_graph=self.use_graph,
+ has_history=bool(message_history),
)
yield AgentEvent(
@@ -205,10 +246,10 @@ async def run(self, query: str) -> AsyncGenerator[AgentEvent, None]:
# Use graph execution if enabled, otherwise fall back to agent chains
if self.use_graph:
- async for event in self._run_with_graph(query, research_mode):
+ async for event in self._run_with_graph(query, research_mode, message_history):
yield event
else:
- async for event in self._run_with_chains(query, research_mode):
+ async for event in self._run_with_chains(query, research_mode, message_history):
yield event
except Exception as e:
@@ -220,13 +261,17 @@ async def run(self, query: str) -> AsyncGenerator[AgentEvent, None]:
)
async def _run_with_graph(
- self, query: str, research_mode: Literal["iterative", "deep"]
+ self,
+ query: str,
+ research_mode: Literal["iterative", "deep"],
+ message_history: list[ModelMessage] | None = None,
) -> AsyncGenerator[AgentEvent, None]:
"""Run workflow using graph execution.
Args:
query: The research query
research_mode: The research mode
+ message_history: Optional user conversation history
Yields:
AgentEvent objects
@@ -235,7 +280,10 @@ async def _run_with_graph(
from src.services.embeddings import get_embedding_service
embedding_service = get_embedding_service()
- state = init_workflow_state(embedding_service=embedding_service)
+ state = init_workflow_state(
+ embedding_service=embedding_service,
+ message_history=message_history,
+ )
budget_tracker = BudgetTracker()
budget_tracker.create_budget(
loop_id="graph_execution",
@@ -245,7 +293,11 @@ async def _run_with_graph(
)
budget_tracker.start_timer("graph_execution")
- context = GraphExecutionContext(state, budget_tracker)
+ context = GraphExecutionContext(
+ state,
+ budget_tracker,
+ message_history=message_history or [],
+ )
# Build graph
self._graph = await self._build_graph(research_mode)
@@ -255,13 +307,17 @@ async def _run_with_graph(
yield event
async def _run_with_chains(
- self, query: str, research_mode: Literal["iterative", "deep"]
+ self,
+ query: str,
+ research_mode: Literal["iterative", "deep"],
+ message_history: list[ModelMessage] | None = None,
) -> AsyncGenerator[AgentEvent, None]:
"""Run workflow using agent chains (backward compatibility).
Args:
query: The research query
research_mode: The research mode
+ message_history: Optional user conversation history
Yields:
AgentEvent objects
@@ -282,7 +338,9 @@ async def _run_with_chains(
)
try:
- final_report = await self._iterative_flow.run(query)
+ final_report = await self._iterative_flow.run(
+ query, message_history=message_history
+ )
except Exception as e:
self.logger.error("Iterative flow failed", error=str(e), exc_info=True)
# Yield error event - outer handler will also catch and yield error event
@@ -318,7 +376,7 @@ async def _run_with_chains(
)
try:
- final_report = await self._deep_flow.run(query)
+ final_report = await self._deep_flow.run(query, message_history=message_history)
except Exception as e:
self.logger.error("Deep flow failed", error=str(e), exc_info=True)
# Yield error event before re-raising so test can capture it
@@ -488,73 +546,17 @@ def _emit_completion_event(
iteration=iteration,
)
- async def _execute_graph(
- self, query: str, context: GraphExecutionContext
- ) -> AsyncGenerator[AgentEvent, None]:
- """Execute the graph from entry node.
-
- Args:
- query: The research query
- context: Execution context
-
- Yields:
- AgentEvent objects
- """
+ def _get_final_result_from_exit_nodes(
+ self, context: GraphExecutionContext, current_node_id: str | None
+ ) -> tuple[Any, str | None]:
+ """Get final result from exit nodes, prioritizing synthesizer/writer."""
if not self._graph:
- raise ValueError("Graph not built")
+ return None, current_node_id
- current_node_id = self._graph.entry_node
- iteration = 0
-
- # Execute nodes until we reach an exit node
- while current_node_id:
- # Check budget
- if not context.budget_tracker.can_continue("graph_execution"):
- self.logger.warning("Budget exceeded, exiting graph execution")
- break
-
- # Execute current node
- iteration += 1
- context.current_node = current_node_id
- node = self._graph.get_node(current_node_id)
-
- # Emit start event
- yield self._emit_start_event(node, current_node_id, iteration, context)
-
- try:
- result = await self._execute_node(current_node_id, query, context)
- context.set_node_result(current_node_id, result)
- context.mark_visited(current_node_id)
-
- # Yield completion event
- yield self._emit_completion_event(node, current_node_id, result, iteration)
-
- except Exception as e:
- self.logger.error("Node execution failed", node_id=current_node_id, error=str(e))
- yield AgentEvent(
- type="error",
- message=f"Node {current_node_id} failed: {e!s}",
- iteration=iteration,
- )
- break
-
- # Check if current node is an exit node - if so, we're done
- if current_node_id in self._graph.exit_nodes:
- break
-
- # Get next node(s)
- next_nodes = self._get_next_node(current_node_id, context)
-
- if not next_nodes:
- # No more nodes, we've reached a dead end
- self.logger.warning("Reached dead end in graph", node_id=current_node_id)
- break
-
- current_node_id = next_nodes[0] # For now, take first next node (handle parallel later)
+ final_result = None
+ result_node_id = current_node_id
- # Final event - get result from exit nodes (prioritize synthesizer/writer nodes)
# First try to get result from current node (if it's an exit node)
- final_result = None
if current_node_id and current_node_id in self._graph.exit_nodes:
final_result = context.get_node_result(current_node_id)
self.logger.debug(
@@ -563,7 +565,7 @@ async def _execute_graph(
has_result=final_result is not None,
result_type=type(final_result).__name__ if final_result else None,
)
-
+
# If no result from current node, check all exit nodes for results
# Prioritize synthesizer (deep research) or writer (iterative research)
if not final_result:
@@ -573,28 +575,28 @@ async def _execute_graph(
result = context.get_node_result(exit_node_id)
if result:
final_result = result
- current_node_id = exit_node_id
+ result_node_id = exit_node_id
self.logger.debug(
"Final result from priority exit node",
node_id=exit_node_id,
result_type=type(final_result).__name__,
)
break
-
+
# If still no result, check all exit nodes
if not final_result:
for exit_node_id in self._graph.exit_nodes:
result = context.get_node_result(exit_node_id)
if result:
final_result = result
- current_node_id = exit_node_id
+ result_node_id = exit_node_id
self.logger.debug(
"Final result from any exit node",
node_id=exit_node_id,
result_type=type(final_result).__name__,
)
break
-
+
# Log warning if no result found
if not final_result:
self.logger.warning(
@@ -604,8 +606,11 @@ async def _execute_graph(
all_node_results=list(context.node_results.keys()),
)
- # Check if final result contains file information
- event_data: dict[str, Any] = {"mode": self.mode, "iterations": iteration}
+ return final_result, result_node_id
+
+ def _extract_final_message_and_files(self, final_result: Any) -> tuple[str, dict[str, Any]]:
+ """Extract message and file information from final result."""
+ event_data: dict[str, Any] = {"mode": self.mode}
message: str = "Research completed"
if isinstance(final_result, str):
@@ -619,7 +624,7 @@ async def _execute_graph(
"Final message extracted from dict 'message' key",
length=len(message) if isinstance(message, str) else 0,
)
-
+
# Then check for file paths
if "file" in final_result:
file_path = final_result["file"]
@@ -629,26 +634,89 @@ async def _execute_graph(
if "message" not in final_result:
message = "Report generated. Download available."
self.logger.debug("File path added to event data", file_path=file_path)
- elif "files" in final_result:
+
+ # Check for multiple files
+ if "files" in final_result:
files = final_result["files"]
if isinstance(files, list):
event_data["files"] = files
- # Only override message if not already set from "message" key
- if "message" not in final_result:
- message = "Report generated. Downloads available."
- elif isinstance(files, str):
- event_data["files"] = [files]
- # Only override message if not already set from "message" key
- if "message" not in final_result:
- message = "Report generated. Download available."
- self.logger.debug("File paths added to event data", count=len(event_data.get("files", [])))
- else:
- # Log warning if result type is unexpected
- self.logger.warning(
- "Final result has unexpected type",
- result_type=type(final_result).__name__ if final_result else None,
- result_repr=str(final_result)[:200] if final_result else None,
- )
+ self.logger.debug("Multiple files added to event data", count=len(files))
+
+ return message, event_data
+
+ async def _execute_graph(
+ self, query: str, context: GraphExecutionContext
+ ) -> AsyncGenerator[AgentEvent, None]:
+ """Execute the graph from entry node.
+
+ Args:
+ query: The research query
+ context: Execution context
+
+ Yields:
+ AgentEvent objects
+ """
+ if not self._graph:
+ raise ValueError("Graph not built")
+
+ current_node_id = self._graph.entry_node
+ iteration = 0
+
+ # Execute nodes until we reach an exit node
+ while current_node_id:
+ # Check budget
+ if not context.budget_tracker.can_continue("graph_execution"):
+ self.logger.warning("Budget exceeded, exiting graph execution")
+ break
+
+ # Execute current node
+ iteration += 1
+ context.current_node = current_node_id
+ node = self._graph.get_node(current_node_id)
+
+ # Emit start event
+ yield self._emit_start_event(node, current_node_id, iteration, context)
+
+ try:
+ result = await self._execute_node(current_node_id, query, context)
+ context.set_node_result(current_node_id, result)
+ context.mark_visited(current_node_id)
+
+ # Yield completion event
+ yield self._emit_completion_event(node, current_node_id, result, iteration)
+
+ except Exception as e:
+ self.logger.error("Node execution failed", node_id=current_node_id, error=str(e))
+ yield AgentEvent(
+ type="error",
+ message=f"Node {current_node_id} failed: {e!s}",
+ iteration=iteration,
+ )
+ break
+
+ # Check if current node is an exit node - if so, we're done
+ if current_node_id in self._graph.exit_nodes:
+ break
+
+ # Get next node(s)
+ next_nodes = self._get_next_node(current_node_id, context)
+
+ if not next_nodes:
+ # No more nodes, we've reached a dead end
+ self.logger.warning("Reached dead end in graph", node_id=current_node_id)
+ break
+
+ current_node_id = next_nodes[0] # For now, take first next node (handle parallel later)
+
+ # Final event - get result from exit nodes (prioritize synthesizer/writer nodes)
+ final_result, result_node_id = self._get_final_result_from_exit_nodes(
+ context, current_node_id
+ )
+
+ # Check if final result contains file information
+ event_data: dict[str, Any] = {"mode": self.mode, "iterations": iteration}
+ message, file_event_data = self._extract_final_message_and_files(final_result)
+ event_data.update(file_event_data)
yield AgentEvent(
type="complete",
@@ -686,170 +754,121 @@ async def _execute_node(self, node_id: str, query: str, context: GraphExecutionC
else:
raise ValueError(f"Unknown node type: {type(node)}")
- async def _execute_agent_node(
- self, node: AgentNode, query: str, context: GraphExecutionContext
- ) -> Any:
- """Execute an agent node.
-
- Special handling for deep research nodes:
- - "planner": Takes query string, returns ReportPlan
- - "synthesizer": Takes query + ReportPlan + section drafts, returns final report
-
- Args:
- node: The agent node
- query: The research query
- context: Execution context
-
- Returns:
- Agent execution result
- """
- # Special handling for synthesizer node (deep research)
- if node.node_id == "synthesizer":
- # Call LongWriterAgent.write_report() directly instead of using agent.run()
- from src.agent_factory.agents import create_long_writer_agent
- from src.utils.models import ReportDraft, ReportDraftSection, ReportPlan
+ async def _execute_synthesizer_node(self, query: str, context: GraphExecutionContext) -> Any:
+ """Execute synthesizer node for deep research."""
+ from src.agent_factory.agents import create_long_writer_agent
+ from src.utils.models import ReportDraft, ReportDraftSection, ReportPlan
- report_plan = context.get_node_result("planner")
- section_drafts = context.get_node_result("parallel_loops") or []
+ report_plan = context.get_node_result("planner")
+ section_drafts = context.get_node_result("parallel_loops") or []
- if not isinstance(report_plan, ReportPlan):
- raise ValueError("ReportPlan not found for synthesizer")
+ if not isinstance(report_plan, ReportPlan):
+ raise ValueError("ReportPlan not found for synthesizer")
- if not section_drafts:
- raise ValueError("Section drafts not found for synthesizer")
+ if not section_drafts:
+ raise ValueError("Section drafts not found for synthesizer")
- # Create ReportDraft from section drafts
- report_draft = ReportDraft(
- sections=[
- ReportDraftSection(
- section_title=section.title,
- section_content=draft,
- )
- for section, draft in zip(
- report_plan.report_outline, section_drafts, strict=False
- )
- ]
- )
-
- # Get LongWriterAgent instance and call write_report directly
- long_writer_agent = create_long_writer_agent(oauth_token=self.oauth_token)
- final_report = await long_writer_agent.write_report(
- original_query=query,
- report_title=report_plan.report_title,
- report_draft=report_draft,
- )
+ # Create ReportDraft from section drafts
+ report_draft = ReportDraft(
+ sections=[
+ ReportDraftSection(
+ section_title=section.title,
+ section_content=draft,
+ )
+ for section, draft in zip(report_plan.report_outline, section_drafts, strict=False)
+ ]
+ )
- # Estimate tokens (rough estimate)
- estimated_tokens = len(final_report) // 4 # Rough token estimate
- context.budget_tracker.add_tokens("graph_execution", estimated_tokens)
+ # Get LongWriterAgent instance and call write_report directly
+ long_writer_agent = create_long_writer_agent(oauth_token=self.oauth_token)
+ final_report = await long_writer_agent.write_report(
+ original_query=query,
+ report_title=report_plan.report_title,
+ report_draft=report_draft,
+ )
- # Save report to file if enabled (may generate multiple formats)
- file_path: str | None = None
- pdf_path: str | None = None
- try:
- file_service = self._get_file_service()
- if file_service:
- # Use save_report_multiple_formats to get both MD and PDF if enabled
- saved_files = file_service.save_report_multiple_formats(
- report_content=final_report,
- query=query,
- )
- file_path = saved_files.get("md")
- pdf_path = saved_files.get("pdf")
- self.logger.info(
- "Report saved to file",
- md_path=file_path,
- pdf_path=pdf_path,
- )
- except Exception as e:
- # Don't fail the entire operation if file saving fails
- self.logger.warning("Failed to save report to file", error=str(e))
- file_path = None
- pdf_path = None
-
- # Return dict with file paths if available, otherwise return string (backward compatible)
- if file_path:
- result: dict[str, Any] = {
- "message": final_report,
- "file": file_path,
- }
- # Add PDF path if generated
- if pdf_path:
- result["files"] = [file_path, pdf_path]
- return result
- return final_report
+ # Estimate tokens (rough estimate)
+ estimated_tokens = len(final_report) // 4 # Rough token estimate
+ context.budget_tracker.add_tokens("graph_execution", estimated_tokens)
- # Special handling for writer node (iterative research)
- if node.node_id == "writer":
- # Call WriterAgent.write_report() directly instead of using agent.run()
- # Collect all findings from workflow state
- from src.agent_factory.agents import create_writer_agent
+ # Save report to file if enabled (may generate multiple formats)
+ return self._save_report_and_return_result(final_report, query)
- # Get all evidence from workflow state and convert to findings string
- evidence = context.state.evidence
- if evidence:
- # Convert evidence to findings format (similar to conversation.get_all_findings())
- findings_parts: list[str] = []
- for ev in evidence:
- finding = f"**{ev.title}**\n{ev.content}"
- if ev.url:
- finding += f"\nSource: {ev.url}"
- findings_parts.append(finding)
- all_findings = "\n\n".join(findings_parts)
- else:
- all_findings = "No findings available yet."
+ def _save_report_and_return_result(self, final_report: str, query: str) -> dict[str, Any] | str:
+ """Save report to file and return result with file paths if available."""
+ file_path: str | None = None
+ pdf_path: str | None = None
+ try:
+ file_service = self._get_file_service()
+ if file_service:
+ # Use save_report_multiple_formats to get both MD and PDF if enabled
+ saved_files = file_service.save_report_multiple_formats(
+ report_content=final_report,
+ query=query,
+ )
+ file_path = saved_files.get("md")
+ pdf_path = saved_files.get("pdf")
+ self.logger.info(
+ "Report saved to file",
+ md_path=file_path,
+ pdf_path=pdf_path,
+ )
+ except Exception as e:
+ # Don't fail the entire operation if file saving fails
+ self.logger.warning("Failed to save report to file", error=str(e))
+ file_path = None
+ pdf_path = None
+
+ # Return dict with file paths if available, otherwise return string (backward compatible)
+ if file_path:
+ result: dict[str, Any] = {
+ "message": final_report,
+ "file": file_path,
+ }
+ # Add PDF path if generated
+ if pdf_path:
+ result["files"] = [file_path, pdf_path]
+ return result
+ return final_report
+
+ async def _execute_writer_node(self, query: str, context: GraphExecutionContext) -> Any:
+ """Execute writer node for iterative research."""
+ from src.agent_factory.agents import create_writer_agent
+
+ # Get all evidence from workflow state and convert to findings string
+ evidence = context.state.evidence
+ if evidence:
+ # Convert evidence to findings format (similar to conversation.get_all_findings())
+ findings_parts: list[str] = []
+ for ev in evidence:
+ finding = f"**{ev.citation.title}**\n{ev.content}"
+ if ev.citation.url:
+ finding += f"\nSource: {ev.citation.url}"
+ findings_parts.append(finding)
+ all_findings = "\n\n".join(findings_parts)
+ else:
+ all_findings = "No findings available yet."
+
+ # Get WriterAgent instance and call write_report directly
+ writer_agent = create_writer_agent(oauth_token=self.oauth_token)
+ final_report = await writer_agent.write_report(
+ query=query,
+ findings=all_findings,
+ output_length="",
+ output_instructions="",
+ )
- # Get WriterAgent instance and call write_report directly
- writer_agent = create_writer_agent(oauth_token=self.oauth_token)
- final_report = await writer_agent.write_report(
- query=query,
- findings=all_findings,
- output_length="",
- output_instructions="",
- )
+ # Estimate tokens (rough estimate)
+ estimated_tokens = len(final_report) // 4 # Rough token estimate
+ context.budget_tracker.add_tokens("graph_execution", estimated_tokens)
- # Estimate tokens (rough estimate)
- estimated_tokens = len(final_report) // 4 # Rough token estimate
- context.budget_tracker.add_tokens("graph_execution", estimated_tokens)
+ # Save report to file if enabled (may generate multiple formats)
+ return self._save_report_and_return_result(final_report, query)
- # Save report to file if enabled (may generate multiple formats)
- file_path: str | None = None
- pdf_path: str | None = None
- try:
- file_service = self._get_file_service()
- if file_service:
- # Use save_report_multiple_formats to get both MD and PDF if enabled
- saved_files = file_service.save_report_multiple_formats(
- report_content=final_report,
- query=query,
- )
- file_path = saved_files.get("md")
- pdf_path = saved_files.get("pdf")
- self.logger.info(
- "Report saved to file",
- md_path=file_path,
- pdf_path=pdf_path,
- )
- except Exception as e:
- # Don't fail the entire operation if file saving fails
- self.logger.warning("Failed to save report to file", error=str(e))
- file_path = None
- pdf_path = None
-
- # Return dict with file paths if available, otherwise return string (backward compatible)
- if file_path:
- result: dict[str, Any] = {
- "message": final_report,
- "file": file_path,
- }
- # Add PDF path if generated
- if pdf_path:
- result["files"] = [file_path, pdf_path]
- return result
- return final_report
-
- # Standard agent execution
- # Prepare input based on node type
+ def _prepare_agent_input(
+ self, node: AgentNode, query: str, context: GraphExecutionContext
+ ) -> Any:
+ """Prepare input data for agent execution."""
if node.node_id == "planner":
# Planner takes the original query
input_data = query
@@ -862,94 +881,348 @@ async def _execute_agent_node(
if node.input_transformer:
input_data = node.input_transformer(input_data)
- # Execute agent with error handling
+ return input_data
+
+ async def _execute_standard_agent(
+ self, node: AgentNode, input_data: Any, query: str, context: GraphExecutionContext
+ ) -> Any:
+ """Execute standard agent with error handling and fallback models."""
+ # Get message history from context (limit to most recent 10 messages for token efficiency)
+ message_history = context.get_message_history(max_messages=10)
+
+ # Try with the original agent first
try:
- result = await node.agent.run(input_data)
+ # Pass message_history if available (Pydantic AI agents support this)
+ if message_history:
+ result = await node.agent.run(input_data, message_history=message_history)
+ else:
+ result = await node.agent.run(input_data)
+
+ # Accumulate new messages from agent result if available
+ if hasattr(result, "new_messages"):
+ try:
+ new_messages = result.new_messages()
+ for msg in new_messages:
+ context.add_message(msg)
+ except Exception as e:
+ # Don't fail if message accumulation fails
+ self.logger.debug(
+ "Failed to accumulate messages from agent result", error=str(e)
+ )
+ return result
except Exception as e:
- # Handle validation errors and API errors for planner node
+ # Check if we should retry with fallback models
+ from src.utils.hf_error_handler import (
+ extract_error_details,
+ should_retry_with_fallback,
+ )
+
+ error_details = extract_error_details(e)
+ should_retry = should_retry_with_fallback(e)
+
+ # Handle validation errors and API errors for planner node (with fallback)
if node.node_id == "planner":
- self.logger.error(
- "Planner agent execution failed, using fallback plan",
- error=str(e),
- error_type=type(e).__name__,
+ if should_retry:
+ self.logger.warning(
+ "Planner failed, trying fallback models",
+ original_error=str(e),
+ status_code=error_details.get("status_code"),
+ )
+ # Try fallback models for planner
+ fallback_result = await self._try_fallback_models(
+ node, input_data, message_history, query, context, e
+ )
+ if fallback_result is not None:
+ return fallback_result
+ # If fallback failed or not applicable, use fallback plan
+ return self._create_fallback_plan(query, input_data)
+
+ # For other nodes, try fallback models if applicable
+ if should_retry:
+ self.logger.warning(
+ "Agent node failed, trying fallback models",
+ node_id=node.node_id,
+ original_error=str(e),
+ status_code=error_details.get("status_code"),
)
- # Return a minimal fallback ReportPlan
- from src.utils.models import ReportPlan, ReportPlanSection
-
- # Extract query from input_data if possible
- fallback_query = query
- if isinstance(input_data, str):
- # Try to extract query from input string
- if "QUERY:" in input_data:
- fallback_query = input_data.split("QUERY:")[-1].strip()
-
- return ReportPlan(
- background_context="",
- report_outline=[
- ReportPlanSection(
- title="Research Findings",
- key_question=fallback_query,
- )
- ],
- report_title=f"Research Report: {fallback_query[:50]}",
+ fallback_result = await self._try_fallback_models(
+ node, input_data, message_history, query, context, e
)
- # For other nodes, re-raise the exception
+ if fallback_result is not None:
+ return fallback_result
+
+ # If fallback didn't work or wasn't applicable, re-raise the exception
raise
- # Transform output if needed
+ async def _try_fallback_models(
+ self,
+ node: AgentNode,
+ input_data: Any,
+ message_history: list[Any],
+ query: str,
+ context: GraphExecutionContext,
+ original_error: Exception,
+ ) -> Any | None:
+ """Try executing agent with fallback models.
+
+ Args:
+ node: The agent node that failed
+ input_data: Input data for the agent
+ message_history: Message history for the agent
+ query: The research query
+ context: Execution context
+ original_error: The original error that triggered fallback
+
+ Returns:
+ Agent result if successful, None if all fallbacks failed
+ """
+ from src.utils.hf_error_handler import extract_error_details, get_fallback_models
+
+ error_details = extract_error_details(original_error)
+ original_model = error_details.get("model_name")
+ fallback_models = get_fallback_models(original_model)
+
+ # Also try models from settings fallback list
+ from src.utils.config import settings
+
+ settings_fallbacks = settings.get_hf_fallback_models_list()
+ for model in settings_fallbacks:
+ if model not in fallback_models:
+ fallback_models.append(model)
+
+ self.logger.info(
+ "Trying fallback models",
+ node_id=node.node_id,
+ original_model=original_model,
+ fallback_count=len(fallback_models),
+ )
+
+ # Try each fallback model
+ for fallback_model in fallback_models:
+ try:
+ # Recreate agent with fallback model
+ fallback_agent = self._recreate_agent_with_model(node.node_id, fallback_model)
+ if fallback_agent is None:
+ continue
+
+ # Try running with fallback agent
+ if message_history:
+ result = await fallback_agent.run(input_data, message_history=message_history)
+ else:
+ result = await fallback_agent.run(input_data)
+
+ self.logger.info(
+ "Fallback model succeeded",
+ node_id=node.node_id,
+ fallback_model=fallback_model,
+ )
+
+ # Accumulate new messages from agent result if available
+ if hasattr(result, "new_messages"):
+ try:
+ new_messages = result.new_messages()
+ for msg in new_messages:
+ context.add_message(msg)
+ except Exception as e:
+ self.logger.debug(
+ "Failed to accumulate messages from fallback agent result", error=str(e)
+ )
+
+ return result
+
+ except Exception as e:
+ self.logger.warning(
+ "Fallback model failed",
+ node_id=node.node_id,
+ fallback_model=fallback_model,
+ error=str(e),
+ )
+ continue
+
+ # All fallback models failed
+ self.logger.error(
+ "All fallback models failed",
+ node_id=node.node_id,
+ fallback_count=len(fallback_models),
+ )
+ return None
+
+ def _recreate_agent_with_model(self, node_id: str, model_name: str) -> Any | None:
+ """Recreate an agent with a specific model.
+
+ Args:
+ node_id: The node ID (e.g., "thinking", "knowledge_gap")
+ model_name: The model name to use
+
+ Returns:
+ Agent instance or None if recreation failed
+ """
+ try:
+ from pydantic_ai.models.huggingface import HuggingFaceModel
+ from pydantic_ai.providers.huggingface import HuggingFaceProvider
+
+ # Create model with fallback model name
+ hf_provider = HuggingFaceProvider(api_key=self.oauth_token)
+ model = HuggingFaceModel(model_name, provider=hf_provider)
+
+ # Recreate agent based on node_id
+ if node_id == "thinking":
+ from src.agent_factory.agents import create_thinking_agent
+
+ agent_wrapper = create_thinking_agent(model=model, oauth_token=self.oauth_token)
+ return agent_wrapper.agent
+ elif node_id == "knowledge_gap":
+ from src.agent_factory.agents import create_knowledge_gap_agent
+
+ agent_wrapper = create_knowledge_gap_agent( # type: ignore[assignment]
+ model=model, oauth_token=self.oauth_token
+ )
+ return agent_wrapper.agent
+ elif node_id == "tool_selector":
+ from src.agent_factory.agents import create_tool_selector_agent
+
+ agent_wrapper = create_tool_selector_agent( # type: ignore[assignment]
+ model=model, oauth_token=self.oauth_token
+ )
+ return agent_wrapper.agent
+ elif node_id == "planner":
+ from src.agent_factory.agents import create_planner_agent
+
+ agent_wrapper = create_planner_agent(model=model, oauth_token=self.oauth_token) # type: ignore[assignment]
+ return agent_wrapper.agent
+ elif node_id == "writer":
+ from src.agent_factory.agents import create_writer_agent
+
+ agent_wrapper = create_writer_agent(model=model, oauth_token=self.oauth_token) # type: ignore[assignment]
+ return agent_wrapper.agent
+ else:
+ self.logger.warning("Unknown node_id for agent recreation", node_id=node_id)
+ return None
+
+ except Exception as e:
+ self.logger.error(
+ "Failed to recreate agent with fallback model",
+ node_id=node_id,
+ model_name=model_name,
+ error=str(e),
+ )
+ return None
+
+ def _create_fallback_plan(self, query: str, input_data: Any) -> Any:
+ """Create fallback ReportPlan when planner fails."""
+ from src.utils.models import ReportPlan, ReportPlanSection
+
+ self.logger.error(
+ "Planner agent execution failed, using fallback plan",
+ error_type=type(input_data).__name__,
+ )
+
+ # Extract query from input_data if possible
+ fallback_query = query
+ if isinstance(input_data, str):
+ # Try to extract query from input string
+ if "QUERY:" in input_data:
+ fallback_query = input_data.split("QUERY:")[-1].strip()
+
+ return ReportPlan(
+ background_context="",
+ report_outline=[
+ ReportPlanSection(
+ title="Research Findings",
+ key_question=fallback_query,
+ )
+ ],
+ report_title=f"Research Report: {fallback_query[:50]}",
+ )
+
+ def _extract_agent_output(self, node: AgentNode, result: Any) -> Any:
+ """Extract and transform output from agent result."""
# Defensively extract output - handle various result formats
output = result.output if hasattr(result, "output") else result
# Handle case where output might be a tuple (from pydantic-ai validation errors)
if isinstance(output, tuple):
- # If tuple contains a dict-like structure, try to reconstruct the object
- if len(output) == 2 and isinstance(output[0], str) and output[0] == "research_complete":
- # This is likely a validation error format: ('research_complete', False)
- # Try to get the actual output from result
- self.logger.warning(
- "Agent result output is a tuple, attempting to extract actual output",
+ output = self._handle_tuple_output(node, output, result)
+ return output
+
+ def _handle_tuple_output(self, node: AgentNode, output: tuple[Any, ...], result: Any) -> Any:
+ """Handle tuple output from agent (validation errors)."""
+ # If tuple contains a dict-like structure, try to reconstruct the object
+ if len(output) == 2 and isinstance(output[0], str) and output[0] == "research_complete":
+ # This is likely a validation error format: ('research_complete', False)
+ # Try to get the actual output from result
+ self.logger.warning(
+ "Agent result output is a tuple, attempting to extract actual output",
+ node_id=node.node_id,
+ tuple_value=output,
+ )
+ # Try to get output from result attributes
+ if hasattr(result, "data"):
+ return result.data
+ if hasattr(result, "response"):
+ return result.response
+ # Last resort: try to reconstruct from tuple
+ # This shouldn't happen, but handle gracefully
+ from src.utils.models import KnowledgeGapOutput
+
+ if node.node_id == "knowledge_gap":
+ # Reconstruct KnowledgeGapOutput from validation error tuple
+ reconstructed = KnowledgeGapOutput(
+ research_complete=output[1] if len(output) > 1 else False,
+ outstanding_gaps=[],
+ )
+ self.logger.info(
+ "Reconstructed KnowledgeGapOutput from validation error tuple",
node_id=node.node_id,
- tuple_value=output,
+ research_complete=reconstructed.research_complete,
)
- # Try to get output from result attributes
- if hasattr(result, "data"):
- output = result.data
- elif hasattr(result, "response"):
- output = result.response
- else:
- # Last resort: try to reconstruct from tuple
- # This shouldn't happen, but handle gracefully
- from src.utils.models import KnowledgeGapOutput
+ return reconstructed
- if node.node_id == "knowledge_gap":
- # Reconstruct KnowledgeGapOutput from validation error tuple
- output = KnowledgeGapOutput(
- research_complete=output[1] if len(output) > 1 else False,
- outstanding_gaps=[],
- )
- self.logger.info(
- "Reconstructed KnowledgeGapOutput from validation error tuple",
- node_id=node.node_id,
- research_complete=output.research_complete,
- )
- else:
- # For other nodes, try to extract meaningful output or use fallback
- self.logger.warning(
- "Agent node output is tuple format, attempting extraction",
- node_id=node.node_id,
- tuple_value=output,
- )
- # Try to extract first meaningful element
- if len(output) > 0:
- # If first element is a string or dict, might be the actual output
- if isinstance(output[0], (str, dict)):
- output = output[0]
- else:
- # Last resort: use first element
- output = output[0]
- else:
- # Empty tuple - use None and let downstream handle it
- output = None
+ # For other nodes, try to extract meaningful output or use fallback
+ self.logger.warning(
+ "Agent node output is tuple format, attempting extraction",
+ node_id=node.node_id,
+ tuple_value=output,
+ )
+ # Try to extract first meaningful element
+ if len(output) > 0:
+ # If first element is a string or dict, might be the actual output
+ if isinstance(output[0], str | dict):
+ return output[0]
+ # Last resort: use first element
+ return output[0]
+ # Empty tuple - use None and let downstream handle it
+ return None
+
+ async def _execute_agent_node(
+ self, node: AgentNode, query: str, context: GraphExecutionContext
+ ) -> Any:
+ """Execute an agent node.
+
+ Special handling for deep research nodes:
+ - "planner": Takes query string, returns ReportPlan
+ - "synthesizer": Takes query + ReportPlan + section drafts, returns final report
+
+ Args:
+ node: The agent node
+ query: The research query
+ context: Execution context
+
+ Returns:
+ Agent execution result
+ """
+ # Special handling for synthesizer node (deep research)
+ if node.node_id == "synthesizer":
+ return await self._execute_synthesizer_node(query, context)
+
+ # Special handling for writer node (iterative research)
+ if node.node_id == "writer":
+ return await self._execute_writer_node(query, context)
+
+ # Standard agent execution
+ input_data = self._prepare_agent_input(node, query, context)
+ result = await self._execute_standard_agent(node, input_data, query, context)
+ output = self._extract_agent_output(node, result)
if node.output_transformer:
output = node.output_transformer(output)
@@ -1133,10 +1406,15 @@ async def _execute_decision_node(
prev_result = prev_result[0]
elif len(prev_result) > 1 and hasattr(prev_result[1], "research_complete"):
prev_result = prev_result[1]
- elif len(prev_result) == 2 and isinstance(prev_result[0], str) and prev_result[0] == "research_complete":
+ elif (
+ len(prev_result) == 2
+ and isinstance(prev_result[0], str)
+ and prev_result[0] == "research_complete"
+ ):
# Handle validation error format: ('research_complete', False)
# Reconstruct KnowledgeGapOutput from tuple
from src.utils.models import KnowledgeGapOutput
+
self.logger.warning(
"Decision node received validation error tuple, reconstructing KnowledgeGapOutput",
node_id=node.node_id,
@@ -1157,6 +1435,7 @@ async def _execute_decision_node(
# Try to reconstruct KnowledgeGapOutput if this is from knowledge_gap node
if prev_node_id == "knowledge_gap":
from src.utils.models import KnowledgeGapOutput
+
# Try to extract research_complete from tuple
research_complete = False
for item in prev_result:
diff --git a/src/orchestrator/research_flow.py b/src/orchestrator/research_flow.py
index 52756654..3ce777bd 100644
--- a/src/orchestrator/research_flow.py
+++ b/src/orchestrator/research_flow.py
@@ -10,6 +10,11 @@
import structlog
+try:
+ from pydantic_ai import ModelMessage
+except ImportError:
+ ModelMessage = Any # type: ignore[assignment, misc]
+
from src.agent_factory.agents import (
create_graph_orchestrator,
create_knowledge_gap_agent,
@@ -137,6 +142,7 @@ async def run(
background_context: str = "",
output_length: str = "",
output_instructions: str = "",
+ message_history: list[ModelMessage] | None = None,
) -> str:
"""
Run the iterative research flow.
@@ -146,17 +152,18 @@ async def run(
background_context: Optional background context
output_length: Optional description of desired output length
output_instructions: Optional additional instructions
+ message_history: Optional user conversation history
Returns:
Final report string
"""
if self.use_graph:
return await self._run_with_graph(
- query, background_context, output_length, output_instructions
+ query, background_context, output_length, output_instructions, message_history
)
else:
return await self._run_with_chains(
- query, background_context, output_length, output_instructions
+ query, background_context, output_length, output_instructions, message_history
)
async def _run_with_chains(
@@ -165,6 +172,7 @@ async def _run_with_chains(
background_context: str = "",
output_length: str = "",
output_instructions: str = "",
+ message_history: list[ModelMessage] | None = None,
) -> str:
"""
Run the iterative research flow using agent chains.
@@ -174,6 +182,7 @@ async def _run_with_chains(
background_context: Optional background context
output_length: Optional description of desired output length
output_instructions: Optional additional instructions
+ message_history: Optional user conversation history
Returns:
Final report string
@@ -193,10 +202,10 @@ async def _run_with_chains(
self.conversation.add_iteration()
# 1. Generate observations
- await self._generate_observations(query, background_context)
+ await self._generate_observations(query, background_context, message_history)
# 2. Evaluate gaps
- evaluation = await self._evaluate_gaps(query, background_context)
+ evaluation = await self._evaluate_gaps(query, background_context, message_history)
# 3. Assess with judge (after tools execute, we'll assess again)
# For now, check knowledge gap evaluation
@@ -210,7 +219,9 @@ async def _run_with_chains(
# 4. Select tools for next gap
next_gap = evaluation.outstanding_gaps[0] if evaluation.outstanding_gaps else query
- selection_plan = await self._select_agents(next_gap, query, background_context)
+ selection_plan = await self._select_agents(
+ next_gap, query, background_context, message_history
+ )
# 5. Execute tools
await self._execute_tools(selection_plan.tasks)
@@ -250,6 +261,7 @@ async def _run_with_graph(
background_context: str = "",
output_length: str = "",
output_instructions: str = "",
+ message_history: list[ModelMessage] | None = None,
) -> str:
"""
Run the iterative research flow using graph execution.
@@ -313,7 +325,12 @@ def _check_constraints(self) -> bool:
return True
- async def _generate_observations(self, query: str, background_context: str = "") -> str:
+ async def _generate_observations(
+ self,
+ query: str,
+ background_context: str = "",
+ message_history: list[ModelMessage] | None = None,
+ ) -> str:
"""Generate observations from current research state."""
# Build input prompt for token estimation
conversation_history = self.conversation.compile_conversation_history()
@@ -335,6 +352,7 @@ async def _generate_observations(self, query: str, background_context: str = "")
query=query,
background_context=background_context,
conversation_history=conversation_history,
+ message_history=message_history,
iteration=self.iteration,
)
@@ -350,7 +368,12 @@ async def _generate_observations(self, query: str, background_context: str = "")
self.conversation.set_latest_thought(observations)
return observations
- async def _evaluate_gaps(self, query: str, background_context: str = "") -> KnowledgeGapOutput:
+ async def _evaluate_gaps(
+ self,
+ query: str,
+ background_context: str = "",
+ message_history: list[ModelMessage] | None = None,
+ ) -> KnowledgeGapOutput:
"""Evaluate knowledge gaps in current research."""
if self.start_time:
elapsed_minutes = (time.time() - self.start_time) / 60
@@ -377,6 +400,7 @@ async def _evaluate_gaps(self, query: str, background_context: str = "") -> Know
query=query,
background_context=background_context,
conversation_history=conversation_history,
+ message_history=message_history,
iteration=self.iteration,
time_elapsed_minutes=elapsed_minutes,
max_time_minutes=self.max_time_minutes,
@@ -437,7 +461,11 @@ async def _assess_with_judge(self, query: str) -> JudgeAssessment:
return assessment
async def _select_agents(
- self, gap: str, query: str, background_context: str = ""
+ self,
+ gap: str,
+ query: str,
+ background_context: str = "",
+ message_history: list[ModelMessage] | None = None,
) -> AgentSelectionPlan:
"""Select tools to address knowledge gap."""
# Build input prompt for token estimation
@@ -461,6 +489,7 @@ async def _select_agents(
query=query,
background_context=background_context,
conversation_history=conversation_history,
+ message_history=message_history,
)
# Track tokens for this iteration
@@ -775,27 +804,31 @@ def _get_file_service(self) -> ReportFileService | None:
return None
return self._file_service
- async def run(self, query: str) -> str:
+ async def run(self, query: str, message_history: list[ModelMessage] | None = None) -> str:
"""
Run the deep research flow.
Args:
query: The research query
+ message_history: Optional user conversation history
Returns:
Final report string
"""
if self.use_graph:
- return await self._run_with_graph(query)
+ return await self._run_with_graph(query, message_history)
else:
- return await self._run_with_chains(query)
+ return await self._run_with_chains(query, message_history)
- async def _run_with_chains(self, query: str) -> str:
+ async def _run_with_chains(
+ self, query: str, message_history: list[ModelMessage] | None = None
+ ) -> str:
"""
Run the deep research flow using agent chains.
Args:
query: The research query
+ message_history: Optional user conversation history
Returns:
Final report string
@@ -812,11 +845,11 @@ async def _run_with_chains(self, query: str) -> str:
embedding_service = None
self.logger.debug("Embedding service unavailable, initializing state without it")
- init_workflow_state(embedding_service=embedding_service)
+ init_workflow_state(embedding_service=embedding_service, message_history=message_history)
self.logger.debug("Workflow state initialized for deep research")
# 1. Build report plan
- report_plan = await self._build_report_plan(query)
+ report_plan = await self._build_report_plan(query, message_history)
self.logger.info(
"Report plan created",
sections=len(report_plan.report_outline),
@@ -824,7 +857,7 @@ async def _run_with_chains(self, query: str) -> str:
)
# 2. Run parallel research loops with state synchronization
- section_drafts = await self._run_research_loops(report_plan)
+ section_drafts = await self._run_research_loops(report_plan, message_history)
# Verify state synchronization - log evidence count
state = get_workflow_state()
@@ -845,12 +878,15 @@ async def _run_with_chains(self, query: str) -> str:
return final_report
- async def _run_with_graph(self, query: str) -> str:
+ async def _run_with_graph(
+ self, query: str, message_history: list[ModelMessage] | None = None
+ ) -> str:
"""
Run the deep research flow using graph execution.
Args:
query: The research query
+ message_history: Optional user conversation history
Returns:
Final report string
@@ -868,7 +904,7 @@ async def _run_with_graph(self, query: str) -> str:
# Run orchestrator and collect events
final_report = ""
- async for event in self._graph_orchestrator.run(query):
+ async for event in self._graph_orchestrator.run(query, message_history=message_history):
if event.type == "complete":
final_report = event.message
break
@@ -884,13 +920,17 @@ async def _run_with_graph(self, query: str) -> str:
return final_report
- async def _build_report_plan(self, query: str) -> ReportPlan:
+ async def _build_report_plan(
+ self, query: str, message_history: list[ModelMessage] | None = None
+ ) -> ReportPlan:
"""Build the initial report plan."""
self.logger.info("Building report plan")
# Build input prompt for token estimation
input_prompt = f"QUERY: {query}"
+ # Planner agent may not support message_history yet, so we'll pass it if available
+ # For now, just use the standard run() call
report_plan = await self.planner_agent.run(query)
# Track tokens for planner agent
@@ -913,7 +953,9 @@ async def _build_report_plan(self, query: str) -> ReportPlan:
return report_plan
- async def _run_research_loops(self, report_plan: ReportPlan) -> list[str]:
+ async def _run_research_loops(
+ self, report_plan: ReportPlan, message_history: list[ModelMessage] | None = None
+ ) -> list[str]:
"""Run parallel iterative research loops for each section."""
self.logger.info("Running research loops", sections=len(report_plan.report_outline))
@@ -950,10 +992,11 @@ async def run_research_for_section(config: dict[str, Any]) -> str:
judge_handler=self.judge_handler if not self.use_graph else None,
)
- # Run research
+ # Run research with message_history
result = await flow.run(
query=query,
background_context=background_context,
+ message_history=message_history,
)
# Sync evidence from flow to loop
diff --git a/src/orchestrator_hierarchical.py b/src/orchestrator_hierarchical.py
index bf3848ad..a7bfb85a 100644
--- a/src/orchestrator_hierarchical.py
+++ b/src/orchestrator_hierarchical.py
@@ -2,9 +2,15 @@
import asyncio
from collections.abc import AsyncGenerator
+from typing import Any
import structlog
+try:
+ from pydantic_ai import ModelMessage
+except ImportError:
+ ModelMessage = Any # type: ignore[assignment, misc]
+
from src.agents.judge_agent_llm import LLMSubIterationJudge
from src.agents.magentic_agents import create_search_agent
from src.middleware.sub_iteration import SubIterationMiddleware, SubIterationTeam
@@ -38,8 +44,14 @@ def __init__(self) -> None:
self.judge = LLMSubIterationJudge()
self.middleware = SubIterationMiddleware(self.team, self.judge, max_iterations=5)
- async def run(self, query: str) -> AsyncGenerator[AgentEvent, None]:
- logger.info("Starting hierarchical orchestrator", query=query)
+ async def run(
+ self, query: str, message_history: list[ModelMessage] | None = None
+ ) -> AsyncGenerator[AgentEvent, None]:
+ logger.info(
+ "Starting hierarchical orchestrator",
+ query=query,
+ has_history=bool(message_history),
+ )
try:
service = get_embedding_service()
@@ -58,6 +70,8 @@ async def run(self, query: str) -> AsyncGenerator[AgentEvent, None]:
async def event_callback(event: AgentEvent) -> None:
await queue.put(event)
+ # Note: middleware.run() may not support message_history yet
+ # Pass query for now, message_history can be added to middleware later if needed
task_future = asyncio.create_task(self.middleware.run(query, event_callback))
while not task_future.done():
diff --git a/src/orchestrator_magentic.py b/src/orchestrator_magentic.py
index fd9d4f72..416d8968 100644
--- a/src/orchestrator_magentic.py
+++ b/src/orchestrator_magentic.py
@@ -4,6 +4,11 @@
from typing import TYPE_CHECKING, Any
import structlog
+
+try:
+ from pydantic_ai import ModelMessage
+except ImportError:
+ ModelMessage = Any # type: ignore[assignment, misc]
from agent_framework import (
MagenticAgentDeltaEvent,
MagenticAgentMessageEvent,
@@ -98,17 +103,24 @@ def _build_workflow(self) -> Any:
.build()
)
- async def run(self, query: str) -> AsyncGenerator[AgentEvent, None]:
+ async def run(
+ self, query: str, message_history: list[ModelMessage] | None = None
+ ) -> AsyncGenerator[AgentEvent, None]:
"""
Run the Magentic workflow.
Args:
query: User's research question
+ message_history: Optional user conversation history (for compatibility)
Yields:
AgentEvent objects for real-time UI updates
"""
- logger.info("Starting Magentic orchestrator", query=query)
+ logger.info(
+ "Starting Magentic orchestrator",
+ query=query,
+ has_history=bool(message_history),
+ )
yield AgentEvent(
type="started",
@@ -122,7 +134,17 @@ async def run(self, query: str) -> AsyncGenerator[AgentEvent, None]:
workflow = self._build_workflow()
- task = f"""Research query: {query}
+ # Include conversation history context if provided
+ history_context = ""
+ if message_history:
+ # Convert message history to string context for task
+ from src.utils.message_history import message_history_to_string
+
+ history_str = message_history_to_string(message_history, max_messages=5)
+ if history_str:
+ history_context = f"\n\nPrevious conversation context:\n{history_str}"
+
+ task = f"""Research query: {query}{history_context}
Workflow:
1. SearchAgent: Find evidence from available sources (automatically selects: web search, PubMed, ClinicalTrials.gov, Europe PMC, or RAG based on query)
diff --git a/src/services/audio_processing.py b/src/services/audio_processing.py
index 076ba571..588497f9 100644
--- a/src/services/audio_processing.py
+++ b/src/services/audio_processing.py
@@ -6,9 +6,9 @@
import numpy as np
import structlog
+from src.agents.audio_refiner import audio_refiner
from src.services.stt_gradio import STTService, get_stt_service
from src.utils.config import settings
-from src.utils.exceptions import ConfigurationError
logger = structlog.get_logger(__name__)
@@ -53,7 +53,7 @@ def __init__(
async def process_audio_input(
self,
- audio_input: tuple[int, np.ndarray] | None,
+ audio_input: tuple[int, np.ndarray[Any, Any]] | None, # type: ignore[type-arg]
hf_token: str | None = None,
) -> str | None:
"""Process audio input and return transcribed text.
@@ -82,11 +82,11 @@ async def generate_audio_output(
text: str,
voice: str | None = None,
speed: float | None = None,
- ) -> tuple[int, np.ndarray] | None:
+ ) -> tuple[int, np.ndarray[Any, Any]] | None: # type: ignore[type-arg]
"""Generate audio output from text.
Args:
- text: Text to synthesize
+ text: Text to synthesize (markdown will be cleaned for audio)
voice: Voice ID (default: settings.tts_voice)
speed: Speech speed (default: settings.tts_speed)
@@ -102,11 +102,23 @@ async def generate_audio_output(
return None
try:
+ # Refine text for audio (remove markdown, citations, etc.)
+ # Use LLM polish if enabled in settings
+ refined_text = await audio_refiner.refine_for_audio(
+ text, use_llm_polish=settings.tts_use_llm_polish
+ )
+ logger.info(
+ "text_refined_for_audio",
+ original_length=len(text),
+ refined_length=len(refined_text),
+ llm_polish_enabled=settings.tts_use_llm_polish,
+ )
+
# Use provided voice/speed or fallback to settings defaults
voice = voice if voice else settings.tts_voice
speed = speed if speed is not None else settings.tts_speed
- audio_output = await self.tts.synthesize_async(text, voice, speed) # type: ignore[misc]
+ audio_output = await self.tts.synthesize_async(refined_text, voice, speed) # type: ignore[misc]
if audio_output:
logger.info(
@@ -115,7 +127,7 @@ async def generate_audio_output(
sample_rate=audio_output[0],
)
- return audio_output
+ return audio_output # type: ignore[no-any-return]
except Exception as e:
logger.error("audio_output_generation_failed", error=str(e))
@@ -131,4 +143,3 @@ def get_audio_service() -> AudioService:
AudioService instance
"""
return AudioService()
-
diff --git a/src/services/image_ocr.py b/src/services/image_ocr.py
index b885c6d4..ed888193 100644
--- a/src/services/image_ocr.py
+++ b/src/services/image_ocr.py
@@ -31,7 +31,10 @@ def __init__(self, api_url: str | None = None, hf_token: str | None = None) -> N
ConfigurationError: If API URL not configured
"""
# Defensively access ocr_api_url - may not exist in older config versions
- default_url = getattr(settings, "ocr_api_url", None) or "https://prithivmlmods-multimodal-ocr3.hf.space"
+ default_url = (
+ getattr(settings, "ocr_api_url", None)
+ or "https://prithivmlmods-multimodal-ocr3.hf.space"
+ )
self.api_url = api_url or default_url
if not self.api_url:
raise ConfigurationError("OCR API URL not configured")
@@ -49,11 +52,11 @@ async def _get_client(self, hf_token: str | None = None) -> Client:
"""
# Use provided token or instance token
token = hf_token or self.hf_token
-
+
# If client exists but token changed, recreate it
if self.client is not None and token != self.hf_token:
self.client = None
-
+
if self.client is None:
loop = asyncio.get_running_loop()
# Pass token to Client for authenticated Spaces
@@ -129,7 +132,7 @@ async def extract_text(
async def extract_text_from_image(
self,
- image_data: np.ndarray | Image.Image | str,
+ image_data: np.ndarray[Any, Any] | Image.Image | str, # type: ignore[type-arg]
hf_token: str | None = None,
) -> str:
"""Extract text from image data (numpy array, PIL Image, or file path).
@@ -240,10 +243,3 @@ def get_image_ocr_service() -> ImageOCRService:
ImageOCRService instance
"""
return ImageOCRService()
-
-
-
-
-
-
-
diff --git a/src/services/llamaindex_rag.py b/src/services/llamaindex_rag.py
index 9107c240..5de92f55 100644
--- a/src/services/llamaindex_rag.py
+++ b/src/services/llamaindex_rag.py
@@ -86,13 +86,15 @@ def __init__(
self._initialize_chromadb()
def _import_dependencies(self) -> dict[str, Any]:
- """Import LlamaIndex dependencies and return as dict."""
+ """Import LlamaIndex dependencies and return as dict.
+
+ OpenAI dependencies are imported lazily (only when needed) to avoid
+ tiktoken circular import issues on Windows when using local embeddings.
+ """
try:
import chromadb
from llama_index.core import Document, Settings, StorageContext, VectorStoreIndex
from llama_index.core.retrievers import VectorIndexRetriever
- from llama_index.embeddings.openai import OpenAIEmbedding
- from llama_index.llms.openai import OpenAI
from llama_index.vector_stores.chroma import ChromaVectorStore
# Try to import Hugging Face embeddings (may not be available in all versions)
@@ -120,10 +122,22 @@ def _import_dependencies(self) -> dict[str, Any]:
HuggingFaceLLM as _HuggingFaceLLM, # type: ignore[import-untyped]
)
- huggingface_llm = _HuggingFaceLLM
+ huggingface_llm = _HuggingFaceLLM # type: ignore[assignment]
except ImportError:
huggingface_llm = None # type: ignore[assignment]
+ # OpenAI imports are optional - only import when actually needed
+ # This avoids tiktoken circular import issues on Windows
+ try:
+ from llama_index.embeddings.openai import OpenAIEmbedding
+ except ImportError:
+ OpenAIEmbedding = None # type: ignore[assignment, misc] # noqa: N806
+
+ try:
+ from llama_index.llms.openai import OpenAI
+ except ImportError:
+ OpenAI = None # type: ignore[assignment, misc] # noqa: N806
+
return {
"chromadb": chromadb,
"Document": Document,
@@ -151,6 +165,10 @@ def _configure_embeddings(
) -> None:
"""Configure embedding model."""
if use_openai_embeddings:
+ if openai_embedding is None:
+ raise ConfigurationError(
+ "OpenAI embeddings not available. Install with: uv sync --extra modal"
+ )
if not settings.openai_api_key:
raise ConfigurationError("OPENAI_API_KEY required for OpenAI embeddings")
self.embedding_model = embedding_model or settings.openai_embedding_model
@@ -167,8 +185,33 @@ def _configure_embeddings(
self._Settings.embed_model = self._create_sentence_transformer_embedding(model_name)
def _create_sentence_transformer_embedding(self, model_name: str) -> Any:
- """Create sentence-transformer embedding wrapper."""
- from sentence_transformers import SentenceTransformer
+ """Create sentence-transformer embedding wrapper.
+
+ Note: sentence-transformers is a required dependency (in pyproject.toml).
+ If this fails, it's likely a Windows-specific regex package issue.
+
+ Raises:
+ ConfigurationError: If sentence_transformers cannot be imported
+ (e.g., due to circular import issues on Windows with regex package)
+ """
+ try:
+ from sentence_transformers import SentenceTransformer
+ except ImportError as e:
+ # Handle Windows-specific circular import issues with regex package
+ # This is a known bug: https://github.com/mrabarnett/mrab-regex/issues/417
+ error_msg = str(e)
+ if "regex" in error_msg.lower() or "_regex" in error_msg:
+ raise ConfigurationError(
+ "sentence_transformers cannot be imported due to circular import issue "
+ "with regex package (Windows-specific bug). "
+ "sentence-transformers is installed but regex has a circular import. "
+ "Try: uv pip install --upgrade --force-reinstall regex "
+ "Or use HuggingFace embeddings via llama-index-embeddings-huggingface instead."
+ ) from e
+ raise ConfigurationError(
+ f"sentence_transformers not available: {e}. "
+ "This is a required dependency - check your uv sync installation."
+ ) from e
try:
from llama_index.embeddings.base import (
@@ -205,11 +248,7 @@ async def _aget_text_embedding(self, text: str) -> list[float]:
def _configure_llm(self, huggingface_llm: Any, openai_llm: Any) -> None:
"""Configure LLM for query synthesis."""
# Priority: oauth_token > env vars
- effective_token = (
- self.oauth_token
- or settings.hf_token
- or settings.huggingface_api_key
- )
+ effective_token = self.oauth_token or settings.hf_token or settings.huggingface_api_key
if huggingface_llm is not None and effective_token:
model_name = settings.huggingface_model or "meta-llama/Llama-3.1-8B-Instruct"
token = effective_token
@@ -245,7 +284,7 @@ def _configure_llm(self, huggingface_llm: Any, openai_llm: Any) -> None:
tokenizer_name=model_name,
)
logger.info("Using HuggingFace LLM for query synthesis", model=model_name)
- elif settings.openai_api_key:
+ elif settings.openai_api_key and openai_llm is not None:
self._Settings.llm = openai_llm(
model=settings.openai_model,
api_key=settings.openai_api_key,
@@ -461,6 +500,4 @@ def get_rag_service(
# Default to local embeddings if not explicitly set
if "use_openai_embeddings" not in kwargs:
kwargs["use_openai_embeddings"] = False
- return LlamaIndexRAGService(
- collection_name=collection_name, oauth_token=oauth_token, **kwargs
- )
+ return LlamaIndexRAGService(collection_name=collection_name, oauth_token=oauth_token, **kwargs)
diff --git a/src/services/multimodal_processing.py b/src/services/multimodal_processing.py
index cefb9d7b..9199f194 100644
--- a/src/services/multimodal_processing.py
+++ b/src/services/multimodal_processing.py
@@ -83,7 +83,9 @@ async def process_multimodal_input(
# For now, log a warning
logger.warning("audio_file_upload_not_supported", file_path=file_path)
except Exception as e:
- logger.warning("audio_file_processing_failed", file_path=file_path, error=str(e))
+ logger.warning(
+ "audio_file_processing_failed", file_path=file_path, error=str(e)
+ )
# Add original text if present
if text and text.strip():
@@ -142,7 +144,3 @@ def get_multimodal_service() -> MultimodalService:
MultimodalService instance
"""
return MultimodalService()
-
-
-
-
diff --git a/src/services/neo4j_service.py b/src/services/neo4j_service.py
index 26a2340b..6a9c6f7e 100644
--- a/src/services/neo4j_service.py
+++ b/src/services/neo4j_service.py
@@ -1,25 +1,28 @@
"""Neo4j Knowledge Graph Service for Drug Repurposing"""
-from neo4j import GraphDatabase
-from typing import List, Dict, Optional, Any
+
+import logging
import os
+from typing import Any
+
from dotenv import load_dotenv
-import logging
+from neo4j import GraphDatabase
load_dotenv()
logger = logging.getLogger(__name__)
+
class Neo4jService:
- def __init__(self):
+ def __init__(self) -> None:
self.uri = os.getenv("NEO4J_URI", "bolt://localhost:7687")
self.user = os.getenv("NEO4J_USER", "neo4j")
self.password = os.getenv("NEO4J_PASSWORD")
self.database = os.getenv("NEO4J_DATABASE", "neo4j")
-
+
if not self.password:
logger.warning("β οΈ NEO4J_PASSWORD not set")
self.driver = None
return
-
+
try:
self.driver = GraphDatabase.driver(self.uri, auth=(self.user, self.password))
self.driver.verify_connectivity()
@@ -27,80 +30,96 @@ def __init__(self):
except Exception as e:
logger.error(f"β Neo4j connection failed: {e}")
self.driver = None
-
+
def is_connected(self) -> bool:
return self.driver is not None
-
- def close(self):
+
+ def close(self) -> None:
if self.driver:
self.driver.close()
-
- def ingest_search_results(self, disease_name: str, papers: List[Dict[str, Any]],
- drugs_mentioned: List[str] = None) -> Dict[str, int]:
+
+ def ingest_search_results(
+ self,
+ disease_name: str,
+ papers: list[dict[str, Any]],
+ drugs_mentioned: list[str] | None = None,
+ ) -> dict[str, int]:
if not self.driver:
- return {"error": "Neo4j not connected"}
-
+ return {"error": "Neo4j not connected"} # type: ignore[dict-item]
+
stats = {"papers": 0, "drugs": 0, "relationships": 0, "errors": 0}
-
+
try:
with self.driver.session(database=self.database) as session:
session.run("MERGE (d:Disease {name: $name})", name=disease_name)
-
+
for paper in papers:
try:
- paper_id = paper.get('id') or paper.get('url', '')
+ paper_id = paper.get("id") or paper.get("url", "")
if not paper_id:
continue
-
- session.run("""
+
+ session.run(
+ """
MERGE (p:Paper {paper_id: $id})
SET p.title = $title,
p.abstract = $abstract,
p.url = $url,
p.source = $source,
p.updated_at = datetime()
- """,
- id=paper_id,
- title=str(paper.get('title', ''))[:500],
- abstract=str(paper.get('abstract', ''))[:2000],
- url=str(paper.get('url', ''))[:500],
- source=str(paper.get('source', ''))[:100])
-
- session.run("""
+ """,
+ id=paper_id,
+ title=str(paper.get("title", ""))[:500],
+ abstract=str(paper.get("abstract", ""))[:2000],
+ url=str(paper.get("url", ""))[:500],
+ source=str(paper.get("source", ""))[:100],
+ )
+
+ session.run(
+ """
MATCH (p:Paper {paper_id: $id})
MATCH (d:Disease {name: $disease})
MERGE (p)-[r:ABOUT]->(d)
- """, id=paper_id, disease=disease_name)
-
- stats['papers'] += 1
- stats['relationships'] += 1
- except Exception as e:
- stats['errors'] += 1
-
+ """,
+ id=paper_id,
+ disease=disease_name,
+ )
+
+ stats["papers"] += 1
+ stats["relationships"] += 1
+ except Exception:
+ stats["errors"] += 1
+
if drugs_mentioned:
for drug in drugs_mentioned:
try:
session.run("MERGE (d:Drug {name: $name})", name=drug)
- session.run("""
+ session.run(
+ """
MATCH (drug:Drug {name: $drug})
MATCH (disease:Disease {name: $disease})
MERGE (drug)-[r:POTENTIAL_TREATMENT]->(disease)
- """, drug=drug, disease=disease_name)
- stats['drugs'] += 1
- stats['relationships'] += 1
- except Exception as e:
- stats['errors'] += 1
-
+ """,
+ drug=drug,
+ disease=disease_name,
+ )
+ stats["drugs"] += 1
+ stats["relationships"] += 1
+ except Exception:
+ stats["errors"] += 1
+
logger.info(f"οΏ½οΏ½ Neo4j ingestion: {stats['papers']} papers, {stats['drugs']} drugs")
except Exception as e:
logger.error(f"Neo4j ingestion error: {e}")
- stats['errors'] += 1
-
+ stats["errors"] += 1
+
return stats
+
_neo4j_service = None
-def get_neo4j_service() -> Optional[Neo4jService]:
+
+def get_neo4j_service() -> Neo4jService | None:
global _neo4j_service
if _neo4j_service is None:
_neo4j_service = Neo4jService()
diff --git a/src/services/report_file_service.py b/src/services/report_file_service.py
index 9b9f82fe..14496640 100644
--- a/src/services/report_file_service.py
+++ b/src/services/report_file_service.py
@@ -329,4 +329,3 @@ def _get_service() -> ReportFileService:
return ReportFileService()
return _get_service()
-
diff --git a/src/services/stt_gradio.py b/src/services/stt_gradio.py
index 28062205..44b993b4 100644
--- a/src/services/stt_gradio.py
+++ b/src/services/stt_gradio.py
@@ -46,11 +46,11 @@ async def _get_client(self, hf_token: str | None = None) -> Client:
"""
# Use provided token or instance token
token = hf_token or self.hf_token
-
+
# If client exists but token changed, recreate it
if self.client is not None and token != self.hf_token:
self.client = None
-
+
if self.client is None:
loop = asyncio.get_running_loop()
# Pass token to Client for authenticated Spaces
@@ -130,7 +130,7 @@ async def transcribe_file(
async def transcribe_audio(
self,
- audio_data: tuple[int, np.ndarray],
+ audio_data: tuple[int, np.ndarray[Any, Any]], # type: ignore[type-arg]
hf_token: str | None = None,
) -> str:
"""Transcribe audio numpy array to text.
@@ -163,7 +163,7 @@ async def transcribe_audio(
except Exception as e:
logger.warning("failed_to_cleanup_temp_file", path=temp_path, error=str(e))
- def _extract_transcription(self, api_result: tuple) -> str:
+ def _extract_transcription(self, api_result: tuple[Any, ...]) -> str:
"""Extract transcription text from API result.
Args:
@@ -210,7 +210,7 @@ def _extract_transcription(self, api_result: tuple) -> str:
def _save_audio_temp(
self,
- audio_data: tuple[int, np.ndarray],
+ audio_data: tuple[int, np.ndarray[Any, Any]], # type: ignore[type-arg]
) -> str:
"""Save audio numpy array to temporary WAV file.
@@ -269,4 +269,3 @@ def get_stt_service() -> STTService:
STTService instance
"""
return STTService()
-
diff --git a/src/services/tts_modal.py b/src/services/tts_modal.py
index e0e30afb..ce55c49a 100644
--- a/src/services/tts_modal.py
+++ b/src/services/tts_modal.py
@@ -1,12 +1,22 @@
"""Text-to-Speech service using Kokoro 82M via Modal GPU."""
import asyncio
+import os
+from collections.abc import Iterator
+from contextlib import contextmanager
from functools import lru_cache
-from typing import Any
+from typing import Any, cast
import numpy as np
+from numpy.typing import NDArray
import structlog
+# Load .env file BEFORE importing Modal SDK
+# Modal SDK reads MODAL_TOKEN_ID and MODAL_TOKEN_SECRET from environment on import
+from dotenv import load_dotenv
+
+load_dotenv()
+
from src.utils.config import settings
from src.utils.exceptions import ConfigurationError
@@ -24,39 +34,107 @@
# Modal app and function definitions (module-level for Modal)
_modal_app: Any | None = None
_tts_function: Any | None = None
+_tts_image: Any | None = None
+
+
+@contextmanager
+def modal_credentials_override(token_id: str | None, token_secret: str | None) -> Iterator[None]:
+ """Context manager to temporarily override Modal credentials.
+
+ Args:
+ token_id: Modal token ID (overrides env if provided)
+ token_secret: Modal token secret (overrides env if provided)
+
+ Yields:
+ None
+
+ Note:
+ Resets global Modal state to force re-initialization with new credentials.
+ """
+ global _modal_app, _tts_function
+
+ # Save original credentials
+ original_token_id = os.environ.get("MODAL_TOKEN_ID")
+ original_token_secret = os.environ.get("MODAL_TOKEN_SECRET")
+
+ # Save original Modal state
+ original_app = _modal_app
+ original_function = _tts_function
+
+ try:
+ # Override environment variables if provided
+ if token_id:
+ os.environ["MODAL_TOKEN_ID"] = token_id
+ if token_secret:
+ os.environ["MODAL_TOKEN_SECRET"] = token_secret
+
+ # Reset Modal state to force re-initialization
+ _modal_app = None
+ _tts_function = None
+
+ yield
+
+ finally:
+ # Restore original credentials
+ if original_token_id is not None:
+ os.environ["MODAL_TOKEN_ID"] = original_token_id
+ elif "MODAL_TOKEN_ID" in os.environ:
+ del os.environ["MODAL_TOKEN_ID"]
+
+ if original_token_secret is not None:
+ os.environ["MODAL_TOKEN_SECRET"] = original_token_secret
+ elif "MODAL_TOKEN_SECRET" in os.environ:
+ del os.environ["MODAL_TOKEN_SECRET"]
+
+ # Restore original Modal state
+ _modal_app = original_app
+ _tts_function = original_function
def _get_modal_app() -> Any:
- """Get or create Modal app instance."""
+ """Get or create Modal app instance.
+
+ Retrieves Modal credentials directly from environment variables (.env file)
+ instead of relying on settings configuration.
+ """
global _modal_app
if _modal_app is None:
try:
import modal
- # Validate Modal credentials before attempting lookup
- if not settings.modal_available:
+ # Get credentials directly from environment variables
+ token_id = os.getenv("MODAL_TOKEN_ID")
+ token_secret = os.getenv("MODAL_TOKEN_SECRET")
+
+ # Validate Modal credentials
+ if not token_id or not token_secret:
raise ConfigurationError(
- "Modal credentials not configured. Set MODAL_TOKEN_ID and MODAL_TOKEN_SECRET environment variables."
+ "Modal credentials not found in environment. "
+ "Set MODAL_TOKEN_ID and MODAL_TOKEN_SECRET in .env file."
)
# Validate token ID format (Modal token IDs are typically UUIDs or specific formats)
- token_id = settings.modal_token_id
- if token_id:
- # Basic validation: token ID should not be empty and should be a reasonable length
- if len(token_id.strip()) < 10:
- raise ConfigurationError(
- f"Modal token ID appears malformed (too short: {len(token_id)} chars). "
- "Token ID should be a valid Modal token identifier."
- )
+ if len(token_id.strip()) < 10:
+ raise ConfigurationError(
+ f"Modal token ID appears malformed (too short: {len(token_id)} chars). "
+ "Token ID should be a valid Modal token identifier."
+ )
+
+ logger.info(
+ "modal_credentials_loaded",
+ token_id_prefix=token_id[:8] + "...", # Log prefix for debugging
+ has_secret=bool(token_secret),
+ )
try:
+ # Use lookup with create_if_missing for inline function fallback
_modal_app = modal.App.lookup("deepcritical-tts", create_if_missing=True)
except Exception as e:
error_msg = str(e).lower()
if "token" in error_msg or "malformed" in error_msg or "invalid" in error_msg:
raise ConfigurationError(
f"Modal token validation failed: {e}. "
- "Please check that MODAL_TOKEN_ID and MODAL_TOKEN_SECRET are correctly set."
+ "Please check that MODAL_TOKEN_ID and MODAL_TOKEN_SECRET in .env are correctly set."
) from e
raise
except ImportError as e:
@@ -69,103 +147,137 @@ def _get_modal_app() -> Any:
# Define Modal image with Kokoro dependencies (module-level)
def _get_tts_image() -> Any:
"""Get Modal image with Kokoro dependencies."""
+ global _tts_image
+ if _tts_image is not None:
+ return _tts_image
+
try:
import modal
- return (
+ _tts_image = (
modal.Image.debian_slim(python_version="3.11")
.pip_install(*KOKORO_DEPENDENCIES)
.pip_install("git+https://github.com/hexgrad/kokoro.git")
)
+ return _tts_image
except ImportError:
return None
-def _setup_modal_function() -> None:
- """Setup Modal GPU function for TTS (called once, lazy initialization).
+# Modal TTS function - Using serialized=True to allow dynamic creation
+# This will be initialized lazily when _setup_modal_function() is called
+def _create_tts_function() -> Any:
+ """Create the Modal TTS function using serialized=True.
- Note: GPU type is set at function definition time. Changes to settings.tts_gpu
- require app restart to take effect.
+ The serialized=True parameter allows the function to be defined outside
+ of global scope, which is necessary for dynamic initialization.
"""
- global _tts_function, _modal_app
+ app = _get_modal_app()
+ tts_image = _get_tts_image()
+
+ if tts_image is None:
+ raise ConfigurationError("Modal image setup failed")
+
+ # Get GPU and timeout from settings (with defaults)
+ gpu_type = getattr(settings, "tts_gpu", None) or "T4"
+ timeout_seconds = getattr(settings, "tts_timeout", None) or 120 # 2 minutes for cold starts
+
+ @app.function(
+ image=tts_image,
+ gpu=gpu_type,
+ timeout=timeout_seconds,
+ serialized=True, # Allow function to be defined outside global scope
+ )
+ def kokoro_tts_function(text: str, voice: str, speed: float) -> tuple[int, NDArray[np.float32]]:
+ """Modal GPU function for Kokoro TTS.
+
+ This function runs on Modal's GPU infrastructure.
+ Based on: https://huggingface.co/spaces/hexgrad/Kokoro-TTS
+ Reference: https://huggingface.co/spaces/hexgrad/Kokoro-TTS/raw/main/app.py
+ """
+ import numpy as np
- if _tts_function is not None:
- return # Already set up
+ # Import Kokoro inside function (lazy load)
+ try:
+ import torch
+ from kokoro import KModel, KPipeline
- try:
- app = _get_modal_app()
- tts_image = _get_tts_image()
-
- if tts_image is None:
- raise ConfigurationError("Modal image setup failed")
-
- # Get GPU and timeout from settings (with defaults)
- # Note: These are evaluated at function definition time, not at call time
- # Changes to settings require app restart
- gpu_type = getattr(settings, "tts_gpu", None) or "T4"
- timeout_seconds = getattr(settings, "tts_timeout", None) or 60
-
- # Define GPU function at module level (required by Modal)
- # Modal functions are immutable once defined, so GPU changes require restart
- @app.function(
- image=tts_image,
- gpu=gpu_type,
- timeout=timeout_seconds,
- )
- def kokoro_tts_function(text: str, voice: str, speed: float) -> tuple[int, np.ndarray]:
- """Modal GPU function for Kokoro TTS.
+ # Initialize model (cached on GPU)
+ model = KModel().to("cuda").eval()
+ pipeline = KPipeline(lang_code=voice[0])
+ pack = pipeline.load_voice(voice)
- This function runs on Modal's GPU infrastructure.
- Based on: https://huggingface.co/spaces/hexgrad/Kokoro-TTS
- Reference: https://huggingface.co/spaces/hexgrad/Kokoro-TTS/raw/main/app.py
- """
- import numpy as np
+ # Generate audio
+ for _, ps, _ in pipeline(text, voice, speed):
+ ref_s = pack[len(ps) - 1]
+ audio = model(ps, ref_s, speed)
+ return (24000, audio.numpy())
- # Import Kokoro inside function (lazy load)
- try:
- import torch
- from kokoro import KModel, KPipeline
+ # If no audio generated, return empty
+ return (24000, np.zeros(1, dtype=np.float32))
- # Initialize model (cached on GPU)
- model = KModel().to("cuda").eval()
- pipeline = KPipeline(lang_code=voice[0])
- pack = pipeline.load_voice(voice)
+ except ImportError as e:
+ raise ConfigurationError(
+ "Kokoro not installed. Install with: pip install git+https://github.com/hexgrad/kokoro.git"
+ ) from e
+ except Exception as e:
+ raise ConfigurationError(f"TTS synthesis failed: {e}") from e
- # Generate audio
- for _, ps, _ in pipeline(text, voice, speed):
- ref_s = pack[len(ps) - 1]
- audio = model(ps, ref_s, speed)
- return (24000, audio.numpy())
+ return kokoro_tts_function
- # If no audio generated, return empty
- return (24000, np.zeros(1, dtype=np.float32))
- except ImportError as e:
- raise ConfigurationError(
- "Kokoro not installed. Install with: pip install git+https://github.com/hexgrad/kokoro.git"
- ) from e
- except Exception as e:
- raise ConfigurationError(f"TTS synthesis failed: {e}") from e
+def _setup_modal_function() -> None:
+ """Setup Modal GPU function for TTS (called once, lazy initialization).
- # Store function reference for remote calls
- _tts_function = kokoro_tts_function
+ Hybrid approach:
+ 1. Try to lookup pre-deployed function (fast path for advanced users)
+ 2. If lookup fails, create function inline (fallback for casual users)
- # Verify function is properly attached to app
- if not hasattr(app, kokoro_tts_function.__name__):
- logger.warning(
- "modal_function_not_attached", function_name=kokoro_tts_function.__name__
+ This allows both workflows:
+ - Advanced: Deploy with `modal deploy deployments/modal_tts.py` for best performance
+ - Casual: Just add Modal keys and it auto-creates function on first use
+ """
+ global _tts_function
+
+ if _tts_function is not None:
+ return # Already set up
+
+ try:
+ import modal
+
+ # Try path 1: Lookup pre-deployed function (fast path)
+ try:
+ _tts_function = modal.Function.from_name("deepcritical-tts", "kokoro_tts_function")
+ logger.info(
+ "modal_tts_function_lookup_success",
+ app_name="deepcritical-tts",
+ function_name="kokoro_tts_function",
+ method="lookup",
+ )
+ return
+ except Exception as lookup_error:
+ logger.info(
+ "modal_tts_function_lookup_failed",
+ error=str(lookup_error),
+ fallback="Creating function inline",
)
+ # Try path 2: Create function inline (fallback for casual users)
+ logger.info("modal_tts_creating_inline_function")
+ _tts_function = _create_tts_function()
logger.info(
"modal_tts_function_setup_complete",
- gpu=gpu_type,
- timeout=timeout_seconds,
- function_name=kokoro_tts_function.__name__,
+ app_name="deepcritical-tts",
+ function_name="kokoro_tts_function",
+ method="inline",
)
except Exception as e:
logger.error("modal_tts_function_setup_failed", error=str(e))
- raise ConfigurationError(f"Failed to setup Modal TTS function: {e}") from e
+ raise ConfigurationError(
+ f"Failed to setup Modal TTS function: {e}. "
+ "Ensure Modal credentials (MODAL_TOKEN_ID, MODAL_TOKEN_SECRET) are valid."
+ ) from e
class ModalTTSExecutor:
@@ -179,13 +291,17 @@ def __init__(self) -> None:
"""Initialize Modal TTS executor.
Note:
- Logs a warning if Modal credentials are not configured.
- Execution will fail at runtime without valid credentials.
+ Logs a warning if Modal credentials are not configured in environment.
+ Execution will fail at runtime without valid credentials in .env file.
"""
- # Check for Modal credentials
- if not settings.modal_available:
+ # Check for Modal credentials directly from environment
+ token_id = os.getenv("MODAL_TOKEN_ID")
+ token_secret = os.getenv("MODAL_TOKEN_SECRET")
+
+ if not token_id or not token_secret:
logger.warning(
- "Modal credentials not found. TTS will not be available unless modal setup is run."
+ "Modal credentials not found in environment. "
+ "TTS will not be available. Set MODAL_TOKEN_ID and MODAL_TOKEN_SECRET in .env file."
)
def synthesize(
@@ -193,8 +309,8 @@ def synthesize(
text: str,
voice: str = "af_heart",
speed: float = 1.0,
- timeout: int = 60,
- ) -> tuple[int, np.ndarray]:
+ timeout: int = 120,
+ ) -> tuple[int, NDArray[np.float32]]:
"""Synthesize text to speech using Kokoro on Modal GPU.
Args:
@@ -219,7 +335,7 @@ def synthesize(
try:
# Call the GPU function remotely
- result = _tts_function.remote(text, voice, speed)
+ result = cast(tuple[int, NDArray[np.float32]], _tts_function.remote(text, voice, speed))
logger.info(
"tts_synthesis_complete", sample_rate=result[0], audio_shape=result[1].shape
@@ -236,9 +352,19 @@ class TTSService:
"""TTS service wrapper for async usage."""
def __init__(self) -> None:
- """Initialize TTS service."""
- if not settings.modal_available:
- raise ConfigurationError("Modal credentials required for TTS")
+ """Initialize TTS service.
+
+ Validates Modal credentials from environment variables (.env file).
+ """
+ # Check credentials directly from environment
+ token_id = os.getenv("MODAL_TOKEN_ID")
+ token_secret = os.getenv("MODAL_TOKEN_SECRET")
+
+ if not token_id or not token_secret:
+ raise ConfigurationError(
+ "Modal credentials required for TTS. "
+ "Set MODAL_TOKEN_ID and MODAL_TOKEN_SECRET in .env file."
+ )
self.executor = ModalTTSExecutor()
async def synthesize_async(
@@ -246,7 +372,7 @@ async def synthesize_async(
text: str,
voice: str = "af_heart",
speed: float = 1.0,
- ) -> tuple[int, np.ndarray] | None:
+ ) -> tuple[int, NDArray[np.float32]] | None:
"""Async wrapper for TTS synthesis.
Args:
@@ -284,3 +410,73 @@ def get_tts_service() -> TTSService:
ConfigurationError: If Modal credentials not configured
"""
return TTSService()
+
+
+async def generate_audio_on_demand(
+ text: str,
+ modal_token_id: str | None = None,
+ modal_token_secret: str | None = None,
+ voice: str = "af_heart",
+ speed: float = 1.0,
+ use_llm_polish: bool = False,
+) -> tuple[tuple[int, NDArray[np.float32]] | None, str]:
+ """Generate audio on-demand with optional runtime credentials.
+
+ Args:
+ text: Text to synthesize
+ modal_token_id: Modal token ID (UI input, overrides .env)
+ modal_token_secret: Modal token secret (UI input, overrides .env)
+ voice: Voice ID (default: af_heart)
+ speed: Speech speed (default: 1.0)
+ use_llm_polish: Apply LLM polish to text (default: False)
+
+ Returns:
+ Tuple of (audio_output, status_message)
+ - audio_output: (sample_rate, audio_array) or None if failed
+ - status_message: Status/error message for user
+
+ Priority: UI credentials > .env credentials
+ """
+ # Priority: UI keys > .env keys
+ token_id = (modal_token_id or "").strip() or os.getenv("MODAL_TOKEN_ID")
+ token_secret = (modal_token_secret or "").strip() or os.getenv("MODAL_TOKEN_SECRET")
+
+ if not token_id or not token_secret:
+ return (
+ None,
+ "β Modal credentials required. Enter keys above or set MODAL_TOKEN_ID and MODAL_TOKEN_SECRET in .env",
+ )
+
+ try:
+ # Use credentials override context
+ with modal_credentials_override(token_id, token_secret):
+ # Import audio_processing here to avoid circular import
+ from src.services.audio_processing import AudioService
+
+ # Temporarily override LLM polish setting
+ original_llm_polish = settings.tts_use_llm_polish
+ try:
+ settings.tts_use_llm_polish = use_llm_polish
+
+ # Create fresh AudioService instance (bypass cache to pick up new credentials)
+ audio_service = AudioService()
+ audio_output = await audio_service.generate_audio_output(
+ text=text,
+ voice=voice,
+ speed=speed,
+ )
+
+ if audio_output:
+ return audio_output, "β
Audio generated successfully"
+ else:
+ return None, "β οΈ Audio generation returned no output"
+
+ finally:
+ settings.tts_use_llm_polish = original_llm_polish
+
+ except ConfigurationError as e:
+ logger.error("audio_generation_config_error", error=str(e))
+ return None, f"β Configuration error: {e}"
+ except Exception as e:
+ logger.error("audio_generation_failed", error=str(e), exc_info=True)
+ return None, f"β Audio generation failed: {e}"
diff --git a/src/tools/crawl_adapter.py b/src/tools/crawl_adapter.py
index 53394234..332569c5 100644
--- a/src/tools/crawl_adapter.py
+++ b/src/tools/crawl_adapter.py
@@ -56,8 +56,3 @@ async def crawl_website(starting_url: str) -> str:
except Exception as e:
logger.error("Crawl failed", error=str(e), url=starting_url)
return f"Error crawling website: {e!s}"
-
-
-
-
-
diff --git a/src/tools/neo4j_search.py b/src/tools/neo4j_search.py
index 93b49ace..01983889 100644
--- a/src/tools/neo4j_search.py
+++ b/src/tools/neo4j_search.py
@@ -1,16 +1,19 @@
"""Neo4j knowledge graph search tool."""
+
import structlog
-from src.utils.models import Citation, Evidence
+
from src.services.neo4j_service import get_neo4j_service
+from src.utils.models import Citation, Evidence
logger = structlog.get_logger()
+
class Neo4jSearchTool:
"""Search Neo4j knowledge graph for papers."""
-
- def __init__(self):
+
+ def __init__(self) -> None:
self.name = "neo4j" # β
Definir explΓcitamente
-
+
async def search(self, query: str, max_results: int = 10) -> list[Evidence]:
"""Search Neo4j for papers about diseases in the query."""
try:
@@ -18,25 +21,32 @@ async def search(self, query: str, max_results: int = 10) -> list[Evidence]:
if not service:
logger.warning("Neo4j service not available")
return []
-
+
# Extract disease name from query
disease = query
if "for" in query.lower():
disease = query.split("for")[-1].strip().rstrip("?")
-
+
# Query Neo4j
+ if not service.driver:
+ logger.warning("Neo4j driver not available")
+ return []
with service.driver.session(database=service.database) as session:
- result = session.run("""
+ result = session.run(
+ """
MATCH (p:Paper)-[:ABOUT]->(d:Disease)
WHERE d.name CONTAINS $disease
RETURN p.title as title, p.abstract as abstract,
p.url as url, p.source as source
ORDER BY p.updated_at DESC
LIMIT $max_results
- """, disease=disease, max_results=max_results)
-
+ """,
+ disease=disease,
+ max_results=max_results,
+ )
+
records = list(result)
-
+
results = []
for record in records:
citation = Citation(
@@ -44,17 +54,14 @@ async def search(self, query: str, max_results: int = 10) -> list[Evidence]:
title=record["title"] or "Untitled",
url=record["url"] or "",
date="",
- authors=[]
+ authors=[],
)
-
+
evidence = Evidence(
content=record["abstract"] or record["title"] or "",
citation=citation,
relevance=1.0,
- metadata={
- "from_kb": True,
- "original_source": record["source"]
- }
+ metadata={"from_kb": True, "original_source": record["source"]},
)
results.append(evidence)
diff --git a/src/tools/search_handler.py b/src/tools/search_handler.py
index 77edba21..8814a8b7 100644
--- a/src/tools/search_handler.py
+++ b/src/tools/search_handler.py
@@ -5,11 +5,11 @@
import structlog
+from src.services.neo4j_service import get_neo4j_service
from src.tools.base import SearchTool
from src.tools.rag_tool import create_rag_tool
from src.utils.exceptions import ConfigurationError, SearchError
from src.utils.models import Evidence, SearchResult, SourceName
-from src.services.neo4j_service import get_neo4j_service
if TYPE_CHECKING:
from src.services.llamaindex_rag import LlamaIndexRAGService
@@ -113,6 +113,8 @@ async def execute(self, query: str, max_results_per_tool: int = 10) -> SearchRes
# Some tools have internal names that differ from SourceName literals
tool_name_to_source: dict[str, SourceName] = {
"duckduckgo": "web",
+ "serper": "web", # Serper uses Google search but maps to "web" source
+ "searchxng": "web", # SearchXNG also maps to "web" source
"pubmed": "pubmed",
"clinicaltrials": "clinicaltrials",
"europepmc": "europepmc",
@@ -131,7 +133,15 @@ async def execute(self, query: str, max_results_per_tool: int = 10) -> SearchRes
# Map tool.name to SourceName (handle tool names that don't match SourceName literals)
tool_name = tool_name_to_source.get(tool.name, cast(SourceName, tool.name))
- if tool_name not in ["pubmed", "clinicaltrials", "biorxiv", "europepmc", "preprint", "rag", "web"]:
+ if tool_name not in [
+ "pubmed",
+ "clinicaltrials",
+ "biorxiv",
+ "europepmc",
+ "preprint",
+ "rag",
+ "web",
+ ]:
logger.warning(
"Tool name not in SourceName literals, defaulting to 'web'",
tool_name=tool.name,
@@ -173,18 +183,20 @@ async def execute(self, query: str, max_results_per_tool: int = 10) -> SearchRes
disease = query
if "for" in query.lower():
disease = query.split("for")[-1].strip().rstrip("?")
-
+
# Convert Evidence objects to dicts for Neo4j
papers = []
for ev in all_evidence:
- papers.append({
- 'id': ev.citation.url or '',
- 'title': ev.citation.title or '',
- 'abstract': ev.content,
- 'url': ev.citation.url or '',
- 'source': ev.citation.source,
- })
-
+ papers.append(
+ {
+ "id": ev.citation.url or "",
+ "title": ev.citation.title or "",
+ "abstract": ev.content,
+ "url": ev.citation.url or "",
+ "source": ev.citation.source,
+ }
+ )
+
stats = neo4j_service.ingest_search_results(disease, papers)
logger.info("πΎ Saved to Neo4j", stats=stats)
except Exception as e:
diff --git a/src/tools/searchxng_web_search.py b/src/tools/searchxng_web_search.py
index 80cf8be5..e120b61f 100644
--- a/src/tools/searchxng_web_search.py
+++ b/src/tools/searchxng_web_search.py
@@ -85,12 +85,17 @@ async def search(self, query: str, max_results: int = 10) -> list[Evidence]:
# Convert ScrapeResult to Evidence objects
evidence = []
for result in scraped:
+ # Truncate title to max 500 characters to match Citation model validation
+ title = result.title
+ if len(title) > 500:
+ title = title[:497] + "..."
+
ev = Evidence(
content=result.text,
citation=Citation(
- title=result.title,
+ title=title,
url=result.url,
- source="searchxng",
+ source="web", # Use "web" to match SourceName literal, not "searchxng"
date="Unknown",
authors=[],
),
@@ -113,13 +118,3 @@ async def search(self, query: str, max_results: int = 10) -> list[Evidence]:
except Exception as e:
logger.error("Unexpected error in SearchXNG search", error=str(e), query=final_query)
raise SearchError(f"SearchXNG search failed: {e}") from e
-
-
-
-
-
-
-
-
-
-
diff --git a/src/tools/serper_web_search.py b/src/tools/serper_web_search.py
index 79e9449e..77fc6eb0 100644
--- a/src/tools/serper_web_search.py
+++ b/src/tools/serper_web_search.py
@@ -85,12 +85,17 @@ async def search(self, query: str, max_results: int = 10) -> list[Evidence]:
# Convert ScrapeResult to Evidence objects
evidence = []
for result in scraped:
+ # Truncate title to max 500 characters to match Citation model validation
+ title = result.title
+ if len(title) > 500:
+ title = title[:497] + "..."
+
ev = Evidence(
content=result.text,
citation=Citation(
- title=result.title,
+ title=title,
url=result.url,
- source="serper",
+ source="web", # Use "web" to match SourceName literal, not "serper"
date="Unknown",
authors=[],
),
@@ -113,13 +118,3 @@ async def search(self, query: str, max_results: int = 10) -> list[Evidence]:
except Exception as e:
logger.error("Unexpected error in Serper search", error=str(e), query=final_query)
raise SearchError(f"Serper search failed: {e}") from e
-
-
-
-
-
-
-
-
-
-
diff --git a/src/tools/tool_executor.py b/src/tools/tool_executor.py
index 04099226..e4999607 100644
--- a/src/tools/tool_executor.py
+++ b/src/tools/tool_executor.py
@@ -182,9 +182,9 @@ async def execute_tool_tasks(
results[f"{task.agent}_{i}"] = ToolAgentOutput(output=f"Error: {result!s}", sources=[])
else:
# Type narrowing: result is ToolAgentOutput after Exception check
- assert isinstance(
- result, ToolAgentOutput
- ), "Expected ToolAgentOutput after Exception check"
+ assert isinstance(result, ToolAgentOutput), (
+ "Expected ToolAgentOutput after Exception check"
+ )
key = f"{task.agent}_{task.gap or i}" if task.gap else f"{task.agent}_{i}"
results[key] = result
diff --git a/src/tools/vendored/__init__.py b/src/tools/vendored/__init__.py
index 6db9f15a..e96e8252 100644
--- a/src/tools/vendored/__init__.py
+++ b/src/tools/vendored/__init__.py
@@ -16,12 +16,12 @@
__all__ = [
"CONTENT_LENGTH_LIMIT",
"ScrapeResult",
- "WebpageSnippet",
- "SerperClient",
"SearchXNGClient",
- "scrape_urls",
+ "SerperClient",
+ "WebpageSnippet",
+ "crawl_website",
"fetch_and_process_url",
"html_to_text",
"is_valid_url",
- "crawl_website",
+ "scrape_urls",
]
diff --git a/src/tools/vendored/crawl_website.py b/src/tools/vendored/crawl_website.py
index 32fc2d94..cd75bb50 100644
--- a/src/tools/vendored/crawl_website.py
+++ b/src/tools/vendored/crawl_website.py
@@ -20,6 +20,63 @@
logger = structlog.get_logger()
+async def _extract_links(
+ html: str, current_url: str, base_domain: str
+) -> tuple[list[str], list[str]]:
+ """Extract prioritized links from HTML content."""
+ soup = BeautifulSoup(html, "html.parser")
+ nav_links = set()
+ body_links = set()
+
+ # Find navigation/header links
+ for nav_element in soup.find_all(["nav", "header"]):
+ for a in nav_element.find_all("a", href=True):
+ href = str(a["href"])
+ link = urljoin(current_url, href)
+ if urlparse(link).netloc == base_domain:
+ nav_links.add(link)
+
+ # Find remaining body links
+ for a in soup.find_all("a", href=True):
+ href = str(a["href"])
+ link = urljoin(current_url, href)
+ if urlparse(link).netloc == base_domain and link not in nav_links:
+ body_links.add(link)
+
+ return list(nav_links), list(body_links)
+
+
+async def _fetch_page(url: str) -> str:
+ """Fetch HTML content from a URL."""
+ connector = aiohttp.TCPConnector(ssl=ssl_context)
+ async with aiohttp.ClientSession(connector=connector) as session:
+ try:
+ timeout = aiohttp.ClientTimeout(total=30)
+ async with session.get(url, timeout=timeout) as response:
+ if response.status == 200:
+ return await response.text()
+ return ""
+ except Exception as e:
+ logger.warning("Error fetching URL", url=url, error=str(e))
+ return ""
+
+
+def _add_links_to_queue(
+ links: list[str],
+ queue: list[str],
+ all_pages_to_scrape: set[str],
+ remaining_slots: int,
+) -> int:
+ """Add normalized links to queue if not already visited."""
+ for link in links:
+ normalized_link = link.rstrip("/")
+ if normalized_link not in all_pages_to_scrape and remaining_slots > 0:
+ queue.append(normalized_link)
+ all_pages_to_scrape.add(normalized_link)
+ remaining_slots -= 1
+ return remaining_slots
+
+
async def crawl_website(starting_url: str) -> list[ScrapeResult] | str:
"""Crawl the pages of a website starting with the starting_url and then descending into the pages linked from there.
@@ -45,41 +102,6 @@ async def crawl_website(starting_url: str) -> list[ScrapeResult] | str:
max_pages = 10
base_domain = urlparse(starting_url).netloc
- async def extract_links(html: str, current_url: str) -> tuple[list[str], list[str]]:
- """Extract prioritized links from HTML content"""
- soup = BeautifulSoup(html, "html.parser")
- nav_links = set()
- body_links = set()
-
- # Find navigation/header links
- for nav_element in soup.find_all(["nav", "header"]):
- for a in nav_element.find_all("a", href=True):
- link = urljoin(current_url, a["href"])
- if urlparse(link).netloc == base_domain:
- nav_links.add(link)
-
- # Find remaining body links
- for a in soup.find_all("a", href=True):
- link = urljoin(current_url, a["href"])
- if urlparse(link).netloc == base_domain and link not in nav_links:
- body_links.add(link)
-
- return list(nav_links), list(body_links)
-
- async def fetch_page(url: str) -> str:
- """Fetch HTML content from a URL"""
- connector = aiohttp.TCPConnector(ssl=ssl_context)
- async with aiohttp.ClientSession(connector=connector) as session:
- try:
- timeout = aiohttp.ClientTimeout(total=30)
- async with session.get(url, timeout=timeout) as response:
- if response.status == 200:
- return await response.text()
- return ""
- except Exception as e:
- logger.warning("Error fetching URL", url=url, error=str(e))
- return ""
-
# Initialize with starting URL
queue: list[str] = [starting_url]
next_level_queue: list[str] = []
@@ -90,26 +112,20 @@ async def fetch_page(url: str) -> str:
current_url = queue.pop(0)
# Fetch and process the page
- html_content = await fetch_page(current_url)
+ html_content = await _fetch_page(current_url)
if html_content:
- nav_links, body_links = await extract_links(html_content, current_url)
+ nav_links, body_links = await _extract_links(html_content, current_url, base_domain)
# Add unvisited nav links to current queue (higher priority)
remaining_slots = max_pages - len(all_pages_to_scrape)
- for link in nav_links:
- link = link.rstrip("/")
- if link not in all_pages_to_scrape and remaining_slots > 0:
- queue.append(link)
- all_pages_to_scrape.add(link)
- remaining_slots -= 1
+ remaining_slots = _add_links_to_queue(
+ nav_links, queue, all_pages_to_scrape, remaining_slots
+ )
# Add unvisited body links to next level queue (lower priority)
- for link in body_links:
- link = link.rstrip("/")
- if link not in all_pages_to_scrape and remaining_slots > 0:
- next_level_queue.append(link)
- all_pages_to_scrape.add(link)
- remaining_slots -= 1
+ remaining_slots = _add_links_to_queue(
+ body_links, next_level_queue, all_pages_to_scrape, remaining_slots
+ )
# If current queue is empty, add next level links
if not queue:
diff --git a/src/tools/vendored/web_search_core.py b/src/tools/vendored/web_search_core.py
index 391b5c2e..a465f398 100644
--- a/src/tools/vendored/web_search_core.py
+++ b/src/tools/vendored/web_search_core.py
@@ -199,13 +199,3 @@ def is_valid_url(url: str) -> bool:
if any(ext in url for ext in restricted_extensions):
return False
return True
-
-
-
-
-
-
-
-
-
-
diff --git a/src/tools/web_search.py b/src/tools/web_search.py
index 318b43d3..f6350e82 100644
--- a/src/tools/web_search.py
+++ b/src/tools/web_search.py
@@ -3,6 +3,7 @@
import asyncio
import structlog
+
try:
from ddgs import DDGS # New package name
except ImportError:
@@ -55,10 +56,15 @@ def _do_search() -> list[dict[str, str]]:
evidence = []
for r in raw_results:
+ # Truncate title to max 500 characters to match Citation model validation
+ title = r.get("title", "No Title")
+ if len(title) > 500:
+ title = title[:497] + "..."
+
ev = Evidence(
content=r.get("body", ""),
citation=Citation(
- title=r.get("title", "No Title"),
+ title=title,
url=r.get("href", ""),
source="web",
date="Unknown",
diff --git a/src/tools/web_search_adapter.py b/src/tools/web_search_adapter.py
index 4b4bfab0..6ee833f2 100644
--- a/src/tools/web_search_adapter.py
+++ b/src/tools/web_search_adapter.py
@@ -53,11 +53,3 @@ async def web_search(query: str) -> str:
except Exception as e:
logger.error("Web search failed", error=str(e), query=query)
return f"Error performing web search: {e!s}"
-
-
-
-
-
-
-
-
diff --git a/src/tools/web_search_factory.py b/src/tools/web_search_factory.py
index fae4c5ff..213de5de 100644
--- a/src/tools/web_search_factory.py
+++ b/src/tools/web_search_factory.py
@@ -12,19 +12,66 @@
logger = structlog.get_logger()
-def create_web_search_tool() -> SearchTool | None:
+def create_web_search_tool(provider: str | None = None) -> SearchTool | None:
"""Create a web search tool based on configuration.
+ Args:
+ provider: Override provider selection. If None, uses settings.web_search_provider.
+
Returns:
SearchTool instance, or None if not available/configured
- The tool is selected based on settings.web_search_provider:
+ The tool is selected based on provider (or settings.web_search_provider if None):
- "serper": SerperWebSearchTool (requires SERPER_API_KEY)
- "searchxng": SearchXNGWebSearchTool (requires SEARCHXNG_HOST)
- "duckduckgo": WebSearchTool (always available, no API key)
- "brave" or "tavily": Not yet implemented, returns None
+ - "auto": Auto-detect best available provider (prefers Serper > SearchXNG > DuckDuckGo)
+
+ Auto-detection logic (when provider is "auto" or not explicitly set):
+ 1. Try Serper if SERPER_API_KEY is available (best quality - Google search + full content scraping)
+ 2. Try SearchXNG if SEARCHXNG_HOST is available
+ 3. Fall back to DuckDuckGo (always available, but lower quality - snippets only)
"""
- provider = settings.web_search_provider
+ provider = provider or settings.web_search_provider
+
+ # Auto-detect best available provider if "auto" or if provider is duckduckgo but better options exist
+ if provider == "auto" or (provider == "duckduckgo" and settings.serper_api_key):
+ # Prefer Serper if API key is available (better quality)
+ if settings.serper_api_key:
+ try:
+ logger.info(
+ "Auto-detected Serper web search (SERPER_API_KEY found)",
+ provider="serper",
+ )
+ return SerperWebSearchTool()
+ except Exception as e:
+ logger.warning(
+ "Failed to initialize Serper, falling back",
+ error=str(e),
+ )
+
+ # Try SearchXNG as second choice
+ if settings.searchxng_host:
+ try:
+ logger.info(
+ "Auto-detected SearchXNG web search (SEARCHXNG_HOST found)",
+ provider="searchxng",
+ )
+ return SearchXNGWebSearchTool()
+ except Exception as e:
+ logger.warning(
+ "Failed to initialize SearchXNG, falling back",
+ error=str(e),
+ )
+
+ # Fall back to DuckDuckGo
+ if provider == "auto":
+ logger.info(
+ "Auto-detected DuckDuckGo web search (no API keys found)",
+ provider="duckduckgo",
+ )
+ return WebSearchTool()
try:
if provider == "serper":
@@ -66,13 +113,3 @@ def create_web_search_tool() -> SearchTool | None:
except Exception as e:
logger.error("Unexpected error creating web search tool", error=str(e), provider=provider)
return None
-
-
-
-
-
-
-
-
-
-
diff --git a/src/utils/config.py b/src/utils/config.py
index b6b2a0ba..9d4356d7 100644
--- a/src/utils/config.py
+++ b/src/utils/config.py
@@ -61,6 +61,15 @@ class Settings(BaseSettings):
default="meta-llama/Llama-3.1-8B-Instruct",
description="Default HuggingFace model ID for inference",
)
+ hf_fallback_models: str = Field(
+ default="Qwen/Qwen3-Next-80B-A3B-Thinking,Qwen/Qwen3-Next-80B-A3B-Instruct,meta-llama/Llama-3.3-70B-Instruct,meta-llama/Llama-3.1-8B-Instruct,HuggingFaceH4/zephyr-7b-beta,Qwen/Qwen2-7B-Instruct",
+ alias="HF_FALLBACK_MODELS",
+ description=(
+ "Comma-separated list of fallback models for provider discovery and error recovery. "
+ "Reads from HF_FALLBACK_MODELS environment variable. "
+ "Default value is used only if the environment variable is not set."
+ ),
+ )
# PubMed Configuration
ncbi_api_key: str | None = Field(
@@ -68,9 +77,11 @@ class Settings(BaseSettings):
)
# Web Search Configuration
- web_search_provider: Literal["serper", "searchxng", "brave", "tavily", "duckduckgo"] = Field(
- default="duckduckgo",
- description="Web search provider to use",
+ web_search_provider: Literal["serper", "searchxng", "brave", "tavily", "duckduckgo", "auto"] = (
+ Field(
+ default="auto",
+ description="Web search provider to use. 'auto' will auto-detect best available (prefers Serper > SearchXNG > DuckDuckGo)",
+ )
)
serper_api_key: str | None = Field(default=None, description="Serper API key for Google search")
searchxng_host: str | None = Field(default=None, description="SearchXNG host URL")
@@ -163,6 +174,10 @@ class Settings(BaseSettings):
le=2.0,
description="TTS speech speed multiplier (0.5x to 2.0x)",
)
+ tts_use_llm_polish: bool = Field(
+ default=False,
+ description="Use LLM for final text polish before TTS (optional, costs API calls)",
+ )
tts_gpu: str | None = Field(
default=None,
description="Modal GPU type for TTS (T4, A10, A100, L4, L40S). None uses default T4.",
@@ -269,6 +284,19 @@ def web_search_available(self) -> bool:
return bool(self.tavily_api_key)
return False
+ def get_hf_fallback_models_list(self) -> list[str]:
+ """Get the list of fallback models as a list.
+
+ Parses the comma-separated HF_FALLBACK_MODELS string into a list,
+ stripping whitespace from each model ID.
+
+ Returns:
+ List of model IDs
+ """
+ if not self.hf_fallback_models:
+ return []
+ return [model.strip() for model in self.hf_fallback_models.split(",") if model.strip()]
+
def get_settings() -> Settings:
"""Factory function to get settings (allows mocking in tests)."""
diff --git a/src/utils/hf_error_handler.py b/src/utils/hf_error_handler.py
new file mode 100644
index 00000000..becc6862
--- /dev/null
+++ b/src/utils/hf_error_handler.py
@@ -0,0 +1,199 @@
+"""Utility functions for handling HuggingFace API errors and token validation."""
+
+import re
+from typing import Any
+
+import structlog
+
+logger = structlog.get_logger()
+
+
+def extract_error_details(error: Exception) -> dict[str, Any]:
+ """Extract error details from HuggingFace API errors.
+
+ Pydantic AI and HuggingFace Inference API errors often contain
+ information in the error message string like:
+ "status_code: 403, model_name: Qwen/Qwen3-Next-80B-A3B-Thinking, body: Forbidden"
+
+ Args:
+ error: The exception object
+
+ Returns:
+ Dictionary with extracted error details:
+ - status_code: HTTP status code (if found)
+ - model_name: Model name (if found)
+ - body: Error body/message (if found)
+ - error_type: Type of error (403, 422, etc.)
+ - is_auth_error: Whether this is an authentication/authorization error
+ - is_model_error: Whether this is a model-specific error
+ """
+ error_str = str(error)
+ details: dict[str, Any] = {
+ "status_code": None,
+ "model_name": None,
+ "body": None,
+ "error_type": "unknown",
+ "is_auth_error": False,
+ "is_model_error": False,
+ }
+
+ # Try to extract status_code
+ status_match = re.search(r"status_code:\s*(\d+)", error_str)
+ if status_match:
+ details["status_code"] = int(status_match.group(1))
+ details["error_type"] = f"http_{details['status_code']}"
+
+ # Determine error category
+ if details["status_code"] == 403:
+ details["is_auth_error"] = True
+ elif details["status_code"] == 422:
+ details["is_model_error"] = True
+
+ # Try to extract model_name
+ model_match = re.search(r"model_name:\s*([^\s,]+)", error_str)
+ if model_match:
+ details["model_name"] = model_match.group(1)
+
+ # Try to extract body
+ body_match = re.search(r"body:\s*(.+)", error_str)
+ if body_match:
+ details["body"] = body_match.group(1).strip()
+
+ return details
+
+
+def get_user_friendly_error_message(error: Exception, model_name: str | None = None) -> str:
+ """Generate a user-friendly error message from an exception.
+
+ Args:
+ error: The exception object
+ model_name: Optional model name for context
+
+ Returns:
+ User-friendly error message
+ """
+ details = extract_error_details(error)
+
+ if details["is_auth_error"]:
+ return (
+ "π **Authentication Error**\n\n"
+ "Your HuggingFace token doesn't have permission to access this model or API.\n\n"
+ "**Possible solutions:**\n"
+ "1. **Re-authenticate**: Log out and log back in to ensure your token has the `inference-api` scope\n"
+ "2. **Check model access**: Visit the model page on HuggingFace and request access if it's gated\n"
+ "3. **Use alternative model**: Try a different model that's publicly available\n\n"
+ f"**Model attempted**: {details['model_name'] or model_name or 'Unknown'}\n"
+ f"**Error**: {details['body'] or str(error)}"
+ )
+
+ if details["is_model_error"]:
+ return (
+ "β οΈ **Model Compatibility Error**\n\n"
+ "The selected model is not compatible with the current provider or has specific requirements.\n\n"
+ "**Possible solutions:**\n"
+ "1. **Try a different model**: Use a model that's compatible with the current provider\n"
+ "2. **Check provider status**: The provider may be in staging mode or unavailable\n"
+ "3. **Wait and retry**: If the model is in staging, it may become available later\n\n"
+ f"**Model attempted**: {details['model_name'] or model_name or 'Unknown'}\n"
+ f"**Error**: {details['body'] or str(error)}"
+ )
+
+ # Generic error
+ return (
+ "β **API Error**\n\n"
+ f"An error occurred while calling the HuggingFace API:\n\n"
+ f"**Error**: {error!s}\n\n"
+ "Please try again or contact support if the issue persists."
+ )
+
+
+def validate_hf_token(token: str | None) -> tuple[bool, str | None]:
+ """Validate HuggingFace token format.
+
+ Args:
+ token: The token to validate
+
+ Returns:
+ Tuple of (is_valid, error_message)
+ - is_valid: True if token appears valid
+ - error_message: Error message if invalid, None if valid
+ """
+ if not token:
+ return False, "Token is None or empty"
+
+ if not isinstance(token, str):
+ return False, f"Token is not a string (type: {type(token).__name__})"
+
+ if len(token) < 10:
+ return False, "Token appears too short (minimum 10 characters expected)"
+
+ # HuggingFace tokens typically start with "hf_" for user tokens
+ # OAuth tokens may have different formats, so we're lenient
+ # Just check it's not obviously invalid
+
+ return True, None
+
+
+def log_token_info(token: str | None, context: str = "") -> None:
+ """Log token information for debugging (without exposing the actual token).
+
+ Args:
+ token: The token to log info about
+ context: Additional context for the log message
+ """
+ if token:
+ is_valid, error_msg = validate_hf_token(token)
+ logger.debug(
+ "Token validation",
+ context=context,
+ has_token=True,
+ is_valid=is_valid,
+ token_length=len(token),
+ token_prefix=token[:4] + "..." if len(token) > 4 else "***",
+ validation_error=error_msg,
+ )
+ else:
+ logger.debug("Token validation", context=context, has_token=False)
+
+
+def should_retry_with_fallback(error: Exception) -> bool:
+ """Determine if an error should trigger a fallback to alternative models.
+
+ Args:
+ error: The exception object
+
+ Returns:
+ True if the error suggests we should try a fallback model
+ """
+ details = extract_error_details(error)
+
+ # Retry with fallback for:
+ # - 403 errors (authentication/permission issues - might work with different model)
+ # - 422 errors (model/provider compatibility - definitely try different model)
+ # - Model-specific errors
+ return (
+ details["is_auth_error"] or details["is_model_error"] or details["model_name"] is not None
+ )
+
+
+def get_fallback_models(original_model: str | None = None) -> list[str]:
+ """Get a list of fallback models to try.
+
+ Args:
+ original_model: The original model that failed
+
+ Returns:
+ List of fallback model names to try in order
+ """
+ # Publicly available models that should work with most tokens
+ fallbacks = [
+ "meta-llama/Llama-3.1-8B-Instruct", # Common, often available
+ "mistralai/Mistral-7B-Instruct-v0.3", # Alternative
+ "HuggingFaceH4/zephyr-7b-beta", # Ungated fallback
+ ]
+
+ # If original model is in the list, remove it
+ if original_model and original_model in fallbacks:
+ fallbacks.remove(original_model)
+
+ return fallbacks
diff --git a/src/utils/hf_model_validator.py b/src/utils/hf_model_validator.py
new file mode 100644
index 00000000..e7dd75ee
--- /dev/null
+++ b/src/utils/hf_model_validator.py
@@ -0,0 +1,477 @@
+"""Validator for querying available HuggingFace models and providers using OAuth token.
+
+This module provides functions to:
+1. Query available models from HuggingFace Hub
+2. Query available inference providers (with dynamic discovery)
+3. Validate model/provider combinations
+4. Return formatted lists for Gradio dropdowns
+
+Uses Hugging Face Hub API to discover providers dynamically by querying model
+information. Falls back to known providers list if discovery fails.
+"""
+
+import asyncio
+from time import time
+from typing import Any
+
+import structlog
+from huggingface_hub import HfApi
+
+from src.utils.config import settings
+
+logger = structlog.get_logger()
+
+
+def extract_oauth_token(oauth_token: Any) -> str | None:
+ """Extract OAuth token value from Gradio OAuthToken object.
+
+ Handles both gr.OAuthToken objects (with .token attribute) and plain strings.
+ This is a convenience function for Gradio apps that use OAuth authentication.
+
+ Args:
+ oauth_token: Gradio OAuthToken object or string token
+
+ Returns:
+ Token string if available, None otherwise
+ """
+ if oauth_token is None:
+ return None
+
+ if hasattr(oauth_token, "token"):
+ return oauth_token.token # type: ignore[no-any-return]
+ elif isinstance(oauth_token, str):
+ return oauth_token
+
+ logger.warning(
+ "Could not extract token from OAuthToken object",
+ oauth_token_type=type(oauth_token).__name__,
+ )
+ return None
+
+
+# Known providers as fallback (updated from Hugging Face documentation)
+# These are used when dynamic discovery fails or times out
+KNOWN_PROVIDERS = [
+ "auto", # Auto-select (always available)
+ "hf-inference", # HuggingFace's own Inference API
+ "nebius",
+ "together",
+ "scaleway",
+ "hyperbolic",
+ "novita",
+ "nscale",
+ "sambanova",
+ "ovh",
+ "fireworks-ai", # Note: API uses "fireworks-ai", not "fireworks"
+ "cerebras",
+ "fal-ai",
+ "cohere",
+]
+
+
+def get_provider_discovery_models() -> list[str]:
+ """Get list of models to use for provider discovery.
+
+ Reads from HF_FALLBACK_MODELS environment variable via settings.
+ The environment variable should be a comma-separated list of model IDs.
+
+ Returns:
+ List of model IDs to query for provider discovery
+ """
+ # Get models from HF_FALLBACK_MODELS environment variable
+ # This is automatically read by Pydantic Settings from the env var
+ fallback_models = settings.get_hf_fallback_models_list()
+
+ logger.debug(
+ "Using HF_FALLBACK_MODELS for provider discovery",
+ count=len(fallback_models),
+ models=fallback_models,
+ )
+
+ return fallback_models
+
+
+# Simple in-memory cache for provider lists (TTL: 1 hour)
+_provider_cache: dict[str, tuple[list[str], float]] = {}
+PROVIDER_CACHE_TTL = 3600 # 1 hour in seconds
+
+
+async def get_available_providers(token: str | None = None) -> list[str]:
+ """Get list of available inference providers.
+
+ Discovers providers dynamically by querying model information from HuggingFace Hub.
+ Uses caching to avoid repeated API calls. Falls back to known providers if discovery fails.
+
+ Strategy:
+ 1. Check cache (if valid, return cached list)
+ 2. Query popular models to extract unique providers from their inferenceProviderMapping
+ 3. Fall back to known providers list if discovery fails
+ 4. Cache results for future use
+
+ Args:
+ token: Optional HuggingFace API token for authenticated requests
+ Can be extracted from gr.OAuthToken.token in Gradio apps
+
+ Returns:
+ List of provider names sorted alphabetically, with "auto" first
+ (e.g., ["auto", "fireworks-ai", "hf-inference", "nebius", ...])
+ """
+ # Check cache first
+ cache_key = "providers" + (f"_{token[:8]}" if token else "_no_token")
+ if cache_key in _provider_cache:
+ cached_providers, cache_time = _provider_cache[cache_key]
+ if time() - cache_time < PROVIDER_CACHE_TTL:
+ logger.debug("Returning cached providers", count=len(cached_providers))
+ return cached_providers
+
+ try:
+ providers = set(["auto"]) # Always include "auto"
+
+ # Try dynamic discovery by querying popular models
+ loop = asyncio.get_running_loop()
+ api = HfApi(token=token)
+
+ # Get models to query from HF_FALLBACK_MODELS environment variable via settings
+ discovery_models = get_provider_discovery_models()
+
+ # Query a sample of popular models to discover providers
+ # This is more efficient than querying all models
+ discovery_count = 0
+ for model_id in discovery_models:
+ try:
+
+ def _get_model_info(m: str) -> Any:
+ """Get model info synchronously."""
+ return api.model_info(m, expand=["inferenceProviderMapping"]) # type: ignore[arg-type]
+
+ info = await loop.run_in_executor(None, _get_model_info, model_id)
+
+ # Extract providers from inference_provider_mapping
+ if hasattr(info, "inference_provider_mapping") and info.inference_provider_mapping:
+ mapping = info.inference_provider_mapping
+ # mapping is a dict like {'hf-inference': InferenceProviderMapping(...), ...}
+ providers.update(mapping.keys())
+ discovery_count += 1
+ logger.debug(
+ "Discovered providers from model",
+ model=model_id,
+ providers=list(mapping.keys()),
+ )
+ except Exception as e:
+ logger.debug(
+ "Could not get provider info for model",
+ model=model_id,
+ error=str(e),
+ )
+ continue
+
+ # If we discovered providers, use them; otherwise fall back to known providers
+ if len(providers) > 1: # More than just "auto"
+ provider_list = sorted(list(providers))
+ logger.info(
+ "Discovered providers dynamically",
+ count=len(provider_list),
+ models_queried=discovery_count,
+ has_token=bool(token),
+ )
+ else:
+ # Fallback to known providers
+ provider_list = KNOWN_PROVIDERS.copy()
+ logger.info(
+ "Using known providers list (discovery failed or incomplete)",
+ count=len(provider_list),
+ models_queried=discovery_count,
+ )
+
+ # Cache the results
+ _provider_cache[cache_key] = (provider_list, time())
+
+ return provider_list
+
+ except Exception as e:
+ logger.warning("Failed to get providers", error=str(e))
+ # Return known providers as fallback
+ return KNOWN_PROVIDERS.copy()
+
+
+async def get_available_models(
+ token: str | None = None,
+ task: str = "text-generation",
+ limit: int = 100,
+ inference_provider: str | None = None,
+) -> list[str]:
+ """Get list of available models for text generation.
+
+ Queries HuggingFace Hub API to get models that support text generation.
+ Optionally filters by inference provider to show only models available via that provider.
+
+ Args:
+ token: Optional HuggingFace API token for authenticated requests
+ Can be extracted from gr.OAuthToken.token in Gradio apps
+ task: Task type to filter models (default: "text-generation")
+ limit: Maximum number of models to return
+ inference_provider: Optional provider name to filter models (e.g., "fireworks-ai", "nebius")
+ If None, returns all models for the task
+
+ Returns:
+ List of model IDs (e.g., ["meta-llama/Llama-3.1-8B-Instruct", ...])
+ """
+ try:
+ loop = asyncio.get_running_loop()
+
+ def _fetch_models() -> list[str]:
+ """Fetch models synchronously in executor."""
+ api = HfApi(token=token)
+
+ # Build query parameters
+ query_params: dict[str, Any] = {
+ "task": task,
+ "sort": "downloads",
+ "direction": -1,
+ "limit": limit,
+ }
+
+ # Filter by inference provider if specified
+ if inference_provider and inference_provider != "auto":
+ query_params["inference_provider"] = inference_provider
+
+ # Search for models
+ models = api.list_models(**query_params)
+
+ # Extract model IDs
+ model_ids = [model.id for model in models]
+ return model_ids
+
+ model_ids = await loop.run_in_executor(None, _fetch_models)
+
+ logger.info(
+ "Fetched available models",
+ count=len(model_ids),
+ task=task,
+ provider=inference_provider or "all",
+ has_token=bool(token),
+ )
+
+ return model_ids
+
+ except Exception as e:
+ logger.warning("Failed to get models from Hub API", error=str(e))
+ # Return popular fallback models
+ return [
+ "meta-llama/Llama-3.1-8B-Instruct",
+ "mistralai/Mistral-7B-Instruct-v0.3",
+ "HuggingFaceH4/zephyr-7b-beta",
+ "google/gemma-2-9b-it",
+ ]
+
+
+async def validate_model_provider_combination(
+ model_id: str,
+ provider: str | None,
+ token: str | None = None,
+) -> tuple[bool, str | None]:
+ """Validate that a model is available with a specific provider.
+
+ Uses HuggingFace Hub API to check if the provider is listed in the model's
+ inferenceProviderMapping. This is faster and more reliable than making test API calls.
+
+ Args:
+ model_id: HuggingFace model ID
+ provider: Provider name (or None/empty for auto)
+ token: Optional HuggingFace API token (from gr.OAuthToken.token)
+
+ Returns:
+ Tuple of (is_valid, error_message)
+ - is_valid: True if combination is valid or provider is "auto"
+ - error_message: Error message if invalid, None if valid
+ """
+ # "auto" is always valid - let HuggingFace select the provider
+ if not provider or provider == "auto":
+ return True, None
+
+ try:
+ loop = asyncio.get_running_loop()
+ api = HfApi(token=token)
+
+ def _get_model_info() -> Any:
+ """Get model info with provider mapping synchronously."""
+ return api.model_info(model_id, expand=["inferenceProviderMapping"]) # type: ignore[arg-type]
+
+ info = await loop.run_in_executor(None, _get_model_info)
+
+ # Check if provider is in the model's inference provider mapping
+ if hasattr(info, "inference_provider_mapping") and info.inference_provider_mapping:
+ mapping = info.inference_provider_mapping
+ available_providers = set(mapping.keys())
+
+ # Normalize provider name (some APIs use "fireworks-ai", others use "fireworks")
+ normalized_provider = provider.lower()
+ provider_variants = {normalized_provider}
+
+ # Handle common provider name variations
+ if normalized_provider == "fireworks":
+ provider_variants.add("fireworks-ai")
+ elif normalized_provider == "fireworks-ai":
+ provider_variants.add("fireworks")
+
+ # Check if any variant matches
+ if any(p in available_providers for p in provider_variants):
+ logger.debug(
+ "Model/provider combination validated via API",
+ model=model_id,
+ provider=provider,
+ available_providers=list(available_providers),
+ )
+ return True, None
+ else:
+ error_msg = (
+ f"Model {model_id} is not available with provider '{provider}'. "
+ f"Available providers: {', '.join(sorted(available_providers))}"
+ )
+ logger.debug(
+ "Model/provider combination invalid",
+ model=model_id,
+ provider=provider,
+ available_providers=list(available_providers),
+ )
+ return False, error_msg
+ else:
+ # Model doesn't have provider mapping - assume valid and let actual usage determine
+ logger.debug(
+ "Model has no provider mapping, assuming valid",
+ model=model_id,
+ provider=provider,
+ )
+ return True, None
+
+ except Exception as e:
+ logger.warning(
+ "Model/provider validation failed",
+ model=model_id,
+ provider=provider,
+ error=str(e),
+ )
+ # Don't fail validation on error - let the actual request fail
+ # This is more user-friendly than blocking on validation errors
+ return True, None
+
+
+async def get_models_for_provider(
+ provider: str,
+ token: str | None = None,
+ limit: int = 50,
+) -> list[str]:
+ """Get models available for a specific provider.
+
+ This is a convenience wrapper around get_available_models() with provider filtering.
+
+ Args:
+ provider: Provider name (e.g., "nebius", "together", "fireworks-ai")
+ Note: Use "fireworks-ai" not "fireworks" for the API
+ token: Optional HuggingFace API token (from gr.OAuthToken.token)
+ limit: Maximum number of models to return
+
+ Returns:
+ List of model IDs available for the provider
+ """
+ # Normalize provider name for API
+ normalized_provider = provider
+ if provider.lower() == "fireworks":
+ normalized_provider = "fireworks-ai"
+ logger.debug("Normalized provider name", original=provider, normalized=normalized_provider)
+
+ return await get_available_models(
+ token=token,
+ task="text-generation",
+ limit=limit,
+ inference_provider=normalized_provider,
+ )
+
+
+async def validate_oauth_token(token: str | None) -> dict[str, Any]:
+ """Validate OAuth token and return available resources.
+
+ Args:
+ token: OAuth token to validate
+
+ Returns:
+ Dictionary with:
+ - is_valid: Whether token is valid
+ - has_inference_api_scope: Whether token has inference-api scope
+ - available_models: List of available model IDs
+ - available_providers: List of available provider names
+ - username: HuggingFace username (if available)
+ - error: Error message if validation failed
+ """
+ result: dict[str, Any] = {
+ "is_valid": False,
+ "has_inference_api_scope": False,
+ "available_models": [],
+ "available_providers": [],
+ "username": None,
+ "error": None,
+ }
+
+ if not token:
+ result["error"] = "No token provided"
+ return result
+
+ try:
+ # Validate token format
+ from src.utils.hf_error_handler import validate_hf_token
+
+ is_valid_format, format_error = validate_hf_token(token)
+ if not is_valid_format:
+ result["error"] = f"Invalid token format: {format_error}"
+ return result
+
+ # Try to get user info to validate token
+ loop = asyncio.get_running_loop()
+
+ def _get_user_info() -> dict[str, Any] | None:
+ """Get user info from HuggingFace API."""
+ try:
+ api = HfApi(token=token)
+ user_info = api.whoami()
+ return user_info
+ except Exception:
+ return None
+
+ user_info = await loop.run_in_executor(None, _get_user_info)
+
+ if user_info:
+ result["is_valid"] = True
+ result["username"] = user_info.get("name") or user_info.get("fullname")
+ logger.info("Token validated", username=result["username"])
+ else:
+ result["error"] = "Token validation failed - could not authenticate"
+ return result
+
+ # Try to query models to check inference-api scope
+ try:
+ models = await get_available_models(token=token, limit=10)
+ if models:
+ result["has_inference_api_scope"] = True
+ result["available_models"] = models
+ logger.info("Inference API scope confirmed", model_count=len(models))
+ except Exception as e:
+ logger.warning("Could not verify inference-api scope", error=str(e))
+ # Token might be valid but without inference-api scope
+ result["has_inference_api_scope"] = False
+ result["error"] = f"Token may not have inference-api scope: {e}"
+
+ # Get available providers
+ try:
+ providers = await get_available_providers(token=token)
+ result["available_providers"] = providers
+ except Exception as e:
+ logger.warning("Could not get providers", error=str(e))
+ # Use fallback providers
+ result["available_providers"] = ["auto"]
+
+ return result
+
+ except Exception as e:
+ logger.error("Token validation failed", error=str(e))
+ result["error"] = str(e)
+ return result
diff --git a/src/utils/llm_factory.py b/src/utils/llm_factory.py
index bcd958bd..20095682 100644
--- a/src/utils/llm_factory.py
+++ b/src/utils/llm_factory.py
@@ -16,9 +16,13 @@
from typing import TYPE_CHECKING, Any
+import structlog
+
from src.utils.config import settings
from src.utils.exceptions import ConfigurationError
+logger = structlog.get_logger()
+
if TYPE_CHECKING:
from agent_framework.openai import OpenAIChatClient
@@ -98,7 +102,7 @@ def get_chat_client_for_agent(oauth_token: str | None = None) -> Any:
"""
# Check if we have OAuth token or env vars
has_hf_key = bool(oauth_token or settings.has_huggingface_key)
-
+
# Prefer HuggingFace if available (free tier)
if has_hf_key:
return get_huggingface_chat_client(oauth_token=oauth_token)
@@ -147,6 +151,19 @@ def get_pydantic_ai_model(oauth_token: str | None = None) -> Any:
"3. Set huggingface_api_key in settings"
)
+ # Validate and log token information
+ from src.utils.hf_error_handler import log_token_info, validate_hf_token
+
+ log_token_info(effective_hf_token, context="get_pydantic_ai_model")
+ is_valid, error_msg = validate_hf_token(effective_hf_token)
+ if not is_valid:
+ logger.warning(
+ "Token validation failed in get_pydantic_ai_model",
+ error=error_msg,
+ has_oauth=bool(oauth_token),
+ )
+ # Continue anyway - let the API call fail with a clear error
+
# Always use HuggingFace with available token
model_name = settings.huggingface_model or "meta-llama/Llama-3.1-8B-Instruct"
hf_provider = HuggingFaceProvider(api_key=effective_hf_token)
diff --git a/src/utils/markdown.css b/src/utils/markdown.css
index b083c296..1854a4bb 100644
--- a/src/utils/markdown.css
+++ b/src/utils/markdown.css
@@ -5,12 +5,3 @@ body {
color: #000;
}
-
-
-
-
-
-
-
-
-
diff --git a/src/utils/md_to_pdf.py b/src/utils/md_to_pdf.py
index 940d707d..08e7f71d 100644
--- a/src/utils/md_to_pdf.py
+++ b/src/utils/md_to_pdf.py
@@ -1,6 +1,5 @@
"""Utility for converting markdown to PDF."""
-import os
from pathlib import Path
from typing import TYPE_CHECKING
@@ -43,9 +42,7 @@ def md_to_pdf(md_text: str, pdf_file_path: str) -> None:
OSError: If PDF file cannot be written
"""
if not _MD2PDF_AVAILABLE:
- raise ImportError(
- "md2pdf is not installed. Install it with: pip install md2pdf"
- )
+ raise ImportError("md2pdf is not installed. Install it with: pip install md2pdf")
if not md_text or not md_text.strip():
raise ValueError("Markdown text cannot be empty")
@@ -64,13 +61,3 @@ def md_to_pdf(md_text: str, pdf_file_path: str) -> None:
md2pdf(pdf_file_path, md_text, css_file_path=str(css_path))
logger.debug("PDF generated successfully", pdf_path=pdf_file_path)
-
-
-
-
-
-
-
-
-
-
diff --git a/src/utils/message_history.py b/src/utils/message_history.py
new file mode 100644
index 00000000..5ddbbe30
--- /dev/null
+++ b/src/utils/message_history.py
@@ -0,0 +1,169 @@
+"""Message history utilities for Pydantic AI integration."""
+
+from typing import Any
+
+import structlog
+
+try:
+ from pydantic_ai import ModelMessage, ModelRequest, ModelResponse
+ from pydantic_ai.messages import TextPart, UserPromptPart
+
+ _PYDANTIC_AI_AVAILABLE = True
+except ImportError:
+ # Fallback for older pydantic-ai versions
+ ModelMessage = Any # type: ignore[assignment, misc]
+ ModelRequest = Any # type: ignore[assignment, misc]
+ ModelResponse = Any # type: ignore[assignment, misc]
+ TextPart = Any # type: ignore[assignment, misc]
+ UserPromptPart = Any # type: ignore[assignment, misc]
+ _PYDANTIC_AI_AVAILABLE = False
+
+logger = structlog.get_logger()
+
+
+def convert_gradio_to_message_history(
+ history: list[dict[str, Any]],
+ max_messages: int = 20,
+) -> list[ModelMessage]:
+ """
+ Convert Gradio chat history to Pydantic AI message history.
+
+ Args:
+ history: Gradio chat history format [{"role": "user", "content": "..."}, ...]
+ max_messages: Maximum messages to include (most recent)
+
+ Returns:
+ List of ModelMessage objects for Pydantic AI
+ """
+ if not history:
+ return []
+
+ if not _PYDANTIC_AI_AVAILABLE:
+ logger.warning(
+ "Pydantic AI message history not available, returning empty list",
+ )
+ return []
+
+ messages: list[ModelMessage] = []
+
+ # Take most recent messages
+ recent = history[-max_messages:] if len(history) > max_messages else history
+
+ for msg in recent:
+ role = msg.get("role", "")
+ content = msg.get("content", "")
+
+ if not content or role not in ("user", "assistant"):
+ continue
+
+ # Convert content to string if needed
+ content_str = str(content)
+
+ if role == "user":
+ messages.append(
+ ModelRequest(parts=[UserPromptPart(content=content_str)]),
+ )
+ elif role == "assistant":
+ messages.append(
+ ModelResponse(parts=[TextPart(content=content_str)]),
+ )
+
+ logger.debug(
+ "Converted Gradio history to message history",
+ input_turns=len(history),
+ output_messages=len(messages),
+ )
+
+ return messages
+
+
+def message_history_to_string(
+ messages: list[ModelMessage],
+ max_messages: int = 5,
+ include_metadata: bool = False,
+) -> str:
+ """
+ Convert message history to string format for backward compatibility.
+
+ Used during transition period when some agents still expect strings.
+
+ Args:
+ messages: List of ModelMessage objects
+ max_messages: Maximum messages to include
+ include_metadata: Whether to include metadata
+
+ Returns:
+ Formatted string representation
+ """
+ if not messages:
+ return ""
+
+ recent = messages[-max_messages:] if len(messages) > max_messages else messages
+
+ parts = ["PREVIOUS CONVERSATION:", "---"]
+ turn_num = 1
+
+ for msg in recent:
+ # Extract text content
+ text = ""
+ if isinstance(msg, ModelRequest):
+ for part in msg.parts:
+ if hasattr(part, "content"):
+ text += str(part.content)
+ parts.append(f"[Turn {turn_num}]")
+ parts.append(f"User: {text}")
+ turn_num += 1
+ elif isinstance(msg, ModelResponse):
+ for part in msg.parts: # type: ignore[assignment]
+ if hasattr(part, "content"):
+ text += str(part.content)
+ parts.append(f"Assistant: {text}")
+
+ parts.append("---")
+ return "\n".join(parts)
+
+
+def create_truncation_processor(max_messages: int = 10) -> Any:
+ """Create a history processor that keeps only the most recent N messages.
+
+ Args:
+ max_messages: Maximum number of messages to keep
+
+ Returns:
+ Processor function that takes a list of messages and returns truncated list
+ """
+
+ def processor(messages: list[ModelMessage]) -> list[ModelMessage]:
+ return messages[-max_messages:] if len(messages) > max_messages else messages
+
+ return processor
+
+
+def create_relevance_processor(min_length: int = 10) -> Any:
+ """Create a history processor that filters out very short messages.
+
+ Args:
+ min_length: Minimum message length to keep
+
+ Returns:
+ Processor function that filters messages by length
+ """
+
+ def processor(messages: list[ModelMessage]) -> list[ModelMessage]:
+ filtered = []
+ for msg in messages:
+ text = ""
+ if isinstance(msg, ModelRequest):
+ for part in msg.parts:
+ if hasattr(part, "content"):
+ text += str(part.content)
+ elif isinstance(msg, ModelResponse):
+ for part in msg.parts: # type: ignore[assignment]
+ if hasattr(part, "content"):
+ text += str(part.content)
+
+ if len(text.strip()) >= min_length:
+ filtered.append(msg)
+ return filtered
+
+ return processor
diff --git a/src/utils/models.py b/src/utils/models.py
index 0582aab4..e43efac0 100644
--- a/src/utils/models.py
+++ b/src/utils/models.py
@@ -6,7 +6,9 @@
from pydantic import BaseModel, Field
# Centralized source type - add new sources here (e.g., "biorxiv" in Phase 11)
-SourceName = Literal["pubmed", "clinicaltrials", "biorxiv", "europepmc", "preprint", "rag", "web", "neo4j"]
+SourceName = Literal[
+ "pubmed", "clinicaltrials", "biorxiv", "europepmc", "preprint", "rag", "web", "neo4j"
+]
class Citation(BaseModel):
diff --git a/src/utils/report_generator.py b/src/utils/report_generator.py
index ad283050..330f2ad2 100644
--- a/src/utils/report_generator.py
+++ b/src/utils/report_generator.py
@@ -5,11 +5,101 @@
import structlog
if TYPE_CHECKING:
- from src.utils.models import Evidence
+ from src.utils.models import Citation, Evidence
logger = structlog.get_logger()
+def _format_authors(citation: "Citation") -> str:
+ """Format authors string from citation."""
+ authors = ", ".join(citation.authors[:3])
+ if len(citation.authors) > 3:
+ authors += " et al."
+ elif not authors:
+ authors = "Unknown"
+ return authors
+
+
+def _add_evidence_section(report_parts: list[str], evidence: list["Evidence"]) -> None:
+ """Add evidence summary section to report."""
+ from src.utils.models import SourceName
+
+ report_parts.append("## Evidence Summary\n")
+ report_parts.append(f"**Total Sources Found:** {len(evidence)}\n\n")
+
+ # Group evidence by source
+ by_source: dict[SourceName, list[Evidence]] = {}
+ for ev in evidence:
+ source = ev.citation.source
+ if source not in by_source:
+ by_source[source] = []
+ by_source[source].append(ev)
+
+ # Organize by source
+ for source in sorted(by_source.keys()): # type: ignore[assignment]
+ source_evidence = by_source[source]
+ report_parts.append(f"### {source.upper()} Sources ({len(source_evidence)})\n\n")
+
+ for i, ev in enumerate(source_evidence, 1):
+ authors = _format_authors(ev.citation)
+ report_parts.append(f"#### {i}. {ev.citation.title}\n")
+ if authors and authors != "Unknown":
+ report_parts.append(f"**Authors:** {authors} \n")
+ report_parts.append(f"**Date:** {ev.citation.date} \n")
+ report_parts.append(f"**Source:** {ev.citation.source.upper()} \n")
+ report_parts.append(f"**URL:** {ev.citation.url} \n\n")
+
+ # Content (truncated if too long)
+ content = ev.content
+ if len(content) > 500:
+ content = content[:500] + "... [truncated]"
+ report_parts.append(f"{content}\n\n")
+
+
+def _add_key_findings(report_parts: list[str], evidence: list["Evidence"]) -> None:
+ """Add key findings section to report."""
+ report_parts.append("## Key Findings\n\n")
+ report_parts.append(
+ "Based on the evidence collected, the following key points were identified:\n\n"
+ )
+
+ # Extract key points from evidence (first sentence or summary)
+ key_points: list[str] = []
+ for ev in evidence[:10]: # Limit to top 10
+ # Try to extract first meaningful sentence
+ content = ev.content.strip()
+ if content:
+ # Find first sentence
+ first_period = content.find(".")
+ if first_period > 0 and first_period < 200:
+ key_point = content[: first_period + 1].strip()
+ else:
+ # Fallback: first 150 chars
+ key_point = content[:150].strip()
+ if len(content) > 150:
+ key_point += "..."
+ key_points.append(f"- {key_point} [[{len(key_points) + 1}]](#references)")
+
+ if key_points:
+ report_parts.append("\n".join(key_points))
+ report_parts.append("\n\n")
+ else:
+ report_parts.append("*No specific key findings could be extracted from the evidence.*\n\n")
+
+
+def _add_references(report_parts: list[str], evidence: list["Evidence"]) -> None:
+ """Add references section to report."""
+ report_parts.append("## References\n\n")
+ for i, ev in enumerate(evidence, 1):
+ authors = _format_authors(ev.citation)
+ report_parts.append(
+ f"[{i}] {authors} ({ev.citation.date}). "
+ f"*{ev.citation.title}*. "
+ f"{ev.citation.source.upper()}. "
+ f"Available at: {ev.citation.url}\n\n"
+ )
+
+
def generate_report_from_evidence(
query: str,
evidence: list["Evidence"] | None = None,
@@ -36,9 +126,7 @@ def generate_report_from_evidence(
# Introduction
report_parts.append("## Introduction\n")
- report_parts.append(
- f"This report addresses the following research query: **{query}**\n"
- )
+ report_parts.append(f"This report addresses the following research query: **{query}**\n")
report_parts.append(
"*Note: This report was generated from collected evidence. "
"LLM-based synthesis was unavailable due to API limitations.*\n\n"
@@ -46,73 +134,8 @@ def generate_report_from_evidence(
# Evidence Summary
if evidence and len(evidence) > 0:
- report_parts.append("## Evidence Summary\n")
- report_parts.append(
- f"**Total Sources Found:** {len(evidence)}\n\n"
- )
-
- # Group evidence by source
- by_source: dict[str, list["Evidence"]] = {}
- for ev in evidence:
- source = ev.citation.source
- if source not in by_source:
- by_source[source] = []
- by_source[source].append(ev)
-
- # Organize by source
- for source in sorted(by_source.keys()):
- source_evidence = by_source[source]
- report_parts.append(f"### {source.upper()} Sources ({len(source_evidence)})\n\n")
-
- for i, ev in enumerate(source_evidence, 1):
- # Format citation
- authors = ", ".join(ev.citation.authors[:3])
- if len(ev.citation.authors) > 3:
- authors += " et al."
-
- report_parts.append(f"#### {i}. {ev.citation.title}\n")
- if authors:
- report_parts.append(f"**Authors:** {authors} \n")
- report_parts.append(f"**Date:** {ev.citation.date} \n")
- report_parts.append(f"**Source:** {ev.citation.source.upper()} \n")
- report_parts.append(f"**URL:** {ev.citation.url} \n\n")
-
- # Content (truncated if too long)
- content = ev.content
- if len(content) > 500:
- content = content[:500] + "... [truncated]"
- report_parts.append(f"{content}\n\n")
-
- # Key Findings Section
- report_parts.append("## Key Findings\n\n")
- report_parts.append(
- "Based on the evidence collected, the following key points were identified:\n\n"
- )
-
- # Extract key points from evidence (first sentence or summary)
- key_points: list[str] = []
- for ev in evidence[:10]: # Limit to top 10
- # Try to extract first meaningful sentence
- content = ev.content.strip()
- if content:
- # Find first sentence
- first_period = content.find(".")
- if first_period > 0 and first_period < 200:
- key_point = content[: first_period + 1].strip()
- else:
- # Fallback: first 150 chars
- key_point = content[:150].strip()
- if len(content) > 150:
- key_point += "..."
- key_points.append(f"- {key_point} [[{len(key_points) + 1}]](#references)")
-
- if key_points:
- report_parts.append("\n".join(key_points))
- report_parts.append("\n\n")
- else:
- report_parts.append(
- "*No specific key findings could be extracted from the evidence.*\n\n"
- )
+ _add_evidence_section(report_parts, evidence)
+ _add_key_findings(report_parts, evidence)
elif findings:
# Fallback: use findings string if evidence not available
@@ -129,20 +152,7 @@ def generate_report_from_evidence(
# References Section
if evidence and len(evidence) > 0:
- report_parts.append("## References\n\n")
- for i, ev in enumerate(evidence, 1):
- authors = ", ".join(ev.citation.authors[:3])
- if len(ev.citation.authors) > 3:
- authors += " et al."
- elif not authors:
- authors = "Unknown"
-
- report_parts.append(
- f"[{i}] {authors} ({ev.citation.date}). "
- f"*{ev.citation.title}*. "
- f"{ev.citation.source.upper()}. "
- f"Available at: {ev.citation.url}\n\n"
- )
+ _add_references(report_parts, evidence)
# Conclusion
report_parts.append("## Conclusion\n\n")
@@ -167,13 +177,3 @@ def generate_report_from_evidence(
)
return "".join(report_parts)
-
-
-
-
-
-
-
-
-
-
diff --git a/tests/conftest.py b/tests/conftest.py
index 6d942f2b..0d44787e 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -1,7 +1,7 @@
"""Shared pytest fixtures for all tests."""
import os
-from unittest.mock import AsyncMock
+from unittest.mock import AsyncMock, MagicMock, patch
import pytest
@@ -78,4 +78,63 @@ def default_to_huggingface(monkeypatch):
# Set a dummy HF_TOKEN if not set (prevents errors, but tests should mock actual API calls)
if "HF_TOKEN" not in os.environ:
- monkeypatch.setenv("HF_TOKEN", "dummy_token_for_testing")
\ No newline at end of file
+ monkeypatch.setenv("HF_TOKEN", "dummy_token_for_testing")
+
+ # Unset OpenAI/Anthropic keys to prevent fallback (unless explicitly set for specific tests)
+ # This ensures get_model() uses HuggingFace
+ if "OPENAI_API_KEY" in os.environ and os.environ.get("OPENAI_API_KEY"):
+ # Only unset if it's not explicitly needed (tests can set it if needed)
+ monkeypatch.delenv("OPENAI_API_KEY", raising=False)
+ if "ANTHROPIC_API_KEY" in os.environ and os.environ.get("ANTHROPIC_API_KEY"):
+ monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False)
+
+
+@pytest.fixture
+def mock_hf_model():
+ """Create a real HuggingFace model instance for testing.
+
+ This fixture provides a real HuggingFaceModel instance that uses
+ the HF API/client. The InferenceClient is mocked to prevent real API calls.
+ """
+ from pydantic_ai.models.huggingface import HuggingFaceModel
+ from pydantic_ai.providers.huggingface import HuggingFaceProvider
+
+ # Create a real HuggingFace model with dummy token
+ # The InferenceClient will be mocked by auto_mock_hf_inference_client
+ provider = HuggingFaceProvider(api_key="dummy_token_for_testing")
+ model = HuggingFaceModel("meta-llama/Llama-3.1-8B-Instruct", provider=provider)
+ return model
+
+
+@pytest.fixture(autouse=True)
+def auto_mock_hf_inference_client(request):
+ """Automatically mock HuggingFace InferenceClient to prevent real API calls.
+
+ This fixture runs automatically for all tests (except OpenAI tests) and
+ mocks the InferenceClient so tests use HuggingFace models but don't make
+ real API calls. This allows tests to use the actual HF model/client setup
+ without requiring API keys or making network requests.
+
+ Tests marked with @pytest.mark.openai will skip this fixture.
+ Tests can override by explicitly patching InferenceClient themselves.
+ """
+ # Skip auto-mocking for OpenAI tests
+ if "openai" in request.keywords:
+ return
+
+ # Mock InferenceClient to prevent real API calls
+ # This allows get_model() to create real HuggingFaceModel instances
+ # but prevents actual network requests
+ mock_client = MagicMock()
+ mock_client.text_generation = AsyncMock(return_value="Mocked response")
+ mock_client.chat_completion = AsyncMock(return_value={"choices": [{"message": {"content": "Mocked response"}}]})
+
+ # Patch InferenceClient at its source (huggingface_hub)
+ # This will affect all imports of InferenceClient, including in pydantic_ai
+ inference_client_patch = patch("huggingface_hub.InferenceClient", return_value=mock_client)
+ inference_client_patch.start()
+
+ yield
+
+ # Stop patch
+ inference_client_patch.stop()
\ No newline at end of file
diff --git a/tests/integration/test_rag_integration.py b/tests/integration/test_rag_integration.py
index 38d3f6ec..56f9339c 100644
--- a/tests/integration/test_rag_integration.py
+++ b/tests/integration/test_rag_integration.py
@@ -8,6 +8,31 @@
import pytest
+# Skip if sentence_transformers cannot be imported
+# Note: sentence-transformers is a required dependency, but may fail due to:
+# - Windows regex circular import bug
+# - PyTorch C extensions not loading properly
+try:
+ pytest.importorskip("sentence_transformers", exc_type=ImportError)
+except (ImportError, OSError) as e:
+ # Handle various import issues
+ error_msg = str(e).lower()
+ if "regex" in error_msg or "_regex" in error_msg:
+ pytest.skip(
+ "sentence_transformers import failed due to Windows regex circular import bug. "
+ "This is a known issue with the regex package on Windows. "
+ "Try: uv pip install --upgrade --force-reinstall regex",
+ allow_module_level=True,
+ )
+ elif "pytorch" in error_msg or "torch" in error_msg:
+ pytest.skip(
+ "sentence_transformers import failed due to PyTorch C extensions issue. "
+ "Try: uv pip install --upgrade --force-reinstall torch",
+ allow_module_level=True,
+ )
+ # Re-raise other import errors
+ raise
+
from src.services.llamaindex_rag import get_rag_service
from src.tools.rag_tool import create_rag_tool
from src.tools.search_handler import SearchHandler
diff --git a/tests/integration/test_rag_integration_hf.py b/tests/integration/test_rag_integration_hf.py
index cee61ba7..9c58fb1d 100644
--- a/tests/integration/test_rag_integration_hf.py
+++ b/tests/integration/test_rag_integration_hf.py
@@ -6,6 +6,31 @@
import pytest
+# Skip if sentence_transformers cannot be imported
+# Note: sentence-transformers is a required dependency, but may fail due to:
+# - Windows regex circular import bug
+# - PyTorch C extensions not loading properly
+try:
+ pytest.importorskip("sentence_transformers", exc_type=ImportError)
+except (ImportError, OSError) as e:
+ # Handle various import issues
+ error_msg = str(e).lower()
+ if "regex" in error_msg or "_regex" in error_msg:
+ pytest.skip(
+ "sentence_transformers import failed due to Windows regex circular import bug. "
+ "This is a known issue with the regex package on Windows. "
+ "Try: uv pip install --upgrade --force-reinstall regex",
+ allow_module_level=True,
+ )
+ elif "pytorch" in error_msg or "torch" in error_msg:
+ pytest.skip(
+ "sentence_transformers import failed due to PyTorch C extensions issue. "
+ "Try: uv pip install --upgrade --force-reinstall torch",
+ allow_module_level=True,
+ )
+ # Re-raise other import errors
+ raise
+
from src.services.llamaindex_rag import get_rag_service
from src.tools.rag_tool import create_rag_tool
from src.tools.search_handler import SearchHandler
diff --git a/tests/unit/agent_factory/test_judges.py b/tests/unit/agent_factory/test_judges.py
index c2075cda..90551d99 100644
--- a/tests/unit/agent_factory/test_judges.py
+++ b/tests/unit/agent_factory/test_judges.py
@@ -142,6 +142,135 @@ async def test_assess_handles_llm_failure(self):
assert result.recommendation == "continue"
assert "failed" in result.reasoning.lower()
+ @pytest.mark.asyncio
+ async def test_assess_handles_403_error(self):
+ """JudgeHandler should handle 403 Forbidden errors with error extraction."""
+ error_msg = "status_code: 403, model_name: Qwen/Qwen3-Next-80B-A3B-Thinking, body: Forbidden"
+
+ with (
+ patch("src.agent_factory.judges.get_model") as mock_get_model,
+ patch("src.agent_factory.judges.Agent") as mock_agent_class,
+ patch("src.agent_factory.judges.logger") as mock_logger,
+ ):
+ mock_get_model.return_value = MagicMock()
+ mock_agent = AsyncMock()
+ mock_agent.run = AsyncMock(side_effect=Exception(error_msg))
+ mock_agent_class.return_value = mock_agent
+
+ handler = JudgeHandler()
+ handler.agent = mock_agent
+
+ evidence = [
+ Evidence(
+ content="Some content",
+ citation=Citation(
+ source="pubmed",
+ title="Title",
+ url="url",
+ date="2024",
+ ),
+ )
+ ]
+
+ result = await handler.assess("test question", evidence)
+
+ # Should return fallback
+ assert result.sufficient is False
+ assert result.recommendation == "continue"
+
+ # Should log error details
+ error_calls = [call for call in mock_logger.error.call_args_list if "Assessment failed" in str(call)]
+ assert len(error_calls) > 0
+
+ @pytest.mark.asyncio
+ async def test_assess_handles_422_error(self):
+ """JudgeHandler should handle 422 Unprocessable Entity errors."""
+ error_msg = "status_code: 422, model_name: meta-llama/Llama-3.1-70B-Instruct, body: Unprocessable Entity"
+
+ with (
+ patch("src.agent_factory.judges.get_model") as mock_get_model,
+ patch("src.agent_factory.judges.Agent") as mock_agent_class,
+ patch("src.agent_factory.judges.logger") as mock_logger,
+ ):
+ mock_get_model.return_value = MagicMock()
+ mock_agent = AsyncMock()
+ mock_agent.run = AsyncMock(side_effect=Exception(error_msg))
+ mock_agent_class.return_value = mock_agent
+
+ handler = JudgeHandler()
+ handler.agent = mock_agent
+
+ evidence = [
+ Evidence(
+ content="Some content",
+ citation=Citation(
+ source="pubmed",
+ title="Title",
+ url="url",
+ date="2024",
+ ),
+ )
+ ]
+
+ result = await handler.assess("test question", evidence)
+
+ # Should return fallback
+ assert result.sufficient is False
+
+ # Should log warning with user-friendly message
+ warning_calls = [call for call in mock_logger.warning.call_args_list if "API error details" in str(call)]
+ assert len(warning_calls) > 0
+
+
+class TestGetModel:
+ """Tests for get_model function with token validation."""
+
+ @patch("src.agent_factory.judges.settings")
+ @patch("src.utils.hf_error_handler.log_token_info")
+ @patch("src.utils.hf_error_handler.validate_hf_token")
+ def test_get_model_validates_oauth_token(self, mock_validate, mock_log, mock_settings):
+ """Should validate and log OAuth token when provided."""
+ mock_settings.hf_token = None
+ mock_settings.huggingface_api_key = None
+ mock_settings.huggingface_model = "test-model"
+ mock_validate.return_value = (True, None)
+
+ with patch("src.agent_factory.judges.HuggingFaceProvider"), \
+ patch("src.agent_factory.judges.HuggingFaceModel") as mock_model_class:
+ mock_model_class.return_value = MagicMock()
+
+ from src.agent_factory.judges import get_model
+
+ get_model(oauth_token="hf_test_token")
+
+ # Should log token info
+ mock_log.assert_called_once_with("hf_test_token", context="get_model")
+ # Should validate token
+ mock_validate.assert_called_once_with("hf_test_token")
+
+ @patch("src.agent_factory.judges.settings")
+ @patch("src.utils.hf_error_handler.log_token_info")
+ @patch("src.utils.hf_error_handler.validate_hf_token")
+ @patch("src.agent_factory.judges.logger")
+ def test_get_model_warns_on_invalid_token(self, mock_logger, mock_validate, mock_log, mock_settings):
+ """Should warn when token validation fails."""
+ mock_settings.hf_token = None
+ mock_settings.huggingface_api_key = None
+ mock_settings.huggingface_model = "test-model"
+ mock_validate.return_value = (False, "Token too short")
+
+ with patch("src.agent_factory.judges.HuggingFaceProvider"), \
+ patch("src.agent_factory.judges.HuggingFaceModel") as mock_model_class:
+ mock_model_class.return_value = MagicMock()
+
+ from src.agent_factory.judges import get_model
+
+ get_model(oauth_token="short")
+
+ # Should warn about invalid token
+ warning_calls = [call for call in mock_logger.warning.call_args_list if "Token validation failed" in str(call)]
+ assert len(warning_calls) > 0
+
class TestMockJudgeHandler:
"""Tests for MockJudgeHandler."""
diff --git a/tests/unit/agent_factory/test_judges_factory.py b/tests/unit/agent_factory/test_judges_factory.py
index 3cc7e331..037ef87e 100644
--- a/tests/unit/agent_factory/test_judges_factory.py
+++ b/tests/unit/agent_factory/test_judges_factory.py
@@ -21,8 +21,13 @@ def mock_settings():
yield mock_settings
+@pytest.mark.openai
def test_get_model_openai(mock_settings):
- """Test that OpenAI model is returned when provider is openai."""
+ """Test that OpenAI model is returned when provider is openai.
+
+ This test instantiates an OpenAI model and requires OpenAI API key.
+ Marked with @pytest.mark.openai to exclude from pre-commit and CI.
+ """
mock_settings.llm_provider = "openai"
mock_settings.openai_api_key = "sk-test"
mock_settings.openai_model = "gpt-5.1"
@@ -37,6 +42,11 @@ def test_get_model_anthropic(mock_settings):
mock_settings.llm_provider = "anthropic"
mock_settings.anthropic_api_key = "sk-ant-test"
mock_settings.anthropic_model = "claude-sonnet-4-5-20250929"
+ # Ensure no HF token is set, otherwise get_model() will prefer HuggingFace
+ mock_settings.hf_token = None
+ mock_settings.huggingface_api_key = None
+ mock_settings.has_openai_key = False
+ mock_settings.has_anthropic_key = True
model = get_model()
assert isinstance(model, AnthropicModel)
diff --git a/tests/unit/agents/test_audio_refiner.py b/tests/unit/agents/test_audio_refiner.py
new file mode 100644
index 00000000..162241ab
--- /dev/null
+++ b/tests/unit/agents/test_audio_refiner.py
@@ -0,0 +1,306 @@
+"""Unit tests for AudioRefiner agent."""
+
+import pytest
+from unittest.mock import AsyncMock, Mock, patch
+
+from src.agents.audio_refiner import AudioRefiner, refine_text_for_audio
+
+
+class TestAudioRefiner:
+ """Test suite for AudioRefiner functionality."""
+
+ @pytest.fixture
+ def refiner(self):
+ """Create AudioRefiner instance."""
+ return AudioRefiner()
+
+ def test_remove_markdown_headers(self, refiner):
+ """Test removal of markdown headers."""
+ text = """# Main Title
+## Subtitle
+### Section
+Content here"""
+ result = refiner._remove_markdown_syntax(text)
+ assert "#" not in result
+ assert "Main Title" in result
+ assert "Subtitle" in result
+
+ def test_remove_bold_italic(self, refiner):
+ """Test removal of bold and italic formatting."""
+ text = "**Bold text** and *italic text* and __another bold__"
+ result = refiner._remove_markdown_syntax(text)
+ assert "**" not in result
+ assert "*" not in result
+ assert "__" not in result
+ assert "Bold text" in result
+ assert "italic text" in result
+
+ def test_remove_links(self, refiner):
+ """Test removal of markdown links."""
+ text = "Check [this link](https://example.com) for details"
+ result = refiner._remove_markdown_syntax(text)
+ assert "[" not in result
+ assert "]" not in result
+ assert "https://" not in result
+ assert "this link" in result
+
+ def test_remove_citations_numbered(self, refiner):
+ """Test removal of numbered citations."""
+ text = "Research shows [1] that metformin [2,3] works [4-6]."
+ result = refiner._remove_citations(text)
+ assert "[1]" not in result
+ assert "[2,3]" not in result
+ assert "[4-6]" not in result
+ assert "Research shows" in result
+
+ def test_remove_citations_author_year(self, refiner):
+ """Test removal of author-year citations."""
+ text = "Studies (Smith et al., 2023) and (Jones, 2022) confirm this."
+ result = refiner._remove_citations(text)
+ assert "(Smith et al., 2023)" not in result
+ assert "(Jones, 2022)" not in result
+ assert "Studies" in result
+ assert "confirm this" in result
+
+ def test_remove_first_references_section(self, refiner):
+ """Test that References sections are removed while preserving other content."""
+ text = """Main content here.
+
+# References
+[1] First reference
+[2] Second reference
+
+# More Content
+This should remain.
+
+## References
+This second References should also be removed."""
+
+ result = refiner._remove_references_sections(text)
+ assert "Main content here" in result
+ assert "References" not in result
+ assert "First reference" not in result
+ assert "More Content" in result # Content after References should be preserved
+ assert "This should remain" in result
+ assert "second References should also be removed" not in result # Second References section removed
+
+ def test_roman_to_int_conversion(self, refiner):
+ """Test roman numeral to integer conversion."""
+ assert refiner._roman_to_int("I") == 1
+ assert refiner._roman_to_int("II") == 2
+ assert refiner._roman_to_int("III") == 3
+ assert refiner._roman_to_int("IV") == 4
+ assert refiner._roman_to_int("V") == 5
+ assert refiner._roman_to_int("IX") == 9
+ assert refiner._roman_to_int("X") == 10
+ assert refiner._roman_to_int("XII") == 12
+ assert refiner._roman_to_int("XX") == 20
+
+ def test_int_to_word_conversion(self, refiner):
+ """Test integer to word conversion."""
+ assert refiner._int_to_word(1) == "One"
+ assert refiner._int_to_word(2) == "Two"
+ assert refiner._int_to_word(3) == "Three"
+ assert refiner._int_to_word(10) == "Ten"
+ assert refiner._int_to_word(20) == "Twenty"
+ assert refiner._int_to_word(25) == "25" # Falls back to digit
+
+ def test_convert_roman_numerals_with_context(self, refiner):
+ """Test roman numeral conversion with context words."""
+ test_cases = [
+ ("Phase I trial", "Phase One trial"),
+ ("Phase II study", "Phase Two study"),
+ ("Phase III data", "Phase Three data"),
+ ("Type I diabetes", "Type One diabetes"),
+ ("Type II error", "Type Two error"),
+ ("Stage IV cancer", "Stage Four cancer"),
+ ("Trial I results", "Trial One results"),
+ ]
+
+ for input_text, expected in test_cases:
+ result = refiner._convert_roman_numerals(input_text)
+ assert expected in result, f"Failed for: {input_text}"
+
+ def test_convert_standalone_roman_numerals(self, refiner):
+ """Test standalone roman numeral conversion."""
+ text = "Results for I, II, and III are positive."
+ result = refiner._convert_roman_numerals(text)
+ # Standalone roman numerals should be converted
+ assert "One" in result or "Two" in result or "Three" in result
+
+ def test_dont_convert_roman_in_words(self, refiner):
+ """Test that roman numerals inside words aren't converted."""
+ text = "INVALID data fromIXIN compound"
+ result = refiner._convert_roman_numerals(text)
+ # Should not break words containing I, V, X, etc.
+ assert "INVALID" in result or "Invalid" in result # May be case-normalized
+
+ def test_clean_special_characters(self, refiner):
+ """Test special character cleanup."""
+ # Using unicode escapes to avoid syntax issues
+ text = "Text with \u2014 em-dash and \u201csmart quotes\u201d and \u2018apostrophes\u2019."
+ result = refiner._clean_special_characters(text)
+ assert "\u2014" not in result # em-dash
+ assert "\u201c" not in result # smart quote open
+ assert "\u2018" not in result # smart apostrophe
+ assert "-" in result
+
+ def test_normalize_whitespace(self, refiner):
+ """Test whitespace normalization."""
+ text = "Text with multiple spaces\n\n\n\nand many newlines"
+ result = refiner._normalize_whitespace(text)
+ assert " " not in result # No double spaces
+ assert "\n\n\n" not in result # Max two newlines
+
+ async def test_full_refine_workflow(self, refiner):
+ """Test complete refinement workflow."""
+ markdown_text = """# Summary
+
+**Metformin** shows promise for *long COVID* treatment [1].
+
+## Phase I Trials
+
+Research (Smith et al., 2023) indicates [2,3]:
+- 50% improvement
+- Low side effects
+
+Check [this study](https://example.com) for details.
+
+# References
+[1] Smith, J. et al. (2023)
+[2] Jones, K. (2022)
+"""
+
+ result = await refiner.refine_for_audio(markdown_text)
+
+ # Check markdown removed
+ assert "#" not in result
+ assert "**" not in result
+ assert "*" not in result
+
+ # Check citations removed
+ assert "[1]" not in result
+ assert "(Smith et al., 2023)" not in result
+
+ # Check roman numerals converted
+ assert "Phase One" in result
+
+ # Check references section removed
+ assert "References" not in result
+ assert "Smith, J. et al." not in result
+
+ # Check content preserved
+ assert "Metformin" in result
+ assert "long COVID" in result
+
+ async def test_convenience_function(self):
+ """Test convenience function."""
+ text = "**Bold** text with [link](url)"
+ result = await refine_text_for_audio(text)
+ assert "**" not in result
+ assert "[link]" not in result
+ assert "Bold" in result
+
+ async def test_empty_text(self, refiner):
+ """Test handling of empty text."""
+ assert await refiner.refine_for_audio("") == ""
+ assert await refiner.refine_for_audio(" ") == ""
+
+ async def test_no_references_section(self, refiner):
+ """Test text without References section."""
+ text = "Main content without references."
+ result = await refiner.refine_for_audio(text)
+ assert "Main content without references" in result
+
+ def test_multiple_reference_formats(self, refiner):
+ """Test different References section formats."""
+ formats = [
+ ("# References\nContent", True), # Markdown header - will be removed
+ ("## References\nContent", True), # Markdown header - will be removed
+ ("**References**\nContent", True), # Bold heading - will be removed
+ ("References:\nContent", False), # Standalone without markers - NOT removed (edge case)
+ ]
+
+ for format_text, should_remove in formats:
+ text = f"Main content\n{format_text}"
+ result = refiner._remove_references_sections(text)
+ assert "Main content" in result
+ if should_remove:
+ assert "References" not in result or result.count("References") == 0
+ # Standalone "References:" without markers is an edge case we don't handle
+
+ def test_preserve_paragraph_structure(self, refiner):
+ """Test that paragraph structure is preserved."""
+ text = "First paragraph.\n\nSecond paragraph.\n\nThird paragraph."
+
+ result = refiner._normalize_whitespace(text)
+ # Should have paragraph breaks (double newlines)
+ assert "\n\n" in result
+ # But not excessive newlines
+ assert "\n\n\n" not in result
+
+ @patch('src.agents.audio_refiner.get_pydantic_ai_model')
+ async def test_llm_polish_disabled_by_default(self, mock_get_model, refiner):
+ """Test that LLM polish is not called by default."""
+ text = "Test text"
+ result = await refiner.refine_for_audio(text, use_llm_polish=False)
+
+ # LLM should not be called when disabled
+ mock_get_model.assert_not_called()
+ assert "Test text" in result
+
+ @patch('src.agents.audio_refiner.Agent')
+ @patch('src.agents.audio_refiner.get_pydantic_ai_model')
+ async def test_llm_polish_enabled(self, mock_get_model, mock_agent_class, refiner):
+ """Test that LLM polish is called when enabled."""
+ # Setup mock
+ mock_model = Mock()
+ mock_get_model.return_value = mock_model
+
+ mock_agent_instance = Mock()
+ mock_result = Mock()
+ mock_result.output = "Polished text"
+ mock_agent_instance.run = AsyncMock(return_value=mock_result)
+ mock_agent_class.return_value = mock_agent_instance
+
+ # Test with LLM polish enabled
+ text = "**Test** text"
+ result = await refiner.refine_for_audio(text, use_llm_polish=True)
+
+ # Verify LLM was called
+ mock_get_model.assert_called_once()
+ mock_agent_class.assert_called_once()
+ mock_agent_instance.run.assert_called_once()
+
+ assert result == "Polished text"
+
+ @patch('src.agents.audio_refiner.Agent')
+ @patch('src.agents.audio_refiner.get_pydantic_ai_model')
+ async def test_llm_polish_graceful_fallback(self, mock_get_model, mock_agent_class, refiner):
+ """Test graceful fallback when LLM polish fails."""
+ # Setup mock to raise exception
+ mock_get_model.return_value = Mock()
+ mock_agent_instance = Mock()
+ mock_agent_instance.run = AsyncMock(side_effect=Exception("API Error"))
+ mock_agent_class.return_value = mock_agent_instance
+
+ # Test with LLM polish enabled but failing
+ text = "Test text"
+ result = await refiner.refine_for_audio(text, use_llm_polish=True)
+
+ # Should fall back to rule-based output
+ assert "Test text" in result
+ assert result != "" # Should not be empty
+
+ async def test_convenience_function_with_llm_polish(self):
+ """Test convenience function with LLM polish parameter."""
+ with patch.object(AudioRefiner, 'refine_for_audio') as mock_refine:
+ mock_refine.return_value = AsyncMock(return_value="Refined text")()
+
+ # Test without LLM polish
+ result = await refine_text_for_audio("Test", use_llm_polish=False)
+ mock_refine.assert_called_with("Test", use_llm_polish=False)
+
+ # Test with LLM polish
+ result = await refine_text_for_audio("Test", use_llm_polish=True)
+ mock_refine.assert_called_with("Test", use_llm_polish=True)
diff --git a/tests/unit/agents/test_input_parser.py b/tests/unit/agents/test_input_parser.py
index 95324026..f0c1075b 100644
--- a/tests/unit/agents/test_input_parser.py
+++ b/tests/unit/agents/test_input_parser.py
@@ -11,11 +11,13 @@
@pytest.fixture
-def mock_model() -> MagicMock:
- """Create a mock Pydantic AI model."""
- model = MagicMock()
- model.name = "test-model"
- return model
+def mock_model(mock_hf_model):
+ """Create a HuggingFace model for testing.
+
+ Uses the mock_hf_model from conftest which is a real HuggingFaceModel
+ instance with mocked InferenceClient to prevent real API calls.
+ """
+ return mock_hf_model
@pytest.fixture
diff --git a/tests/unit/middleware/test_budget_tracker_phase7.py b/tests/unit/middleware/test_budget_tracker_phase7.py
index 903addc1..3338ad03 100644
--- a/tests/unit/middleware/test_budget_tracker_phase7.py
+++ b/tests/unit/middleware/test_budget_tracker_phase7.py
@@ -159,4 +159,3 @@ def test_iteration_tokens_separate_per_loop(self) -> None:
assert budget2.iteration_tokens[1] == 200
-
diff --git a/tests/unit/middleware/test_state_machine.py b/tests/unit/middleware/test_state_machine.py
index 90fc3a4d..73a6008d 100644
--- a/tests/unit/middleware/test_state_machine.py
+++ b/tests/unit/middleware/test_state_machine.py
@@ -12,6 +12,14 @@
)
from src.utils.models import Citation, Conversation, Evidence, IterationData
+try:
+ from pydantic_ai import ModelRequest, ModelResponse
+ from pydantic_ai.messages import TextPart, UserPromptPart
+
+ _PYDANTIC_AI_AVAILABLE = True
+except ImportError:
+ _PYDANTIC_AI_AVAILABLE = False
+
@pytest.mark.unit
class TestWorkflowState:
@@ -143,6 +151,49 @@ async def test_search_related_handles_empty_authors(self) -> None:
assert len(results) == 1
assert results[0].citation.authors == []
+ @pytest.mark.skipif(not _PYDANTIC_AI_AVAILABLE, reason="pydantic_ai not available")
+ def test_user_message_history_initialization(self) -> None:
+ """WorkflowState should initialize with empty user_message_history."""
+ state = WorkflowState()
+ assert state.user_message_history == []
+
+ @pytest.mark.skipif(not _PYDANTIC_AI_AVAILABLE, reason="pydantic_ai not available")
+ def test_add_user_message(self) -> None:
+ """add_user_message should add messages to history."""
+ state = WorkflowState()
+ message = ModelRequest(parts=[UserPromptPart(content="Test message")])
+ state.add_user_message(message)
+ assert len(state.user_message_history) == 1
+ assert state.user_message_history[0] == message
+
+ @pytest.mark.skipif(not _PYDANTIC_AI_AVAILABLE, reason="pydantic_ai not available")
+ def test_get_user_history(self) -> None:
+ """get_user_history should return message history."""
+ state = WorkflowState()
+ for i in range(5):
+ message = ModelRequest(parts=[UserPromptPart(content=f"Message {i}")])
+ state.add_user_message(message)
+
+ # Get all history
+ all_history = state.get_user_history()
+ assert len(all_history) == 5
+
+ # Get limited history
+ limited = state.get_user_history(max_messages=3)
+ assert len(limited) == 3
+ # Should be most recent messages
+ assert limited[0].parts[0].content == "Message 2"
+
+ @pytest.mark.skipif(not _PYDANTIC_AI_AVAILABLE, reason="pydantic_ai not available")
+ def test_init_workflow_state_with_message_history(self) -> None:
+ """init_workflow_state should accept message_history parameter."""
+ messages = [
+ ModelRequest(parts=[UserPromptPart(content="Question")]),
+ ModelResponse(parts=[TextPart(content="Answer")]),
+ ]
+ state = init_workflow_state(message_history=messages)
+ assert len(state.user_message_history) == 2
+
@pytest.mark.unit
class TestConversation:
@@ -354,7 +405,3 @@ def context2():
assert len(state2.evidence) == 1
assert state1.evidence[0].citation.url == "https://example.com/1"
assert state2.evidence[0].citation.url == "https://example.com/2"
-
-
-
-
diff --git a/tests/unit/middleware/test_workflow_manager.py b/tests/unit/middleware/test_workflow_manager.py
index 3703390c..42d67658 100644
--- a/tests/unit/middleware/test_workflow_manager.py
+++ b/tests/unit/middleware/test_workflow_manager.py
@@ -285,4 +285,3 @@ async def test_get_shared_evidence(self, monkeypatch) -> None:
assert len(shared) == 1
assert shared[0].content == "Shared"
-
diff --git a/tests/unit/orchestrator/test_graph_orchestrator.py b/tests/unit/orchestrator/test_graph_orchestrator.py
index 4136663f..44732bf7 100644
--- a/tests/unit/orchestrator/test_graph_orchestrator.py
+++ b/tests/unit/orchestrator/test_graph_orchestrator.py
@@ -48,7 +48,86 @@ def test_visited_nodes_tracking(self):
context = GraphExecutionContext(WorkflowState(), BudgetTracker())
assert not context.has_visited("node1")
context.mark_visited("node1")
- assert context.has_visited("node1")
+
+ def test_message_history_initialization(self):
+ """Test message history initialization in context."""
+ from src.middleware.budget_tracker import BudgetTracker
+ from src.middleware.state_machine import WorkflowState
+
+ context = GraphExecutionContext(WorkflowState(), BudgetTracker())
+ assert context.message_history == []
+
+ def test_message_history_with_initial_history(self):
+ """Test context with initial message history."""
+ from src.middleware.budget_tracker import BudgetTracker
+ from src.middleware.state_machine import WorkflowState
+
+ try:
+ from pydantic_ai import ModelRequest
+ from pydantic_ai.messages import UserPromptPart
+
+ messages = [
+ ModelRequest(parts=[UserPromptPart(content="Test message")])
+ ]
+ context = GraphExecutionContext(
+ WorkflowState(), BudgetTracker(), message_history=messages
+ )
+ assert len(context.message_history) == 1
+ except ImportError:
+ pytest.skip("pydantic_ai not available")
+
+ def test_add_message(self):
+ """Test adding messages to context."""
+ from src.middleware.budget_tracker import BudgetTracker
+ from src.middleware.state_machine import WorkflowState
+
+ try:
+ from pydantic_ai import ModelRequest, ModelResponse
+ from pydantic_ai.messages import TextPart, UserPromptPart
+
+ context = GraphExecutionContext(WorkflowState(), BudgetTracker())
+ message1 = ModelRequest(parts=[UserPromptPart(content="Question")])
+ message2 = ModelResponse(parts=[TextPart(content="Answer")])
+
+ context.add_message(message1)
+ context.add_message(message2)
+
+ assert len(context.message_history) == 2
+ except ImportError:
+ pytest.skip("pydantic_ai not available")
+
+ def test_get_message_history(self):
+ """Test getting message history with limits."""
+ from src.middleware.budget_tracker import BudgetTracker
+ from src.middleware.state_machine import WorkflowState
+
+ try:
+ from pydantic_ai import ModelRequest
+ from pydantic_ai.messages import UserPromptPart
+
+ messages = [
+ ModelRequest(parts=[UserPromptPart(content=f"Message {i}")])
+ for i in range(10)
+ ]
+ context = GraphExecutionContext(
+ WorkflowState(), BudgetTracker(), message_history=messages
+ )
+
+ # Get all
+ all_messages = context.get_message_history()
+ assert len(all_messages) == 10
+
+ # Get limited
+ limited = context.get_message_history(max_messages=5)
+ assert len(limited) == 5
+ # Should be most recent
+ assert limited[0].parts[0].content == "Message 5"
+
+ # Visit a node to test has_visited
+ context.visited_nodes.add("node1")
+ assert context.has_visited("node1")
+ except ImportError:
+ pytest.skip("pydantic_ai not available")
class TestGraphOrchestrator:
@@ -177,7 +256,7 @@ async def mock_build_graph(mode: str):
orchestrator._build_graph = mock_build_graph
# Mock the graph execution
- async def mock_run_with_graph(query: str, mode: str):
+ async def mock_run_with_graph(query: str, research_mode: str, message_history: list | None = None):
yield AgentEvent(type="started", message="Starting", iteration=0)
yield AgentEvent(type="looping", message="Processing", iteration=1)
yield AgentEvent(type="complete", message="# Final Report\n\nContent", iteration=1)
diff --git a/tests/unit/services/test_embeddings.py b/tests/unit/services/test_embeddings.py
index f7e59e9b..aaa0c99e 100644
--- a/tests/unit/services/test_embeddings.py
+++ b/tests/unit/services/test_embeddings.py
@@ -6,15 +6,16 @@
import pytest
# Skip if embeddings dependencies are not installed
-# Handle Windows-specific scipy import issues
+# Handle Windows-specific scipy import issues and PyTorch C extensions issues
try:
pytest.importorskip("chromadb")
- pytest.importorskip("sentence_transformers")
-except OSError:
+ pytest.importorskip("sentence_transformers", exc_type=ImportError)
+except (OSError, ImportError):
# On Windows, scipy import can fail with OSError during collection
+ # PyTorch C extensions can also fail to load
# Skip the entire test module in this case
pytest.skip(
- "Embeddings dependencies not available (scipy import issue)", allow_module_level=True
+ "Embeddings dependencies not available (scipy/PyTorch import issue)", allow_module_level=True
)
from src.services.embeddings import EmbeddingService
diff --git a/tests/unit/test_app_oauth.py b/tests/unit/test_app_oauth.py
new file mode 100644
index 00000000..2a906a92
--- /dev/null
+++ b/tests/unit/test_app_oauth.py
@@ -0,0 +1,260 @@
+"""Unit tests for OAuth-related functions in app.py."""
+
+from unittest.mock import AsyncMock, MagicMock, patch
+
+import pytest
+
+# Mock gradio and its dependencies before importing app functions
+import sys
+from unittest.mock import MagicMock as Mock
+
+# Create a comprehensive gradio mock
+mock_gradio = Mock()
+mock_gradio.OAuthToken = Mock
+mock_gradio.OAuthProfile = Mock
+mock_gradio.Request = Mock
+mock_gradio.update = Mock(return_value={"choices": [], "value": ""})
+mock_gradio.Blocks = Mock
+mock_gradio.Markdown = Mock
+mock_gradio.LoginButton = Mock
+mock_gradio.Dropdown = Mock
+mock_gradio.Textbox = Mock
+mock_gradio.State = Mock
+
+# Mock gradio.data_classes
+mock_gradio.data_classes = Mock()
+mock_gradio.data_classes.FileData = Mock()
+
+sys.modules["gradio"] = mock_gradio
+sys.modules["gradio.data_classes"] = mock_gradio.data_classes
+
+# Mock other dependencies that might be imported
+sys.modules["neo4j"] = Mock()
+sys.modules["neo4j"].GraphDatabase = Mock()
+
+from src.app import extract_oauth_info, update_model_provider_dropdowns
+
+
+class TestExtractOAuthInfo:
+ """Tests for extract_oauth_info function."""
+
+ def test_extract_from_oauth_token_attribute(self) -> None:
+ """Should extract token from request.oauth_token.token."""
+ mock_request = MagicMock()
+ mock_oauth_token = MagicMock()
+ mock_oauth_token.token = "hf_test_token_123"
+ mock_request.oauth_token = mock_oauth_token
+ mock_request.username = "testuser"
+
+ token, username = extract_oauth_info(mock_request)
+
+ assert token == "hf_test_token_123"
+ assert username == "testuser"
+
+ def test_extract_from_string_oauth_token(self) -> None:
+ """Should extract token when oauth_token is a string."""
+ mock_request = MagicMock()
+ mock_request.oauth_token = "hf_test_token_123"
+ mock_request.username = "testuser"
+
+ token, username = extract_oauth_info(mock_request)
+
+ assert token == "hf_test_token_123"
+ assert username == "testuser"
+
+ def test_extract_from_authorization_header(self) -> None:
+ """Should extract token from Authorization header."""
+ mock_request = MagicMock()
+ mock_request.oauth_token = None
+ mock_request.headers = {"Authorization": "Bearer hf_test_token_123"}
+ mock_request.username = "testuser"
+
+ token, username = extract_oauth_info(mock_request)
+
+ assert token == "hf_test_token_123"
+ assert username == "testuser"
+
+ def test_extract_username_from_oauth_profile(self) -> None:
+ """Should extract username from oauth_profile."""
+ mock_request = MagicMock()
+ mock_request.oauth_token = None
+ mock_request.username = None
+ mock_oauth_profile = MagicMock()
+ mock_oauth_profile.username = "testuser"
+ mock_request.oauth_profile = mock_oauth_profile
+
+ token, username = extract_oauth_info(mock_request)
+
+ assert username == "testuser"
+
+ def test_extract_name_from_oauth_profile(self) -> None:
+ """Should extract name from oauth_profile when username not available."""
+ mock_request = MagicMock()
+ mock_request.oauth_token = None
+ # Ensure username attribute doesn't exist or is explicitly None
+ # Use delattr to remove it, then set oauth_profile
+ if hasattr(mock_request, "username"):
+ delattr(mock_request, "username")
+ mock_oauth_profile = MagicMock()
+ mock_oauth_profile.username = None
+ mock_oauth_profile.name = "Test User"
+ mock_request.oauth_profile = mock_oauth_profile
+
+ token, username = extract_oauth_info(mock_request)
+
+ assert username == "Test User"
+
+ def test_extract_none_request(self) -> None:
+ """Should return None for both when request is None."""
+ token, username = extract_oauth_info(None)
+
+ assert token is None
+ assert username is None
+
+ def test_extract_no_oauth_info(self) -> None:
+ """Should return None when no OAuth info is available."""
+ mock_request = MagicMock()
+ mock_request.oauth_token = None
+ mock_request.headers = {}
+ mock_request.username = None
+ mock_request.oauth_profile = None
+
+ token, username = extract_oauth_info(mock_request)
+
+ assert token is None
+ assert username is None
+
+
+class TestUpdateModelProviderDropdowns:
+ """Tests for update_model_provider_dropdowns function."""
+
+ @pytest.mark.asyncio
+ async def test_update_with_valid_token(self) -> None:
+ """Should update dropdowns with available models and providers."""
+ mock_oauth_token = MagicMock()
+ mock_oauth_token.token = "hf_test_token_123"
+ mock_oauth_profile = MagicMock()
+
+ mock_validation_result = {
+ "is_valid": True,
+ "has_inference_api_scope": True,
+ "available_models": ["model1", "model2"],
+ "available_providers": ["auto", "nebius"],
+ "username": "testuser",
+ }
+
+ with patch("src.utils.hf_model_validator.validate_oauth_token", return_value=mock_validation_result) as mock_validate, \
+ patch("src.utils.hf_model_validator.get_available_models", new_callable=AsyncMock) as mock_get_models, \
+ patch("src.utils.hf_model_validator.get_available_providers", new_callable=AsyncMock) as mock_get_providers, \
+ patch("src.app.gr") as mock_gr, \
+ patch("src.app.logger"):
+ mock_get_models.return_value = ["model1", "model2"]
+ mock_get_providers.return_value = ["auto", "nebius"]
+ mock_gr.update.return_value = {"choices": [], "value": ""}
+
+ result = await update_model_provider_dropdowns(mock_oauth_token, mock_oauth_profile)
+
+ assert len(result) == 3 # model_update, provider_update, status_msg
+ assert "testuser" in result[2] # Status message should contain username
+ mock_validate.assert_called_once_with("hf_test_token_123")
+
+ @pytest.mark.asyncio
+ async def test_update_with_no_token(self) -> None:
+ """Should return defaults when no token provided."""
+ with patch("src.app.gr") as mock_gr:
+ mock_gr.update.return_value = {"choices": [], "value": ""}
+
+ result = await update_model_provider_dropdowns(None, None)
+
+ assert len(result) == 3
+ assert "Not authenticated" in result[2] # Status message
+
+ @pytest.mark.asyncio
+ async def test_update_with_invalid_token(self) -> None:
+ """Should return error message for invalid token."""
+ mock_oauth_token = MagicMock()
+ mock_oauth_token.token = "invalid_token"
+
+ mock_validation_result = {
+ "is_valid": False,
+ "error": "Invalid token format",
+ }
+
+ with patch("src.utils.hf_model_validator.validate_oauth_token", return_value=mock_validation_result), \
+ patch("src.app.gr") as mock_gr:
+ mock_gr.update.return_value = {"choices": [], "value": ""}
+
+ result = await update_model_provider_dropdowns(mock_oauth_token, None)
+
+ assert len(result) == 3
+ assert "Token validation failed" in result[2]
+
+ @pytest.mark.asyncio
+ async def test_update_without_inference_scope(self) -> None:
+ """Should warn when token lacks inference-api scope."""
+ mock_oauth_token = MagicMock()
+ mock_oauth_token.token = "hf_token_without_scope"
+
+ mock_validation_result = {
+ "is_valid": True,
+ "has_inference_api_scope": False,
+ "available_models": [],
+ "available_providers": ["auto"],
+ "username": "testuser",
+ }
+
+ with patch("src.utils.hf_model_validator.validate_oauth_token", return_value=mock_validation_result), \
+ patch("src.utils.hf_model_validator.get_available_models", new_callable=AsyncMock) as mock_get_models, \
+ patch("src.utils.hf_model_validator.get_available_providers", new_callable=AsyncMock) as mock_get_providers, \
+ patch("src.app.gr") as mock_gr, \
+ patch("src.app.logger"):
+ mock_get_models.return_value = []
+ mock_get_providers.return_value = ["auto"]
+ mock_gr.update.return_value = {"choices": [], "value": ""}
+
+ result = await update_model_provider_dropdowns(mock_oauth_token, None)
+
+ assert len(result) == 3
+ assert "inference-api" in result[2] and "scope" in result[2]
+
+ @pytest.mark.asyncio
+ async def test_update_handles_exception(self) -> None:
+ """Should handle exceptions gracefully."""
+ mock_oauth_token = MagicMock()
+ mock_oauth_token.token = "hf_test_token"
+
+ with patch("src.utils.hf_model_validator.validate_oauth_token", side_effect=Exception("API error")), \
+ patch("src.app.gr") as mock_gr, \
+ patch("src.app.logger"):
+ mock_gr.update.return_value = {"choices": [], "value": ""}
+
+ result = await update_model_provider_dropdowns(mock_oauth_token, None)
+
+ assert len(result) == 3
+ assert "Failed to load models" in result[2]
+
+ @pytest.mark.asyncio
+ async def test_update_with_string_token(self) -> None:
+ """Should handle string token (edge case)."""
+ # Edge case: oauth_token is already a string
+ with patch("src.utils.hf_model_validator.validate_oauth_token") as mock_validate, \
+ patch("src.utils.hf_model_validator.get_available_models", new_callable=AsyncMock), \
+ patch("src.utils.hf_model_validator.get_available_providers", new_callable=AsyncMock), \
+ patch("src.app.gr") as mock_gr, \
+ patch("src.app.logger"):
+ mock_validation_result = {
+ "is_valid": True,
+ "has_inference_api_scope": True,
+ "available_models": ["model1"],
+ "available_providers": ["auto"],
+ "username": "testuser",
+ }
+ mock_validate.return_value = mock_validation_result
+ mock_gr.update.return_value = {"choices": [], "value": ""}
+
+ # Pass string directly (shouldn't happen but defensive)
+ result = await update_model_provider_dropdowns("hf_string_token", None)
+
+ # Should still work (extracts as string)
+ assert len(result) == 3
+
diff --git a/tests/unit/tools/test_web_search.py b/tests/unit/tools/test_web_search.py
new file mode 100644
index 00000000..d0a49ef8
--- /dev/null
+++ b/tests/unit/tools/test_web_search.py
@@ -0,0 +1,195 @@
+"""Unit tests for WebSearchTool."""
+
+from unittest.mock import AsyncMock, MagicMock, patch
+
+import pytest
+
+# Mock dependencies before importing to avoid import errors
+import sys
+sys.modules["neo4j"] = MagicMock()
+sys.modules["neo4j"].GraphDatabase = MagicMock()
+
+# Mock ddgs/duckduckgo_search
+# Create a proper mock structure to avoid "ddgs.ddgs" import errors
+mock_ddgs_module = MagicMock()
+mock_ddgs_submodule = MagicMock()
+# Create a mock DDGS class that can be instantiated
+class MockDDGS:
+ def __init__(self, *args, **kwargs):
+ pass
+ def text(self, *args, **kwargs):
+ return []
+
+mock_ddgs_submodule.DDGS = MockDDGS
+mock_ddgs_module.ddgs = mock_ddgs_submodule
+mock_ddgs_module.DDGS = MockDDGS
+sys.modules["ddgs"] = mock_ddgs_module
+sys.modules["ddgs.ddgs"] = mock_ddgs_submodule
+sys.modules["duckduckgo_search"] = MagicMock()
+sys.modules["duckduckgo_search"].DDGS = MockDDGS
+
+from src.tools.web_search import WebSearchTool
+from src.utils.exceptions import SearchError
+from src.utils.models import Citation, Evidence
+
+
+class TestWebSearchTool:
+ """Tests for WebSearchTool class."""
+
+ @pytest.fixture
+ def web_search_tool(self) -> WebSearchTool:
+ """Create a WebSearchTool instance."""
+ return WebSearchTool()
+
+ def test_name_property(self, web_search_tool: WebSearchTool) -> None:
+ """Should return correct tool name."""
+ assert web_search_tool.name == "duckduckgo"
+
+ @pytest.mark.asyncio
+ async def test_search_success(self, web_search_tool: WebSearchTool) -> None:
+ """Should return evidence from successful search."""
+ mock_results = [
+ {
+ "title": "Test Title 1",
+ "body": "Test content 1",
+ "href": "https://example.com/1",
+ },
+ {
+ "title": "Test Title 2",
+ "body": "Test content 2",
+ "href": "https://example.com/2",
+ },
+ ]
+
+ with patch.object(web_search_tool._ddgs, "text", return_value=iter(mock_results)):
+ with patch("src.tools.web_search.preprocess_query", return_value="clean query"):
+ evidence = await web_search_tool.search("test query", max_results=10)
+
+ assert len(evidence) == 2
+ assert isinstance(evidence[0], Evidence)
+ assert evidence[0].citation.title == "Test Title 1"
+ assert evidence[0].citation.url == "https://example.com/1"
+ assert evidence[0].citation.source == "web"
+
+ @pytest.mark.asyncio
+ async def test_search_title_truncation_500_chars(self, web_search_tool: WebSearchTool) -> None:
+ """Should truncate titles longer than 500 characters."""
+ long_title = "A" * 600 # 600 characters
+ mock_results = [
+ {
+ "title": long_title,
+ "body": "Test content",
+ "href": "https://example.com/1",
+ },
+ ]
+
+ with patch.object(web_search_tool._ddgs, "text", return_value=iter(mock_results)):
+ with patch("src.tools.web_search.preprocess_query", return_value="clean query"):
+ evidence = await web_search_tool.search("test query", max_results=10)
+
+ assert len(evidence) == 1
+ truncated_title = evidence[0].citation.title
+ assert len(truncated_title) == 500 # Should be exactly 500 chars
+ assert truncated_title.endswith("...") # Should end with ellipsis
+ assert truncated_title.startswith("A") # Should start with original content
+
+ @pytest.mark.asyncio
+ async def test_search_title_truncation_exactly_500_chars(self, web_search_tool: WebSearchTool) -> None:
+ """Should not truncate titles that are exactly 500 characters."""
+ exact_title = "A" * 500 # Exactly 500 characters
+ mock_results = [
+ {
+ "title": exact_title,
+ "body": "Test content",
+ "href": "https://example.com/1",
+ },
+ ]
+
+ with patch.object(web_search_tool._ddgs, "text", return_value=iter(mock_results)):
+ with patch("src.tools.web_search.preprocess_query", return_value="clean query"):
+ evidence = await web_search_tool.search("test query", max_results=10)
+
+ assert len(evidence) == 1
+ title = evidence[0].citation.title
+ assert len(title) == 500
+ assert title == exact_title # Should be unchanged
+
+ @pytest.mark.asyncio
+ async def test_search_title_truncation_501_chars(self, web_search_tool: WebSearchTool) -> None:
+ """Should truncate titles that are 501 characters."""
+ long_title = "A" * 501 # 501 characters
+ mock_results = [
+ {
+ "title": long_title,
+ "body": "Test content",
+ "href": "https://example.com/1",
+ },
+ ]
+
+ with patch.object(web_search_tool._ddgs, "text", return_value=iter(mock_results)):
+ with patch("src.tools.web_search.preprocess_query", return_value="clean query"):
+ evidence = await web_search_tool.search("test query", max_results=10)
+
+ assert len(evidence) == 1
+ truncated_title = evidence[0].citation.title
+ assert len(truncated_title) == 500
+ assert truncated_title.endswith("...")
+
+ @pytest.mark.asyncio
+ async def test_search_missing_title(self, web_search_tool: WebSearchTool) -> None:
+ """Should handle missing title gracefully."""
+ mock_results = [
+ {
+ "body": "Test content",
+ "href": "https://example.com/1",
+ },
+ ]
+
+ with patch.object(web_search_tool._ddgs, "text", return_value=iter(mock_results)):
+ with patch("src.tools.web_search.preprocess_query", return_value="clean query"):
+ evidence = await web_search_tool.search("test query", max_results=10)
+
+ assert len(evidence) == 1
+ assert evidence[0].citation.title == "No Title"
+
+ @pytest.mark.asyncio
+ async def test_search_empty_results(self, web_search_tool: WebSearchTool) -> None:
+ """Should return empty list for no results."""
+ with patch.object(web_search_tool._ddgs, "text", return_value=iter([])):
+ with patch("src.tools.web_search.preprocess_query", return_value="clean query"):
+ evidence = await web_search_tool.search("test query", max_results=10)
+
+ assert evidence == []
+
+ @pytest.mark.asyncio
+ async def test_search_raises_search_error(self, web_search_tool: WebSearchTool) -> None:
+ """Should raise SearchError on exception."""
+ with patch.object(web_search_tool._ddgs, "text", side_effect=Exception("API error")):
+ with patch("src.tools.web_search.preprocess_query", return_value="clean query"):
+ with pytest.raises(SearchError, match="DuckDuckGo search failed"):
+ await web_search_tool.search("test query", max_results=10)
+
+ @pytest.mark.asyncio
+ async def test_search_uses_preprocessed_query(self, web_search_tool: WebSearchTool) -> None:
+ """Should use preprocessed query for search."""
+ mock_results = [{"title": "Test", "body": "Content", "href": "https://example.com"}]
+
+ with patch.object(web_search_tool._ddgs, "text", return_value=iter(mock_results)) as mock_text:
+ with patch("src.tools.web_search.preprocess_query", return_value="preprocessed query"):
+ await web_search_tool.search("original query", max_results=5)
+
+ # Should call text with preprocessed query
+ mock_text.assert_called_once_with("preprocessed query", max_results=5)
+
+ @pytest.mark.asyncio
+ async def test_search_falls_back_to_original_query(self, web_search_tool: WebSearchTool) -> None:
+ """Should fall back to original query if preprocessing returns empty."""
+ mock_results = [{"title": "Test", "body": "Content", "href": "https://example.com"}]
+
+ with patch.object(web_search_tool._ddgs, "text", return_value=iter(mock_results)) as mock_text:
+ with patch("src.tools.web_search.preprocess_query", return_value=""):
+ await web_search_tool.search("original query", max_results=5)
+
+ # Should call text with original query
+ mock_text.assert_called_once_with("original query", max_results=5)
+
diff --git a/tests/unit/utils/test_hf_error_handler.py b/tests/unit/utils/test_hf_error_handler.py
new file mode 100644
index 00000000..cd7df2c2
--- /dev/null
+++ b/tests/unit/utils/test_hf_error_handler.py
@@ -0,0 +1,239 @@
+"""Unit tests for HuggingFace error handling utilities."""
+
+from unittest.mock import patch
+
+import pytest
+
+from src.utils.hf_error_handler import (
+ extract_error_details,
+ get_fallback_models,
+ get_user_friendly_error_message,
+ log_token_info,
+ should_retry_with_fallback,
+ validate_hf_token,
+)
+
+
+class TestExtractErrorDetails:
+ """Tests for extract_error_details function."""
+
+ def test_extract_403_error(self) -> None:
+ """Should extract 403 error details correctly."""
+ error = Exception("status_code: 403, model_name: Qwen/Qwen3-Next-80B-A3B-Thinking, body: Forbidden")
+ details = extract_error_details(error)
+
+ assert details["status_code"] == 403
+ assert details["model_name"] == "Qwen/Qwen3-Next-80B-A3B-Thinking"
+ assert details["body"] == "Forbidden"
+ assert details["error_type"] == "http_403"
+ assert details["is_auth_error"] is True
+ assert details["is_model_error"] is False
+
+ def test_extract_422_error(self) -> None:
+ """Should extract 422 error details correctly."""
+ error = Exception("status_code: 422, model_name: meta-llama/Llama-3.1-70B-Instruct, body: Unprocessable Entity")
+ details = extract_error_details(error)
+
+ assert details["status_code"] == 422
+ assert details["model_name"] == "meta-llama/Llama-3.1-70B-Instruct"
+ assert details["body"] == "Unprocessable Entity"
+ assert details["error_type"] == "http_422"
+ assert details["is_auth_error"] is False
+ assert details["is_model_error"] is True
+
+ def test_extract_partial_error(self) -> None:
+ """Should handle partial error information."""
+ error = Exception("status_code: 500")
+ details = extract_error_details(error)
+
+ assert details["status_code"] == 500
+ assert details["model_name"] is None
+ assert details["body"] is None
+ assert details["error_type"] == "http_500"
+ assert details["is_auth_error"] is False
+ assert details["is_model_error"] is False
+
+ def test_extract_generic_error(self) -> None:
+ """Should handle generic errors without status codes."""
+ error = Exception("Something went wrong")
+ details = extract_error_details(error)
+
+ assert details["status_code"] is None
+ assert details["model_name"] is None
+ assert details["body"] is None
+ assert details["error_type"] == "unknown"
+ assert details["is_auth_error"] is False
+ assert details["is_model_error"] is False
+
+
+class TestGetUserFriendlyErrorMessage:
+ """Tests for get_user_friendly_error_message function."""
+
+ def test_403_error_message(self) -> None:
+ """Should generate user-friendly 403 error message."""
+ error = Exception("status_code: 403, model_name: Qwen/Qwen3-Next-80B-A3B-Thinking, body: Forbidden")
+ message = get_user_friendly_error_message(error)
+
+ assert "Authentication Error" in message
+ assert "inference-api" in message
+ assert "Qwen/Qwen3-Next-80B-A3B-Thinking" in message
+ assert "Forbidden" in message
+
+ def test_422_error_message(self) -> None:
+ """Should generate user-friendly 422 error message."""
+ error = Exception("status_code: 422, model_name: meta-llama/Llama-3.1-70B-Instruct, body: Unprocessable Entity")
+ message = get_user_friendly_error_message(error)
+
+ assert "Model Compatibility Error" in message
+ assert "meta-llama/Llama-3.1-70B-Instruct" in message
+ assert "Unprocessable Entity" in message
+
+ def test_generic_error_message(self) -> None:
+ """Should generate generic error message for unknown errors."""
+ error = Exception("Something went wrong")
+ message = get_user_friendly_error_message(error)
+
+ assert "API Error" in message
+ assert "Something went wrong" in message
+
+ def test_error_message_with_model_name_param(self) -> None:
+ """Should use provided model_name parameter when not in error."""
+ error = Exception("status_code: 403, body: Forbidden")
+ message = get_user_friendly_error_message(error, model_name="test-model")
+
+ assert "test-model" in message
+
+
+class TestValidateHfToken:
+ """Tests for validate_hf_token function."""
+
+ def test_valid_token(self) -> None:
+ """Should validate a valid token."""
+ token = "hf_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
+ is_valid, error_msg = validate_hf_token(token)
+
+ assert is_valid is True
+ assert error_msg is None
+
+ def test_none_token(self) -> None:
+ """Should reject None token."""
+ is_valid, error_msg = validate_hf_token(None)
+
+ assert is_valid is False
+ assert "None or empty" in error_msg
+
+ def test_empty_token(self) -> None:
+ """Should reject empty token."""
+ is_valid, error_msg = validate_hf_token("")
+
+ assert is_valid is False
+ assert "None or empty" in error_msg
+
+ def test_non_string_token(self) -> None:
+ """Should reject non-string token."""
+ is_valid, error_msg = validate_hf_token(123) # type: ignore[arg-type]
+
+ assert is_valid is False
+ assert "not a string" in error_msg
+
+ def test_short_token(self) -> None:
+ """Should reject token that's too short."""
+ is_valid, error_msg = validate_hf_token("hf_123")
+
+ assert is_valid is False
+ assert "too short" in error_msg
+
+ def test_oauth_token_format(self) -> None:
+ """Should accept OAuth tokens (may not start with hf_)."""
+ # OAuth tokens may have different formats
+ token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ"
+ is_valid, error_msg = validate_hf_token(token)
+
+ assert is_valid is True
+ assert error_msg is None
+
+
+class TestLogTokenInfo:
+ """Tests for log_token_info function."""
+
+ @patch("src.utils.hf_error_handler.logger")
+ def test_log_valid_token(self, mock_logger) -> None:
+ """Should log token info for valid token."""
+ token = "hf_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
+ log_token_info(token, context="test")
+
+ mock_logger.debug.assert_called_once()
+ call_args = mock_logger.debug.call_args
+ assert call_args[0][0] == "Token validation"
+ assert call_args[1]["context"] == "test"
+ assert call_args[1]["has_token"] is True
+ assert call_args[1]["is_valid"] is True
+ assert call_args[1]["token_length"] == len(token)
+ assert "token_prefix" in call_args[1]
+
+ @patch("src.utils.hf_error_handler.logger")
+ def test_log_none_token(self, mock_logger) -> None:
+ """Should log None token info."""
+ log_token_info(None, context="test")
+
+ mock_logger.debug.assert_called_once()
+ call_args = mock_logger.debug.call_args
+ assert call_args[0][0] == "Token validation"
+ assert call_args[1]["context"] == "test"
+ assert call_args[1]["has_token"] is False
+
+
+class TestShouldRetryWithFallback:
+ """Tests for should_retry_with_fallback function."""
+
+ def test_403_error_should_retry(self) -> None:
+ """Should retry for 403 errors."""
+ error = Exception("status_code: 403, model_name: test-model, body: Forbidden")
+ assert should_retry_with_fallback(error) is True
+
+ def test_422_error_should_retry(self) -> None:
+ """Should retry for 422 errors."""
+ error = Exception("status_code: 422, model_name: test-model, body: Unprocessable Entity")
+ assert should_retry_with_fallback(error) is True
+
+ def test_model_specific_error_should_retry(self) -> None:
+ """Should retry for model-specific errors."""
+ error = Exception("status_code: 500, model_name: test-model, body: Error")
+ assert should_retry_with_fallback(error) is True
+
+ def test_generic_error_should_not_retry(self) -> None:
+ """Should not retry for generic errors without model info."""
+ error = Exception("Something went wrong")
+ assert should_retry_with_fallback(error) is False
+
+
+class TestGetFallbackModels:
+ """Tests for get_fallback_models function."""
+
+ def test_get_fallback_models_default(self) -> None:
+ """Should return default fallback models."""
+ fallbacks = get_fallback_models()
+
+ assert len(fallbacks) > 0
+ assert "meta-llama/Llama-3.1-8B-Instruct" in fallbacks
+ assert isinstance(fallbacks, list)
+
+ def test_get_fallback_models_excludes_original(self) -> None:
+ """Should exclude original model from fallbacks."""
+ original = "meta-llama/Llama-3.1-8B-Instruct"
+ fallbacks = get_fallback_models(original_model=original)
+
+ assert original not in fallbacks
+ assert len(fallbacks) > 0
+
+ def test_get_fallback_models_with_unknown_original(self) -> None:
+ """Should return all fallbacks if original is not in list."""
+ original = "unknown/model"
+ fallbacks = get_fallback_models(original_model=original)
+
+ # Should still have all fallbacks since original is not in the list
+ assert len(fallbacks) >= 3 # At least 3 fallback models
+
+
+
+
diff --git a/tests/unit/utils/test_hf_model_validator.py b/tests/unit/utils/test_hf_model_validator.py
new file mode 100644
index 00000000..2027ff72
--- /dev/null
+++ b/tests/unit/utils/test_hf_model_validator.py
@@ -0,0 +1,416 @@
+"""Unit tests for HuggingFace model and provider validator."""
+
+from unittest.mock import AsyncMock, MagicMock, patch
+
+import pytest
+
+from src.utils.hf_model_validator import (
+ extract_oauth_token,
+ get_available_models,
+ get_available_providers,
+ get_models_for_provider,
+ validate_model_provider_combination,
+ validate_oauth_token,
+)
+
+
+class TestExtractOAuthToken:
+ """Tests for extract_oauth_token function."""
+
+ def test_extract_from_oauth_token_object(self) -> None:
+ """Should extract token from OAuthToken object with .token attribute."""
+ mock_oauth_token = MagicMock()
+ mock_oauth_token.token = "hf_test_token_123"
+
+ result = extract_oauth_token(mock_oauth_token)
+
+ assert result == "hf_test_token_123"
+
+ def test_extract_from_string(self) -> None:
+ """Should return string token as-is."""
+ token = "hf_test_token_123"
+ result = extract_oauth_token(token)
+
+ assert result == token
+
+ def test_extract_none(self) -> None:
+ """Should return None for None input."""
+ result = extract_oauth_token(None)
+
+ assert result is None
+
+ def test_extract_invalid_object(self) -> None:
+ """Should return None for object without .token attribute."""
+ invalid_object = MagicMock()
+ del invalid_object.token # Remove token attribute
+
+ with patch("src.utils.hf_model_validator.logger") as mock_logger:
+ result = extract_oauth_token(invalid_object)
+
+ assert result is None
+ mock_logger.warning.assert_called_once()
+
+
+class TestGetAvailableProviders:
+ """Tests for get_available_providers function."""
+
+ @pytest.mark.asyncio
+ async def test_get_providers_with_cache(self) -> None:
+ """Should return cached providers if available."""
+ # First call - should query API
+ with patch("src.utils.hf_model_validator.HfApi") as mock_api_class:
+ mock_api = MagicMock()
+ mock_api_class.return_value = mock_api
+
+ # Mock model_info to return provider mapping
+ mock_model_info = MagicMock()
+ mock_model_info.inference_provider_mapping = {
+ "hf-inference": MagicMock(),
+ "nebius": MagicMock(),
+ }
+ mock_api.model_info.return_value = mock_model_info
+
+ # Mock settings
+ with patch("src.utils.hf_model_validator.settings") as mock_settings:
+ mock_settings.get_hf_fallback_models_list.return_value = [
+ "meta-llama/Llama-3.1-8B-Instruct"
+ ]
+
+ providers = await get_available_providers(token="test_token")
+
+ assert "auto" in providers
+ assert len(providers) > 1
+
+ @pytest.mark.asyncio
+ async def test_get_providers_fallback_to_known(self) -> None:
+ """Should fall back to known providers if discovery fails."""
+ with patch("src.utils.hf_model_validator.HfApi") as mock_api_class:
+ mock_api = MagicMock()
+ mock_api_class.return_value = mock_api
+ mock_api.model_info.side_effect = Exception("API error")
+
+ with patch("src.utils.hf_model_validator.settings") as mock_settings:
+ mock_settings.get_hf_fallback_models_list.return_value = [
+ "meta-llama/Llama-3.1-8B-Instruct"
+ ]
+
+ providers = await get_available_providers(token="test_token")
+
+ # Should return known providers as fallback
+ assert "auto" in providers
+ assert len(providers) > 0
+
+ @pytest.mark.asyncio
+ async def test_get_providers_no_token(self) -> None:
+ """Should work without token."""
+ with patch("src.utils.hf_model_validator.HfApi") as mock_api_class:
+ mock_api = MagicMock()
+ mock_api_class.return_value = mock_api
+ mock_api.model_info.side_effect = Exception("API error")
+
+ with patch("src.utils.hf_model_validator.settings") as mock_settings:
+ mock_settings.get_hf_fallback_models_list.return_value = [
+ "meta-llama/Llama-3.1-8B-Instruct"
+ ]
+
+ providers = await get_available_providers(token=None)
+
+ # Should return known providers as fallback
+ assert "auto" in providers
+
+
+class TestGetAvailableModels:
+ """Tests for get_available_models function."""
+
+ @pytest.mark.asyncio
+ async def test_get_models_with_token(self) -> None:
+ """Should fetch models with token."""
+ with patch("src.utils.hf_model_validator.HfApi") as mock_api_class:
+ mock_api = MagicMock()
+ mock_api_class.return_value = mock_api
+
+ # Mock list_models to return model objects
+ mock_model1 = MagicMock()
+ mock_model1.id = "model1"
+ mock_model2 = MagicMock()
+ mock_model2.id = "model2"
+ mock_api.list_models.return_value = [mock_model1, mock_model2]
+
+ models = await get_available_models(token="test_token", limit=10)
+
+ assert len(models) == 2
+ assert "model1" in models
+ assert "model2" in models
+ mock_api.list_models.assert_called_once()
+
+ @pytest.mark.asyncio
+ async def test_get_models_with_provider_filter(self) -> None:
+ """Should filter models by provider."""
+ with patch("src.utils.hf_model_validator.HfApi") as mock_api_class:
+ mock_api = MagicMock()
+ mock_api_class.return_value = mock_api
+
+ mock_model = MagicMock()
+ mock_model.id = "model1"
+ mock_api.list_models.return_value = [mock_model]
+
+ models = await get_available_models(
+ token="test_token",
+ inference_provider="nebius",
+ limit=10,
+ )
+
+ # Check that inference_provider was passed to list_models
+ call_kwargs = mock_api.list_models.call_args[1]
+ assert call_kwargs.get("inference_provider") == "nebius"
+
+ @pytest.mark.asyncio
+ async def test_get_models_fallback_on_error(self) -> None:
+ """Should return fallback models on error."""
+ with patch("src.utils.hf_model_validator.HfApi") as mock_api_class:
+ mock_api = MagicMock()
+ mock_api_class.return_value = mock_api
+ mock_api.list_models.side_effect = Exception("API error")
+
+ models = await get_available_models(token="test_token", limit=10)
+
+ # Should return fallback models
+ assert len(models) > 0
+ assert "meta-llama/Llama-3.1-8B-Instruct" in models
+
+
+class TestValidateModelProviderCombination:
+ """Tests for validate_model_provider_combination function."""
+
+ @pytest.mark.asyncio
+ async def test_validate_auto_provider(self) -> None:
+ """Should always validate 'auto' provider."""
+ is_valid, error_msg = await validate_model_provider_combination(
+ model_id="test-model",
+ provider="auto",
+ token="test_token",
+ )
+
+ assert is_valid is True
+ assert error_msg is None
+
+ @pytest.mark.asyncio
+ async def test_validate_none_provider(self) -> None:
+ """Should validate None provider as auto."""
+ is_valid, error_msg = await validate_model_provider_combination(
+ model_id="test-model",
+ provider=None,
+ token="test_token",
+ )
+
+ assert is_valid is True
+ assert error_msg is None
+
+ @pytest.mark.asyncio
+ async def test_validate_valid_combination(self) -> None:
+ """Should validate valid model/provider combination."""
+ with patch("src.utils.hf_model_validator.HfApi") as mock_api_class:
+ mock_api = MagicMock()
+ mock_api_class.return_value = mock_api
+
+ # Mock model_info with provider mapping
+ mock_model_info = MagicMock()
+ mock_model_info.inference_provider_mapping = {
+ "nebius": MagicMock(),
+ "hf-inference": MagicMock(),
+ }
+ mock_api.model_info.return_value = mock_model_info
+
+ is_valid, error_msg = await validate_model_provider_combination(
+ model_id="test-model",
+ provider="nebius",
+ token="test_token",
+ )
+
+ assert is_valid is True
+ assert error_msg is None
+
+ @pytest.mark.asyncio
+ async def test_validate_invalid_combination(self) -> None:
+ """Should reject invalid model/provider combination."""
+ with patch("src.utils.hf_model_validator.HfApi") as mock_api_class:
+ mock_api = MagicMock()
+ mock_api_class.return_value = mock_api
+
+ # Mock model_info with provider mapping (without requested provider)
+ mock_model_info = MagicMock()
+ mock_model_info.inference_provider_mapping = {
+ "hf-inference": MagicMock(),
+ }
+ mock_api.model_info.return_value = mock_model_info
+
+ is_valid, error_msg = await validate_model_provider_combination(
+ model_id="test-model",
+ provider="nebius",
+ token="test_token",
+ )
+
+ assert is_valid is False
+ assert error_msg is not None
+ assert "nebius" in error_msg
+
+ @pytest.mark.asyncio
+ async def test_validate_fireworks_variants(self) -> None:
+ """Should handle fireworks/fireworks-ai name variants."""
+ with patch("src.utils.hf_model_validator.HfApi") as mock_api_class:
+ mock_api = MagicMock()
+ mock_api_class.return_value = mock_api
+
+ # Mock model_info with fireworks-ai in mapping
+ mock_model_info = MagicMock()
+ mock_model_info.inference_provider_mapping = {
+ "fireworks-ai": MagicMock(),
+ }
+ mock_api.model_info.return_value = mock_model_info
+
+ # Should accept "fireworks" when mapping has "fireworks-ai"
+ is_valid, error_msg = await validate_model_provider_combination(
+ model_id="test-model",
+ provider="fireworks",
+ token="test_token",
+ )
+
+ assert is_valid is True
+ assert error_msg is None
+
+ @pytest.mark.asyncio
+ async def test_validate_graceful_on_error(self) -> None:
+ """Should return valid on error (graceful degradation)."""
+ with patch("src.utils.hf_model_validator.HfApi") as mock_api_class:
+ mock_api = MagicMock()
+ mock_api_class.return_value = mock_api
+ mock_api.model_info.side_effect = Exception("API error")
+
+ is_valid, error_msg = await validate_model_provider_combination(
+ model_id="test-model",
+ provider="nebius",
+ token="test_token",
+ )
+
+ # Should return True to allow actual request to determine validity
+ assert is_valid is True
+
+
+class TestGetModelsForProvider:
+ """Tests for get_models_for_provider function."""
+
+ @pytest.mark.asyncio
+ async def test_get_models_for_provider(self) -> None:
+ """Should get models for specific provider."""
+ with patch("src.utils.hf_model_validator.get_available_models") as mock_get_models:
+ mock_get_models.return_value = ["model1", "model2"]
+
+ models = await get_models_for_provider(
+ provider="nebius",
+ token="test_token",
+ limit=10,
+ )
+
+ assert len(models) == 2
+ mock_get_models.assert_called_once_with(
+ token="test_token",
+ task="text-generation",
+ limit=10,
+ inference_provider="nebius",
+ )
+
+ @pytest.mark.asyncio
+ async def test_get_models_normalize_fireworks(self) -> None:
+ """Should normalize 'fireworks' to 'fireworks-ai'."""
+ with patch("src.utils.hf_model_validator.get_available_models") as mock_get_models:
+ mock_get_models.return_value = ["model1"]
+
+ models = await get_models_for_provider(
+ provider="fireworks",
+ token="test_token",
+ )
+
+ # Should call with "fireworks-ai" not "fireworks"
+ call_kwargs = mock_get_models.call_args[1]
+ assert call_kwargs["inference_provider"] == "fireworks-ai"
+
+
+class TestValidateOAuthToken:
+ """Tests for validate_oauth_token function."""
+
+ @pytest.mark.asyncio
+ async def test_validate_none_token(self) -> None:
+ """Should return invalid for None token."""
+ result = await validate_oauth_token(None)
+
+ assert result["is_valid"] is False
+ assert result["error"] == "No token provided"
+
+ @pytest.mark.asyncio
+ async def test_validate_invalid_format(self) -> None:
+ """Should return invalid for malformed token."""
+ result = await validate_oauth_token("short")
+
+ assert result["is_valid"] is False
+ assert "Invalid token format" in result["error"]
+
+ @pytest.mark.asyncio
+ async def test_validate_valid_token(self) -> None:
+ """Should validate valid token and return resources."""
+ with patch("src.utils.hf_model_validator.HfApi") as mock_api_class:
+ mock_api = MagicMock()
+ mock_api_class.return_value = mock_api
+
+ # Mock whoami to return user info
+ mock_api.whoami.return_value = {"name": "testuser", "fullname": "Test User"}
+
+ # Mock get_available_models and get_available_providers
+ with patch("src.utils.hf_model_validator.get_available_models") as mock_get_models, \
+ patch("src.utils.hf_model_validator.get_available_providers") as mock_get_providers:
+ mock_get_models.return_value = ["model1", "model2"]
+ mock_get_providers.return_value = ["auto", "nebius"]
+
+ result = await validate_oauth_token("hf_valid_token_123")
+
+ assert result["is_valid"] is True
+ assert result["username"] == "testuser"
+ assert result["has_inference_api_scope"] is True
+ assert len(result["available_models"]) == 2
+ assert len(result["available_providers"]) == 2
+
+ @pytest.mark.asyncio
+ async def test_validate_token_without_scope(self) -> None:
+ """Should detect missing inference-api scope."""
+ with patch("src.utils.hf_model_validator.HfApi") as mock_api_class:
+ mock_api = MagicMock()
+ mock_api_class.return_value = mock_api
+ mock_api.whoami.return_value = {"name": "testuser"}
+
+ # Mock get_available_models to fail (no scope)
+ with patch("src.utils.hf_model_validator.get_available_models") as mock_get_models, \
+ patch("src.utils.hf_model_validator.get_available_providers") as mock_get_providers:
+ mock_get_models.side_effect = Exception("403 Forbidden")
+ mock_get_providers.return_value = ["auto"]
+
+ result = await validate_oauth_token("hf_token_without_scope")
+
+ assert result["is_valid"] is True # Token is valid
+ assert result["has_inference_api_scope"] is False # But no scope
+ assert "inference-api scope" in result["error"]
+
+ @pytest.mark.asyncio
+ async def test_validate_invalid_token(self) -> None:
+ """Should return invalid for token that fails authentication."""
+ with patch("src.utils.hf_model_validator.HfApi") as mock_api_class:
+ mock_api = MagicMock()
+ mock_api_class.return_value = mock_api
+ mock_api.whoami.side_effect = Exception("401 Unauthorized")
+
+ result = await validate_oauth_token("hf_invalid_token")
+
+ assert result["is_valid"] is False
+ assert "could not authenticate" in result["error"]
+
+
+
+
diff --git a/tests/unit/utils/test_llm_factory_token_validation.py b/tests/unit/utils/test_llm_factory_token_validation.py
new file mode 100644
index 00000000..e663958f
--- /dev/null
+++ b/tests/unit/utils/test_llm_factory_token_validation.py
@@ -0,0 +1,142 @@
+"""Unit tests for token validation in llm_factory.py."""
+
+from unittest.mock import MagicMock, patch
+
+import pytest
+
+from src.utils.exceptions import ConfigurationError
+
+
+class TestGetPydanticAiModelTokenValidation:
+ """Tests for get_pydantic_ai_model function with token validation."""
+
+ @patch("src.utils.llm_factory.settings")
+ @patch("src.utils.hf_error_handler.log_token_info")
+ @patch("src.utils.hf_error_handler.validate_hf_token")
+ @patch("pydantic_ai.providers.huggingface.HuggingFaceProvider")
+ @patch("pydantic_ai.models.huggingface.HuggingFaceModel")
+ def test_validates_oauth_token(
+ self,
+ mock_model_class,
+ mock_provider_class,
+ mock_validate,
+ mock_log,
+ mock_settings,
+ ) -> None:
+ """Should validate and log OAuth token when provided."""
+ mock_settings.hf_token = None
+ mock_settings.huggingface_api_key = None
+ mock_settings.huggingface_model = "test-model"
+ mock_validate.return_value = (True, None)
+ mock_model_class.return_value = MagicMock()
+
+ from src.utils.llm_factory import get_pydantic_ai_model
+
+ get_pydantic_ai_model(oauth_token="hf_test_token")
+
+ # Should log token info
+ mock_log.assert_called_once_with("hf_test_token", context="get_pydantic_ai_model")
+ # Should validate token
+ mock_validate.assert_called_once_with("hf_test_token")
+
+ @patch("src.utils.llm_factory.settings")
+ @patch("src.utils.hf_error_handler.log_token_info")
+ @patch("src.utils.hf_error_handler.validate_hf_token")
+ @patch("src.utils.llm_factory.logger")
+ @patch("pydantic_ai.providers.huggingface.HuggingFaceProvider")
+ @patch("pydantic_ai.models.huggingface.HuggingFaceModel")
+ def test_warns_on_invalid_token(
+ self,
+ mock_model_class,
+ mock_provider_class,
+ mock_logger,
+ mock_validate,
+ mock_log,
+ mock_settings,
+ ) -> None:
+ """Should warn when token validation fails."""
+ mock_settings.hf_token = None
+ mock_settings.huggingface_api_key = None
+ mock_settings.huggingface_model = "test-model"
+ mock_validate.return_value = (False, "Token too short")
+ mock_model_class.return_value = MagicMock()
+
+ from src.utils.llm_factory import get_pydantic_ai_model
+
+ get_pydantic_ai_model(oauth_token="short")
+
+ # Should warn about invalid token
+ warning_calls = [
+ call
+ for call in mock_logger.warning.call_args_list
+ if "Token validation failed" in str(call)
+ ]
+ assert len(warning_calls) > 0
+
+ @patch("src.utils.llm_factory.settings")
+ @patch("src.utils.hf_error_handler.log_token_info")
+ @patch("src.utils.hf_error_handler.validate_hf_token")
+ @patch("pydantic_ai.providers.huggingface.HuggingFaceProvider")
+ @patch("pydantic_ai.models.huggingface.HuggingFaceModel")
+ def test_uses_env_token_when_oauth_not_provided(
+ self,
+ mock_model_class,
+ mock_provider_class,
+ mock_validate,
+ mock_log,
+ mock_settings,
+ ) -> None:
+ """Should use environment token when OAuth token not provided."""
+ mock_settings.hf_token = "hf_env_token"
+ mock_settings.huggingface_api_key = None
+ mock_settings.huggingface_model = "test-model"
+ mock_validate.return_value = (True, None)
+ mock_model_class.return_value = MagicMock()
+
+ from src.utils.llm_factory import get_pydantic_ai_model
+
+ get_pydantic_ai_model(oauth_token=None)
+
+ # Should log and validate env token
+ mock_log.assert_called_once_with("hf_env_token", context="get_pydantic_ai_model")
+ mock_validate.assert_called_once_with("hf_env_token")
+
+ @patch("src.utils.llm_factory.settings")
+ def test_raises_when_no_token_available(self, mock_settings) -> None:
+ """Should raise ConfigurationError when no token is available."""
+ mock_settings.hf_token = None
+ mock_settings.huggingface_api_key = None
+
+ from src.utils.llm_factory import get_pydantic_ai_model
+
+ with pytest.raises(ConfigurationError, match="HuggingFace token required"):
+ get_pydantic_ai_model(oauth_token=None)
+
+ @patch("src.utils.llm_factory.settings")
+ @patch("src.utils.hf_error_handler.log_token_info")
+ @patch("src.utils.hf_error_handler.validate_hf_token")
+ @patch("pydantic_ai.providers.huggingface.HuggingFaceProvider")
+ @patch("pydantic_ai.models.huggingface.HuggingFaceModel")
+ def test_oauth_token_priority_over_env(
+ self,
+ mock_model_class,
+ mock_provider_class,
+ mock_validate,
+ mock_log,
+ mock_settings,
+ ) -> None:
+ """Should prefer OAuth token over environment token."""
+ mock_settings.hf_token = "hf_env_token"
+ mock_settings.huggingface_api_key = None
+ mock_settings.huggingface_model = "test-model"
+ mock_validate.return_value = (True, None)
+ mock_model_class.return_value = MagicMock()
+
+ from src.utils.llm_factory import get_pydantic_ai_model
+
+ get_pydantic_ai_model(oauth_token="hf_oauth_token")
+
+ # Should use OAuth token, not env token
+ mock_log.assert_called_once_with("hf_oauth_token", context="get_pydantic_ai_model")
+ mock_validate.assert_called_once_with("hf_oauth_token")
+
diff --git a/tests/unit/utils/test_message_history.py b/tests/unit/utils/test_message_history.py
new file mode 100644
index 00000000..19344e58
--- /dev/null
+++ b/tests/unit/utils/test_message_history.py
@@ -0,0 +1,151 @@
+"""Unit tests for message history utilities."""
+
+import pytest
+
+pytestmark = pytest.mark.unit
+
+from src.utils.message_history import (
+ convert_gradio_to_message_history,
+ create_relevance_processor,
+ create_truncation_processor,
+ message_history_to_string,
+)
+
+
+def test_convert_gradio_to_message_history_empty():
+ """Test conversion with empty history."""
+ result = convert_gradio_to_message_history([])
+ assert result == []
+
+
+def test_convert_gradio_to_message_history_single_turn():
+ """Test conversion with a single turn."""
+ gradio_history = [
+ {"role": "user", "content": "What is AI?"},
+ {"role": "assistant", "content": "AI is artificial intelligence."},
+ ]
+ result = convert_gradio_to_message_history(gradio_history)
+ assert len(result) == 2
+
+
+def test_convert_gradio_to_message_history_multiple_turns():
+ """Test conversion with multiple turns."""
+ gradio_history = [
+ {"role": "user", "content": "What is AI?"},
+ {"role": "assistant", "content": "AI is artificial intelligence."},
+ {"role": "user", "content": "Tell me more"},
+ {"role": "assistant", "content": "AI includes machine learning..."},
+ ]
+ result = convert_gradio_to_message_history(gradio_history)
+ assert len(result) == 4
+
+
+def test_convert_gradio_to_message_history_max_messages():
+ """Test conversion with max_messages limit."""
+ gradio_history = []
+ for i in range(15): # Create 15 turns
+ gradio_history.append({"role": "user", "content": f"Message {i}"})
+ gradio_history.append({"role": "assistant", "content": f"Response {i}"})
+
+ result = convert_gradio_to_message_history(gradio_history, max_messages=10)
+ # Should only include most recent 10 messages
+ assert len(result) <= 10
+
+
+def test_convert_gradio_to_message_history_filters_invalid():
+ """Test that invalid entries are filtered out."""
+ gradio_history = [
+ {"role": "user", "content": "Valid message"},
+ {"role": "system", "content": "Should be filtered"},
+ {"role": "assistant", "content": ""}, # Empty content should be filtered
+ {"role": "assistant", "content": "Valid response"},
+ ]
+ result = convert_gradio_to_message_history(gradio_history)
+ # Should only have 2 valid messages (user + assistant)
+ assert len(result) == 2
+
+
+def test_message_history_to_string_empty():
+ """Test string conversion with empty history."""
+ result = message_history_to_string([])
+ assert result == ""
+
+
+def test_message_history_to_string_format():
+ """Test string conversion format."""
+ # Create mock message history
+ try:
+ from pydantic_ai import ModelRequest, ModelResponse
+ from pydantic_ai.messages import TextPart, UserPromptPart
+
+ messages = [
+ ModelRequest(parts=[UserPromptPart(content="Question 1")]),
+ ModelResponse(parts=[TextPart(content="Answer 1")]),
+ ]
+ result = message_history_to_string(messages)
+ assert "PREVIOUS CONVERSATION" in result
+ assert "User:" in result
+ assert "Assistant:" in result
+ except ImportError:
+ # Skip if pydantic_ai not available
+ pytest.skip("pydantic_ai not available")
+
+
+def test_message_history_to_string_max_messages():
+ """Test string conversion with max_messages limit."""
+ try:
+ from pydantic_ai import ModelRequest, ModelResponse
+ from pydantic_ai.messages import TextPart, UserPromptPart
+
+ messages = []
+ for i in range(10): # Create 10 turns
+ messages.append(ModelRequest(parts=[UserPromptPart(content=f"Question {i}")]))
+ messages.append(ModelResponse(parts=[TextPart(content=f"Answer {i}")]))
+
+ result = message_history_to_string(messages, max_messages=3)
+ # Should only include most recent 3 messages (1.5 turns)
+ assert result != ""
+ except ImportError:
+ pytest.skip("pydantic_ai not available")
+
+
+def test_create_truncation_processor():
+ """Test truncation processor factory."""
+ processor = create_truncation_processor(max_messages=5)
+ assert callable(processor)
+
+ try:
+ from pydantic_ai import ModelRequest
+ from pydantic_ai.messages import UserPromptPart
+
+ messages = [
+ ModelRequest(parts=[UserPromptPart(content=f"Message {i}")])
+ for i in range(10)
+ ]
+ result = processor(messages)
+ assert len(result) == 5
+ except ImportError:
+ pytest.skip("pydantic_ai not available")
+
+
+def test_create_relevance_processor():
+ """Test relevance processor factory."""
+ processor = create_relevance_processor(min_length=10)
+ assert callable(processor)
+
+ try:
+ from pydantic_ai import ModelRequest, ModelResponse
+ from pydantic_ai.messages import TextPart, UserPromptPart
+
+ messages = [
+ ModelRequest(parts=[UserPromptPart(content="Short")]), # Too short
+ ModelRequest(parts=[UserPromptPart(content="This is a longer message")]), # Valid
+ ModelResponse(parts=[TextPart(content="OK")]), # Too short
+ ModelResponse(parts=[TextPart(content="This is a valid response")]), # Valid
+ ]
+ result = processor(messages)
+ # Should only keep messages with length >= 10
+ assert len(result) == 2
+ except ImportError:
+ pytest.skip("pydantic_ai not available")
+