diff --git a/config.yaml.example b/config.yaml.example index 10fbec21..20dddf80 100644 --- a/config.yaml.example +++ b/config.yaml.example @@ -57,15 +57,16 @@ tools: base_class: path.to.my.tools.CustomTool my_other_tool: base_class: "name_of_tool_class_in_registry" - # Search tools: configure Tavily API key and search limits per tool + # Search tools: configure search provider and API keys per tool # (can be overridden per-agent in tools list) web_search_tool: - tavily_api_key: "your-tavily-api-key-here" # Tavily API key (get at tavily.com) - tavily_api_base_url: "https://api.tavily.com" # Tavily API URL + engine: "tavily" # Search engine: "tavily" (default), "brave", or "perplexity" + api_key: "your-search-api-key-here" # API key for the selected engine + # api_base_url: "https://custom-url" # Optional, uses engine default max_results: 12 max_searches: 6 extract_page_content_tool: - tavily_api_key: "your-tavily-api-key-here" # Same Tavily API key + tavily_api_key: "your-tavily-api-key-here" # Tavily API key (Tavily-only feature) tavily_api_base_url: "https://api.tavily.com" content_limit: 2000 diff --git a/examples/sgr_deep_research/config.yaml.example b/examples/sgr_deep_research/config.yaml.example index d6e765a0..d26ebab2 100644 --- a/examples/sgr_deep_research/config.yaml.example +++ b/examples/sgr_deep_research/config.yaml.example @@ -30,13 +30,12 @@ tools: # Core tools (base_class defaults to sgr_agent_core.tools.*) # Search tools: configure Tavily API key and search limits per tool web_search_tool: - tavily_api_key: "your-tavily-api-key-here" # Tavily API key (get at tavily.com) - tavily_api_base_url: "https://api.tavily.com" # Tavily API URL + engine: "tavily" # Search engine: "tavily" (default), "brave", or "perplexity" + api_key: "your-tavily-api-key-here" # API key for the selected engine max_searches: 4 # Max search operations max_results: 10 # Max results in search query extract_page_content_tool: - tavily_api_key: "your-tavily-api-key-here" # Same Tavily API key - tavily_api_base_url: "https://api.tavily.com" + tavily_api_key: "your-tavily-api-key-here" # Tavily API key (Tavily extract only) content_limit: 1500 # Content char limit per source create_report_tool: # base_class defaults to sgr_agent_core.tools.CreateReportTool diff --git a/examples/sgr_deep_research_without_reporting/config.yaml.example b/examples/sgr_deep_research_without_reporting/config.yaml.example index d84ba528..ed8e29d9 100644 --- a/examples/sgr_deep_research_without_reporting/config.yaml.example +++ b/examples/sgr_deep_research_without_reporting/config.yaml.example @@ -29,13 +29,12 @@ tools: # Core tools (base_class defaults to sgr_agent_core.tools.*) # Search tools: configure Tavily API key and search limits per tool web_search_tool: - tavily_api_key: "your-tavily-api-key-here" # Tavily API key (get at tavily.com) - tavily_api_base_url: "https://api.tavily.com" # Tavily API URL + engine: "tavily" # Search engine: "tavily" (default), "brave", or "perplexity" + api_key: "your-tavily-api-key-here" # API key for the selected engine max_searches: 4 # Max search operations max_results: 10 # Max results in search query extract_page_content_tool: - tavily_api_key: "your-tavily-api-key-here" # Same Tavily API key - tavily_api_base_url: "https://api.tavily.com" + tavily_api_key: "your-tavily-api-key-here" # Tavily API key (Tavily extract only) content_limit: 1500 # Content char limit per source final_answer_tool: # base_class defaults to sgr_agent_core.tools.FinalAnswerTool diff --git a/sgr_agent_core/__init__.py b/sgr_agent_core/__init__.py index 5f3a9198..6778bd87 100644 --- a/sgr_agent_core/__init__.py +++ b/sgr_agent_core/__init__.py @@ -16,7 +16,6 @@ ExecutionConfig, LLMConfig, PromptsConfig, - SearchConfig, ) from sgr_agent_core.agent_factory import AgentFactory from sgr_agent_core.agents import * # noqa: F403 @@ -30,7 +29,12 @@ SourceData, ) from sgr_agent_core.next_step_tool import NextStepToolsBuilder, NextStepToolStub -from sgr_agent_core.services import AgentRegistry, MCP2ToolConverter, PromptLoader, ToolRegistry +from sgr_agent_core.services import ( + AgentRegistry, + MCP2ToolConverter, + PromptLoader, + ToolRegistry, +) from sgr_agent_core.tools import * # noqa: F403 __all__ = [ @@ -50,25 +54,19 @@ "SourceData", # Services "AgentRegistry", - "ToolRegistry", - "PromptLoader", "MCP2ToolConverter", + "PromptLoader", + "ToolRegistry", # Configuration "AgentConfig", "AgentDefinition", "LLMConfig", "PromptsConfig", - "SearchConfig", "ExecutionConfig", "GlobalConfig", # Next step tools "NextStepToolStub", "NextStepToolsBuilder", - # Models - "AgentStatesEnum", - "AgentContext", - "SearchResult", - "SourceData", # Factory "AgentFactory", ] diff --git a/sgr_agent_core/agent_definition.py b/sgr_agent_core/agent_definition.py index 35e6f718..43971826 100644 --- a/sgr_agent_core/agent_definition.py +++ b/sgr_agent_core/agent_definition.py @@ -62,15 +62,6 @@ def to_openai_client_kwargs(self) -> dict[str, Any]: return self.model_dump(exclude={"api_key", "base_url", "proxy"}) -class SearchConfig(BaseModel, extra="allow"): - tavily_api_key: str | None = Field(default=None, description="Tavily API key") - tavily_api_base_url: str = Field(default="https://api.tavily.com", description="Tavily API base URL") - - max_searches: int = Field(default=4, ge=0, description="Maximum number of searches") - max_results: int = Field(default=10, ge=1, description="Maximum number of search results") - content_limit: int = Field(default=3500, gt=0, description="Content character limit per source") - - class PromptsConfig(BaseModel, extra="allow"): system_prompt_file: FilePath | None = Field( default=os.path.join(os.path.dirname(__file__), "prompts/system_prompt.txt"), diff --git a/sgr_agent_core/services/__init__.py b/sgr_agent_core/services/__init__.py index b96a69b4..87eec181 100644 --- a/sgr_agent_core/services/__init__.py +++ b/sgr_agent_core/services/__init__.py @@ -2,12 +2,14 @@ from sgr_agent_core.services.mcp_service import MCP2ToolConverter from sgr_agent_core.services.prompt_loader import PromptLoader -from sgr_agent_core.services.registry import AgentRegistry, StreamingGeneratorRegistry, ToolRegistry -from sgr_agent_core.services.tavily_search import TavilySearchService +from sgr_agent_core.services.registry import ( + AgentRegistry, + StreamingGeneratorRegistry, + ToolRegistry, +) from sgr_agent_core.services.tool_instantiator import ToolInstantiator __all__ = [ - "TavilySearchService", "MCP2ToolConverter", "ToolRegistry", "StreamingGeneratorRegistry", diff --git a/sgr_agent_core/services/tavily_search.py b/sgr_agent_core/services/tavily_search.py deleted file mode 100644 index 83d74437..00000000 --- a/sgr_agent_core/services/tavily_search.py +++ /dev/null @@ -1,107 +0,0 @@ -import logging - -from tavily import AsyncTavilyClient - -from sgr_agent_core.agent_definition import SearchConfig -from sgr_agent_core.models import SourceData - -logger = logging.getLogger(__name__) - - -class TavilySearchService: - def __init__(self, search_config: SearchConfig): - self._client = AsyncTavilyClient( - api_key=search_config.tavily_api_key, api_base_url=search_config.tavily_api_base_url - ) - self._config = search_config - - @staticmethod - def rearrange_sources(sources: list[SourceData], starting_number=1) -> list[SourceData]: - for i, source in enumerate(sources, starting_number): - source.number = i - return sources - - async def search( - self, - query: str, - max_results: int | None = None, - include_raw_content: bool = True, - ) -> list[SourceData]: - """Perform search through Tavily API and return results with - SourceData. - - Args: - query: Search query - max_results: Maximum number of results (default from config) - include_raw_content: Include raw page content - - Returns: - Tuple with tavily answer and list of SourceData - """ - max_results = max_results or self._config.max_results - logger.info(f"🔍 Tavily search: '{query}' (max_results={max_results})") - - # Execute search through Tavily - response = await self._client.search( - query=query, - max_results=max_results, - include_raw_content=include_raw_content, - ) - - # Convert results to SourceData - sources = self._convert_to_source_data(response) - return sources - - async def extract(self, urls: list[str]) -> list[SourceData]: - """Extract full content from specific URLs using Tavily Extract API. - - Args: - urls: List of URLs to extract content from - - Returns: - List of SourceData with extracted content - """ - logger.info(f"📄 Tavily extract: {len(urls)} URLs") - - response = await self._client.extract(urls=urls) - - sources = [] - for i, result in enumerate(response.get("results", [])): - if not result.get("url"): - continue - - source = SourceData( - number=i, - title=result.get("url", "").split("/")[-1] or "Extracted Content", - url=result.get("url", ""), - snippet="", - full_content=result.get("raw_content", ""), - char_count=len(result.get("raw_content", "")), - ) - sources.append(source) - - failed_urls = response.get("failed_results", []) - if failed_urls: - logger.warning(f"⚠️ Failed to extract {len(failed_urls)} URLs: {failed_urls}") - - return sources - - def _convert_to_source_data(self, response: dict) -> list[SourceData]: - """Convert Tavily response to SourceData list.""" - sources = [] - - for i, result in enumerate(response.get("results", [])): - if not result.get("url", ""): - continue - - source = SourceData( - number=i, - title=result.get("title", ""), - url=result.get("url", ""), - snippet=result.get("content", ""), - ) - if result.get("raw_content", ""): - source.full_content = result["raw_content"] - source.char_count = len(source.full_content) - sources.append(source) - return sources diff --git a/sgr_agent_core/tools/__init__.py b/sgr_agent_core/tools/__init__.py index 5cdc14dc..258f63eb 100644 --- a/sgr_agent_core/tools/__init__.py +++ b/sgr_agent_core/tools/__init__.py @@ -8,11 +8,11 @@ from sgr_agent_core.tools.answer_tool import AnswerTool from sgr_agent_core.tools.clarification_tool import ClarificationTool from sgr_agent_core.tools.create_report_tool import CreateReportTool -from sgr_agent_core.tools.extract_page_content_tool import ExtractPageContentTool +from sgr_agent_core.tools.extract_page_content_tool import ExtractPageContentConfig, ExtractPageContentTool from sgr_agent_core.tools.final_answer_tool import FinalAnswerTool from sgr_agent_core.tools.generate_plan_tool import GeneratePlanTool from sgr_agent_core.tools.reasoning_tool import ReasoningTool -from sgr_agent_core.tools.web_search_tool import WebSearchTool +from sgr_agent_core.tools.web_search_tool import WebSearchConfig, WebSearchTool __all__ = [ # Base classes @@ -24,16 +24,15 @@ "ToolNameSelectorStub", "NextStepToolsBuilder", # Individual tools - "ClarificationTool", - "GeneratePlanTool", - "WebSearchTool", - "ExtractPageContentTool", "AdaptPlanTool", - "CreateReportTool", "AnswerTool", + "ClarificationTool", + "CreateReportTool", + "ExtractPageContentConfig", + "ExtractPageContentTool", "FinalAnswerTool", + "GeneratePlanTool", "ReasoningTool", - # Tool lists - "NextStepToolStub", - "NextStepToolsBuilder", + "WebSearchConfig", + "WebSearchTool", ] diff --git a/sgr_agent_core/tools/extract_page_content_tool.py b/sgr_agent_core/tools/extract_page_content_tool.py index fdc6ff32..1f883ee4 100644 --- a/sgr_agent_core/tools/extract_page_content_tool.py +++ b/sgr_agent_core/tools/extract_page_content_tool.py @@ -3,11 +3,11 @@ import logging from typing import TYPE_CHECKING, Any -from pydantic import Field +from pydantic import BaseModel, Field, model_validator +from tavily import AsyncTavilyClient -from sgr_agent_core.agent_definition import SearchConfig from sgr_agent_core.base_tool import BaseTool -from sgr_agent_core.services import TavilySearchService +from sgr_agent_core.models import SourceData if TYPE_CHECKING: from sgr_agent_core.agent_definition import AgentConfig @@ -17,13 +17,31 @@ logger.setLevel(logging.INFO) +class ExtractPageContentConfig(BaseModel, extra="allow"): + """Configuration for ExtractPageContentTool (Tavily Extract API).""" + + tavily_api_key: str | None = Field(default=None, description="Tavily API key") + tavily_api_base_url: str = Field(default="https://api.tavily.com", description="Tavily API base URL") + content_limit: int = Field(default=3500, gt=0, description="Content character limit per source") + + @model_validator(mode="after") + def validate_api_key(self): + if not self.tavily_api_key: + raise ValueError( + "tavily_api_key is required for ExtractPageContentTool." + " Tavily is the only provider that supports content extraction." + ) + return self + + class ExtractPageContentTool(BaseTool): """Extract full detailed content from specific web pages. - Use for: Getting complete page content from URLs found in web search Returns: - Full page content in readable format (via Tavily Extract API) - Best for: Deep analysis of specific pages, extracting structured data - Usage: Call after WebSearchTool to get detailed information from promising URLs + Use for: Getting complete page content from URLs found in web search. + Returns: Full page content in readable format (via Tavily Extract API). + Best for: Deep analysis of specific pages, extracting structured data. + + Usage: Call after WebSearchTool to get detailed information from promising URLs. CRITICAL WARNINGS: - Extracted pages may show data from DIFFERENT years/time periods than asked @@ -33,18 +51,51 @@ class ExtractPageContentTool(BaseTool): - For date/number questions, cross-check extracted values with search snippets """ - config_model = SearchConfig + config_model = ExtractPageContentConfig reasoning: str = Field(description="Why extract these specific pages") urls: list[str] = Field(description="List of URLs to extract full content from", min_length=1, max_length=5) + @staticmethod + async def _extract(config: ExtractPageContentConfig, urls: list[str]) -> list[SourceData]: + """Extract full content from URLs via Tavily Extract API.""" + logger.info(f"Tavily extract: {len(urls)} URLs") + + client = AsyncTavilyClient(api_key=config.tavily_api_key, api_base_url=config.tavily_api_base_url) + response = await client.extract(urls=urls) + + sources = [] + for i, result in enumerate(response.get("results", [])): + if not result.get("url"): + continue + + url = result.get("url", "") + raw_content = result.get("raw_content", "") + source = SourceData( + number=i, + title=url.split("/")[-1] or "Extracted Content", + url=url, + snippet="", + full_content=raw_content, + char_count=len(raw_content), + ) + sources.append(source) + + failed_urls = response.get("failed_results", []) + if failed_urls: + logger.warning(f"Failed to extract {len(failed_urls)} URLs: {failed_urls}") + + return sources + async def __call__(self, context: AgentContext, config: AgentConfig, **kwargs: Any) -> str: """Extract full content from specified URLs.""" - search_config = SearchConfig(**kwargs) - logger.info(f"📄 Extracting content from {len(self.urls)} URLs") + try: + extract_config = ExtractPageContentConfig(**kwargs) + except ValueError as e: + return f"Error: {e}" + logger.info(f"Extracting content from {len(self.urls)} URLs") - self._search_service = TavilySearchService(search_config) - sources = await self._search_service.extract(urls=self.urls) + sources = await self._extract(extract_config, urls=self.urls) # Update existing sources instead of overwriting for source in sources: @@ -62,10 +113,10 @@ async def __call__(self, context: AgentContext, config: AgentConfig, **kwargs: A # Format results using sources from context (to get correct numbers) for url in self.urls: - if url in context.sources: - source = context.sources[url] + source = context.sources.get(url) + if source is not None: if source.full_content: - content_preview = source.full_content[: search_config.content_limit] + content_preview = source.full_content[: extract_config.content_limit] formatted_result += ( f"{str(source)}\n\n**Full Content:**\n" f"{content_preview}\n\n" diff --git a/sgr_agent_core/tools/web_search_tool.py b/sgr_agent_core/tools/web_search_tool.py index 7c992845..cda23415 100644 --- a/sgr_agent_core/tools/web_search_tool.py +++ b/sgr_agent_core/tools/web_search_tool.py @@ -2,29 +2,261 @@ import logging from datetime import datetime -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Literal -from pydantic import Field +import httpx +from pydantic import BaseModel, Field +from tavily import AsyncTavilyClient -from sgr_agent_core.agent_definition import AgentConfig, SearchConfig from sgr_agent_core.base_tool import BaseTool -from sgr_agent_core.models import SearchResult -from sgr_agent_core.services.tavily_search import TavilySearchService +from sgr_agent_core.models import SearchResult, SourceData if TYPE_CHECKING: + from sgr_agent_core.agent_definition import AgentConfig from sgr_agent_core.models import AgentContext logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) +_TAVILY_DEFAULT_URL = "https://api.tavily.com" +_BRAVE_DEFAULT_URL = "https://api.search.brave.com/res/v1/web/search" +_PERPLEXITY_DEFAULT_URL = "https://api.perplexity.ai/search" + +_ENGINE_DEFAULT_URLS: dict[str, str] = { + "tavily": _TAVILY_DEFAULT_URL, + "brave": _BRAVE_DEFAULT_URL, + "perplexity": _PERPLEXITY_DEFAULT_URL, +} + + +class WebSearchConfig(BaseModel, extra="allow"): + """Configuration for WebSearchTool. + + Defines the search engine, credentials, and limits. + """ + + engine: Literal["tavily", "brave", "perplexity"] = Field( + default="tavily", + description="Search engine provider to use", + ) + api_key: str | None = Field(default=None, description="API key for the selected engine") + api_base_url: str | None = Field( + default=None, + description="API base URL for the selected engine (None = engine default)", + ) + max_searches: int = Field(default=4, ge=0, description="Maximum number of searches") + max_results: int = Field(default=10, ge=1, description="Maximum number of search results") + + +# --------------------------------------------------------------------------- +# Tavily +# --------------------------------------------------------------------------- + + +def _convert_tavily_response(response: dict) -> list[SourceData]: + """Convert Tavily API response to SourceData list.""" + sources = [] + for i, result in enumerate(response.get("results", [])): + if not result.get("url", ""): + continue + source = SourceData( + number=i, + title=result.get("title", ""), + url=result.get("url", ""), + snippet=result.get("content", ""), + ) + if result.get("raw_content", ""): + source.full_content = result["raw_content"] + source.char_count = len(source.full_content) + sources.append(source) + return sources + + +async def _search_tavily( + api_key: str, + api_base_url: str, + query: str, + max_results: int, + offset: int, +) -> list[SourceData]: + """Perform search via Tavily API. + + Offset: over-fetch + slice. + """ + fetch_count = max_results + offset if offset > 0 else max_results + logger.info(f"Tavily search: '{query}' (max_results={max_results}, offset={offset})") + + client = AsyncTavilyClient(api_key=api_key, api_base_url=api_base_url) + response = await client.search(query=query, max_results=fetch_count, include_raw_content=False) + + sources = _convert_tavily_response(response) + if offset > 0: + sources = sources[offset:] + return sources[:max_results] + + +# --------------------------------------------------------------------------- +# Brave +# --------------------------------------------------------------------------- + + +def _convert_brave_response(response: dict) -> list[SourceData]: + """Convert Brave Search API response to SourceData list.""" + sources = [] + web_results = response.get("web", {}).get("results", []) + for i, result in enumerate(web_results): + url = result.get("url", "") + if not url: + continue + source = SourceData( + number=i, + title=result.get("title", ""), + url=url, + snippet=result.get("description", ""), + ) + sources.append(source) + return sources + + +async def _search_brave( + api_key: str, + api_base_url: str, + query: str, + max_results: int, + offset: int, +) -> list[SourceData]: + """Perform search via Brave Search API. + + Native offset support. + """ + capped = min(max_results, 20) + logger.info(f"Brave search: '{query}' (max_results={capped}, offset={offset})") + + headers = { + "Accept": "application/json", + "Accept-Encoding": "gzip", + "X-Subscription-Token": api_key, + } + params: dict[str, Any] = {"q": query, "count": capped} + if offset > 0: + params["offset"] = offset + + try: + async with httpx.AsyncClient() as client: + response = await client.get(api_base_url, headers=headers, params=params, timeout=30.0) + response.raise_for_status() + data = response.json() + except httpx.HTTPStatusError as e: + logger.error(f"Brave API HTTP error: {e.response.status_code} — {e.response.text[:200]}") + raise + except httpx.RequestError as e: + logger.error(f"Brave API request error: {e}") + raise + + return _convert_brave_response(data) + + +# --------------------------------------------------------------------------- +# Perplexity +# --------------------------------------------------------------------------- + + +def _convert_perplexity_response(response: dict) -> list[SourceData]: + """Convert Perplexity Search API response to SourceData list.""" + sources = [] + for i, result in enumerate(response.get("results", [])): + url = result.get("url", "") + if not url: + continue + source = SourceData( + number=i, + title=result.get("title", ""), + url=url, + snippet=result.get("snippet", ""), + ) + sources.append(source) + return sources + + +async def _search_perplexity( + api_key: str, + api_base_url: str, + query: str, + max_results: int, + offset: int, +) -> list[SourceData]: + """Perform search via Perplexity API. + + Offset: over-fetch + slice. + """ + fetch_count = max_results + offset if offset > 0 else max_results + logger.info(f"Perplexity search: '{query}' (max_results={max_results}, offset={offset})") + + headers = { + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json", + } + payload: dict[str, Any] = {"query": query, "max_results": fetch_count} + + try: + async with httpx.AsyncClient() as client: + response = await client.post(api_base_url, headers=headers, json=payload, timeout=30.0) + response.raise_for_status() + data = response.json() + except httpx.HTTPStatusError as e: + logger.error(f"Perplexity API HTTP error: {e.response.status_code} — {e.response.text[:200]}") + raise + except httpx.RequestError as e: + logger.error(f"Perplexity API request error: {e}") + raise + + sources = _convert_perplexity_response(data) + if offset > 0: + sources = sources[offset:] + return sources[:max_results] + + +# --------------------------------------------------------------------------- +# Engine handler mapping +# --------------------------------------------------------------------------- + +SearchHandler = Callable[..., Awaitable[list[SourceData]]] + +_ENGINE_HANDLERS: dict[str, SearchHandler] = { + "tavily": _search_tavily, + "brave": _search_brave, + "perplexity": _search_perplexity, +} + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _rearrange_sources(sources: list[SourceData], starting_number: int = 1) -> list[SourceData]: + """Renumber sources sequentially starting from given number.""" + for i, source in enumerate(sources, starting_number): + source.number = i + return sources + + +# --------------------------------------------------------------------------- +# WebSearchTool +# --------------------------------------------------------------------------- + class WebSearchTool(BaseTool): """Search the web for real-time information about any topic. - Use this tool when you need up-to-date information that might not be available in your training data, - or when you need to verify current facts. + + Single search tool with pluggable engine (tavily, brave, perplexity). + Engine is selected via tool config ``engine`` field. + + Use this tool when you need up-to-date information that might not be + available in your training data, or when you need to verify current facts. The search results will include relevant snippets and URLs from web pages. - This is particularly useful for questions about current events, technology updates, - or any topic that requires recent information. + This is particularly useful for questions about current events, technology + updates, or any topic that requires recent information. Use for: Public information, news, market trends, external APIs, general knowledge Returns: Page titles, URLs, and short snippets (100 characters) Best for: Quick overview, finding relevant pages @@ -43,7 +275,7 @@ class WebSearchTool(BaseTool): - If the snippet directly answers the question, you may not need to extract the full page """ - config_model = SearchConfig + config_model = WebSearchConfig reasoning: str = Field(description="Why this search is needed and what to expect") query: str = Field(description="Search query in same language as user request") @@ -51,7 +283,7 @@ class WebSearchTool(BaseTool): description="Maximum results. How much of the web results selection you want to retrieve", default=5, ge=1, - le=10, + le=20, ) offset: int = Field( default=0, @@ -63,25 +295,33 @@ class WebSearchTool(BaseTool): ) async def __call__(self, context: AgentContext, config: AgentConfig, **kwargs: Any) -> str: - """Execute web search using TavilySearchService.""" - search_config = SearchConfig(**kwargs) - logger.info(f"🔍 Search query: '{self.query}'") - self._search_service = TavilySearchService(search_config) + """Execute web search using the configured search engine.""" + search_config = WebSearchConfig(**kwargs) + + engine = search_config.engine + logger.info(f"Search query: '{self.query}' (engine={engine})") + + handler = _ENGINE_HANDLERS.get(engine) + if handler is None: + raise ValueError(f"Unsupported search engine: {engine}") - max_results_limit = search_config.max_results - effective_limit = min(self.max_results, max_results_limit) - fetch_count = effective_limit + self.offset + api_key = search_config.api_key + if not api_key: + raise ValueError(f"api_key is required for engine '{engine}'") - sources = await self._search_service.search( + api_base_url = search_config.api_base_url or _ENGINE_DEFAULT_URLS[engine] + + effective_limit = min(self.max_results, search_config.max_results) + + sources = await handler( + api_key=api_key, + api_base_url=api_base_url, query=self.query, - max_results=fetch_count, - include_raw_content=False, + max_results=effective_limit, + offset=self.offset, ) - if self.offset: - sources = sources[self.offset :] - - sources = TavilySearchService.rearrange_sources(sources, starting_number=len(context.sources) + 1) + sources = _rearrange_sources(sources, starting_number=len(context.sources) + 1) for source in sources: context.sources[source.url] = source diff --git a/sgr_agent_core/utils.py b/sgr_agent_core/utils.py index e193fb28..1a7a8d8c 100644 --- a/sgr_agent_core/utils.py +++ b/sgr_agent_core/utils.py @@ -27,7 +27,7 @@ def config_from_kwargs(config_class: type[T], base: T | None, kwargs: dict[str, agent-level config with per-tool kwargs from the tools array (global or inline). Args: - config_class: Pydantic model class to instantiate (e.g. SearchConfig). + config_class: Pydantic model class to instantiate (e.g. WebSearchConfig). base: Existing config instance, or None to use only kwargs (with model defaults). kwargs: Overrides; keys present here override base. None values are skipped. diff --git a/tests/test_base_agent.py b/tests/test_base_agent.py index 94d38962..227f628e 100644 --- a/tests/test_base_agent.py +++ b/tests/test_base_agent.py @@ -12,7 +12,7 @@ import pytest -from sgr_agent_core.agent_definition import AgentConfig, ExecutionConfig, LLMConfig, PromptsConfig, SearchConfig +from sgr_agent_core.agent_definition import AgentConfig, ExecutionConfig, LLMConfig, PromptsConfig from sgr_agent_core.base_agent import BaseAgent from sgr_agent_core.models import AgentContext, AgentStatesEnum from sgr_agent_core.tools import BaseTool, ReasoningTool, WebSearchTool @@ -351,11 +351,13 @@ def test_get_tool_config_returns_model_when_tool_has_config_model(self): execution=ExecutionConfig(), ) out = agent.get_tool_config(WebSearchTool) - assert isinstance(out, SearchConfig) + from sgr_agent_core.tools.web_search_tool import WebSearchConfig + + assert isinstance(out, WebSearchConfig) assert out.max_searches == 6 def test_get_tool_config_returns_model_from_tool_configs_only(self): - """get_tool_config builds SearchConfig exclusively from tool_configs + """get_tool_config builds config model exclusively from tool_configs (search settings are per-tool, not in AgentConfig).""" agent = create_test_agent( BaseAgent, @@ -369,7 +371,9 @@ def test_get_tool_config_returns_model_from_tool_configs_only(self): execution=ExecutionConfig(), ) out = agent.get_tool_config(WebSearchTool) - assert isinstance(out, SearchConfig) + from sgr_agent_core.tools.web_search_tool import WebSearchConfig + + assert isinstance(out, WebSearchConfig) assert out.max_searches == 10 assert out.tavily_api_key == "key" diff --git a/tests/test_search_providers.py b/tests/test_search_providers.py new file mode 100644 index 00000000..49cdb464 --- /dev/null +++ b/tests/test_search_providers.py @@ -0,0 +1,174 @@ +"""Tests for search engine handler functions.""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from sgr_agent_core.models import SourceData + + +class TestRearrangeSources: + """Tests for _rearrange_sources helper.""" + + def test_renumbers_from_starting_number(self): + from sgr_agent_core.tools.web_search_tool import _rearrange_sources + + sources = [ + SourceData(number=0, url="https://a.com", title="A", snippet="a"), + SourceData(number=0, url="https://b.com", title="B", snippet="b"), + ] + result = _rearrange_sources(sources, starting_number=5) + assert result[0].number == 5 + assert result[1].number == 6 + + +class TestBraveSearchHandler: + """Tests for Brave search handler function.""" + + def test_convert_brave_response(self): + from sgr_agent_core.tools.web_search_tool import _convert_brave_response + + response = { + "web": { + "results": [ + {"title": "Test", "url": "https://example.com", "description": "A test result"}, + {"title": "Test2", "url": "https://example2.com", "description": "Another result"}, + {"title": "No URL", "url": "", "description": "Skipped"}, + ] + } + } + sources = _convert_brave_response(response) + assert len(sources) == 2 + assert sources[0].title == "Test" + assert sources[0].url == "https://example.com" + assert sources[0].snippet == "A test result" + + @pytest.mark.asyncio + async def test_search_calls_brave_api(self): + from sgr_agent_core.tools.web_search_tool import _search_brave + + mock_response = MagicMock() + mock_response.json.return_value = { + "web": { + "results": [ + {"title": "Result", "url": "https://example.com", "description": "desc"}, + ] + } + } + mock_response.raise_for_status = MagicMock() + + with patch("sgr_agent_core.tools.web_search_tool.httpx.AsyncClient") as mock_client_cls: + mock_client = AsyncMock() + mock_client.get = AsyncMock(return_value=mock_response) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + mock_client_cls.return_value = mock_client + + sources = await _search_brave( + api_key="test-key", + api_base_url="https://api.search.brave.com/res/v1/web/search", + query="test query", + max_results=5, + offset=0, + ) + + mock_client.get.assert_called_once() + call_kwargs = mock_client.get.call_args + assert call_kwargs.kwargs["params"]["q"] == "test query" + assert call_kwargs.kwargs["params"]["count"] == 5 + assert len(sources) == 1 + + +class TestPerplexitySearchHandler: + """Tests for Perplexity search handler function.""" + + def test_convert_perplexity_response(self): + from sgr_agent_core.tools.web_search_tool import _convert_perplexity_response + + response = { + "results": [ + {"title": "Page 1", "url": "https://example.com/page1", "snippet": "First result snippet"}, + {"title": "Page 2", "url": "https://example.com/page2", "snippet": "Second result snippet"}, + {"title": "No URL", "url": "", "snippet": "Skipped"}, + ], + } + sources = _convert_perplexity_response(response) + assert len(sources) == 2 + assert sources[0].url == "https://example.com/page1" + assert sources[0].title == "Page 1" + assert sources[0].snippet == "First result snippet" + assert sources[1].snippet == "Second result snippet" + + @pytest.mark.asyncio + async def test_search_calls_perplexity_api(self): + from sgr_agent_core.tools.web_search_tool import _search_perplexity + + mock_response = MagicMock() + mock_response.json.return_value = { + "results": [ + {"title": "Result", "url": "https://example.com", "snippet": "desc"}, + ], + } + mock_response.raise_for_status = MagicMock() + + with patch("sgr_agent_core.tools.web_search_tool.httpx.AsyncClient") as mock_client_cls: + mock_client = AsyncMock() + mock_client.post = AsyncMock(return_value=mock_response) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + mock_client_cls.return_value = mock_client + + sources = await _search_perplexity( + api_key="test-key", + api_base_url="https://api.perplexity.ai/search", + query="test query", + max_results=5, + offset=0, + ) + + mock_client.post.assert_called_once() + call_kwargs = mock_client.post.call_args + assert call_kwargs.kwargs["json"]["query"] == "test query" + assert call_kwargs.kwargs["json"]["max_results"] == 5 + assert len(sources) == 1 + + +class TestTavilySearchHandler: + """Tests for Tavily search handler function.""" + + def test_convert_tavily_response(self): + from sgr_agent_core.tools.web_search_tool import _convert_tavily_response + + response = { + "results": [ + {"title": "Test", "url": "https://example.com", "content": "Snippet", "raw_content": "Full content"}, + ] + } + sources = _convert_tavily_response(response) + assert len(sources) == 1 + assert sources[0].title == "Test" + assert sources[0].snippet == "Snippet" + assert sources[0].full_content == "Full content" + + @pytest.mark.asyncio + async def test_extract_calls_tavily_api(self): + from sgr_agent_core.tools.extract_page_content_tool import ExtractPageContentConfig, ExtractPageContentTool + + config = ExtractPageContentConfig(tavily_api_key="test-key") + + mock_client = AsyncMock() + mock_client.extract = AsyncMock( + return_value={ + "results": [ + {"url": "https://example.com/page", "raw_content": "Full page content"}, + ], + "failed_results": [], + } + ) + + with patch("sgr_agent_core.tools.extract_page_content_tool.AsyncTavilyClient", return_value=mock_client): + sources = await ExtractPageContentTool._extract(config, urls=["https://example.com/page"]) + + assert len(sources) == 1 + assert sources[0].url == "https://example.com/page" + assert sources[0].full_content == "Full page content" diff --git a/tests/test_tools.py b/tests/test_tools.py index f96ab9e2..7f0d3bce 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -10,7 +10,6 @@ import pytest from sgr_agent_core.models import AgentContext, SourceData -from sgr_agent_core.services.tavily_search import TavilySearchService from sgr_agent_core.tools import ( AdaptPlanTool, AnswerTool, @@ -98,13 +97,12 @@ def test_web_search_tool_initialization(self): def test_extract_page_content_tool_initialization(self): """Test ExtractPageContentTool initialization.""" - with patch("sgr_agent_core.tools.extract_page_content_tool.TavilySearchService"): - tool = ExtractPageContentTool( - reasoning="Test", - urls=["https://example.com"], - ) - assert tool.tool_name == "extractpagecontenttool" - assert len(tool.urls) == 1 + tool = ExtractPageContentTool( + reasoning="Test", + urls=["https://example.com"], + ) + assert tool.tool_name == "extractpagecontenttool" + assert len(tool.urls) == 1 def test_create_report_tool_initialization(self): """Test CreateReportTool initialization.""" @@ -186,71 +184,52 @@ class TestSearchToolsKwargs: settings.""" @pytest.mark.asyncio - async def test_web_search_tool_uses_max_results_from_kwargs(self): - """WebSearchTool uses max_results from tool config kwargs.""" - from sgr_agent_core.models import AgentContext - + async def test_web_search_tool_uses_kwargs_max_results(self): + """WebSearchTool uses max_results from kwargs when provided.""" tool = WebSearchTool(reasoning="r", query="test", max_results=5) context = AgentContext() config = MagicMock() - with patch("sgr_agent_core.tools.web_search_tool.TavilySearchService") as mock_svc_class: - mock_svc = AsyncMock() - mock_svc.search = AsyncMock(return_value=[]) - mock_svc_class.return_value = mock_svc - await tool(context, config, tavily_api_key="k", max_results=3) - call_args = mock_svc_class.call_args[0][0] - assert call_args.max_results == 3 + mock_handler = AsyncMock(return_value=[]) + with patch.dict("sgr_agent_core.tools.web_search_tool._ENGINE_HANDLERS", {"tavily": mock_handler}): + await tool(context, config, api_key="k", max_results=3) + assert mock_handler.call_args.kwargs["max_results"] == 3 @pytest.mark.asyncio - async def test_web_search_tool_uses_tool_config_kwargs(self): - """WebSearchTool reads max_results from tool config kwargs (not - config.search).""" - from sgr_agent_core.models import AgentContext - + async def test_web_search_tool_default_max_results(self): + """WebSearchTool uses default max_results when not overridden in + kwargs.""" tool = WebSearchTool(reasoning="r", query="test", max_results=5) context = AgentContext() config = MagicMock() - with patch("sgr_agent_core.tools.web_search_tool.TavilySearchService") as mock_svc_class: - mock_svc = AsyncMock() - mock_svc.search = AsyncMock(return_value=[]) - mock_svc_class.return_value = mock_svc - await tool(context, config, tavily_api_key="k", max_results=10) - call_args = mock_svc_class.call_args[0][0] - assert call_args.max_results == 10 + mock_handler = AsyncMock(return_value=[]) + with patch.dict("sgr_agent_core.tools.web_search_tool._ENGINE_HANDLERS", {"tavily": mock_handler}): + await tool(context, config, api_key="k") + assert mock_handler.call_args.kwargs["max_results"] == 5 @pytest.mark.asyncio async def test_web_search_tool_with_offset(self): - """WebSearchTool with offset fetches offset+max_results from Tavily and - slices.""" + """WebSearchTool passes offset to provider which handles it + internally.""" tool = WebSearchTool(reasoning="r", query="test", max_results=3, offset=2) context = AgentContext() config = MagicMock() mock_sources = [ SourceData(number=i, url=f"https://example.com/{i}", title=f"Result {i}", snippet=f"Snippet {i}") - for i in range(5) + for i in range(2, 5) ] - with patch("sgr_agent_core.tools.web_search_tool.TavilySearchService") as mock_svc_class: - mock_svc = AsyncMock() - mock_svc.search = AsyncMock(return_value=mock_sources) - mock_svc_class.return_value = mock_svc - mock_svc_class.rearrange_sources = TavilySearchService.rearrange_sources - - result = await tool(context, config, tavily_api_key="k", max_results=10) + mock_handler = AsyncMock(return_value=mock_sources) + with patch.dict("sgr_agent_core.tools.web_search_tool._ENGINE_HANDLERS", {"tavily": mock_handler}): + result = await tool(context, config, api_key="k") - # Tavily should receive max_results = 3 + 2 = 5 - mock_svc.search.assert_called_once_with(query="test", max_results=5, include_raw_content=False) - # After slicing [2:], 3 sources should remain assert len(context.searches) == 1 assert len(context.searches[0].citations) == 3 assert "Result 2" in result - assert "Result 0" not in result @pytest.mark.asyncio async def test_web_search_tool_offset_default_zero(self): - """WebSearchTool without offset behaves identically to current - logic.""" + """WebSearchTool without offset passes offset=0 to provider.""" tool = WebSearchTool(reasoning="r", query="test", max_results=3) assert tool.offset == 0 @@ -262,16 +241,10 @@ async def test_web_search_tool_offset_default_zero(self): for i in range(3) ] - with patch("sgr_agent_core.tools.web_search_tool.TavilySearchService") as mock_svc_class: - mock_svc = AsyncMock() - mock_svc.search = AsyncMock(return_value=mock_sources) - mock_svc_class.return_value = mock_svc - mock_svc_class.rearrange_sources = TavilySearchService.rearrange_sources + mock_handler = AsyncMock(return_value=mock_sources) + with patch.dict("sgr_agent_core.tools.web_search_tool._ENGINE_HANDLERS", {"tavily": mock_handler}): + await tool(context, config, api_key="k") - await tool(context, config, tavily_api_key="k", max_results=10) - - # Tavily should receive max_results = 3 (no offset added) - mock_svc.search.assert_called_once_with(query="test", max_results=3, include_raw_content=False) assert len(context.searches[0].citations) == 3 @pytest.mark.asyncio @@ -282,38 +255,20 @@ async def test_web_search_tool_offset_exceeds_results(self): context = AgentContext() config = MagicMock() - mock_sources = [ - SourceData(number=i, url=f"https://example.com/{i}", title=f"Result {i}", snippet=f"Snippet {i}") - for i in range(5) - ] - - with patch("sgr_agent_core.tools.web_search_tool.TavilySearchService") as mock_svc_class: - mock_svc = AsyncMock() - mock_svc.search = AsyncMock(return_value=mock_sources) - mock_svc_class.return_value = mock_svc - mock_svc_class.rearrange_sources = TavilySearchService.rearrange_sources - - result = await tool(context, config, tavily_api_key="k", max_results=20) + mock_handler = AsyncMock(return_value=[]) + with patch.dict("sgr_agent_core.tools.web_search_tool._ENGINE_HANDLERS", {"tavily": mock_handler}): + result = await tool(context, config, api_key="k") - # Tavily should receive max_results = 3 + 10 = 13 - mock_svc.search.assert_called_once_with(query="test", max_results=13, include_raw_content=False) - # After slicing [10:] on 5 results, empty list assert len(context.searches[0].citations) == 0 assert "Search Query: test" in result @pytest.mark.asyncio async def test_extract_page_content_tool_uses_content_limit_from_kwargs(self): - """ExtractPageContentTool uses content_limit from tool config - kwargs.""" - from sgr_agent_core.models import AgentContext - + """ExtractPageContentTool uses content_limit from kwargs.""" tool = ExtractPageContentTool(reasoning="r", urls=["https://example.com"]) context = AgentContext() config = MagicMock() - with patch("sgr_agent_core.tools.extract_page_content_tool.TavilySearchService") as mock_svc_class: - mock_svc = AsyncMock() - mock_svc.extract = AsyncMock(return_value=[]) - mock_svc_class.return_value = mock_svc + with patch.object(ExtractPageContentTool, "_extract", new_callable=AsyncMock, return_value=[]) as mock_extract: await tool(context, config, tavily_api_key="k", content_limit=500) - call_args = mock_svc_class.call_args[0][0] - assert call_args.content_limit == 500 + # search_config is passed as first positional arg + assert mock_extract.call_args[0][0].content_limit == 500