From 3f25e9724f6be84f5041f7abfe811508a213f5ff Mon Sep 17 00:00:00 2001 From: Nikita Matsko Date: Tue, 17 Feb 2026 17:27:33 +0000 Subject: [PATCH 01/12] feat(search): add multi-provider search services with BaseSearchService, Brave and Perplexity Add BaseSearchService base class with factory pattern routing by engine config. Add BraveSearchService (httpx, native offset) and PerplexitySearchService (httpx, over-fetch+slice). Extend SearchConfig with engine, brave_*, perplexity_* fields. Refactor TavilySearchService to inherit BaseSearchService. --- sgr_agent_core/agent_definition.py | 15 +++ sgr_agent_core/services/__init__.py | 6 ++ sgr_agent_core/services/base_search.py | 79 ++++++++++++++ sgr_agent_core/services/brave_search.py | 98 +++++++++++++++++ sgr_agent_core/services/perplexity_search.py | 106 +++++++++++++++++++ sgr_agent_core/services/tavily_search.py | 41 +++---- 6 files changed, 317 insertions(+), 28 deletions(-) create mode 100644 sgr_agent_core/services/base_search.py create mode 100644 sgr_agent_core/services/brave_search.py create mode 100644 sgr_agent_core/services/perplexity_search.py diff --git a/sgr_agent_core/agent_definition.py b/sgr_agent_core/agent_definition.py index 67a2795d..5f8b0967 100644 --- a/sgr_agent_core/agent_definition.py +++ b/sgr_agent_core/agent_definition.py @@ -72,6 +72,21 @@ class SearchConfig(BaseModel, extra="allow"): max_results: int = Field(default=10, ge=1, description="Maximum number of search results") content_limit: int = Field(default=3500, gt=0, description="Content character limit per source") + engine: Literal["tavily", "brave", "perplexity"] = Field( + default="tavily", + description="Search engine provider to use", + ) + brave_api_key: str | None = Field(default=None, description="Brave Search API key") + brave_api_base_url: str = Field( + default="https://api.search.brave.com/res/v1/web/search", + description="Brave Search API base URL", + ) + perplexity_api_key: str | None = Field(default=None, description="Perplexity API key") + perplexity_api_base_url: str = Field( + default="https://api.perplexity.ai/search", + description="Perplexity Search API base URL", + ) + class PromptsConfig(BaseModel, extra="allow"): system_prompt_file: FilePath | None = Field( diff --git a/sgr_agent_core/services/__init__.py b/sgr_agent_core/services/__init__.py index b96a69b4..a3fd26fe 100644 --- a/sgr_agent_core/services/__init__.py +++ b/sgr_agent_core/services/__init__.py @@ -1,12 +1,18 @@ """Services module for external integrations and business logic.""" +from sgr_agent_core.services.base_search import BaseSearchService +from sgr_agent_core.services.brave_search import BraveSearchService from sgr_agent_core.services.mcp_service import MCP2ToolConverter +from sgr_agent_core.services.perplexity_search import PerplexitySearchService from sgr_agent_core.services.prompt_loader import PromptLoader from sgr_agent_core.services.registry import AgentRegistry, StreamingGeneratorRegistry, ToolRegistry from sgr_agent_core.services.tavily_search import TavilySearchService from sgr_agent_core.services.tool_instantiator import ToolInstantiator __all__ = [ + "BaseSearchService", + "BraveSearchService", + "PerplexitySearchService", "TavilySearchService", "MCP2ToolConverter", "ToolRegistry", diff --git a/sgr_agent_core/services/base_search.py b/sgr_agent_core/services/base_search.py new file mode 100644 index 00000000..f7443e4c --- /dev/null +++ b/sgr_agent_core/services/base_search.py @@ -0,0 +1,79 @@ +import logging +from typing import TYPE_CHECKING + +from sgr_agent_core.models import SourceData + +if TYPE_CHECKING: + from sgr_agent_core.agent_definition import SearchConfig + +logger = logging.getLogger(__name__) + + +class BaseSearchService: + """Base class for search service providers. + + Subclasses must implement the `search` method. + """ + + def __init__(self, search_config: "SearchConfig"): + self._config = search_config + + async def search( + self, + query: str, + max_results: int | None = None, + offset: int = 0, + include_raw_content: bool = True, + ) -> list[SourceData]: + """Perform a search and return results as SourceData list. + + Each provider handles offset internally: + - Brave: uses native API offset parameter + - Tavily/Perplexity: over-fetch+slice + + Args: + query: Search query string + max_results: Maximum number of results to return (after offset) + offset: Number of results to skip + include_raw_content: Whether to include raw page content + + Returns: + List of SourceData results (at most max_results items) + """ + raise NotImplementedError("Subclasses must implement search()") + + @staticmethod + def rearrange_sources(sources: list[SourceData], starting_number: int = 1) -> list[SourceData]: + """Renumber sources sequentially starting from given number.""" + for i, source in enumerate(sources, starting_number): + source.number = i + return sources + + @classmethod + def create(cls, config: "SearchConfig") -> "BaseSearchService": + """Factory method to create a search service based on config.engine. + + Args: + config: SearchConfig with engine and API keys + + Returns: + Appropriate search service instance + + Raises: + ValueError: If engine is not supported + """ + from sgr_agent_core.services.brave_search import BraveSearchService + from sgr_agent_core.services.perplexity_search import PerplexitySearchService + from sgr_agent_core.services.tavily_search import TavilySearchService + + engine = config.engine + logger.debug(f"Creating search service for engine: {engine}") + if engine == "tavily": + return TavilySearchService(config) + elif engine == "brave": + return BraveSearchService(config) + elif engine == "perplexity": + return PerplexitySearchService(config) + else: + logger.error(f"Unsupported search engine requested: {engine}") + raise ValueError(f"Unsupported search engine: {engine}") diff --git a/sgr_agent_core/services/brave_search.py b/sgr_agent_core/services/brave_search.py new file mode 100644 index 00000000..dcd7ace0 --- /dev/null +++ b/sgr_agent_core/services/brave_search.py @@ -0,0 +1,98 @@ +import logging +from typing import Any + +import httpx + +from sgr_agent_core.agent_definition import SearchConfig +from sgr_agent_core.models import SourceData +from sgr_agent_core.services.base_search import BaseSearchService + +logger = logging.getLogger(__name__) + + +class BraveSearchService(BaseSearchService): + """Search service using Brave Search API. + + Uses httpx.AsyncClient for HTTP requests. + Auth: X-Subscription-Token header. + Brave API supports native offset for pagination. + """ + + def __init__(self, search_config: SearchConfig): + super().__init__(search_config) + if not search_config.brave_api_key: + raise ValueError("brave_api_key is required for BraveSearchService") + + async def search( + self, + query: str, + max_results: int | None = None, + offset: int = 0, + include_raw_content: bool = True, + ) -> list[SourceData]: + """Perform search through Brave Search API. + + Brave supports native offset parameter for efficient pagination. + + Args: + query: Search query string + max_results: Maximum number of results (max 20 per Brave API) + offset: Number of results to skip (native Brave API support) + include_raw_content: Ignored for Brave (no raw content extraction) + + Returns: + List of SourceData results + """ + max_results = min(max_results or self._config.max_results, 20) + logger.info(f"🔍 Brave search: '{query}' (max_results={max_results}, offset={offset})") + + headers = { + "Accept": "application/json", + "Accept-Encoding": "gzip", + "X-Subscription-Token": self._config.brave_api_key, + } + params: dict[str, Any] = { + "q": query, + "count": max_results, + } + if offset > 0: + params["offset"] = offset + + try: + async with httpx.AsyncClient() as client: + response = await client.get( + self._config.brave_api_base_url, + headers=headers, + params=params, + timeout=30.0, + ) + response.raise_for_status() + data = response.json() + except httpx.HTTPStatusError as e: + logger.error(f"Brave API HTTP error: {e.response.status_code} — {e.response.text[:200]}") + raise + except httpx.RequestError as e: + logger.error(f"Brave API request error: {e}") + raise + + return self._convert_to_source_data(data) + + def _convert_to_source_data(self, response: dict) -> list[SourceData]: + """Convert Brave Search API response to SourceData list.""" + sources = [] + web_results = response.get("web", {}).get("results", []) + + for i, result in enumerate(web_results): + url = result.get("url", "") + if not url: + continue + + source = SourceData( + number=i, + title=result.get("title", ""), + url=url, + snippet=result.get("description", ""), + ) + sources.append(source) + + return sources diff --git a/sgr_agent_core/services/perplexity_search.py b/sgr_agent_core/services/perplexity_search.py new file mode 100644 index 00000000..753bed93 --- /dev/null +++ b/sgr_agent_core/services/perplexity_search.py @@ -0,0 +1,106 @@ +import logging +from typing import Any + +import httpx + +from sgr_agent_core.agent_definition import SearchConfig +from sgr_agent_core.models import SourceData +from sgr_agent_core.services.base_search import BaseSearchService + +logger = logging.getLogger(__name__) + + +class PerplexitySearchService(BaseSearchService): + """Search service using Perplexity Search API. + + Uses httpx.AsyncClient for HTTP requests to the dedicated Search API + endpoint (POST /search) which returns ranked web results with titles, + URLs, and snippets. + Auth: Authorization Bearer header. + Offset is handled by over-fetch+slice (no native offset support). + """ + + def __init__(self, search_config: SearchConfig): + super().__init__(search_config) + if not search_config.perplexity_api_key: + raise ValueError("perplexity_api_key is required for PerplexitySearchService") + + async def search( + self, + query: str, + max_results: int | None = None, + offset: int = 0, + include_raw_content: bool = True, + ) -> list[SourceData]: + """Perform search through Perplexity Search API. + + Perplexity does not support native offset — over-fetch+slice + is applied internally when offset > 0. + + Args: + query: Search query string + max_results: Maximum number of results to return + offset: Number of results to skip (over-fetch+slice internally) + include_raw_content: Ignored for Perplexity + + Returns: + List of SourceData results + """ + max_results = max_results or self._config.max_results + fetch_count = max_results + offset if offset > 0 else max_results + logger.info(f"🔍 Perplexity search: '{query}' (max_results={max_results}, offset={offset})") + + headers = { + "Authorization": f"Bearer {self._config.perplexity_api_key}", + "Content-Type": "application/json", + } + payload: dict[str, Any] = { + "query": query, + "max_results": fetch_count, + } + + try: + async with httpx.AsyncClient() as client: + response = await client.post( + self._config.perplexity_api_base_url, + headers=headers, + json=payload, + timeout=30.0, + ) + response.raise_for_status() + data = response.json() + except httpx.HTTPStatusError as e: + logger.error(f"Perplexity API HTTP error: {e.response.status_code} — {e.response.text[:200]}") + raise + except httpx.RequestError as e: + logger.error(f"Perplexity API request error: {e}") + raise + + sources = self._convert_to_source_data(data) + if offset > 0: + sources = sources[offset:] + return sources[:max_results] + + def _convert_to_source_data(self, response: dict) -> list[SourceData]: + """Convert Perplexity Search API response to SourceData list. + + Perplexity Search API returns results[] with title, url, + snippet. + """ + sources = [] + results = response.get("results", []) + + for i, result in enumerate(results): + url = result.get("url", "") + if not url: + continue + + source = SourceData( + number=i, + title=result.get("title", ""), + url=url, + snippet=result.get("snippet", ""), + ) + sources.append(source) + + return sources diff --git a/sgr_agent_core/services/tavily_search.py b/sgr_agent_core/services/tavily_search.py index 83d74437..ec2f417c 100644 --- a/sgr_agent_core/services/tavily_search.py +++ b/sgr_agent_core/services/tavily_search.py @@ -4,63 +4,48 @@ from sgr_agent_core.agent_definition import SearchConfig from sgr_agent_core.models import SourceData +from sgr_agent_core.services.base_search import BaseSearchService logger = logging.getLogger(__name__) -class TavilySearchService: +class TavilySearchService(BaseSearchService): def __init__(self, search_config: SearchConfig): + super().__init__(search_config) self._client = AsyncTavilyClient( api_key=search_config.tavily_api_key, api_base_url=search_config.tavily_api_base_url ) - self._config = search_config - - @staticmethod - def rearrange_sources(sources: list[SourceData], starting_number=1) -> list[SourceData]: - for i, source in enumerate(sources, starting_number): - source.number = i - return sources async def search( self, query: str, max_results: int | None = None, + offset: int = 0, include_raw_content: bool = True, ) -> list[SourceData]: """Perform search through Tavily API and return results with SourceData. - Args: - query: Search query - max_results: Maximum number of results (default from config) - include_raw_content: Include raw page content - - Returns: - Tuple with tavily answer and list of SourceData + Tavily API does not support native offset — over-fetch+slice is + applied internally when offset > 0. """ max_results = max_results or self._config.max_results - logger.info(f"🔍 Tavily search: '{query}' (max_results={max_results})") + fetch_count = max_results + offset if offset > 0 else max_results + logger.info(f"🔍 Tavily search: '{query}' (max_results={max_results}, offset={offset})") - # Execute search through Tavily response = await self._client.search( query=query, - max_results=max_results, + max_results=fetch_count, include_raw_content=include_raw_content, ) - # Convert results to SourceData sources = self._convert_to_source_data(response) - return sources + if offset > 0: + sources = sources[offset:] + return sources[:max_results] async def extract(self, urls: list[str]) -> list[SourceData]: - """Extract full content from specific URLs using Tavily Extract API. - - Args: - urls: List of URLs to extract content from - - Returns: - List of SourceData with extracted content - """ + """Extract full content from specific URLs using Tavily Extract API.""" logger.info(f"📄 Tavily extract: {len(urls)} URLs") response = await self._client.extract(urls=urls) From d610f91f762b16b2772a86c997202f568cb9d4fe Mon Sep 17 00:00:00 2001 From: Nikita Matsko Date: Tue, 17 Feb 2026 17:29:25 +0000 Subject: [PATCH 02/12] feat(search): add standalone search tools (Tavily, Brave, Perplexity) with shared BaseSearchTool base Add _BaseSearchTool base class with shared fields and __call__ logic Refactor WebSearchTool to inherit from _BaseSearchTool Add TavilySearchTool, BraveSearchTool, PerplexitySearchTool standalone tools Fix description ClassVar inheritance in __init_subclass__ Add Tavily key guard to ExtractPageContentTool Update exports and config examples with multi-provider settings --- config.yaml.example | 15 ++- .../sgr_deep_research/config.yaml.example | 1 + .../config.yaml.example | 1 + sgr_agent_core/__init__.py | 17 ++- sgr_agent_core/base_tool.py | 2 +- sgr_agent_core/tools/__init__.py | 6 + sgr_agent_core/tools/base_search_tool.py | 114 ++++++++++++++++++ sgr_agent_core/tools/brave_search_tool.py | 18 +++ .../tools/extract_page_content_tool.py | 16 ++- .../tools/perplexity_search_tool.py | 18 +++ sgr_agent_core/tools/tavily_search_tool.py | 18 +++ sgr_agent_core/tools/web_search_tool.py | 93 +------------- 12 files changed, 217 insertions(+), 102 deletions(-) create mode 100644 sgr_agent_core/tools/base_search_tool.py create mode 100644 sgr_agent_core/tools/brave_search_tool.py create mode 100644 sgr_agent_core/tools/perplexity_search_tool.py create mode 100644 sgr_agent_core/tools/tavily_search_tool.py diff --git a/config.yaml.example b/config.yaml.example index 10fbec21..1e535e14 100644 --- a/config.yaml.example +++ b/config.yaml.example @@ -57,17 +57,28 @@ tools: base_class: path.to.my.tools.CustomTool my_other_tool: base_class: "name_of_tool_class_in_registry" - # Search tools: configure Tavily API key and search limits per tool + # Search tools: configure search provider and API keys per tool # (can be overridden per-agent in tools list) web_search_tool: + engine: "tavily" # Search engine: "tavily" (default), "brave", or "perplexity" tavily_api_key: "your-tavily-api-key-here" # Tavily API key (get at tavily.com) tavily_api_base_url: "https://api.tavily.com" # Tavily API URL max_results: 12 max_searches: 6 extract_page_content_tool: - tavily_api_key: "your-tavily-api-key-here" # Same Tavily API key + tavily_api_key: "your-tavily-api-key-here" # Same Tavily API key (Tavily-only feature) tavily_api_base_url: "https://api.tavily.com" content_limit: 2000 + # Standalone search tools (for multi-engine setups where LLM picks the engine) + brave_search_tool: + brave_api_key: "your-brave-api-key-here" # Brave Search API key + brave_api_base_url: "https://api.search.brave.com/res/v1/web/search" + perplexity_search_tool: + perplexity_api_key: "your-perplexity-api-key-here" # Perplexity API key + perplexity_api_base_url: "https://api.perplexity.ai/search" + tavily_search_tool: + tavily_api_key: "your-tavily-api-key-here" + tavily_api_base_url: "https://api.tavily.com" agents: custom_research_agent: diff --git a/examples/sgr_deep_research/config.yaml.example b/examples/sgr_deep_research/config.yaml.example index d6e765a0..40c73eed 100644 --- a/examples/sgr_deep_research/config.yaml.example +++ b/examples/sgr_deep_research/config.yaml.example @@ -30,6 +30,7 @@ tools: # Core tools (base_class defaults to sgr_agent_core.tools.*) # Search tools: configure Tavily API key and search limits per tool web_search_tool: + engine: "tavily" # Search engine: "tavily" (default), "brave", or "perplexity" tavily_api_key: "your-tavily-api-key-here" # Tavily API key (get at tavily.com) tavily_api_base_url: "https://api.tavily.com" # Tavily API URL max_searches: 4 # Max search operations diff --git a/examples/sgr_deep_research_without_reporting/config.yaml.example b/examples/sgr_deep_research_without_reporting/config.yaml.example index d84ba528..41953831 100644 --- a/examples/sgr_deep_research_without_reporting/config.yaml.example +++ b/examples/sgr_deep_research_without_reporting/config.yaml.example @@ -29,6 +29,7 @@ tools: # Core tools (base_class defaults to sgr_agent_core.tools.*) # Search tools: configure Tavily API key and search limits per tool web_search_tool: + engine: "tavily" # Search engine: "tavily" (default), "brave", or "perplexity" tavily_api_key: "your-tavily-api-key-here" # Tavily API key (get at tavily.com) tavily_api_base_url: "https://api.tavily.com" # Tavily API URL max_searches: 4 # Max search operations diff --git a/sgr_agent_core/__init__.py b/sgr_agent_core/__init__.py index 2c9c2995..653c19d2 100644 --- a/sgr_agent_core/__init__.py +++ b/sgr_agent_core/__init__.py @@ -30,7 +30,15 @@ SourceData, ) from sgr_agent_core.next_step_tool import NextStepToolsBuilder, NextStepToolStub -from sgr_agent_core.services import AgentRegistry, MCP2ToolConverter, PromptLoader, ToolRegistry +from sgr_agent_core.services import ( + AgentRegistry, + BaseSearchService, + BraveSearchService, + MCP2ToolConverter, + PerplexitySearchService, + PromptLoader, + ToolRegistry, +) from sgr_agent_core.tools import * # noqa: F403 __all__ = [ @@ -49,9 +57,12 @@ "SourceData", # Services "AgentRegistry", - "ToolRegistry", - "PromptLoader", + "BaseSearchService", + "BraveSearchService", "MCP2ToolConverter", + "PerplexitySearchService", + "PromptLoader", + "ToolRegistry", # Configuration "AgentConfig", "AgentDefinition", diff --git a/sgr_agent_core/base_tool.py b/sgr_agent_core/base_tool.py index 463c3c1d..578d06a9 100644 --- a/sgr_agent_core/base_tool.py +++ b/sgr_agent_core/base_tool.py @@ -20,7 +20,7 @@ class ToolRegistryMixin: def __init_subclass__(cls, **kwargs) -> None: super().__init_subclass__(**kwargs) - if cls.__name__ not in ("BaseTool", "MCPBaseTool"): + if cls.__name__ not in ("BaseTool", "MCPBaseTool", "_BaseSearchTool"): ToolRegistry.register(cls, name=cls.tool_name) diff --git a/sgr_agent_core/tools/__init__.py b/sgr_agent_core/tools/__init__.py index 786e5182..07487c95 100644 --- a/sgr_agent_core/tools/__init__.py +++ b/sgr_agent_core/tools/__init__.py @@ -6,12 +6,15 @@ ) from sgr_agent_core.tools.adapt_plan_tool import AdaptPlanTool from sgr_agent_core.tools.answer_tool import AnswerTool +from sgr_agent_core.tools.brave_search_tool import BraveSearchTool from sgr_agent_core.tools.clarification_tool import ClarificationTool from sgr_agent_core.tools.create_report_tool import CreateReportTool from sgr_agent_core.tools.extract_page_content_tool import ExtractPageContentTool from sgr_agent_core.tools.final_answer_tool import FinalAnswerTool from sgr_agent_core.tools.generate_plan_tool import GeneratePlanTool +from sgr_agent_core.tools.perplexity_search_tool import PerplexitySearchTool from sgr_agent_core.tools.reasoning_tool import ReasoningTool +from sgr_agent_core.tools.tavily_search_tool import TavilySearchTool from sgr_agent_core.tools.web_search_tool import WebSearchTool __all__ = [ @@ -22,6 +25,7 @@ "ToolNameSelectorStub", "NextStepToolsBuilder", # Individual tools + "BraveSearchTool", "ClarificationTool", "GeneratePlanTool", "WebSearchTool", @@ -30,7 +34,9 @@ "CreateReportTool", "AnswerTool", "FinalAnswerTool", + "PerplexitySearchTool", "ReasoningTool", + "TavilySearchTool", # Tool lists "NextStepToolStub", "NextStepToolsBuilder", diff --git a/sgr_agent_core/tools/base_search_tool.py b/sgr_agent_core/tools/base_search_tool.py new file mode 100644 index 00000000..547435ff --- /dev/null +++ b/sgr_agent_core/tools/base_search_tool.py @@ -0,0 +1,114 @@ +from __future__ import annotations + +import logging +from datetime import datetime +from typing import TYPE_CHECKING, Any, ClassVar + +from pydantic import Field + +from sgr_agent_core.agent_definition import AgentConfig, SearchConfig +from sgr_agent_core.base_tool import BaseTool +from sgr_agent_core.models import SearchResult +from sgr_agent_core.services.base_search import BaseSearchService +from sgr_agent_core.utils import config_from_kwargs + +if TYPE_CHECKING: + from sgr_agent_core.models import AgentContext + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + + +class _BaseSearchTool(BaseTool): + """Base class for all search tools. + + Provides shared fields (reasoning, query, max_results, offset) and + common __call__ logic. Concrete tools override _default_engine and + docstring. + """ + + _default_engine: ClassVar[str | None] = None + + def __init_subclass__(cls, **kwargs: Any) -> None: + # Reset tool_name and description so each concrete subclass gets its + # own values instead of inheriting "_basesearchtool" / base docstring via MRO + if "tool_name" not in cls.__dict__: + cls.tool_name = cls.__name__.lower() + if "description" not in cls.__dict__: + cls.description = cls.__doc__ or "" + super().__init_subclass__(**kwargs) + + config_model = SearchConfig + base_config_attr = "search" + + reasoning: str = Field(description="Why this search is needed and what to expect") + query: str = Field(description="Search query in same language as user request") + max_results: int = Field( + description="Maximum results. How much of the web results selection you want to retrieve", + default=5, + ge=1, + le=20, + ) + offset: int = Field( + default=0, + ge=0, + description=( + "Number of results to skip from the beginning." + " Use for pagination: first call offset=0, next call offset=5, etc." + ), + ) + + async def __call__(self, context: AgentContext, config: AgentConfig, **kwargs: Any) -> str: + """Execute web search using the configured search engine. + + Search settings are taken from kwargs (tool config) with + fallback to config.search. + """ + # If this tool has a hardcoded engine, force it + if self._default_engine is not None: + kwargs.setdefault("engine", self._default_engine) + + search_config = config_from_kwargs( + SearchConfig, + config.search if config else None, + dict(kwargs), + ) + logger.info(f"Search query: '{self.query}' (engine={search_config.engine})") + + service = BaseSearchService.create(search_config) + + max_results_limit = search_config.max_results + effective_limit = min(self.max_results, max_results_limit) + + # Each service handles offset internally: + # Brave uses native API offset, Tavily/Perplexity use over-fetch+slice + sources = await service.search( + query=self.query, + max_results=effective_limit, + offset=self.offset, + include_raw_content=False, + ) + + sources = BaseSearchService.rearrange_sources(sources, starting_number=len(context.sources) + 1) + + for source in sources: + context.sources[source.url] = source + + search_result = SearchResult( + query=self.query, + answer=None, + citations=sources, + timestamp=datetime.now(), + ) + context.searches.append(search_result) + + formatted_result = f"Search Query: {search_result.query}\n\n" + formatted_result += "Search Results (titles, links, short snippets):\n\n" + + for source in sources: + snippet = source.snippet[:100] + "..." if len(source.snippet) > 100 else source.snippet + formatted_result += f"{str(source)}\n{snippet}\n\n" + + context.searches_used += 1 + logger.debug(formatted_result) + return formatted_result diff --git a/sgr_agent_core/tools/brave_search_tool.py b/sgr_agent_core/tools/brave_search_tool.py new file mode 100644 index 00000000..be33c20a --- /dev/null +++ b/sgr_agent_core/tools/brave_search_tool.py @@ -0,0 +1,18 @@ +from sgr_agent_core.tools.base_search_tool import _BaseSearchTool + + +class BraveSearchTool(_BaseSearchTool): + """Search the web using Brave search engine. Brave Search provides privacy- + focused search results with native pagination support. Use this tool when + you specifically want to search with Brave. + + Returns: Page titles, URLs, and short snippets + Best for: Privacy-focused search, efficient pagination via native offset + + Usage: + - Use SPECIFIC terms and context in queries + - Search queries in SAME LANGUAGE as user request + - Brave supports efficient pagination with offset parameter + """ + + _default_engine = "brave" diff --git a/sgr_agent_core/tools/extract_page_content_tool.py b/sgr_agent_core/tools/extract_page_content_tool.py index d8e56d89..a5156e7d 100644 --- a/sgr_agent_core/tools/extract_page_content_tool.py +++ b/sgr_agent_core/tools/extract_page_content_tool.py @@ -20,11 +20,12 @@ class ExtractPageContentTool(BaseTool): """Extract full detailed content from specific web pages. - Use for: Getting complete page content from URLs found in web search Returns: - Full page content in readable format (via Tavily Extract API) - Best for: Deep analysis of specific pages, extracting structured data - Usage: Call after WebSearchTool to get detailed information from promising URLs + Use for: Getting complete page content from URLs found in web search. + Returns: Full page content in readable format (via Tavily Extract API). + Best for: Deep analysis of specific pages, extracting structured data. + + Usage: Call after WebSearchTool to get detailed information from promising URLs. CRITICAL WARNINGS: - Extracted pages may show data from DIFFERENT years/time periods than asked @@ -51,7 +52,12 @@ async def __call__(self, context: AgentContext, config: AgentConfig, **kwargs: A config.search if config else None, dict(kwargs), ) - logger.info(f"📄 Extracting content from {len(self.urls)} URLs") + if not search_config.tavily_api_key: + return ( + "Error: tavily_api_key is required for ExtractPageContentTool." + " Tavily is the only provider that supports content extraction." + ) + logger.info(f"Extracting content from {len(self.urls)} URLs") self._search_service = TavilySearchService(search_config) sources = await self._search_service.extract(urls=self.urls) diff --git a/sgr_agent_core/tools/perplexity_search_tool.py b/sgr_agent_core/tools/perplexity_search_tool.py new file mode 100644 index 00000000..86ee4edc --- /dev/null +++ b/sgr_agent_core/tools/perplexity_search_tool.py @@ -0,0 +1,18 @@ +from sgr_agent_core.tools.base_search_tool import _BaseSearchTool + + +class PerplexitySearchTool(_BaseSearchTool): + """Search the web using Perplexity AI search engine. Perplexity provides + AI-powered search with synthesized answers and source citations. Use this + tool when you specifically want to search with Perplexity. + + Returns: Page titles, URLs, and AI-synthesized snippets + Best for: Getting AI-synthesized answers with source citations + + Usage: + - Use SPECIFIC terms and context in queries + - Search queries in SAME LANGUAGE as user request + - Results include AI-generated summary alongside source URLs + """ + + _default_engine = "perplexity" diff --git a/sgr_agent_core/tools/tavily_search_tool.py b/sgr_agent_core/tools/tavily_search_tool.py new file mode 100644 index 00000000..cdd2e5b4 --- /dev/null +++ b/sgr_agent_core/tools/tavily_search_tool.py @@ -0,0 +1,18 @@ +from sgr_agent_core.tools.base_search_tool import _BaseSearchTool + + +class TavilySearchTool(_BaseSearchTool): + """Search the web using Tavily search engine. Tavily provides high-quality + search results with optional raw content extraction. Use this tool when you + specifically want to search with Tavily. + + Returns: Page titles, URLs, and short snippets + Best for: General web search, research queries + + Usage: + - Use SPECIFIC terms and context in queries + - Search queries in SAME LANGUAGE as user request + - Use ExtractPageContentTool to get full content from found URLs + """ + + _default_engine = "tavily" diff --git a/sgr_agent_core/tools/web_search_tool.py b/sgr_agent_core/tools/web_search_tool.py index e7fd3e63..90fdae27 100644 --- a/sgr_agent_core/tools/web_search_tool.py +++ b/sgr_agent_core/tools/web_search_tool.py @@ -1,25 +1,7 @@ -from __future__ import annotations +from sgr_agent_core.tools.base_search_tool import _BaseSearchTool -import logging -from datetime import datetime -from typing import TYPE_CHECKING, Any -from pydantic import Field - -from sgr_agent_core.agent_definition import AgentConfig, SearchConfig -from sgr_agent_core.base_tool import BaseTool -from sgr_agent_core.models import SearchResult -from sgr_agent_core.services.tavily_search import TavilySearchService -from sgr_agent_core.utils import config_from_kwargs - -if TYPE_CHECKING: - from sgr_agent_core.models import AgentContext - -logger = logging.getLogger(__name__) -logger.setLevel(logging.INFO) - - -class WebSearchTool(BaseTool): +class WebSearchTool(_BaseSearchTool): """Search the web for real-time information about any topic. Use this tool when you need up-to-date information that might not be available in your training data, or when you need to verify current facts. @@ -43,74 +25,3 @@ class WebSearchTool(BaseTool): - For questions with specific dates/numbers, snippets may be more accurate than full pages - If the snippet directly answers the question, you may not need to extract the full page """ - - config_model = SearchConfig - base_config_attr = "search" - - reasoning: str = Field(description="Why this search is needed and what to expect") - query: str = Field(description="Search query in same language as user request") - max_results: int = Field( - description="Maximum results. How much of the web results selection you want to retrieve", - default=5, - ge=1, - le=10, - ) - offset: int = Field( - default=0, - ge=0, - description=( - "Number of results to skip from the beginning." - " Use for pagination: first call offset=0, next call offset=5, etc." - ), - ) - - async def __call__(self, context: AgentContext, config: AgentConfig, **kwargs: Any) -> str: - """Execute web search using TavilySearchService. - - Search settings are taken from kwargs (tool config) with - fallback to config.search. - """ - search_config = config_from_kwargs( - SearchConfig, - config.search if config else None, - dict(kwargs), - ) - logger.info(f"🔍 Search query: '{self.query}'") - self._search_service = TavilySearchService(search_config) - - max_results_limit = search_config.max_results - effective_limit = min(self.max_results, max_results_limit) - fetch_count = effective_limit + self.offset - - sources = await self._search_service.search( - query=self.query, - max_results=fetch_count, - include_raw_content=False, - ) - - if self.offset: - sources = sources[self.offset :] - - sources = TavilySearchService.rearrange_sources(sources, starting_number=len(context.sources) + 1) - - for source in sources: - context.sources[source.url] = source - - search_result = SearchResult( - query=self.query, - answer=None, - citations=sources, - timestamp=datetime.now(), - ) - context.searches.append(search_result) - - formatted_result = f"Search Query: {search_result.query}\n\n" - formatted_result += "Search Results (titles, links, short snippets):\n\n" - - for source in sources: - snippet = source.snippet[:100] + "..." if len(source.snippet) > 100 else source.snippet - formatted_result += f"{str(source)}\n{snippet}\n\n" - - context.searches_used += 1 - logger.debug(formatted_result) - return formatted_result From 104a20b30292e1f96bf5538220d1bdf606f777a0 Mon Sep 17 00:00:00 2001 From: Nikita Matsko Date: Tue, 17 Feb 2026 17:30:55 +0000 Subject: [PATCH 03/12] test(search): add search services tests and update tool tests for multi-provider support Add test_search_services.py with factory, conversion and missing API key tests Update test_tools.py: patch BaseSearchService instead of TavilySearchService, add init/execution tests for BraveSearchTool, PerplexitySearchTool, TavilySearchTool --- tests/test_search_services.py | 216 ++++++++++++++++++++++++++++++++++ tests/test_tools.py | 104 ++++++++++------ 2 files changed, 284 insertions(+), 36 deletions(-) create mode 100644 tests/test_search_services.py diff --git a/tests/test_search_services.py b/tests/test_search_services.py new file mode 100644 index 00000000..41bafd40 --- /dev/null +++ b/tests/test_search_services.py @@ -0,0 +1,216 @@ +"""Tests for search services (BaseSearchService, TavilySearchService, +BraveSearchService, PerplexitySearchService).""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from sgr_agent_core.agent_definition import SearchConfig +from sgr_agent_core.models import SourceData + + +class TestBaseSearchService: + """Tests for BaseSearchService.""" + + def test_factory_creates_tavily_service(self): + from sgr_agent_core.services.base_search import BaseSearchService + from sgr_agent_core.services.tavily_search import TavilySearchService + + config = SearchConfig(engine="tavily", tavily_api_key="test-key") + service = BaseSearchService.create(config) + assert isinstance(service, TavilySearchService) + + def test_factory_creates_brave_service(self): + from sgr_agent_core.services.base_search import BaseSearchService + from sgr_agent_core.services.brave_search import BraveSearchService + + config = SearchConfig(engine="brave", brave_api_key="test-key") + service = BaseSearchService.create(config) + assert isinstance(service, BraveSearchService) + + def test_factory_creates_perplexity_service(self): + from sgr_agent_core.services.base_search import BaseSearchService + from sgr_agent_core.services.perplexity_search import PerplexitySearchService + + config = SearchConfig(engine="perplexity", perplexity_api_key="test-key") + service = BaseSearchService.create(config) + assert isinstance(service, PerplexitySearchService) + + def test_factory_raises_for_unknown_engine(self): + from sgr_agent_core.services.base_search import BaseSearchService + + # Use model_construct to bypass Literal validation and force an invalid engine + config = SearchConfig.model_construct(engine="unknown", tavily_api_key="k") + with pytest.raises(ValueError, match="Unsupported search engine"): + BaseSearchService.create(config) + + def test_rearrange_sources(self): + from sgr_agent_core.services.base_search import BaseSearchService + + sources = [ + SourceData(number=0, url="https://a.com", title="A", snippet="a"), + SourceData(number=0, url="https://b.com", title="B", snippet="b"), + ] + result = BaseSearchService.rearrange_sources(sources, starting_number=5) + assert result[0].number == 5 + assert result[1].number == 6 + + @pytest.mark.asyncio + async def test_base_search_raises_not_implemented(self): + from sgr_agent_core.services.base_search import BaseSearchService + + config = SearchConfig(tavily_api_key="k") + service = BaseSearchService(config) + with pytest.raises(NotImplementedError): + await service.search("test") + + +class TestBraveSearchService: + """Tests for BraveSearchService.""" + + def test_raises_without_api_key(self): + from sgr_agent_core.services.brave_search import BraveSearchService + + config = SearchConfig(engine="brave") + with pytest.raises(ValueError, match="brave_api_key is required"): + BraveSearchService(config) + + def test_convert_to_source_data(self): + from sgr_agent_core.services.brave_search import BraveSearchService + + config = SearchConfig(engine="brave", brave_api_key="test-key") + service = BraveSearchService(config) + response = { + "web": { + "results": [ + {"title": "Test", "url": "https://example.com", "description": "A test result"}, + {"title": "Test2", "url": "https://example2.com", "description": "Another result"}, + {"title": "No URL", "url": "", "description": "Skipped"}, + ] + } + } + sources = service._convert_to_source_data(response) + assert len(sources) == 2 + assert sources[0].title == "Test" + assert sources[0].url == "https://example.com" + assert sources[0].snippet == "A test result" + + @pytest.mark.asyncio + async def test_search_calls_brave_api(self): + from sgr_agent_core.services.brave_search import BraveSearchService + + config = SearchConfig(engine="brave", brave_api_key="test-key", max_results=10) + service = BraveSearchService(config) + + mock_response = MagicMock() + mock_response.json.return_value = { + "web": { + "results": [ + {"title": "Result", "url": "https://example.com", "description": "desc"}, + ] + } + } + mock_response.raise_for_status = MagicMock() + + with patch("sgr_agent_core.services.brave_search.httpx.AsyncClient") as mock_client_cls: + mock_client = AsyncMock() + mock_client.get = AsyncMock(return_value=mock_response) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + mock_client_cls.return_value = mock_client + + sources = await service.search("test query", max_results=5) + + mock_client.get.assert_called_once() + call_kwargs = mock_client.get.call_args + assert call_kwargs.kwargs["params"]["q"] == "test query" + assert call_kwargs.kwargs["params"]["count"] == 5 + assert len(sources) == 1 + + +class TestPerplexitySearchService: + """Tests for PerplexitySearchService.""" + + def test_raises_without_api_key(self): + from sgr_agent_core.services.perplexity_search import PerplexitySearchService + + config = SearchConfig(engine="perplexity") + with pytest.raises(ValueError, match="perplexity_api_key is required"): + PerplexitySearchService(config) + + def test_convert_to_source_data(self): + from sgr_agent_core.services.perplexity_search import PerplexitySearchService + + config = SearchConfig(engine="perplexity", perplexity_api_key="test-key") + service = PerplexitySearchService(config) + response = { + "results": [ + {"title": "Page 1", "url": "https://example.com/page1", "snippet": "First result snippet"}, + {"title": "Page 2", "url": "https://example.com/page2", "snippet": "Second result snippet"}, + {"title": "No URL", "url": "", "snippet": "Skipped"}, + ], + } + sources = service._convert_to_source_data(response) + assert len(sources) == 2 + assert sources[0].url == "https://example.com/page1" + assert sources[0].title == "Page 1" + assert sources[0].snippet == "First result snippet" + assert sources[1].snippet == "Second result snippet" + + @pytest.mark.asyncio + async def test_search_calls_perplexity_api(self): + from sgr_agent_core.services.perplexity_search import PerplexitySearchService + + config = SearchConfig(engine="perplexity", perplexity_api_key="test-key", max_results=10) + service = PerplexitySearchService(config) + + mock_response = MagicMock() + mock_response.json.return_value = { + "results": [ + {"title": "Result", "url": "https://example.com", "snippet": "desc"}, + ], + } + mock_response.raise_for_status = MagicMock() + + with patch("sgr_agent_core.services.perplexity_search.httpx.AsyncClient") as mock_client_cls: + mock_client = AsyncMock() + mock_client.post = AsyncMock(return_value=mock_response) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + mock_client_cls.return_value = mock_client + + sources = await service.search("test query", max_results=5) + + mock_client.post.assert_called_once() + call_kwargs = mock_client.post.call_args + assert call_kwargs.kwargs["json"]["query"] == "test query" + assert call_kwargs.kwargs["json"]["max_results"] == 5 + assert len(sources) == 1 + + +class TestTavilySearchService: + """Tests for TavilySearchService with BaseSearchService inheritance.""" + + def test_inherits_rearrange_sources(self): + """TavilySearchService should have rearrange_sources from + BaseSearchService.""" + from sgr_agent_core.services.base_search import BaseSearchService + from sgr_agent_core.services.tavily_search import TavilySearchService + + assert TavilySearchService.rearrange_sources is BaseSearchService.rearrange_sources + + def test_convert_to_source_data(self): + from sgr_agent_core.services.tavily_search import TavilySearchService + + config = SearchConfig(tavily_api_key="test-key") + service = TavilySearchService(config) + response = { + "results": [ + {"title": "Test", "url": "https://example.com", "content": "Snippet", "raw_content": "Full content"}, + ] + } + sources = service._convert_to_source_data(response) + assert len(sources) == 1 + assert sources[0].title == "Test" + assert sources[0].snippet == "Snippet" + assert sources[0].full_content == "Full content" diff --git a/tests/test_tools.py b/tests/test_tools.py index bfdfb88e..9e52992b 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -11,16 +11,19 @@ from sgr_agent_core.agent_definition import SearchConfig from sgr_agent_core.models import AgentContext, SourceData -from sgr_agent_core.services.tavily_search import TavilySearchService +from sgr_agent_core.services.base_search import BaseSearchService from sgr_agent_core.tools import ( AdaptPlanTool, AnswerTool, + BraveSearchTool, ClarificationTool, CreateReportTool, ExtractPageContentTool, FinalAnswerTool, GeneratePlanTool, + PerplexitySearchTool, ReasoningTool, + TavilySearchTool, WebSearchTool, ) @@ -131,6 +134,24 @@ def test_answer_tool_initialization(self): assert tool.intermediate_result == "Found 3 relevant sources so far." assert tool.continue_research is True + def test_tavily_search_tool_initialization(self): + """Test TavilySearchTool initialization.""" + tool = TavilySearchTool(reasoning="Test", query="test query") + assert tool.tool_name == "tavilysearchtool" + assert tool._default_engine == "tavily" + + def test_brave_search_tool_initialization(self): + """Test BraveSearchTool initialization.""" + tool = BraveSearchTool(reasoning="Test", query="test query") + assert tool.tool_name == "bravesearchtool" + assert tool._default_engine == "brave" + + def test_perplexity_search_tool_initialization(self): + """Test PerplexitySearchTool initialization.""" + tool = PerplexitySearchTool(reasoning="Test", query="test query") + assert tool.tool_name == "perplexitysearchtool" + assert tool._default_engine == "perplexity" + class TestAnswerToolExecution: """Tests for AnswerTool execution.""" @@ -195,12 +216,13 @@ async def test_web_search_tool_uses_kwargs_over_config_search(self): context = AgentContext() config = MagicMock() config.search = SearchConfig(tavily_api_key="k", max_results=10) - with patch("sgr_agent_core.tools.web_search_tool.TavilySearchService") as mock_svc_class: + with patch("sgr_agent_core.tools.base_search_tool.BaseSearchService") as mock_svc_class: mock_svc = AsyncMock() mock_svc.search = AsyncMock(return_value=[]) - mock_svc_class.return_value = mock_svc + mock_svc_class.create.return_value = mock_svc + mock_svc_class.rearrange_sources = BaseSearchService.rearrange_sources await tool(context, config, max_results=3) - call_args = mock_svc_class.call_args[0][0] + call_args = mock_svc_class.create.call_args[0][0] assert call_args.max_results == 3 @pytest.mark.asyncio @@ -213,48 +235,47 @@ async def test_web_search_tool_fallback_to_config_search(self): context = AgentContext() config = MagicMock() config.search = SearchConfig(tavily_api_key="k", max_results=10) - with patch("sgr_agent_core.tools.web_search_tool.TavilySearchService") as mock_svc_class: + with patch("sgr_agent_core.tools.base_search_tool.BaseSearchService") as mock_svc_class: mock_svc = AsyncMock() mock_svc.search = AsyncMock(return_value=[]) - mock_svc_class.return_value = mock_svc + mock_svc_class.create.return_value = mock_svc + mock_svc_class.rearrange_sources = BaseSearchService.rearrange_sources await tool(context, config) - call_args = mock_svc_class.call_args[0][0] + call_args = mock_svc_class.create.call_args[0][0] assert call_args.max_results == 10 @pytest.mark.asyncio async def test_web_search_tool_with_offset(self): - """WebSearchTool with offset fetches offset+max_results from Tavily and - slices.""" + """WebSearchTool passes offset to service which handles it + internally.""" tool = WebSearchTool(reasoning="r", query="test", max_results=3, offset=2) context = AgentContext() config = MagicMock() config.search = SearchConfig(tavily_api_key="k", max_results=10) + # Service returns already-offset results (3 items after skipping 2) mock_sources = [ SourceData(number=i, url=f"https://example.com/{i}", title=f"Result {i}", snippet=f"Snippet {i}") - for i in range(5) + for i in range(2, 5) ] - with patch("sgr_agent_core.tools.web_search_tool.TavilySearchService") as mock_svc_class: + with patch("sgr_agent_core.tools.base_search_tool.BaseSearchService") as mock_svc_class: mock_svc = AsyncMock() mock_svc.search = AsyncMock(return_value=mock_sources) - mock_svc_class.return_value = mock_svc - mock_svc_class.rearrange_sources = TavilySearchService.rearrange_sources + mock_svc_class.create.return_value = mock_svc + mock_svc_class.rearrange_sources = BaseSearchService.rearrange_sources result = await tool(context, config) - # Tavily should receive max_results = 3 + 2 = 5 - mock_svc.search.assert_called_once_with(query="test", max_results=5, include_raw_content=False) - # After slicing [2:], 3 sources should remain + # Offset is delegated to the service + mock_svc.search.assert_called_once_with(query="test", max_results=3, offset=2, include_raw_content=False) assert len(context.searches) == 1 assert len(context.searches[0].citations) == 3 assert "Result 2" in result - assert "Result 0" not in result @pytest.mark.asyncio async def test_web_search_tool_offset_default_zero(self): - """WebSearchTool without offset behaves identically to current - logic.""" + """WebSearchTool without offset passes offset=0 to service.""" tool = WebSearchTool(reasoning="r", query="test", max_results=3) assert tool.offset == 0 @@ -267,16 +288,15 @@ async def test_web_search_tool_offset_default_zero(self): for i in range(3) ] - with patch("sgr_agent_core.tools.web_search_tool.TavilySearchService") as mock_svc_class: + with patch("sgr_agent_core.tools.base_search_tool.BaseSearchService") as mock_svc_class: mock_svc = AsyncMock() mock_svc.search = AsyncMock(return_value=mock_sources) - mock_svc_class.return_value = mock_svc - mock_svc_class.rearrange_sources = TavilySearchService.rearrange_sources + mock_svc_class.create.return_value = mock_svc + mock_svc_class.rearrange_sources = BaseSearchService.rearrange_sources await tool(context, config) - # Tavily should receive max_results = 3 (no offset added) - mock_svc.search.assert_called_once_with(query="test", max_results=3, include_raw_content=False) + mock_svc.search.assert_called_once_with(query="test", max_results=3, offset=0, include_raw_content=False) assert len(context.searches[0].citations) == 3 @pytest.mark.asyncio @@ -288,25 +308,37 @@ async def test_web_search_tool_offset_exceeds_results(self): config = MagicMock() config.search = SearchConfig(tavily_api_key="k", max_results=20) - mock_sources = [ - SourceData(number=i, url=f"https://example.com/{i}", title=f"Result {i}", snippet=f"Snippet {i}") - for i in range(5) - ] - - with patch("sgr_agent_core.tools.web_search_tool.TavilySearchService") as mock_svc_class: + # Service returns empty list (offset exceeded available results) + with patch("sgr_agent_core.tools.base_search_tool.BaseSearchService") as mock_svc_class: mock_svc = AsyncMock() - mock_svc.search = AsyncMock(return_value=mock_sources) - mock_svc_class.return_value = mock_svc - mock_svc_class.rearrange_sources = TavilySearchService.rearrange_sources + mock_svc.search = AsyncMock(return_value=[]) + mock_svc_class.create.return_value = mock_svc + mock_svc_class.rearrange_sources = BaseSearchService.rearrange_sources result = await tool(context, config) - # Tavily should receive max_results = 3 + 10 = 13 - mock_svc.search.assert_called_once_with(query="test", max_results=13, include_raw_content=False) - # After slicing [10:] on 5 results, empty list + mock_svc.search.assert_called_once_with(query="test", max_results=3, offset=10, include_raw_content=False) assert len(context.searches[0].citations) == 0 assert "Search Query: test" in result + @pytest.mark.asyncio + async def test_brave_search_tool_forces_engine(self): + """BraveSearchTool forces engine='brave' regardless of config.""" + from sgr_agent_core.models import AgentContext + + tool = BraveSearchTool(reasoning="r", query="test", max_results=5) + context = AgentContext() + config = MagicMock() + config.search = SearchConfig(tavily_api_key="k", max_results=10, engine="tavily") + with patch("sgr_agent_core.tools.base_search_tool.BaseSearchService") as mock_svc_class: + mock_svc = AsyncMock() + mock_svc.search = AsyncMock(return_value=[]) + mock_svc_class.create.return_value = mock_svc + mock_svc_class.rearrange_sources = BaseSearchService.rearrange_sources + await tool(context, config) + call_args = mock_svc_class.create.call_args[0][0] + assert call_args.engine == "brave" + @pytest.mark.asyncio async def test_extract_page_content_tool_uses_content_limit_from_kwargs(self): """ExtractPageContentTool uses content_limit from kwargs.""" From 2844f2fe5850dad12220c10a961dc77926a0df4e Mon Sep 17 00:00:00 2001 From: Nikita Matsko Date: Tue, 17 Feb 2026 17:42:59 +0000 Subject: [PATCH 04/12] refactor(exports): reorder and update exports to include TavilySearchService and adjust tool exports for consistency --- sgr_agent_core/__init__.py | 7 ++----- sgr_agent_core/tools/__init__.py | 13 +++++-------- 2 files changed, 7 insertions(+), 13 deletions(-) diff --git a/sgr_agent_core/__init__.py b/sgr_agent_core/__init__.py index 653c19d2..e3f805b6 100644 --- a/sgr_agent_core/__init__.py +++ b/sgr_agent_core/__init__.py @@ -37,6 +37,7 @@ MCP2ToolConverter, PerplexitySearchService, PromptLoader, + TavilySearchService, ToolRegistry, ) from sgr_agent_core.tools import * # noqa: F403 @@ -61,6 +62,7 @@ "BraveSearchService", "MCP2ToolConverter", "PerplexitySearchService", + "TavilySearchService", "PromptLoader", "ToolRegistry", # Configuration @@ -74,11 +76,6 @@ # Next step tools "NextStepToolStub", "NextStepToolsBuilder", - # Models - "AgentStatesEnum", - "AgentContext", - "SearchResult", - "SourceData", # Factory "AgentFactory", ] diff --git a/sgr_agent_core/tools/__init__.py b/sgr_agent_core/tools/__init__.py index 07487c95..e3b0c838 100644 --- a/sgr_agent_core/tools/__init__.py +++ b/sgr_agent_core/tools/__init__.py @@ -25,19 +25,16 @@ "ToolNameSelectorStub", "NextStepToolsBuilder", # Individual tools + "AdaptPlanTool", + "AnswerTool", "BraveSearchTool", "ClarificationTool", - "GeneratePlanTool", - "WebSearchTool", - "ExtractPageContentTool", - "AdaptPlanTool", "CreateReportTool", - "AnswerTool", + "ExtractPageContentTool", "FinalAnswerTool", + "GeneratePlanTool", "PerplexitySearchTool", "ReasoningTool", "TavilySearchTool", - # Tool lists - "NextStepToolStub", - "NextStepToolsBuilder", + "WebSearchTool", ] From 1c53809a76e39b670bed76cf7a20b4c20f15c7a1 Mon Sep 17 00:00:00 2001 From: Nikita Matsko Date: Tue, 17 Feb 2026 17:43:20 +0000 Subject: [PATCH 05/12] docs(tavily_search): add class docstring explaining usage, auth, and offset handling --- sgr_agent_core/services/tavily_search.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/sgr_agent_core/services/tavily_search.py b/sgr_agent_core/services/tavily_search.py index ec2f417c..bd146e2d 100644 --- a/sgr_agent_core/services/tavily_search.py +++ b/sgr_agent_core/services/tavily_search.py @@ -10,6 +10,14 @@ class TavilySearchService(BaseSearchService): + """Search service using Tavily Search API. + + Uses AsyncTavilyClient for HTTP requests. + Auth: API key passed to client constructor. + Offset is handled by over-fetch+slice (no native offset support). + Also provides extract() for full page content retrieval. + """ + def __init__(self, search_config: SearchConfig): super().__init__(search_config) self._client = AsyncTavilyClient( From c83dec6dba7c1d8717e91e17dcfb71d42538bc53 Mon Sep 17 00:00:00 2001 From: Nikita Matsko Date: Thu, 19 Feb 2026 11:58:18 +0000 Subject: [PATCH 06/12] refactor(search): remove deprecated BaseSearchService and inline provider logic into tools with unified registry and tests --- sgr_agent_core/__init__.py | 8 - sgr_agent_core/services/__init__.py | 8 - sgr_agent_core/services/base_search.py | 79 ---------- sgr_agent_core/services/brave_search.py | 98 ------------ sgr_agent_core/services/perplexity_search.py | 106 ------------- sgr_agent_core/services/tavily_search.py | 100 ------------ sgr_agent_core/tools/base_search_tool.py | 28 +++- sgr_agent_core/tools/brave_search_tool.py | 86 ++++++++++- .../tools/extract_page_content_tool.py | 5 +- .../tools/perplexity_search_tool.py | 87 ++++++++++- sgr_agent_core/tools/tavily_search_tool.py | 88 ++++++++++- ...h_services.py => test_search_providers.py} | 146 ++++++++---------- tests/test_tools.py | 91 +++-------- 13 files changed, 366 insertions(+), 564 deletions(-) delete mode 100644 sgr_agent_core/services/base_search.py delete mode 100644 sgr_agent_core/services/brave_search.py delete mode 100644 sgr_agent_core/services/perplexity_search.py delete mode 100644 sgr_agent_core/services/tavily_search.py rename tests/{test_search_services.py => test_search_providers.py} (50%) diff --git a/sgr_agent_core/__init__.py b/sgr_agent_core/__init__.py index e3f805b6..b5886e11 100644 --- a/sgr_agent_core/__init__.py +++ b/sgr_agent_core/__init__.py @@ -32,12 +32,8 @@ from sgr_agent_core.next_step_tool import NextStepToolsBuilder, NextStepToolStub from sgr_agent_core.services import ( AgentRegistry, - BaseSearchService, - BraveSearchService, MCP2ToolConverter, - PerplexitySearchService, PromptLoader, - TavilySearchService, ToolRegistry, ) from sgr_agent_core.tools import * # noqa: F403 @@ -58,11 +54,7 @@ "SourceData", # Services "AgentRegistry", - "BaseSearchService", - "BraveSearchService", "MCP2ToolConverter", - "PerplexitySearchService", - "TavilySearchService", "PromptLoader", "ToolRegistry", # Configuration diff --git a/sgr_agent_core/services/__init__.py b/sgr_agent_core/services/__init__.py index a3fd26fe..7c77722e 100644 --- a/sgr_agent_core/services/__init__.py +++ b/sgr_agent_core/services/__init__.py @@ -1,19 +1,11 @@ """Services module for external integrations and business logic.""" -from sgr_agent_core.services.base_search import BaseSearchService -from sgr_agent_core.services.brave_search import BraveSearchService from sgr_agent_core.services.mcp_service import MCP2ToolConverter -from sgr_agent_core.services.perplexity_search import PerplexitySearchService from sgr_agent_core.services.prompt_loader import PromptLoader from sgr_agent_core.services.registry import AgentRegistry, StreamingGeneratorRegistry, ToolRegistry -from sgr_agent_core.services.tavily_search import TavilySearchService from sgr_agent_core.services.tool_instantiator import ToolInstantiator __all__ = [ - "BaseSearchService", - "BraveSearchService", - "PerplexitySearchService", - "TavilySearchService", "MCP2ToolConverter", "ToolRegistry", "StreamingGeneratorRegistry", diff --git a/sgr_agent_core/services/base_search.py b/sgr_agent_core/services/base_search.py deleted file mode 100644 index f7443e4c..00000000 --- a/sgr_agent_core/services/base_search.py +++ /dev/null @@ -1,79 +0,0 @@ -import logging -from typing import TYPE_CHECKING - -from sgr_agent_core.models import SourceData - -if TYPE_CHECKING: - from sgr_agent_core.agent_definition import SearchConfig - -logger = logging.getLogger(__name__) - - -class BaseSearchService: - """Base class for search service providers. - - Subclasses must implement the `search` method. - """ - - def __init__(self, search_config: "SearchConfig"): - self._config = search_config - - async def search( - self, - query: str, - max_results: int | None = None, - offset: int = 0, - include_raw_content: bool = True, - ) -> list[SourceData]: - """Perform a search and return results as SourceData list. - - Each provider handles offset internally: - - Brave: uses native API offset parameter - - Tavily/Perplexity: over-fetch+slice - - Args: - query: Search query string - max_results: Maximum number of results to return (after offset) - offset: Number of results to skip - include_raw_content: Whether to include raw page content - - Returns: - List of SourceData results (at most max_results items) - """ - raise NotImplementedError("Subclasses must implement search()") - - @staticmethod - def rearrange_sources(sources: list[SourceData], starting_number: int = 1) -> list[SourceData]: - """Renumber sources sequentially starting from given number.""" - for i, source in enumerate(sources, starting_number): - source.number = i - return sources - - @classmethod - def create(cls, config: "SearchConfig") -> "BaseSearchService": - """Factory method to create a search service based on config.engine. - - Args: - config: SearchConfig with engine and API keys - - Returns: - Appropriate search service instance - - Raises: - ValueError: If engine is not supported - """ - from sgr_agent_core.services.brave_search import BraveSearchService - from sgr_agent_core.services.perplexity_search import PerplexitySearchService - from sgr_agent_core.services.tavily_search import TavilySearchService - - engine = config.engine - logger.debug(f"Creating search service for engine: {engine}") - if engine == "tavily": - return TavilySearchService(config) - elif engine == "brave": - return BraveSearchService(config) - elif engine == "perplexity": - return PerplexitySearchService(config) - else: - logger.error(f"Unsupported search engine requested: {engine}") - raise ValueError(f"Unsupported search engine: {engine}") diff --git a/sgr_agent_core/services/brave_search.py b/sgr_agent_core/services/brave_search.py deleted file mode 100644 index dcd7ace0..00000000 --- a/sgr_agent_core/services/brave_search.py +++ /dev/null @@ -1,98 +0,0 @@ -import logging -from typing import Any - -import httpx - -from sgr_agent_core.agent_definition import SearchConfig -from sgr_agent_core.models import SourceData -from sgr_agent_core.services.base_search import BaseSearchService - -logger = logging.getLogger(__name__) - - -class BraveSearchService(BaseSearchService): - """Search service using Brave Search API. - - Uses httpx.AsyncClient for HTTP requests. - Auth: X-Subscription-Token header. - Brave API supports native offset for pagination. - """ - - def __init__(self, search_config: SearchConfig): - super().__init__(search_config) - if not search_config.brave_api_key: - raise ValueError("brave_api_key is required for BraveSearchService") - - async def search( - self, - query: str, - max_results: int | None = None, - offset: int = 0, - include_raw_content: bool = True, - ) -> list[SourceData]: - """Perform search through Brave Search API. - - Brave supports native offset parameter for efficient pagination. - - Args: - query: Search query string - max_results: Maximum number of results (max 20 per Brave API) - offset: Number of results to skip (native Brave API support) - include_raw_content: Ignored for Brave (no raw content extraction) - - Returns: - List of SourceData results - """ - max_results = min(max_results or self._config.max_results, 20) - logger.info(f"🔍 Brave search: '{query}' (max_results={max_results}, offset={offset})") - - headers = { - "Accept": "application/json", - "Accept-Encoding": "gzip", - "X-Subscription-Token": self._config.brave_api_key, - } - params: dict[str, Any] = { - "q": query, - "count": max_results, - } - if offset > 0: - params["offset"] = offset - - try: - async with httpx.AsyncClient() as client: - response = await client.get( - self._config.brave_api_base_url, - headers=headers, - params=params, - timeout=30.0, - ) - response.raise_for_status() - data = response.json() - except httpx.HTTPStatusError as e: - logger.error(f"Brave API HTTP error: {e.response.status_code} — {e.response.text[:200]}") - raise - except httpx.RequestError as e: - logger.error(f"Brave API request error: {e}") - raise - - return self._convert_to_source_data(data) - - def _convert_to_source_data(self, response: dict) -> list[SourceData]: - """Convert Brave Search API response to SourceData list.""" - sources = [] - web_results = response.get("web", {}).get("results", []) - - for i, result in enumerate(web_results): - url = result.get("url", "") - if not url: - continue - - source = SourceData( - number=i, - title=result.get("title", ""), - url=url, - snippet=result.get("description", ""), - ) - sources.append(source) - - return sources diff --git a/sgr_agent_core/services/perplexity_search.py b/sgr_agent_core/services/perplexity_search.py deleted file mode 100644 index 753bed93..00000000 --- a/sgr_agent_core/services/perplexity_search.py +++ /dev/null @@ -1,106 +0,0 @@ -import logging -from typing import Any - -import httpx - -from sgr_agent_core.agent_definition import SearchConfig -from sgr_agent_core.models import SourceData -from sgr_agent_core.services.base_search import BaseSearchService - -logger = logging.getLogger(__name__) - - -class PerplexitySearchService(BaseSearchService): - """Search service using Perplexity Search API. - - Uses httpx.AsyncClient for HTTP requests to the dedicated Search API - endpoint (POST /search) which returns ranked web results with titles, - URLs, and snippets. - Auth: Authorization Bearer header. - Offset is handled by over-fetch+slice (no native offset support). - """ - - def __init__(self, search_config: SearchConfig): - super().__init__(search_config) - if not search_config.perplexity_api_key: - raise ValueError("perplexity_api_key is required for PerplexitySearchService") - - async def search( - self, - query: str, - max_results: int | None = None, - offset: int = 0, - include_raw_content: bool = True, - ) -> list[SourceData]: - """Perform search through Perplexity Search API. - - Perplexity does not support native offset — over-fetch+slice - is applied internally when offset > 0. - - Args: - query: Search query string - max_results: Maximum number of results to return - offset: Number of results to skip (over-fetch+slice internally) - include_raw_content: Ignored for Perplexity - - Returns: - List of SourceData results - """ - max_results = max_results or self._config.max_results - fetch_count = max_results + offset if offset > 0 else max_results - logger.info(f"🔍 Perplexity search: '{query}' (max_results={max_results}, offset={offset})") - - headers = { - "Authorization": f"Bearer {self._config.perplexity_api_key}", - "Content-Type": "application/json", - } - payload: dict[str, Any] = { - "query": query, - "max_results": fetch_count, - } - - try: - async with httpx.AsyncClient() as client: - response = await client.post( - self._config.perplexity_api_base_url, - headers=headers, - json=payload, - timeout=30.0, - ) - response.raise_for_status() - data = response.json() - except httpx.HTTPStatusError as e: - logger.error(f"Perplexity API HTTP error: {e.response.status_code} — {e.response.text[:200]}") - raise - except httpx.RequestError as e: - logger.error(f"Perplexity API request error: {e}") - raise - - sources = self._convert_to_source_data(data) - if offset > 0: - sources = sources[offset:] - return sources[:max_results] - - def _convert_to_source_data(self, response: dict) -> list[SourceData]: - """Convert Perplexity Search API response to SourceData list. - - Perplexity Search API returns results[] with title, url, - snippet. - """ - sources = [] - results = response.get("results", []) - - for i, result in enumerate(results): - url = result.get("url", "") - if not url: - continue - - source = SourceData( - number=i, - title=result.get("title", ""), - url=url, - snippet=result.get("snippet", ""), - ) - sources.append(source) - - return sources diff --git a/sgr_agent_core/services/tavily_search.py b/sgr_agent_core/services/tavily_search.py deleted file mode 100644 index bd146e2d..00000000 --- a/sgr_agent_core/services/tavily_search.py +++ /dev/null @@ -1,100 +0,0 @@ -import logging - -from tavily import AsyncTavilyClient - -from sgr_agent_core.agent_definition import SearchConfig -from sgr_agent_core.models import SourceData -from sgr_agent_core.services.base_search import BaseSearchService - -logger = logging.getLogger(__name__) - - -class TavilySearchService(BaseSearchService): - """Search service using Tavily Search API. - - Uses AsyncTavilyClient for HTTP requests. - Auth: API key passed to client constructor. - Offset is handled by over-fetch+slice (no native offset support). - Also provides extract() for full page content retrieval. - """ - - def __init__(self, search_config: SearchConfig): - super().__init__(search_config) - self._client = AsyncTavilyClient( - api_key=search_config.tavily_api_key, api_base_url=search_config.tavily_api_base_url - ) - - async def search( - self, - query: str, - max_results: int | None = None, - offset: int = 0, - include_raw_content: bool = True, - ) -> list[SourceData]: - """Perform search through Tavily API and return results with - SourceData. - - Tavily API does not support native offset — over-fetch+slice is - applied internally when offset > 0. - """ - max_results = max_results or self._config.max_results - fetch_count = max_results + offset if offset > 0 else max_results - logger.info(f"🔍 Tavily search: '{query}' (max_results={max_results}, offset={offset})") - - response = await self._client.search( - query=query, - max_results=fetch_count, - include_raw_content=include_raw_content, - ) - - sources = self._convert_to_source_data(response) - if offset > 0: - sources = sources[offset:] - return sources[:max_results] - - async def extract(self, urls: list[str]) -> list[SourceData]: - """Extract full content from specific URLs using Tavily Extract API.""" - logger.info(f"📄 Tavily extract: {len(urls)} URLs") - - response = await self._client.extract(urls=urls) - - sources = [] - for i, result in enumerate(response.get("results", [])): - if not result.get("url"): - continue - - source = SourceData( - number=i, - title=result.get("url", "").split("/")[-1] or "Extracted Content", - url=result.get("url", ""), - snippet="", - full_content=result.get("raw_content", ""), - char_count=len(result.get("raw_content", "")), - ) - sources.append(source) - - failed_urls = response.get("failed_results", []) - if failed_urls: - logger.warning(f"⚠️ Failed to extract {len(failed_urls)} URLs: {failed_urls}") - - return sources - - def _convert_to_source_data(self, response: dict) -> list[SourceData]: - """Convert Tavily response to SourceData list.""" - sources = [] - - for i, result in enumerate(response.get("results", [])): - if not result.get("url", ""): - continue - - source = SourceData( - number=i, - title=result.get("title", ""), - url=result.get("url", ""), - snippet=result.get("content", ""), - ) - if result.get("raw_content", ""): - source.full_content = result["raw_content"] - source.char_count = len(source.full_content) - sources.append(source) - return sources diff --git a/sgr_agent_core/tools/base_search_tool.py b/sgr_agent_core/tools/base_search_tool.py index 547435ff..043e9a85 100644 --- a/sgr_agent_core/tools/base_search_tool.py +++ b/sgr_agent_core/tools/base_search_tool.py @@ -8,8 +8,7 @@ from sgr_agent_core.agent_definition import AgentConfig, SearchConfig from sgr_agent_core.base_tool import BaseTool -from sgr_agent_core.models import SearchResult -from sgr_agent_core.services.base_search import BaseSearchService +from sgr_agent_core.models import SearchResult, SourceData from sgr_agent_core.utils import config_from_kwargs if TYPE_CHECKING: @@ -18,6 +17,10 @@ logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) +# Engine name -> tool class with _search() staticmethod. +# Populated explicitly at module level in each provider tool file. +_search_registry: dict[str, type] = {} + class _BaseSearchTool(BaseTool): """Base class for all search tools. @@ -25,6 +28,9 @@ class _BaseSearchTool(BaseTool): Provides shared fields (reasoning, query, max_results, offset) and common __call__ logic. Concrete tools override _default_engine and docstring. + + Provider-specific API logic lives in concrete tools as @staticmethod + _search() methods, dispatched via _search_registry by engine name. """ _default_engine: ClassVar[str | None] = None @@ -58,6 +64,13 @@ def __init_subclass__(cls, **kwargs: Any) -> None: ), ) + @staticmethod + def _rearrange_sources(sources: list[SourceData], starting_number: int = 1) -> list[SourceData]: + """Renumber sources sequentially starting from given number.""" + for i, source in enumerate(sources, starting_number): + source.number = i + return sources + async def __call__(self, context: AgentContext, config: AgentConfig, **kwargs: Any) -> str: """Execute web search using the configured search engine. @@ -75,21 +88,24 @@ async def __call__(self, context: AgentContext, config: AgentConfig, **kwargs: A ) logger.info(f"Search query: '{self.query}' (engine={search_config.engine})") - service = BaseSearchService.create(search_config) + provider_cls = _search_registry.get(search_config.engine) + if provider_cls is None: + raise ValueError(f"Unsupported search engine: {search_config.engine}") max_results_limit = search_config.max_results effective_limit = min(self.max_results, max_results_limit) - # Each service handles offset internally: + # Each provider handles offset internally: # Brave uses native API offset, Tavily/Perplexity use over-fetch+slice - sources = await service.search( + sources = await provider_cls._search( + config=search_config, query=self.query, max_results=effective_limit, offset=self.offset, include_raw_content=False, ) - sources = BaseSearchService.rearrange_sources(sources, starting_number=len(context.sources) + 1) + sources = self._rearrange_sources(sources, starting_number=len(context.sources) + 1) for source in sources: context.sources[source.url] = source diff --git a/sgr_agent_core/tools/brave_search_tool.py b/sgr_agent_core/tools/brave_search_tool.py index be33c20a..0f718ece 100644 --- a/sgr_agent_core/tools/brave_search_tool.py +++ b/sgr_agent_core/tools/brave_search_tool.py @@ -1,4 +1,15 @@ -from sgr_agent_core.tools.base_search_tool import _BaseSearchTool +from __future__ import annotations + +import logging +from typing import Any + +import httpx + +from sgr_agent_core.agent_definition import SearchConfig +from sgr_agent_core.models import SourceData +from sgr_agent_core.tools.base_search_tool import _BaseSearchTool, _search_registry + +logger = logging.getLogger(__name__) class BraveSearchTool(_BaseSearchTool): @@ -16,3 +27,76 @@ class BraveSearchTool(_BaseSearchTool): """ _default_engine = "brave" + + @staticmethod + async def _search( + config: SearchConfig, + query: str, + max_results: int, + offset: int = 0, + include_raw_content: bool = True, + ) -> list[SourceData]: + """Perform search via Brave Search API. + + Brave supports native offset parameter for efficient pagination. + """ + if not config.brave_api_key: + raise ValueError("brave_api_key is required for BraveSearchTool") + + max_results = min(max_results, 20) + logger.info(f"Brave search: '{query}' (max_results={max_results}, offset={offset})") + + headers = { + "Accept": "application/json", + "Accept-Encoding": "gzip", + "X-Subscription-Token": config.brave_api_key, + } + params: dict[str, Any] = { + "q": query, + "count": max_results, + } + if offset > 0: + params["offset"] = offset + + try: + async with httpx.AsyncClient() as client: + response = await client.get( + config.brave_api_base_url, + headers=headers, + params=params, + timeout=30.0, + ) + response.raise_for_status() + data = response.json() + except httpx.HTTPStatusError as e: + logger.error(f"Brave API HTTP error: {e.response.status_code} — {e.response.text[:200]}") + raise + except httpx.RequestError as e: + logger.error(f"Brave API request error: {e}") + raise + + return BraveSearchTool._convert_to_source_data(data) + + @staticmethod + def _convert_to_source_data(response: dict) -> list[SourceData]: + """Convert Brave Search API response to SourceData list.""" + sources = [] + web_results = response.get("web", {}).get("results", []) + + for i, result in enumerate(web_results): + url = result.get("url", "") + if not url: + continue + + source = SourceData( + number=i, + title=result.get("title", ""), + url=url, + snippet=result.get("description", ""), + ) + sources.append(source) + + return sources + + +_search_registry["brave"] = BraveSearchTool diff --git a/sgr_agent_core/tools/extract_page_content_tool.py b/sgr_agent_core/tools/extract_page_content_tool.py index a5156e7d..7aaa1ecd 100644 --- a/sgr_agent_core/tools/extract_page_content_tool.py +++ b/sgr_agent_core/tools/extract_page_content_tool.py @@ -7,7 +7,7 @@ from sgr_agent_core.agent_definition import SearchConfig from sgr_agent_core.base_tool import BaseTool -from sgr_agent_core.services import TavilySearchService +from sgr_agent_core.tools.tavily_search_tool import TavilySearchTool from sgr_agent_core.utils import config_from_kwargs if TYPE_CHECKING: @@ -59,8 +59,7 @@ async def __call__(self, context: AgentContext, config: AgentConfig, **kwargs: A ) logger.info(f"Extracting content from {len(self.urls)} URLs") - self._search_service = TavilySearchService(search_config) - sources = await self._search_service.extract(urls=self.urls) + sources = await TavilySearchTool._extract(search_config, urls=self.urls) # Update existing sources instead of overwriting for source in sources: diff --git a/sgr_agent_core/tools/perplexity_search_tool.py b/sgr_agent_core/tools/perplexity_search_tool.py index 86ee4edc..fd5368c2 100644 --- a/sgr_agent_core/tools/perplexity_search_tool.py +++ b/sgr_agent_core/tools/perplexity_search_tool.py @@ -1,4 +1,15 @@ -from sgr_agent_core.tools.base_search_tool import _BaseSearchTool +from __future__ import annotations + +import logging +from typing import Any + +import httpx + +from sgr_agent_core.agent_definition import SearchConfig +from sgr_agent_core.models import SourceData +from sgr_agent_core.tools.base_search_tool import _BaseSearchTool, _search_registry + +logger = logging.getLogger(__name__) class PerplexitySearchTool(_BaseSearchTool): @@ -16,3 +27,77 @@ class PerplexitySearchTool(_BaseSearchTool): """ _default_engine = "perplexity" + + @staticmethod + async def _search( + config: SearchConfig, + query: str, + max_results: int, + offset: int = 0, + include_raw_content: bool = True, + ) -> list[SourceData]: + """Perform search via Perplexity Search API. + + Perplexity does not support native offset — over-fetch+slice is + applied internally when offset > 0. + """ + if not config.perplexity_api_key: + raise ValueError("perplexity_api_key is required for PerplexitySearchTool") + + fetch_count = max_results + offset if offset > 0 else max_results + logger.info(f"Perplexity search: '{query}' (max_results={max_results}, offset={offset})") + + headers = { + "Authorization": f"Bearer {config.perplexity_api_key}", + "Content-Type": "application/json", + } + payload: dict[str, Any] = { + "query": query, + "max_results": fetch_count, + } + + try: + async with httpx.AsyncClient() as client: + response = await client.post( + config.perplexity_api_base_url, + headers=headers, + json=payload, + timeout=30.0, + ) + response.raise_for_status() + data = response.json() + except httpx.HTTPStatusError as e: + logger.error(f"Perplexity API HTTP error: {e.response.status_code} — {e.response.text[:200]}") + raise + except httpx.RequestError as e: + logger.error(f"Perplexity API request error: {e}") + raise + + sources = PerplexitySearchTool._convert_to_source_data(data) + if offset > 0: + sources = sources[offset:] + return sources[:max_results] + + @staticmethod + def _convert_to_source_data(response: dict) -> list[SourceData]: + """Convert Perplexity Search API response to SourceData list.""" + sources = [] + results = response.get("results", []) + + for i, result in enumerate(results): + url = result.get("url", "") + if not url: + continue + + source = SourceData( + number=i, + title=result.get("title", ""), + url=url, + snippet=result.get("snippet", ""), + ) + sources.append(source) + + return sources + + +_search_registry["perplexity"] = PerplexitySearchTool diff --git a/sgr_agent_core/tools/tavily_search_tool.py b/sgr_agent_core/tools/tavily_search_tool.py index cdd2e5b4..d7319165 100644 --- a/sgr_agent_core/tools/tavily_search_tool.py +++ b/sgr_agent_core/tools/tavily_search_tool.py @@ -1,4 +1,14 @@ -from sgr_agent_core.tools.base_search_tool import _BaseSearchTool +from __future__ import annotations + +import logging + +from tavily import AsyncTavilyClient + +from sgr_agent_core.agent_definition import SearchConfig +from sgr_agent_core.models import SourceData +from sgr_agent_core.tools.base_search_tool import _BaseSearchTool, _search_registry + +logger = logging.getLogger(__name__) class TavilySearchTool(_BaseSearchTool): @@ -16,3 +26,79 @@ class TavilySearchTool(_BaseSearchTool): """ _default_engine = "tavily" + + @staticmethod + async def _search( + config: SearchConfig, + query: str, + max_results: int, + offset: int = 0, + include_raw_content: bool = True, + ) -> list[SourceData]: + """Perform search via Tavily API. + + Offset: over-fetch+slice. + """ + fetch_count = max_results + offset if offset > 0 else max_results + logger.info(f"Tavily search: '{query}' (max_results={max_results}, offset={offset})") + + client = AsyncTavilyClient(api_key=config.tavily_api_key, api_base_url=config.tavily_api_base_url) + response = await client.search(query=query, max_results=fetch_count, include_raw_content=include_raw_content) + + sources = TavilySearchTool._convert_to_source_data(response) + if offset > 0: + sources = sources[offset:] + return sources[:max_results] + + @staticmethod + async def _extract(config: SearchConfig, urls: list[str]) -> list[SourceData]: + """Extract full content from URLs via Tavily Extract API.""" + logger.info(f"Tavily extract: {len(urls)} URLs") + + client = AsyncTavilyClient(api_key=config.tavily_api_key, api_base_url=config.tavily_api_base_url) + response = await client.extract(urls=urls) + + sources = [] + for i, result in enumerate(response.get("results", [])): + if not result.get("url"): + continue + + source = SourceData( + number=i, + title=result.get("url", "").split("/")[-1] or "Extracted Content", + url=result.get("url", ""), + snippet="", + full_content=result.get("raw_content", ""), + char_count=len(result.get("raw_content", "")), + ) + sources.append(source) + + failed_urls = response.get("failed_results", []) + if failed_urls: + logger.warning(f"Failed to extract {len(failed_urls)} URLs: {failed_urls}") + + return sources + + @staticmethod + def _convert_to_source_data(response: dict) -> list[SourceData]: + """Convert Tavily response to SourceData list.""" + sources = [] + + for i, result in enumerate(response.get("results", [])): + if not result.get("url", ""): + continue + + source = SourceData( + number=i, + title=result.get("title", ""), + url=result.get("url", ""), + snippet=result.get("content", ""), + ) + if result.get("raw_content", ""): + source.full_content = result["raw_content"] + source.char_count = len(source.full_content) + sources.append(source) + return sources + + +_search_registry["tavily"] = TavilySearchTool diff --git a/tests/test_search_services.py b/tests/test_search_providers.py similarity index 50% rename from tests/test_search_services.py rename to tests/test_search_providers.py index 41bafd40..f65ecc21 100644 --- a/tests/test_search_services.py +++ b/tests/test_search_providers.py @@ -1,5 +1,5 @@ -"""Tests for search services (BaseSearchService, TavilySearchService, -BraveSearchService, PerplexitySearchService).""" +"""Tests for search provider logic inlined into tools (TavilySearchTool, +BraveSearchTool, PerplexitySearchTool).""" from unittest.mock import AsyncMock, MagicMock, patch @@ -9,77 +9,45 @@ from sgr_agent_core.models import SourceData -class TestBaseSearchService: - """Tests for BaseSearchService.""" +class TestSearchToolRegistry: + """Tests for _BaseSearchTool registry and shared helpers.""" - def test_factory_creates_tavily_service(self): - from sgr_agent_core.services.base_search import BaseSearchService - from sgr_agent_core.services.tavily_search import TavilySearchService + def test_registry_contains_all_engines(self): + from sgr_agent_core.tools.base_search_tool import _search_registry - config = SearchConfig(engine="tavily", tavily_api_key="test-key") - service = BaseSearchService.create(config) - assert isinstance(service, TavilySearchService) + # Ensure concrete tools are imported so registry is populated + from sgr_agent_core.tools.brave_search_tool import BraveSearchTool # noqa: F401 + from sgr_agent_core.tools.perplexity_search_tool import PerplexitySearchTool # noqa: F401 + from sgr_agent_core.tools.tavily_search_tool import TavilySearchTool # noqa: F401 - def test_factory_creates_brave_service(self): - from sgr_agent_core.services.base_search import BaseSearchService - from sgr_agent_core.services.brave_search import BraveSearchService - - config = SearchConfig(engine="brave", brave_api_key="test-key") - service = BaseSearchService.create(config) - assert isinstance(service, BraveSearchService) - - def test_factory_creates_perplexity_service(self): - from sgr_agent_core.services.base_search import BaseSearchService - from sgr_agent_core.services.perplexity_search import PerplexitySearchService - - config = SearchConfig(engine="perplexity", perplexity_api_key="test-key") - service = BaseSearchService.create(config) - assert isinstance(service, PerplexitySearchService) - - def test_factory_raises_for_unknown_engine(self): - from sgr_agent_core.services.base_search import BaseSearchService - - # Use model_construct to bypass Literal validation and force an invalid engine - config = SearchConfig.model_construct(engine="unknown", tavily_api_key="k") - with pytest.raises(ValueError, match="Unsupported search engine"): - BaseSearchService.create(config) + assert set(_search_registry) == {"tavily", "brave", "perplexity"} def test_rearrange_sources(self): - from sgr_agent_core.services.base_search import BaseSearchService + from sgr_agent_core.tools.base_search_tool import _BaseSearchTool sources = [ SourceData(number=0, url="https://a.com", title="A", snippet="a"), SourceData(number=0, url="https://b.com", title="B", snippet="b"), ] - result = BaseSearchService.rearrange_sources(sources, starting_number=5) + result = _BaseSearchTool._rearrange_sources(sources, starting_number=5) assert result[0].number == 5 assert result[1].number == 6 - @pytest.mark.asyncio - async def test_base_search_raises_not_implemented(self): - from sgr_agent_core.services.base_search import BaseSearchService - - config = SearchConfig(tavily_api_key="k") - service = BaseSearchService(config) - with pytest.raises(NotImplementedError): - await service.search("test") - -class TestBraveSearchService: - """Tests for BraveSearchService.""" +class TestBraveSearchProvider: + """Tests for BraveSearchTool provider logic.""" - def test_raises_without_api_key(self): - from sgr_agent_core.services.brave_search import BraveSearchService + @pytest.mark.asyncio + async def test_raises_without_api_key(self): + from sgr_agent_core.tools.brave_search_tool import BraveSearchTool config = SearchConfig(engine="brave") with pytest.raises(ValueError, match="brave_api_key is required"): - BraveSearchService(config) + await BraveSearchTool._search(config, query="test", max_results=5) def test_convert_to_source_data(self): - from sgr_agent_core.services.brave_search import BraveSearchService + from sgr_agent_core.tools.brave_search_tool import BraveSearchTool - config = SearchConfig(engine="brave", brave_api_key="test-key") - service = BraveSearchService(config) response = { "web": { "results": [ @@ -89,7 +57,7 @@ def test_convert_to_source_data(self): ] } } - sources = service._convert_to_source_data(response) + sources = BraveSearchTool._convert_to_source_data(response) assert len(sources) == 2 assert sources[0].title == "Test" assert sources[0].url == "https://example.com" @@ -97,10 +65,9 @@ def test_convert_to_source_data(self): @pytest.mark.asyncio async def test_search_calls_brave_api(self): - from sgr_agent_core.services.brave_search import BraveSearchService + from sgr_agent_core.tools.brave_search_tool import BraveSearchTool config = SearchConfig(engine="brave", brave_api_key="test-key", max_results=10) - service = BraveSearchService(config) mock_response = MagicMock() mock_response.json.return_value = { @@ -112,14 +79,14 @@ async def test_search_calls_brave_api(self): } mock_response.raise_for_status = MagicMock() - with patch("sgr_agent_core.services.brave_search.httpx.AsyncClient") as mock_client_cls: + with patch("sgr_agent_core.tools.brave_search_tool.httpx.AsyncClient") as mock_client_cls: mock_client = AsyncMock() mock_client.get = AsyncMock(return_value=mock_response) mock_client.__aenter__ = AsyncMock(return_value=mock_client) mock_client.__aexit__ = AsyncMock(return_value=False) mock_client_cls.return_value = mock_client - sources = await service.search("test query", max_results=5) + sources = await BraveSearchTool._search(config, query="test query", max_results=5) mock_client.get.assert_called_once() call_kwargs = mock_client.get.call_args @@ -128,21 +95,20 @@ async def test_search_calls_brave_api(self): assert len(sources) == 1 -class TestPerplexitySearchService: - """Tests for PerplexitySearchService.""" +class TestPerplexitySearchProvider: + """Tests for PerplexitySearchTool provider logic.""" - def test_raises_without_api_key(self): - from sgr_agent_core.services.perplexity_search import PerplexitySearchService + @pytest.mark.asyncio + async def test_raises_without_api_key(self): + from sgr_agent_core.tools.perplexity_search_tool import PerplexitySearchTool config = SearchConfig(engine="perplexity") with pytest.raises(ValueError, match="perplexity_api_key is required"): - PerplexitySearchService(config) + await PerplexitySearchTool._search(config, query="test", max_results=5) def test_convert_to_source_data(self): - from sgr_agent_core.services.perplexity_search import PerplexitySearchService + from sgr_agent_core.tools.perplexity_search_tool import PerplexitySearchTool - config = SearchConfig(engine="perplexity", perplexity_api_key="test-key") - service = PerplexitySearchService(config) response = { "results": [ {"title": "Page 1", "url": "https://example.com/page1", "snippet": "First result snippet"}, @@ -150,7 +116,7 @@ def test_convert_to_source_data(self): {"title": "No URL", "url": "", "snippet": "Skipped"}, ], } - sources = service._convert_to_source_data(response) + sources = PerplexitySearchTool._convert_to_source_data(response) assert len(sources) == 2 assert sources[0].url == "https://example.com/page1" assert sources[0].title == "Page 1" @@ -159,10 +125,9 @@ def test_convert_to_source_data(self): @pytest.mark.asyncio async def test_search_calls_perplexity_api(self): - from sgr_agent_core.services.perplexity_search import PerplexitySearchService + from sgr_agent_core.tools.perplexity_search_tool import PerplexitySearchTool config = SearchConfig(engine="perplexity", perplexity_api_key="test-key", max_results=10) - service = PerplexitySearchService(config) mock_response = MagicMock() mock_response.json.return_value = { @@ -172,14 +137,14 @@ async def test_search_calls_perplexity_api(self): } mock_response.raise_for_status = MagicMock() - with patch("sgr_agent_core.services.perplexity_search.httpx.AsyncClient") as mock_client_cls: + with patch("sgr_agent_core.tools.perplexity_search_tool.httpx.AsyncClient") as mock_client_cls: mock_client = AsyncMock() mock_client.post = AsyncMock(return_value=mock_response) mock_client.__aenter__ = AsyncMock(return_value=mock_client) mock_client.__aexit__ = AsyncMock(return_value=False) mock_client_cls.return_value = mock_client - sources = await service.search("test query", max_results=5) + sources = await PerplexitySearchTool._search(config, query="test query", max_results=5) mock_client.post.assert_called_once() call_kwargs = mock_client.post.call_args @@ -188,29 +153,42 @@ async def test_search_calls_perplexity_api(self): assert len(sources) == 1 -class TestTavilySearchService: - """Tests for TavilySearchService with BaseSearchService inheritance.""" - - def test_inherits_rearrange_sources(self): - """TavilySearchService should have rearrange_sources from - BaseSearchService.""" - from sgr_agent_core.services.base_search import BaseSearchService - from sgr_agent_core.services.tavily_search import TavilySearchService - - assert TavilySearchService.rearrange_sources is BaseSearchService.rearrange_sources +class TestTavilySearchProvider: + """Tests for TavilySearchTool provider logic.""" def test_convert_to_source_data(self): - from sgr_agent_core.services.tavily_search import TavilySearchService + from sgr_agent_core.tools.tavily_search_tool import TavilySearchTool - config = SearchConfig(tavily_api_key="test-key") - service = TavilySearchService(config) response = { "results": [ {"title": "Test", "url": "https://example.com", "content": "Snippet", "raw_content": "Full content"}, ] } - sources = service._convert_to_source_data(response) + sources = TavilySearchTool._convert_to_source_data(response) assert len(sources) == 1 assert sources[0].title == "Test" assert sources[0].snippet == "Snippet" assert sources[0].full_content == "Full content" + + @pytest.mark.asyncio + async def test_extract_calls_tavily_api(self): + from sgr_agent_core.tools.tavily_search_tool import TavilySearchTool + + config = SearchConfig(tavily_api_key="test-key") + + mock_client = AsyncMock() + mock_client.extract = AsyncMock( + return_value={ + "results": [ + {"url": "https://example.com/page", "raw_content": "Full page content"}, + ], + "failed_results": [], + } + ) + + with patch("sgr_agent_core.tools.tavily_search_tool.AsyncTavilyClient", return_value=mock_client): + sources = await TavilySearchTool._extract(config, urls=["https://example.com/page"]) + + assert len(sources) == 1 + assert sources[0].url == "https://example.com/page" + assert sources[0].full_content == "Full page content" diff --git a/tests/test_tools.py b/tests/test_tools.py index 9e52992b..e9763f6b 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -11,7 +11,6 @@ from sgr_agent_core.agent_definition import SearchConfig from sgr_agent_core.models import AgentContext, SourceData -from sgr_agent_core.services.base_search import BaseSearchService from sgr_agent_core.tools import ( AdaptPlanTool, AnswerTool, @@ -102,13 +101,12 @@ def test_web_search_tool_initialization(self): def test_extract_page_content_tool_initialization(self): """Test ExtractPageContentTool initialization.""" - with patch("sgr_agent_core.tools.extract_page_content_tool.TavilySearchService"): - tool = ExtractPageContentTool( - reasoning="Test", - urls=["https://example.com"], - ) - assert tool.tool_name == "extractpagecontenttool" - assert len(tool.urls) == 1 + tool = ExtractPageContentTool( + reasoning="Test", + urls=["https://example.com"], + ) + assert tool.tool_name == "extractpagecontenttool" + assert len(tool.urls) == 1 def test_create_report_tool_initialization(self): """Test CreateReportTool initialization.""" @@ -210,72 +208,51 @@ class TestSearchToolsKwargs: @pytest.mark.asyncio async def test_web_search_tool_uses_kwargs_over_config_search(self): """WebSearchTool uses max_results from kwargs when provided.""" - from sgr_agent_core.models import AgentContext - tool = WebSearchTool(reasoning="r", query="test", max_results=5) context = AgentContext() config = MagicMock() config.search = SearchConfig(tavily_api_key="k", max_results=10) - with patch("sgr_agent_core.tools.base_search_tool.BaseSearchService") as mock_svc_class: - mock_svc = AsyncMock() - mock_svc.search = AsyncMock(return_value=[]) - mock_svc_class.create.return_value = mock_svc - mock_svc_class.rearrange_sources = BaseSearchService.rearrange_sources + with patch.object(TavilySearchTool, "_search", new_callable=AsyncMock, return_value=[]) as mock_search: await tool(context, config, max_results=3) - call_args = mock_svc_class.create.call_args[0][0] - assert call_args.max_results == 3 + assert mock_search.call_args.kwargs["config"].max_results == 3 @pytest.mark.asyncio async def test_web_search_tool_fallback_to_config_search(self): """WebSearchTool uses config.search when kwargs do not set max_results.""" - from sgr_agent_core.models import AgentContext - tool = WebSearchTool(reasoning="r", query="test", max_results=5) context = AgentContext() config = MagicMock() config.search = SearchConfig(tavily_api_key="k", max_results=10) - with patch("sgr_agent_core.tools.base_search_tool.BaseSearchService") as mock_svc_class: - mock_svc = AsyncMock() - mock_svc.search = AsyncMock(return_value=[]) - mock_svc_class.create.return_value = mock_svc - mock_svc_class.rearrange_sources = BaseSearchService.rearrange_sources + with patch.object(TavilySearchTool, "_search", new_callable=AsyncMock, return_value=[]) as mock_search: await tool(context, config) - call_args = mock_svc_class.create.call_args[0][0] - assert call_args.max_results == 10 + assert mock_search.call_args.kwargs["config"].max_results == 10 @pytest.mark.asyncio async def test_web_search_tool_with_offset(self): - """WebSearchTool passes offset to service which handles it + """WebSearchTool passes offset to provider which handles it internally.""" tool = WebSearchTool(reasoning="r", query="test", max_results=3, offset=2) context = AgentContext() config = MagicMock() config.search = SearchConfig(tavily_api_key="k", max_results=10) - # Service returns already-offset results (3 items after skipping 2) + # Provider returns already-offset results (3 items after skipping 2) mock_sources = [ SourceData(number=i, url=f"https://example.com/{i}", title=f"Result {i}", snippet=f"Snippet {i}") for i in range(2, 5) ] - with patch("sgr_agent_core.tools.base_search_tool.BaseSearchService") as mock_svc_class: - mock_svc = AsyncMock() - mock_svc.search = AsyncMock(return_value=mock_sources) - mock_svc_class.create.return_value = mock_svc - mock_svc_class.rearrange_sources = BaseSearchService.rearrange_sources - + with patch.object(TavilySearchTool, "_search", new_callable=AsyncMock, return_value=mock_sources): result = await tool(context, config) - # Offset is delegated to the service - mock_svc.search.assert_called_once_with(query="test", max_results=3, offset=2, include_raw_content=False) assert len(context.searches) == 1 assert len(context.searches[0].citations) == 3 assert "Result 2" in result @pytest.mark.asyncio async def test_web_search_tool_offset_default_zero(self): - """WebSearchTool without offset passes offset=0 to service.""" + """WebSearchTool without offset passes offset=0 to provider.""" tool = WebSearchTool(reasoning="r", query="test", max_results=3) assert tool.offset == 0 @@ -288,15 +265,9 @@ async def test_web_search_tool_offset_default_zero(self): for i in range(3) ] - with patch("sgr_agent_core.tools.base_search_tool.BaseSearchService") as mock_svc_class: - mock_svc = AsyncMock() - mock_svc.search = AsyncMock(return_value=mock_sources) - mock_svc_class.create.return_value = mock_svc - mock_svc_class.rearrange_sources = BaseSearchService.rearrange_sources - + with patch.object(TavilySearchTool, "_search", new_callable=AsyncMock, return_value=mock_sources): await tool(context, config) - mock_svc.search.assert_called_once_with(query="test", max_results=3, offset=0, include_raw_content=False) assert len(context.searches[0].citations) == 3 @pytest.mark.asyncio @@ -308,50 +279,32 @@ async def test_web_search_tool_offset_exceeds_results(self): config = MagicMock() config.search = SearchConfig(tavily_api_key="k", max_results=20) - # Service returns empty list (offset exceeded available results) - with patch("sgr_agent_core.tools.base_search_tool.BaseSearchService") as mock_svc_class: - mock_svc = AsyncMock() - mock_svc.search = AsyncMock(return_value=[]) - mock_svc_class.create.return_value = mock_svc - mock_svc_class.rearrange_sources = BaseSearchService.rearrange_sources - + # Provider returns empty list (offset exceeded available results) + with patch.object(TavilySearchTool, "_search", new_callable=AsyncMock, return_value=[]): result = await tool(context, config) - mock_svc.search.assert_called_once_with(query="test", max_results=3, offset=10, include_raw_content=False) assert len(context.searches[0].citations) == 0 assert "Search Query: test" in result @pytest.mark.asyncio async def test_brave_search_tool_forces_engine(self): """BraveSearchTool forces engine='brave' regardless of config.""" - from sgr_agent_core.models import AgentContext - tool = BraveSearchTool(reasoning="r", query="test", max_results=5) context = AgentContext() config = MagicMock() config.search = SearchConfig(tavily_api_key="k", max_results=10, engine="tavily") - with patch("sgr_agent_core.tools.base_search_tool.BaseSearchService") as mock_svc_class: - mock_svc = AsyncMock() - mock_svc.search = AsyncMock(return_value=[]) - mock_svc_class.create.return_value = mock_svc - mock_svc_class.rearrange_sources = BaseSearchService.rearrange_sources + with patch.object(BraveSearchTool, "_search", new_callable=AsyncMock, return_value=[]) as mock_search: await tool(context, config) - call_args = mock_svc_class.create.call_args[0][0] - assert call_args.engine == "brave" + assert mock_search.call_args.kwargs["config"].engine == "brave" @pytest.mark.asyncio async def test_extract_page_content_tool_uses_content_limit_from_kwargs(self): """ExtractPageContentTool uses content_limit from kwargs.""" - from sgr_agent_core.models import AgentContext - tool = ExtractPageContentTool(reasoning="r", urls=["https://example.com"]) context = AgentContext() config = MagicMock() config.search = SearchConfig(tavily_api_key="k", content_limit=1000) - with patch("sgr_agent_core.tools.extract_page_content_tool.TavilySearchService") as mock_svc_class: - mock_svc = AsyncMock() - mock_svc.extract = AsyncMock(return_value=[]) - mock_svc_class.return_value = mock_svc + with patch.object(TavilySearchTool, "_extract", new_callable=AsyncMock, return_value=[]) as mock_extract: await tool(context, config, content_limit=500) - call_args = mock_svc_class.call_args[0][0] - assert call_args.content_limit == 500 + # search_config is passed as first positional arg + assert mock_extract.call_args[0][0].content_limit == 500 From 6c89cdbe3ae5045e6ce6d6b6ffa235bc56348c07 Mon Sep 17 00:00:00 2001 From: Nikita Matsko Date: Fri, 20 Feb 2026 15:37:09 +0000 Subject: [PATCH 07/12] refactor(search): replace _search_registry with SearchProviderRegistry for search tools to centralize and standardize search provider registration and lookup --- sgr_agent_core/__init__.py | 2 ++ sgr_agent_core/services/__init__.py | 8 +++++++- sgr_agent_core/services/registry.py | 4 ++++ sgr_agent_core/tools/base_search_tool.py | 19 ++++--------------- sgr_agent_core/tools/brave_search_tool.py | 5 +++-- .../tools/perplexity_search_tool.py | 5 +++-- sgr_agent_core/tools/tavily_search_tool.py | 5 +++-- tests/test_search_providers.py | 6 ++++-- 8 files changed, 30 insertions(+), 24 deletions(-) diff --git a/sgr_agent_core/__init__.py b/sgr_agent_core/__init__.py index e398ccec..f1141893 100644 --- a/sgr_agent_core/__init__.py +++ b/sgr_agent_core/__init__.py @@ -34,6 +34,7 @@ AgentRegistry, MCP2ToolConverter, PromptLoader, + SearchProviderRegistry, ToolRegistry, ) from sgr_agent_core.tools import * # noqa: F403 @@ -57,6 +58,7 @@ "AgentRegistry", "MCP2ToolConverter", "PromptLoader", + "SearchProviderRegistry", "ToolRegistry", # Configuration "AgentConfig", diff --git a/sgr_agent_core/services/__init__.py b/sgr_agent_core/services/__init__.py index 7c77722e..7abc997d 100644 --- a/sgr_agent_core/services/__init__.py +++ b/sgr_agent_core/services/__init__.py @@ -2,13 +2,19 @@ from sgr_agent_core.services.mcp_service import MCP2ToolConverter from sgr_agent_core.services.prompt_loader import PromptLoader -from sgr_agent_core.services.registry import AgentRegistry, StreamingGeneratorRegistry, ToolRegistry +from sgr_agent_core.services.registry import ( + AgentRegistry, + SearchProviderRegistry, + StreamingGeneratorRegistry, + ToolRegistry, +) from sgr_agent_core.services.tool_instantiator import ToolInstantiator __all__ = [ "MCP2ToolConverter", "ToolRegistry", "StreamingGeneratorRegistry", + "SearchProviderRegistry", "AgentRegistry", "PromptLoader", "ToolInstantiator", diff --git a/sgr_agent_core/services/registry.py b/sgr_agent_core/services/registry.py index a1149e77..a966cd9d 100644 --- a/sgr_agent_core/services/registry.py +++ b/sgr_agent_core/services/registry.py @@ -130,3 +130,7 @@ class ToolRegistry(Registry["BaseTool"]): class StreamingGeneratorRegistry(Registry["BaseStreamingGenerator"]): """Registry for streaming generator classes (openai, open_webui, custom).""" + + +class SearchProviderRegistry(Registry["BaseTool"]): + """Registry for search engine providers (tavily, brave, perplexity).""" diff --git a/sgr_agent_core/tools/base_search_tool.py b/sgr_agent_core/tools/base_search_tool.py index 043e9a85..602869f4 100644 --- a/sgr_agent_core/tools/base_search_tool.py +++ b/sgr_agent_core/tools/base_search_tool.py @@ -9,6 +9,7 @@ from sgr_agent_core.agent_definition import AgentConfig, SearchConfig from sgr_agent_core.base_tool import BaseTool from sgr_agent_core.models import SearchResult, SourceData +from sgr_agent_core.services.registry import SearchProviderRegistry from sgr_agent_core.utils import config_from_kwargs if TYPE_CHECKING: @@ -17,10 +18,6 @@ logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) -# Engine name -> tool class with _search() staticmethod. -# Populated explicitly at module level in each provider tool file. -_search_registry: dict[str, type] = {} - class _BaseSearchTool(BaseTool): """Base class for all search tools. @@ -30,20 +27,12 @@ class _BaseSearchTool(BaseTool): docstring. Provider-specific API logic lives in concrete tools as @staticmethod - _search() methods, dispatched via _search_registry by engine name. + _search() methods, dispatched via SearchProviderRegistry by engine + name. """ _default_engine: ClassVar[str | None] = None - def __init_subclass__(cls, **kwargs: Any) -> None: - # Reset tool_name and description so each concrete subclass gets its - # own values instead of inheriting "_basesearchtool" / base docstring via MRO - if "tool_name" not in cls.__dict__: - cls.tool_name = cls.__name__.lower() - if "description" not in cls.__dict__: - cls.description = cls.__doc__ or "" - super().__init_subclass__(**kwargs) - config_model = SearchConfig base_config_attr = "search" @@ -88,7 +77,7 @@ async def __call__(self, context: AgentContext, config: AgentConfig, **kwargs: A ) logger.info(f"Search query: '{self.query}' (engine={search_config.engine})") - provider_cls = _search_registry.get(search_config.engine) + provider_cls = SearchProviderRegistry.get(search_config.engine) if provider_cls is None: raise ValueError(f"Unsupported search engine: {search_config.engine}") diff --git a/sgr_agent_core/tools/brave_search_tool.py b/sgr_agent_core/tools/brave_search_tool.py index 0f718ece..6f12b458 100644 --- a/sgr_agent_core/tools/brave_search_tool.py +++ b/sgr_agent_core/tools/brave_search_tool.py @@ -7,7 +7,8 @@ from sgr_agent_core.agent_definition import SearchConfig from sgr_agent_core.models import SourceData -from sgr_agent_core.tools.base_search_tool import _BaseSearchTool, _search_registry +from sgr_agent_core.services.registry import SearchProviderRegistry +from sgr_agent_core.tools.base_search_tool import _BaseSearchTool logger = logging.getLogger(__name__) @@ -99,4 +100,4 @@ def _convert_to_source_data(response: dict) -> list[SourceData]: return sources -_search_registry["brave"] = BraveSearchTool +SearchProviderRegistry.register(BraveSearchTool, name="brave") diff --git a/sgr_agent_core/tools/perplexity_search_tool.py b/sgr_agent_core/tools/perplexity_search_tool.py index fd5368c2..dc7658f0 100644 --- a/sgr_agent_core/tools/perplexity_search_tool.py +++ b/sgr_agent_core/tools/perplexity_search_tool.py @@ -7,7 +7,8 @@ from sgr_agent_core.agent_definition import SearchConfig from sgr_agent_core.models import SourceData -from sgr_agent_core.tools.base_search_tool import _BaseSearchTool, _search_registry +from sgr_agent_core.services.registry import SearchProviderRegistry +from sgr_agent_core.tools.base_search_tool import _BaseSearchTool logger = logging.getLogger(__name__) @@ -100,4 +101,4 @@ def _convert_to_source_data(response: dict) -> list[SourceData]: return sources -_search_registry["perplexity"] = PerplexitySearchTool +SearchProviderRegistry.register(PerplexitySearchTool, name="perplexity") diff --git a/sgr_agent_core/tools/tavily_search_tool.py b/sgr_agent_core/tools/tavily_search_tool.py index d7319165..bb5c6719 100644 --- a/sgr_agent_core/tools/tavily_search_tool.py +++ b/sgr_agent_core/tools/tavily_search_tool.py @@ -6,7 +6,8 @@ from sgr_agent_core.agent_definition import SearchConfig from sgr_agent_core.models import SourceData -from sgr_agent_core.tools.base_search_tool import _BaseSearchTool, _search_registry +from sgr_agent_core.services.registry import SearchProviderRegistry +from sgr_agent_core.tools.base_search_tool import _BaseSearchTool logger = logging.getLogger(__name__) @@ -101,4 +102,4 @@ def _convert_to_source_data(response: dict) -> list[SourceData]: return sources -_search_registry["tavily"] = TavilySearchTool +SearchProviderRegistry.register(TavilySearchTool, name="tavily") diff --git a/tests/test_search_providers.py b/tests/test_search_providers.py index f65ecc21..188eb6ee 100644 --- a/tests/test_search_providers.py +++ b/tests/test_search_providers.py @@ -13,14 +13,16 @@ class TestSearchToolRegistry: """Tests for _BaseSearchTool registry and shared helpers.""" def test_registry_contains_all_engines(self): - from sgr_agent_core.tools.base_search_tool import _search_registry + from sgr_agent_core.services.registry import SearchProviderRegistry # Ensure concrete tools are imported so registry is populated from sgr_agent_core.tools.brave_search_tool import BraveSearchTool # noqa: F401 from sgr_agent_core.tools.perplexity_search_tool import PerplexitySearchTool # noqa: F401 from sgr_agent_core.tools.tavily_search_tool import TavilySearchTool # noqa: F401 - assert set(_search_registry) == {"tavily", "brave", "perplexity"} + assert SearchProviderRegistry.get("tavily") is not None + assert SearchProviderRegistry.get("brave") is not None + assert SearchProviderRegistry.get("perplexity") is not None def test_rearrange_sources(self): from sgr_agent_core.tools.base_search_tool import _BaseSearchTool From 3dc9da16575f24741c586a3d376fc276b5b6faee Mon Sep 17 00:00:00 2001 From: Nikita Matsko Date: Fri, 20 Feb 2026 15:37:54 +0000 Subject: [PATCH 08/12] refactor(extract_page_content_tool): move Tavily extract logic from tavily_search_tool to improve separation of concerns and update tests accordingly --- .../tools/extract_page_content_tool.py | 34 +++++++++++++++++-- sgr_agent_core/tools/tavily_search_tool.py | 29 ---------------- tests/test_search_providers.py | 6 ++-- tests/test_tools.py | 2 +- 4 files changed, 36 insertions(+), 35 deletions(-) diff --git a/sgr_agent_core/tools/extract_page_content_tool.py b/sgr_agent_core/tools/extract_page_content_tool.py index 7aaa1ecd..b7856075 100644 --- a/sgr_agent_core/tools/extract_page_content_tool.py +++ b/sgr_agent_core/tools/extract_page_content_tool.py @@ -4,10 +4,11 @@ from typing import TYPE_CHECKING, Any from pydantic import Field +from tavily import AsyncTavilyClient from sgr_agent_core.agent_definition import SearchConfig from sgr_agent_core.base_tool import BaseTool -from sgr_agent_core.tools.tavily_search_tool import TavilySearchTool +from sgr_agent_core.models import SourceData from sgr_agent_core.utils import config_from_kwargs if TYPE_CHECKING: @@ -41,6 +42,35 @@ class ExtractPageContentTool(BaseTool): reasoning: str = Field(description="Why extract these specific pages") urls: list[str] = Field(description="List of URLs to extract full content from", min_length=1, max_length=5) + @staticmethod + async def _extract(config: SearchConfig, urls: list[str]) -> list[SourceData]: + """Extract full content from URLs via Tavily Extract API.""" + logger.info(f"Tavily extract: {len(urls)} URLs") + + client = AsyncTavilyClient(api_key=config.tavily_api_key, api_base_url=config.tavily_api_base_url) + response = await client.extract(urls=urls) + + sources = [] + for i, result in enumerate(response.get("results", [])): + if not result.get("url"): + continue + + source = SourceData( + number=i, + title=result.get("url", "").split("/")[-1] or "Extracted Content", + url=result.get("url", ""), + snippet="", + full_content=result.get("raw_content", ""), + char_count=len(result.get("raw_content", "")), + ) + sources.append(source) + + failed_urls = response.get("failed_results", []) + if failed_urls: + logger.warning(f"Failed to extract {len(failed_urls)} URLs: {failed_urls}") + + return sources + async def __call__(self, context: AgentContext, config: AgentConfig, **kwargs: Any) -> str: """Extract full content from specified URLs. @@ -59,7 +89,7 @@ async def __call__(self, context: AgentContext, config: AgentConfig, **kwargs: A ) logger.info(f"Extracting content from {len(self.urls)} URLs") - sources = await TavilySearchTool._extract(search_config, urls=self.urls) + sources = await self._extract(search_config, urls=self.urls) # Update existing sources instead of overwriting for source in sources: diff --git a/sgr_agent_core/tools/tavily_search_tool.py b/sgr_agent_core/tools/tavily_search_tool.py index bb5c6719..3b5ea10e 100644 --- a/sgr_agent_core/tools/tavily_search_tool.py +++ b/sgr_agent_core/tools/tavily_search_tool.py @@ -51,35 +51,6 @@ async def _search( sources = sources[offset:] return sources[:max_results] - @staticmethod - async def _extract(config: SearchConfig, urls: list[str]) -> list[SourceData]: - """Extract full content from URLs via Tavily Extract API.""" - logger.info(f"Tavily extract: {len(urls)} URLs") - - client = AsyncTavilyClient(api_key=config.tavily_api_key, api_base_url=config.tavily_api_base_url) - response = await client.extract(urls=urls) - - sources = [] - for i, result in enumerate(response.get("results", [])): - if not result.get("url"): - continue - - source = SourceData( - number=i, - title=result.get("url", "").split("/")[-1] or "Extracted Content", - url=result.get("url", ""), - snippet="", - full_content=result.get("raw_content", ""), - char_count=len(result.get("raw_content", "")), - ) - sources.append(source) - - failed_urls = response.get("failed_results", []) - if failed_urls: - logger.warning(f"Failed to extract {len(failed_urls)} URLs: {failed_urls}") - - return sources - @staticmethod def _convert_to_source_data(response: dict) -> list[SourceData]: """Convert Tavily response to SourceData list.""" diff --git a/tests/test_search_providers.py b/tests/test_search_providers.py index 188eb6ee..4fba7cba 100644 --- a/tests/test_search_providers.py +++ b/tests/test_search_providers.py @@ -174,7 +174,7 @@ def test_convert_to_source_data(self): @pytest.mark.asyncio async def test_extract_calls_tavily_api(self): - from sgr_agent_core.tools.tavily_search_tool import TavilySearchTool + from sgr_agent_core.tools.extract_page_content_tool import ExtractPageContentTool config = SearchConfig(tavily_api_key="test-key") @@ -188,8 +188,8 @@ async def test_extract_calls_tavily_api(self): } ) - with patch("sgr_agent_core.tools.tavily_search_tool.AsyncTavilyClient", return_value=mock_client): - sources = await TavilySearchTool._extract(config, urls=["https://example.com/page"]) + with patch("sgr_agent_core.tools.extract_page_content_tool.AsyncTavilyClient", return_value=mock_client): + sources = await ExtractPageContentTool._extract(config, urls=["https://example.com/page"]) assert len(sources) == 1 assert sources[0].url == "https://example.com/page" diff --git a/tests/test_tools.py b/tests/test_tools.py index e9763f6b..c7b96e6d 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -304,7 +304,7 @@ async def test_extract_page_content_tool_uses_content_limit_from_kwargs(self): context = AgentContext() config = MagicMock() config.search = SearchConfig(tavily_api_key="k", content_limit=1000) - with patch.object(TavilySearchTool, "_extract", new_callable=AsyncMock, return_value=[]) as mock_extract: + with patch.object(ExtractPageContentTool, "_extract", new_callable=AsyncMock, return_value=[]) as mock_extract: await tool(context, config, content_limit=500) # search_config is passed as first positional arg assert mock_extract.call_args[0][0].content_limit == 500 From 968239c822f9f9c96ac91528399b745b47cb741c Mon Sep 17 00:00:00 2001 From: Nikita Matsko Date: Fri, 20 Feb 2026 15:38:07 +0000 Subject: [PATCH 09/12] refactor(agent_definition): reorganize SearchConfig fields by provider and add section comments for clarity --- sgr_agent_core/agent_definition.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/sgr_agent_core/agent_definition.py b/sgr_agent_core/agent_definition.py index 5f8b0967..8b89828b 100644 --- a/sgr_agent_core/agent_definition.py +++ b/sgr_agent_core/agent_definition.py @@ -65,22 +65,27 @@ def to_openai_client_kwargs(self) -> dict[str, Any]: class SearchConfig(BaseModel, extra="allow"): - tavily_api_key: str | None = Field(default=None, description="Tavily API key") - tavily_api_base_url: str = Field(default="https://api.tavily.com", description="Tavily API base URL") - - max_searches: int = Field(default=4, ge=0, description="Maximum number of searches") - max_results: int = Field(default=10, ge=1, description="Maximum number of search results") - content_limit: int = Field(default=3500, gt=0, description="Content character limit per source") - + # General search settings engine: Literal["tavily", "brave", "perplexity"] = Field( default="tavily", description="Search engine provider to use", ) + max_searches: int = Field(default=4, ge=0, description="Maximum number of searches") + max_results: int = Field(default=10, ge=1, description="Maximum number of search results") + content_limit: int = Field(default=3500, gt=0, description="Content character limit per source") + + # Tavily provider + tavily_api_key: str | None = Field(default=None, description="Tavily API key") + tavily_api_base_url: str = Field(default="https://api.tavily.com", description="Tavily API base URL") + + # Brave provider brave_api_key: str | None = Field(default=None, description="Brave Search API key") brave_api_base_url: str = Field( default="https://api.search.brave.com/res/v1/web/search", description="Brave Search API base URL", ) + + # Perplexity provider perplexity_api_key: str | None = Field(default=None, description="Perplexity API key") perplexity_api_base_url: str = Field( default="https://api.perplexity.ai/search", From c262eaf2ecac868dadf2e8067542788ab17f759f Mon Sep 17 00:00:00 2001 From: Nikita Matsko Date: Fri, 20 Feb 2026 15:38:18 +0000 Subject: [PATCH 10/12] docs(web_search_tool): expand class docstring to clarify engine selection and usage options --- sgr_agent_core/tools/web_search_tool.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/sgr_agent_core/tools/web_search_tool.py b/sgr_agent_core/tools/web_search_tool.py index 90fdae27..0b937e25 100644 --- a/sgr_agent_core/tools/web_search_tool.py +++ b/sgr_agent_core/tools/web_search_tool.py @@ -2,7 +2,12 @@ class WebSearchTool(_BaseSearchTool): - """Search the web for real-time information about any topic. + """Primary user-facing search tool (not legacy). Engine is selected via + SearchConfig.engine — switch providers by changing one config line. + Standalone tools (TavilySearchTool, BraveSearchTool, PerplexitySearchTool) + are an alternative for multi-engine setups where the LLM picks the engine. + + Search the web for real-time information about any topic. Use this tool when you need up-to-date information that might not be available in your training data, or when you need to verify current facts. The search results will include relevant snippets and URLs from web pages. From 72381b05a6173c32c7365d292c31eec3d809419b Mon Sep 17 00:00:00 2001 From: Nikita Matsko Date: Sun, 15 Mar 2026 11:53:32 +0000 Subject: [PATCH 11/12] refactor(search): consolidate search providers into unified web search tool - consolidate search providers into a unified web search tool and config - remove SearchProviderRegistry and legacy base_search_tool and provider modules - update example configs to use generic api_key and optional api_base_url - adjust exports and registries to reflect removed search-specific classes --- config.yaml.example | 16 +- .../sgr_deep_research/config.yaml.example | 6 +- .../config.yaml.example | 6 +- sgr_agent_core/__init__.py | 2 - sgr_agent_core/agent_definition.py | 21 -- sgr_agent_core/base_tool.py | 2 +- sgr_agent_core/services/__init__.py | 2 - sgr_agent_core/services/registry.py | 4 - sgr_agent_core/tools/__init__.py | 9 +- sgr_agent_core/tools/base_search_tool.py | 119 ------ sgr_agent_core/tools/brave_search_tool.py | 103 ------ .../tools/perplexity_search_tool.py | 104 ------ sgr_agent_core/tools/tavily_search_tool.py | 76 ---- sgr_agent_core/tools/web_search_tool.py | 342 +++++++++++++++++- tests/test_base_agent.py | 8 +- tests/test_search_providers.py | 99 ++--- tests/test_tools.py | 63 +--- 17 files changed, 403 insertions(+), 579 deletions(-) delete mode 100644 sgr_agent_core/tools/base_search_tool.py delete mode 100644 sgr_agent_core/tools/brave_search_tool.py delete mode 100644 sgr_agent_core/tools/perplexity_search_tool.py delete mode 100644 sgr_agent_core/tools/tavily_search_tool.py diff --git a/config.yaml.example b/config.yaml.example index 1e535e14..20dddf80 100644 --- a/config.yaml.example +++ b/config.yaml.example @@ -61,24 +61,14 @@ tools: # (can be overridden per-agent in tools list) web_search_tool: engine: "tavily" # Search engine: "tavily" (default), "brave", or "perplexity" - tavily_api_key: "your-tavily-api-key-here" # Tavily API key (get at tavily.com) - tavily_api_base_url: "https://api.tavily.com" # Tavily API URL + api_key: "your-search-api-key-here" # API key for the selected engine + # api_base_url: "https://custom-url" # Optional, uses engine default max_results: 12 max_searches: 6 extract_page_content_tool: - tavily_api_key: "your-tavily-api-key-here" # Same Tavily API key (Tavily-only feature) + tavily_api_key: "your-tavily-api-key-here" # Tavily API key (Tavily-only feature) tavily_api_base_url: "https://api.tavily.com" content_limit: 2000 - # Standalone search tools (for multi-engine setups where LLM picks the engine) - brave_search_tool: - brave_api_key: "your-brave-api-key-here" # Brave Search API key - brave_api_base_url: "https://api.search.brave.com/res/v1/web/search" - perplexity_search_tool: - perplexity_api_key: "your-perplexity-api-key-here" # Perplexity API key - perplexity_api_base_url: "https://api.perplexity.ai/search" - tavily_search_tool: - tavily_api_key: "your-tavily-api-key-here" - tavily_api_base_url: "https://api.tavily.com" agents: custom_research_agent: diff --git a/examples/sgr_deep_research/config.yaml.example b/examples/sgr_deep_research/config.yaml.example index 40c73eed..d26ebab2 100644 --- a/examples/sgr_deep_research/config.yaml.example +++ b/examples/sgr_deep_research/config.yaml.example @@ -31,13 +31,11 @@ tools: # Search tools: configure Tavily API key and search limits per tool web_search_tool: engine: "tavily" # Search engine: "tavily" (default), "brave", or "perplexity" - tavily_api_key: "your-tavily-api-key-here" # Tavily API key (get at tavily.com) - tavily_api_base_url: "https://api.tavily.com" # Tavily API URL + api_key: "your-tavily-api-key-here" # API key for the selected engine max_searches: 4 # Max search operations max_results: 10 # Max results in search query extract_page_content_tool: - tavily_api_key: "your-tavily-api-key-here" # Same Tavily API key - tavily_api_base_url: "https://api.tavily.com" + tavily_api_key: "your-tavily-api-key-here" # Tavily API key (Tavily extract only) content_limit: 1500 # Content char limit per source create_report_tool: # base_class defaults to sgr_agent_core.tools.CreateReportTool diff --git a/examples/sgr_deep_research_without_reporting/config.yaml.example b/examples/sgr_deep_research_without_reporting/config.yaml.example index 41953831..ed8e29d9 100644 --- a/examples/sgr_deep_research_without_reporting/config.yaml.example +++ b/examples/sgr_deep_research_without_reporting/config.yaml.example @@ -30,13 +30,11 @@ tools: # Search tools: configure Tavily API key and search limits per tool web_search_tool: engine: "tavily" # Search engine: "tavily" (default), "brave", or "perplexity" - tavily_api_key: "your-tavily-api-key-here" # Tavily API key (get at tavily.com) - tavily_api_base_url: "https://api.tavily.com" # Tavily API URL + api_key: "your-tavily-api-key-here" # API key for the selected engine max_searches: 4 # Max search operations max_results: 10 # Max results in search query extract_page_content_tool: - tavily_api_key: "your-tavily-api-key-here" # Same Tavily API key - tavily_api_base_url: "https://api.tavily.com" + tavily_api_key: "your-tavily-api-key-here" # Tavily API key (Tavily extract only) content_limit: 1500 # Content char limit per source final_answer_tool: # base_class defaults to sgr_agent_core.tools.FinalAnswerTool diff --git a/sgr_agent_core/__init__.py b/sgr_agent_core/__init__.py index f1141893..e398ccec 100644 --- a/sgr_agent_core/__init__.py +++ b/sgr_agent_core/__init__.py @@ -34,7 +34,6 @@ AgentRegistry, MCP2ToolConverter, PromptLoader, - SearchProviderRegistry, ToolRegistry, ) from sgr_agent_core.tools import * # noqa: F403 @@ -58,7 +57,6 @@ "AgentRegistry", "MCP2ToolConverter", "PromptLoader", - "SearchProviderRegistry", "ToolRegistry", # Configuration "AgentConfig", diff --git a/sgr_agent_core/agent_definition.py b/sgr_agent_core/agent_definition.py index 8b89828b..e1313110 100644 --- a/sgr_agent_core/agent_definition.py +++ b/sgr_agent_core/agent_definition.py @@ -65,33 +65,12 @@ def to_openai_client_kwargs(self) -> dict[str, Any]: class SearchConfig(BaseModel, extra="allow"): - # General search settings - engine: Literal["tavily", "brave", "perplexity"] = Field( - default="tavily", - description="Search engine provider to use", - ) max_searches: int = Field(default=4, ge=0, description="Maximum number of searches") max_results: int = Field(default=10, ge=1, description="Maximum number of search results") content_limit: int = Field(default=3500, gt=0, description="Content character limit per source") - - # Tavily provider tavily_api_key: str | None = Field(default=None, description="Tavily API key") tavily_api_base_url: str = Field(default="https://api.tavily.com", description="Tavily API base URL") - # Brave provider - brave_api_key: str | None = Field(default=None, description="Brave Search API key") - brave_api_base_url: str = Field( - default="https://api.search.brave.com/res/v1/web/search", - description="Brave Search API base URL", - ) - - # Perplexity provider - perplexity_api_key: str | None = Field(default=None, description="Perplexity API key") - perplexity_api_base_url: str = Field( - default="https://api.perplexity.ai/search", - description="Perplexity Search API base URL", - ) - class PromptsConfig(BaseModel, extra="allow"): system_prompt_file: FilePath | None = Field( diff --git a/sgr_agent_core/base_tool.py b/sgr_agent_core/base_tool.py index 6528664a..510b24e6 100644 --- a/sgr_agent_core/base_tool.py +++ b/sgr_agent_core/base_tool.py @@ -20,7 +20,7 @@ class ToolRegistryMixin: def __init_subclass__(cls, **kwargs) -> None: super().__init_subclass__(**kwargs) - if cls.__name__ not in ("BaseTool", "MCPBaseTool", "_BaseSearchTool", "SystemBaseTool"): + if cls.__name__ not in ("BaseTool", "MCPBaseTool", "SystemBaseTool"): ToolRegistry.register(cls, name=cls.tool_name) diff --git a/sgr_agent_core/services/__init__.py b/sgr_agent_core/services/__init__.py index 7abc997d..87eec181 100644 --- a/sgr_agent_core/services/__init__.py +++ b/sgr_agent_core/services/__init__.py @@ -4,7 +4,6 @@ from sgr_agent_core.services.prompt_loader import PromptLoader from sgr_agent_core.services.registry import ( AgentRegistry, - SearchProviderRegistry, StreamingGeneratorRegistry, ToolRegistry, ) @@ -14,7 +13,6 @@ "MCP2ToolConverter", "ToolRegistry", "StreamingGeneratorRegistry", - "SearchProviderRegistry", "AgentRegistry", "PromptLoader", "ToolInstantiator", diff --git a/sgr_agent_core/services/registry.py b/sgr_agent_core/services/registry.py index a966cd9d..a1149e77 100644 --- a/sgr_agent_core/services/registry.py +++ b/sgr_agent_core/services/registry.py @@ -130,7 +130,3 @@ class ToolRegistry(Registry["BaseTool"]): class StreamingGeneratorRegistry(Registry["BaseStreamingGenerator"]): """Registry for streaming generator classes (openai, open_webui, custom).""" - - -class SearchProviderRegistry(Registry["BaseTool"]): - """Registry for search engine providers (tavily, brave, perplexity).""" diff --git a/sgr_agent_core/tools/__init__.py b/sgr_agent_core/tools/__init__.py index a2647d24..bc810d7c 100644 --- a/sgr_agent_core/tools/__init__.py +++ b/sgr_agent_core/tools/__init__.py @@ -6,16 +6,13 @@ ) from sgr_agent_core.tools.adapt_plan_tool import AdaptPlanTool from sgr_agent_core.tools.answer_tool import AnswerTool -from sgr_agent_core.tools.brave_search_tool import BraveSearchTool from sgr_agent_core.tools.clarification_tool import ClarificationTool from sgr_agent_core.tools.create_report_tool import CreateReportTool from sgr_agent_core.tools.extract_page_content_tool import ExtractPageContentTool from sgr_agent_core.tools.final_answer_tool import FinalAnswerTool from sgr_agent_core.tools.generate_plan_tool import GeneratePlanTool -from sgr_agent_core.tools.perplexity_search_tool import PerplexitySearchTool from sgr_agent_core.tools.reasoning_tool import ReasoningTool -from sgr_agent_core.tools.tavily_search_tool import TavilySearchTool -from sgr_agent_core.tools.web_search_tool import WebSearchTool +from sgr_agent_core.tools.web_search_tool import WebSearchConfig, WebSearchTool __all__ = [ # Base classes @@ -28,14 +25,12 @@ # Individual tools "AdaptPlanTool", "AnswerTool", - "BraveSearchTool", "ClarificationTool", "CreateReportTool", "ExtractPageContentTool", "FinalAnswerTool", "GeneratePlanTool", - "PerplexitySearchTool", "ReasoningTool", - "TavilySearchTool", + "WebSearchConfig", "WebSearchTool", ] diff --git a/sgr_agent_core/tools/base_search_tool.py b/sgr_agent_core/tools/base_search_tool.py deleted file mode 100644 index 602869f4..00000000 --- a/sgr_agent_core/tools/base_search_tool.py +++ /dev/null @@ -1,119 +0,0 @@ -from __future__ import annotations - -import logging -from datetime import datetime -from typing import TYPE_CHECKING, Any, ClassVar - -from pydantic import Field - -from sgr_agent_core.agent_definition import AgentConfig, SearchConfig -from sgr_agent_core.base_tool import BaseTool -from sgr_agent_core.models import SearchResult, SourceData -from sgr_agent_core.services.registry import SearchProviderRegistry -from sgr_agent_core.utils import config_from_kwargs - -if TYPE_CHECKING: - from sgr_agent_core.models import AgentContext - -logger = logging.getLogger(__name__) -logger.setLevel(logging.INFO) - - -class _BaseSearchTool(BaseTool): - """Base class for all search tools. - - Provides shared fields (reasoning, query, max_results, offset) and - common __call__ logic. Concrete tools override _default_engine and - docstring. - - Provider-specific API logic lives in concrete tools as @staticmethod - _search() methods, dispatched via SearchProviderRegistry by engine - name. - """ - - _default_engine: ClassVar[str | None] = None - - config_model = SearchConfig - base_config_attr = "search" - - reasoning: str = Field(description="Why this search is needed and what to expect") - query: str = Field(description="Search query in same language as user request") - max_results: int = Field( - description="Maximum results. How much of the web results selection you want to retrieve", - default=5, - ge=1, - le=20, - ) - offset: int = Field( - default=0, - ge=0, - description=( - "Number of results to skip from the beginning." - " Use for pagination: first call offset=0, next call offset=5, etc." - ), - ) - - @staticmethod - def _rearrange_sources(sources: list[SourceData], starting_number: int = 1) -> list[SourceData]: - """Renumber sources sequentially starting from given number.""" - for i, source in enumerate(sources, starting_number): - source.number = i - return sources - - async def __call__(self, context: AgentContext, config: AgentConfig, **kwargs: Any) -> str: - """Execute web search using the configured search engine. - - Search settings are taken from kwargs (tool config) with - fallback to config.search. - """ - # If this tool has a hardcoded engine, force it - if self._default_engine is not None: - kwargs.setdefault("engine", self._default_engine) - - search_config = config_from_kwargs( - SearchConfig, - config.search if config else None, - dict(kwargs), - ) - logger.info(f"Search query: '{self.query}' (engine={search_config.engine})") - - provider_cls = SearchProviderRegistry.get(search_config.engine) - if provider_cls is None: - raise ValueError(f"Unsupported search engine: {search_config.engine}") - - max_results_limit = search_config.max_results - effective_limit = min(self.max_results, max_results_limit) - - # Each provider handles offset internally: - # Brave uses native API offset, Tavily/Perplexity use over-fetch+slice - sources = await provider_cls._search( - config=search_config, - query=self.query, - max_results=effective_limit, - offset=self.offset, - include_raw_content=False, - ) - - sources = self._rearrange_sources(sources, starting_number=len(context.sources) + 1) - - for source in sources: - context.sources[source.url] = source - - search_result = SearchResult( - query=self.query, - answer=None, - citations=sources, - timestamp=datetime.now(), - ) - context.searches.append(search_result) - - formatted_result = f"Search Query: {search_result.query}\n\n" - formatted_result += "Search Results (titles, links, short snippets):\n\n" - - for source in sources: - snippet = source.snippet[:100] + "..." if len(source.snippet) > 100 else source.snippet - formatted_result += f"{str(source)}\n{snippet}\n\n" - - context.searches_used += 1 - logger.debug(formatted_result) - return formatted_result diff --git a/sgr_agent_core/tools/brave_search_tool.py b/sgr_agent_core/tools/brave_search_tool.py deleted file mode 100644 index 6f12b458..00000000 --- a/sgr_agent_core/tools/brave_search_tool.py +++ /dev/null @@ -1,103 +0,0 @@ -from __future__ import annotations - -import logging -from typing import Any - -import httpx - -from sgr_agent_core.agent_definition import SearchConfig -from sgr_agent_core.models import SourceData -from sgr_agent_core.services.registry import SearchProviderRegistry -from sgr_agent_core.tools.base_search_tool import _BaseSearchTool - -logger = logging.getLogger(__name__) - - -class BraveSearchTool(_BaseSearchTool): - """Search the web using Brave search engine. Brave Search provides privacy- - focused search results with native pagination support. Use this tool when - you specifically want to search with Brave. - - Returns: Page titles, URLs, and short snippets - Best for: Privacy-focused search, efficient pagination via native offset - - Usage: - - Use SPECIFIC terms and context in queries - - Search queries in SAME LANGUAGE as user request - - Brave supports efficient pagination with offset parameter - """ - - _default_engine = "brave" - - @staticmethod - async def _search( - config: SearchConfig, - query: str, - max_results: int, - offset: int = 0, - include_raw_content: bool = True, - ) -> list[SourceData]: - """Perform search via Brave Search API. - - Brave supports native offset parameter for efficient pagination. - """ - if not config.brave_api_key: - raise ValueError("brave_api_key is required for BraveSearchTool") - - max_results = min(max_results, 20) - logger.info(f"Brave search: '{query}' (max_results={max_results}, offset={offset})") - - headers = { - "Accept": "application/json", - "Accept-Encoding": "gzip", - "X-Subscription-Token": config.brave_api_key, - } - params: dict[str, Any] = { - "q": query, - "count": max_results, - } - if offset > 0: - params["offset"] = offset - - try: - async with httpx.AsyncClient() as client: - response = await client.get( - config.brave_api_base_url, - headers=headers, - params=params, - timeout=30.0, - ) - response.raise_for_status() - data = response.json() - except httpx.HTTPStatusError as e: - logger.error(f"Brave API HTTP error: {e.response.status_code} — {e.response.text[:200]}") - raise - except httpx.RequestError as e: - logger.error(f"Brave API request error: {e}") - raise - - return BraveSearchTool._convert_to_source_data(data) - - @staticmethod - def _convert_to_source_data(response: dict) -> list[SourceData]: - """Convert Brave Search API response to SourceData list.""" - sources = [] - web_results = response.get("web", {}).get("results", []) - - for i, result in enumerate(web_results): - url = result.get("url", "") - if not url: - continue - - source = SourceData( - number=i, - title=result.get("title", ""), - url=url, - snippet=result.get("description", ""), - ) - sources.append(source) - - return sources - - -SearchProviderRegistry.register(BraveSearchTool, name="brave") diff --git a/sgr_agent_core/tools/perplexity_search_tool.py b/sgr_agent_core/tools/perplexity_search_tool.py deleted file mode 100644 index dc7658f0..00000000 --- a/sgr_agent_core/tools/perplexity_search_tool.py +++ /dev/null @@ -1,104 +0,0 @@ -from __future__ import annotations - -import logging -from typing import Any - -import httpx - -from sgr_agent_core.agent_definition import SearchConfig -from sgr_agent_core.models import SourceData -from sgr_agent_core.services.registry import SearchProviderRegistry -from sgr_agent_core.tools.base_search_tool import _BaseSearchTool - -logger = logging.getLogger(__name__) - - -class PerplexitySearchTool(_BaseSearchTool): - """Search the web using Perplexity AI search engine. Perplexity provides - AI-powered search with synthesized answers and source citations. Use this - tool when you specifically want to search with Perplexity. - - Returns: Page titles, URLs, and AI-synthesized snippets - Best for: Getting AI-synthesized answers with source citations - - Usage: - - Use SPECIFIC terms and context in queries - - Search queries in SAME LANGUAGE as user request - - Results include AI-generated summary alongside source URLs - """ - - _default_engine = "perplexity" - - @staticmethod - async def _search( - config: SearchConfig, - query: str, - max_results: int, - offset: int = 0, - include_raw_content: bool = True, - ) -> list[SourceData]: - """Perform search via Perplexity Search API. - - Perplexity does not support native offset — over-fetch+slice is - applied internally when offset > 0. - """ - if not config.perplexity_api_key: - raise ValueError("perplexity_api_key is required for PerplexitySearchTool") - - fetch_count = max_results + offset if offset > 0 else max_results - logger.info(f"Perplexity search: '{query}' (max_results={max_results}, offset={offset})") - - headers = { - "Authorization": f"Bearer {config.perplexity_api_key}", - "Content-Type": "application/json", - } - payload: dict[str, Any] = { - "query": query, - "max_results": fetch_count, - } - - try: - async with httpx.AsyncClient() as client: - response = await client.post( - config.perplexity_api_base_url, - headers=headers, - json=payload, - timeout=30.0, - ) - response.raise_for_status() - data = response.json() - except httpx.HTTPStatusError as e: - logger.error(f"Perplexity API HTTP error: {e.response.status_code} — {e.response.text[:200]}") - raise - except httpx.RequestError as e: - logger.error(f"Perplexity API request error: {e}") - raise - - sources = PerplexitySearchTool._convert_to_source_data(data) - if offset > 0: - sources = sources[offset:] - return sources[:max_results] - - @staticmethod - def _convert_to_source_data(response: dict) -> list[SourceData]: - """Convert Perplexity Search API response to SourceData list.""" - sources = [] - results = response.get("results", []) - - for i, result in enumerate(results): - url = result.get("url", "") - if not url: - continue - - source = SourceData( - number=i, - title=result.get("title", ""), - url=url, - snippet=result.get("snippet", ""), - ) - sources.append(source) - - return sources - - -SearchProviderRegistry.register(PerplexitySearchTool, name="perplexity") diff --git a/sgr_agent_core/tools/tavily_search_tool.py b/sgr_agent_core/tools/tavily_search_tool.py deleted file mode 100644 index 3b5ea10e..00000000 --- a/sgr_agent_core/tools/tavily_search_tool.py +++ /dev/null @@ -1,76 +0,0 @@ -from __future__ import annotations - -import logging - -from tavily import AsyncTavilyClient - -from sgr_agent_core.agent_definition import SearchConfig -from sgr_agent_core.models import SourceData -from sgr_agent_core.services.registry import SearchProviderRegistry -from sgr_agent_core.tools.base_search_tool import _BaseSearchTool - -logger = logging.getLogger(__name__) - - -class TavilySearchTool(_BaseSearchTool): - """Search the web using Tavily search engine. Tavily provides high-quality - search results with optional raw content extraction. Use this tool when you - specifically want to search with Tavily. - - Returns: Page titles, URLs, and short snippets - Best for: General web search, research queries - - Usage: - - Use SPECIFIC terms and context in queries - - Search queries in SAME LANGUAGE as user request - - Use ExtractPageContentTool to get full content from found URLs - """ - - _default_engine = "tavily" - - @staticmethod - async def _search( - config: SearchConfig, - query: str, - max_results: int, - offset: int = 0, - include_raw_content: bool = True, - ) -> list[SourceData]: - """Perform search via Tavily API. - - Offset: over-fetch+slice. - """ - fetch_count = max_results + offset if offset > 0 else max_results - logger.info(f"Tavily search: '{query}' (max_results={max_results}, offset={offset})") - - client = AsyncTavilyClient(api_key=config.tavily_api_key, api_base_url=config.tavily_api_base_url) - response = await client.search(query=query, max_results=fetch_count, include_raw_content=include_raw_content) - - sources = TavilySearchTool._convert_to_source_data(response) - if offset > 0: - sources = sources[offset:] - return sources[:max_results] - - @staticmethod - def _convert_to_source_data(response: dict) -> list[SourceData]: - """Convert Tavily response to SourceData list.""" - sources = [] - - for i, result in enumerate(response.get("results", [])): - if not result.get("url", ""): - continue - - source = SourceData( - number=i, - title=result.get("title", ""), - url=result.get("url", ""), - snippet=result.get("content", ""), - ) - if result.get("raw_content", ""): - source.full_content = result["raw_content"] - source.char_count = len(source.full_content) - sources.append(source) - return sources - - -SearchProviderRegistry.register(TavilySearchTool, name="tavily") diff --git a/sgr_agent_core/tools/web_search_tool.py b/sgr_agent_core/tools/web_search_tool.py index 0b937e25..a9b1f53f 100644 --- a/sgr_agent_core/tools/web_search_tool.py +++ b/sgr_agent_core/tools/web_search_tool.py @@ -1,18 +1,263 @@ -from sgr_agent_core.tools.base_search_tool import _BaseSearchTool +from __future__ import annotations +import logging +from datetime import datetime +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Literal -class WebSearchTool(_BaseSearchTool): - """Primary user-facing search tool (not legacy). Engine is selected via - SearchConfig.engine — switch providers by changing one config line. - Standalone tools (TavilySearchTool, BraveSearchTool, PerplexitySearchTool) - are an alternative for multi-engine setups where the LLM picks the engine. +import httpx +from pydantic import BaseModel, Field +from tavily import AsyncTavilyClient - Search the web for real-time information about any topic. - Use this tool when you need up-to-date information that might not be available in your training data, - or when you need to verify current facts. +from sgr_agent_core.base_tool import BaseTool +from sgr_agent_core.models import SearchResult, SourceData +from sgr_agent_core.utils import config_from_kwargs + +if TYPE_CHECKING: + from sgr_agent_core.agent_definition import AgentConfig + from sgr_agent_core.models import AgentContext + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + +_TAVILY_DEFAULT_URL = "https://api.tavily.com" +_BRAVE_DEFAULT_URL = "https://api.search.brave.com/res/v1/web/search" +_PERPLEXITY_DEFAULT_URL = "https://api.perplexity.ai/search" + +_ENGINE_DEFAULT_URLS: dict[str, str] = { + "tavily": _TAVILY_DEFAULT_URL, + "brave": _BRAVE_DEFAULT_URL, + "perplexity": _PERPLEXITY_DEFAULT_URL, +} + + +class WebSearchConfig(BaseModel, extra="allow"): + """Configuration for WebSearchTool. + + Defines the search engine, credentials, and limits. + """ + + engine: Literal["tavily", "brave", "perplexity"] = Field( + default="tavily", + description="Search engine provider to use", + ) + api_key: str | None = Field(default=None, description="API key for the selected engine") + api_base_url: str | None = Field( + default=None, + description="API base URL for the selected engine (None = engine default)", + ) + max_searches: int = Field(default=4, ge=0, description="Maximum number of searches") + max_results: int = Field(default=10, ge=1, description="Maximum number of search results") + + +# --------------------------------------------------------------------------- +# Tavily +# --------------------------------------------------------------------------- + + +def _convert_tavily_response(response: dict) -> list[SourceData]: + """Convert Tavily API response to SourceData list.""" + sources = [] + for i, result in enumerate(response.get("results", [])): + if not result.get("url", ""): + continue + source = SourceData( + number=i, + title=result.get("title", ""), + url=result.get("url", ""), + snippet=result.get("content", ""), + ) + if result.get("raw_content", ""): + source.full_content = result["raw_content"] + source.char_count = len(source.full_content) + sources.append(source) + return sources + + +async def _search_tavily( + api_key: str, + api_base_url: str, + query: str, + max_results: int, + offset: int, +) -> list[SourceData]: + """Perform search via Tavily API. + + Offset: over-fetch + slice. + """ + fetch_count = max_results + offset if offset > 0 else max_results + logger.info(f"Tavily search: '{query}' (max_results={max_results}, offset={offset})") + + client = AsyncTavilyClient(api_key=api_key, api_base_url=api_base_url) + response = await client.search(query=query, max_results=fetch_count, include_raw_content=False) + + sources = _convert_tavily_response(response) + if offset > 0: + sources = sources[offset:] + return sources[:max_results] + + +# --------------------------------------------------------------------------- +# Brave +# --------------------------------------------------------------------------- + + +def _convert_brave_response(response: dict) -> list[SourceData]: + """Convert Brave Search API response to SourceData list.""" + sources = [] + web_results = response.get("web", {}).get("results", []) + for i, result in enumerate(web_results): + url = result.get("url", "") + if not url: + continue + source = SourceData( + number=i, + title=result.get("title", ""), + url=url, + snippet=result.get("description", ""), + ) + sources.append(source) + return sources + + +async def _search_brave( + api_key: str, + api_base_url: str, + query: str, + max_results: int, + offset: int, +) -> list[SourceData]: + """Perform search via Brave Search API. + + Native offset support. + """ + capped = min(max_results, 20) + logger.info(f"Brave search: '{query}' (max_results={capped}, offset={offset})") + + headers = { + "Accept": "application/json", + "Accept-Encoding": "gzip", + "X-Subscription-Token": api_key, + } + params: dict[str, Any] = {"q": query, "count": capped} + if offset > 0: + params["offset"] = offset + + try: + async with httpx.AsyncClient() as client: + response = await client.get(api_base_url, headers=headers, params=params, timeout=30.0) + response.raise_for_status() + data = response.json() + except httpx.HTTPStatusError as e: + logger.error(f"Brave API HTTP error: {e.response.status_code} — {e.response.text[:200]}") + raise + except httpx.RequestError as e: + logger.error(f"Brave API request error: {e}") + raise + + return _convert_brave_response(data) + + +# --------------------------------------------------------------------------- +# Perplexity +# --------------------------------------------------------------------------- + + +def _convert_perplexity_response(response: dict) -> list[SourceData]: + """Convert Perplexity Search API response to SourceData list.""" + sources = [] + for i, result in enumerate(response.get("results", [])): + url = result.get("url", "") + if not url: + continue + source = SourceData( + number=i, + title=result.get("title", ""), + url=url, + snippet=result.get("snippet", ""), + ) + sources.append(source) + return sources + + +async def _search_perplexity( + api_key: str, + api_base_url: str, + query: str, + max_results: int, + offset: int, +) -> list[SourceData]: + """Perform search via Perplexity API. + + Offset: over-fetch + slice. + """ + fetch_count = max_results + offset if offset > 0 else max_results + logger.info(f"Perplexity search: '{query}' (max_results={max_results}, offset={offset})") + + headers = { + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json", + } + payload: dict[str, Any] = {"query": query, "max_results": fetch_count} + + try: + async with httpx.AsyncClient() as client: + response = await client.post(api_base_url, headers=headers, json=payload, timeout=30.0) + response.raise_for_status() + data = response.json() + except httpx.HTTPStatusError as e: + logger.error(f"Perplexity API HTTP error: {e.response.status_code} — {e.response.text[:200]}") + raise + except httpx.RequestError as e: + logger.error(f"Perplexity API request error: {e}") + raise + + sources = _convert_perplexity_response(data) + if offset > 0: + sources = sources[offset:] + return sources[:max_results] + + +# --------------------------------------------------------------------------- +# Engine handler mapping +# --------------------------------------------------------------------------- + +SearchHandler = Callable[..., Awaitable[list[SourceData]]] + +_ENGINE_HANDLERS: dict[str, SearchHandler] = { + "tavily": _search_tavily, + "brave": _search_brave, + "perplexity": _search_perplexity, +} + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _rearrange_sources(sources: list[SourceData], starting_number: int = 1) -> list[SourceData]: + """Renumber sources sequentially starting from given number.""" + for i, source in enumerate(sources, starting_number): + source.number = i + return sources + + +# --------------------------------------------------------------------------- +# WebSearchTool +# --------------------------------------------------------------------------- + + +class WebSearchTool(BaseTool): + """Search the web for real-time information about any topic. + + Single search tool with pluggable engine (tavily, brave, perplexity). + Engine is selected via tool config ``engine`` field. + + Use this tool when you need up-to-date information that might not be + available in your training data, or when you need to verify current facts. The search results will include relevant snippets and URLs from web pages. - This is particularly useful for questions about current events, technology updates, - or any topic that requires recent information. + This is particularly useful for questions about current events, technology + updates, or any topic that requires recent information. Use for: Public information, news, market trends, external APIs, general knowledge Returns: Page titles, URLs, and short snippets (100 characters) Best for: Quick overview, finding relevant pages @@ -30,3 +275,78 @@ class WebSearchTool(_BaseSearchTool): - For questions with specific dates/numbers, snippets may be more accurate than full pages - If the snippet directly answers the question, you may not need to extract the full page """ + + config_model = WebSearchConfig + base_config_attr = "search" + + reasoning: str = Field(description="Why this search is needed and what to expect") + query: str = Field(description="Search query in same language as user request") + max_results: int = Field( + description="Maximum results. How much of the web results selection you want to retrieve", + default=5, + ge=1, + le=20, + ) + offset: int = Field( + default=0, + ge=0, + description=( + "Number of results to skip from the beginning." + " Use for pagination: first call offset=0, next call offset=5, etc." + ), + ) + + async def __call__(self, context: AgentContext, config: AgentConfig, **kwargs: Any) -> str: + """Execute web search using the configured search engine.""" + search_config = config_from_kwargs( + WebSearchConfig, + config.search if config else None, + dict(kwargs), + ) + + engine = search_config.engine + logger.info(f"Search query: '{self.query}' (engine={engine})") + + handler = _ENGINE_HANDLERS.get(engine) + if handler is None: + raise ValueError(f"Unsupported search engine: {engine}") + + api_key = search_config.api_key + if not api_key: + raise ValueError(f"api_key is required for engine '{engine}'") + + api_base_url = search_config.api_base_url or _ENGINE_DEFAULT_URLS[engine] + + effective_limit = min(self.max_results, search_config.max_results) + + sources = await handler( + api_key=api_key, + api_base_url=api_base_url, + query=self.query, + max_results=effective_limit, + offset=self.offset, + ) + + sources = _rearrange_sources(sources, starting_number=len(context.sources) + 1) + + for source in sources: + context.sources[source.url] = source + + search_result = SearchResult( + query=self.query, + answer=None, + citations=sources, + timestamp=datetime.now(), + ) + context.searches.append(search_result) + + formatted_result = f"Search Query: {search_result.query}\n\n" + formatted_result += "Search Results (titles, links, short snippets):\n\n" + + for source in sources: + snippet = source.snippet[:100] + "..." if len(source.snippet) > 100 else source.snippet + formatted_result += f"{str(source)}\n{snippet}\n\n" + + context.searches_used += 1 + logger.debug(formatted_result) + return formatted_result diff --git a/tests/test_base_agent.py b/tests/test_base_agent.py index 61150fc4..38796f65 100644 --- a/tests/test_base_agent.py +++ b/tests/test_base_agent.py @@ -351,7 +351,9 @@ def test_get_tool_config_returns_model_when_tool_has_config_model(self): search=None, ) out = agent.get_tool_config(WebSearchTool) - assert isinstance(out, SearchConfig) + from sgr_agent_core.tools.web_search_tool import WebSearchConfig + + assert isinstance(out, WebSearchConfig) assert out.max_searches == 6 def test_get_tool_config_merges_base_from_agent_config(self): @@ -370,7 +372,9 @@ def test_get_tool_config_merges_base_from_agent_config(self): search=SearchConfig(tavily_api_key="key", max_searches=10), ) out = agent.get_tool_config(WebSearchTool) - assert isinstance(out, SearchConfig) + from sgr_agent_core.tools.web_search_tool import WebSearchConfig + + assert isinstance(out, WebSearchConfig) assert out.max_searches == 10 assert out.tavily_api_key == "key" diff --git a/tests/test_search_providers.py b/tests/test_search_providers.py index 4fba7cba..43df7b11 100644 --- a/tests/test_search_providers.py +++ b/tests/test_search_providers.py @@ -1,5 +1,4 @@ -"""Tests for search provider logic inlined into tools (TavilySearchTool, -BraveSearchTool, PerplexitySearchTool).""" +"""Tests for search engine handler functions.""" from unittest.mock import AsyncMock, MagicMock, patch @@ -9,46 +8,26 @@ from sgr_agent_core.models import SourceData -class TestSearchToolRegistry: - """Tests for _BaseSearchTool registry and shared helpers.""" +class TestRearrangeSources: + """Tests for _rearrange_sources helper.""" - def test_registry_contains_all_engines(self): - from sgr_agent_core.services.registry import SearchProviderRegistry - - # Ensure concrete tools are imported so registry is populated - from sgr_agent_core.tools.brave_search_tool import BraveSearchTool # noqa: F401 - from sgr_agent_core.tools.perplexity_search_tool import PerplexitySearchTool # noqa: F401 - from sgr_agent_core.tools.tavily_search_tool import TavilySearchTool # noqa: F401 - - assert SearchProviderRegistry.get("tavily") is not None - assert SearchProviderRegistry.get("brave") is not None - assert SearchProviderRegistry.get("perplexity") is not None - - def test_rearrange_sources(self): - from sgr_agent_core.tools.base_search_tool import _BaseSearchTool + def test_renumbers_from_starting_number(self): + from sgr_agent_core.tools.web_search_tool import _rearrange_sources sources = [ SourceData(number=0, url="https://a.com", title="A", snippet="a"), SourceData(number=0, url="https://b.com", title="B", snippet="b"), ] - result = _BaseSearchTool._rearrange_sources(sources, starting_number=5) + result = _rearrange_sources(sources, starting_number=5) assert result[0].number == 5 assert result[1].number == 6 -class TestBraveSearchProvider: - """Tests for BraveSearchTool provider logic.""" - - @pytest.mark.asyncio - async def test_raises_without_api_key(self): - from sgr_agent_core.tools.brave_search_tool import BraveSearchTool - - config = SearchConfig(engine="brave") - with pytest.raises(ValueError, match="brave_api_key is required"): - await BraveSearchTool._search(config, query="test", max_results=5) +class TestBraveSearchHandler: + """Tests for Brave search handler function.""" - def test_convert_to_source_data(self): - from sgr_agent_core.tools.brave_search_tool import BraveSearchTool + def test_convert_brave_response(self): + from sgr_agent_core.tools.web_search_tool import _convert_brave_response response = { "web": { @@ -59,7 +38,7 @@ def test_convert_to_source_data(self): ] } } - sources = BraveSearchTool._convert_to_source_data(response) + sources = _convert_brave_response(response) assert len(sources) == 2 assert sources[0].title == "Test" assert sources[0].url == "https://example.com" @@ -67,9 +46,7 @@ def test_convert_to_source_data(self): @pytest.mark.asyncio async def test_search_calls_brave_api(self): - from sgr_agent_core.tools.brave_search_tool import BraveSearchTool - - config = SearchConfig(engine="brave", brave_api_key="test-key", max_results=10) + from sgr_agent_core.tools.web_search_tool import _search_brave mock_response = MagicMock() mock_response.json.return_value = { @@ -81,14 +58,20 @@ async def test_search_calls_brave_api(self): } mock_response.raise_for_status = MagicMock() - with patch("sgr_agent_core.tools.brave_search_tool.httpx.AsyncClient") as mock_client_cls: + with patch("sgr_agent_core.tools.web_search_tool.httpx.AsyncClient") as mock_client_cls: mock_client = AsyncMock() mock_client.get = AsyncMock(return_value=mock_response) mock_client.__aenter__ = AsyncMock(return_value=mock_client) mock_client.__aexit__ = AsyncMock(return_value=False) mock_client_cls.return_value = mock_client - sources = await BraveSearchTool._search(config, query="test query", max_results=5) + sources = await _search_brave( + api_key="test-key", + api_base_url="https://api.search.brave.com/res/v1/web/search", + query="test query", + max_results=5, + offset=0, + ) mock_client.get.assert_called_once() call_kwargs = mock_client.get.call_args @@ -97,19 +80,11 @@ async def test_search_calls_brave_api(self): assert len(sources) == 1 -class TestPerplexitySearchProvider: - """Tests for PerplexitySearchTool provider logic.""" - - @pytest.mark.asyncio - async def test_raises_without_api_key(self): - from sgr_agent_core.tools.perplexity_search_tool import PerplexitySearchTool +class TestPerplexitySearchHandler: + """Tests for Perplexity search handler function.""" - config = SearchConfig(engine="perplexity") - with pytest.raises(ValueError, match="perplexity_api_key is required"): - await PerplexitySearchTool._search(config, query="test", max_results=5) - - def test_convert_to_source_data(self): - from sgr_agent_core.tools.perplexity_search_tool import PerplexitySearchTool + def test_convert_perplexity_response(self): + from sgr_agent_core.tools.web_search_tool import _convert_perplexity_response response = { "results": [ @@ -118,7 +93,7 @@ def test_convert_to_source_data(self): {"title": "No URL", "url": "", "snippet": "Skipped"}, ], } - sources = PerplexitySearchTool._convert_to_source_data(response) + sources = _convert_perplexity_response(response) assert len(sources) == 2 assert sources[0].url == "https://example.com/page1" assert sources[0].title == "Page 1" @@ -127,9 +102,7 @@ def test_convert_to_source_data(self): @pytest.mark.asyncio async def test_search_calls_perplexity_api(self): - from sgr_agent_core.tools.perplexity_search_tool import PerplexitySearchTool - - config = SearchConfig(engine="perplexity", perplexity_api_key="test-key", max_results=10) + from sgr_agent_core.tools.web_search_tool import _search_perplexity mock_response = MagicMock() mock_response.json.return_value = { @@ -139,14 +112,20 @@ async def test_search_calls_perplexity_api(self): } mock_response.raise_for_status = MagicMock() - with patch("sgr_agent_core.tools.perplexity_search_tool.httpx.AsyncClient") as mock_client_cls: + with patch("sgr_agent_core.tools.web_search_tool.httpx.AsyncClient") as mock_client_cls: mock_client = AsyncMock() mock_client.post = AsyncMock(return_value=mock_response) mock_client.__aenter__ = AsyncMock(return_value=mock_client) mock_client.__aexit__ = AsyncMock(return_value=False) mock_client_cls.return_value = mock_client - sources = await PerplexitySearchTool._search(config, query="test query", max_results=5) + sources = await _search_perplexity( + api_key="test-key", + api_base_url="https://api.perplexity.ai/search", + query="test query", + max_results=5, + offset=0, + ) mock_client.post.assert_called_once() call_kwargs = mock_client.post.call_args @@ -155,18 +134,18 @@ async def test_search_calls_perplexity_api(self): assert len(sources) == 1 -class TestTavilySearchProvider: - """Tests for TavilySearchTool provider logic.""" +class TestTavilySearchHandler: + """Tests for Tavily search handler function.""" - def test_convert_to_source_data(self): - from sgr_agent_core.tools.tavily_search_tool import TavilySearchTool + def test_convert_tavily_response(self): + from sgr_agent_core.tools.web_search_tool import _convert_tavily_response response = { "results": [ {"title": "Test", "url": "https://example.com", "content": "Snippet", "raw_content": "Full content"}, ] } - sources = TavilySearchTool._convert_to_source_data(response) + sources = _convert_tavily_response(response) assert len(sources) == 1 assert sources[0].title == "Test" assert sources[0].snippet == "Snippet" diff --git a/tests/test_tools.py b/tests/test_tools.py index c7b96e6d..397e1010 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -14,15 +14,12 @@ from sgr_agent_core.tools import ( AdaptPlanTool, AnswerTool, - BraveSearchTool, ClarificationTool, CreateReportTool, ExtractPageContentTool, FinalAnswerTool, GeneratePlanTool, - PerplexitySearchTool, ReasoningTool, - TavilySearchTool, WebSearchTool, ) @@ -132,24 +129,6 @@ def test_answer_tool_initialization(self): assert tool.intermediate_result == "Found 3 relevant sources so far." assert tool.continue_research is True - def test_tavily_search_tool_initialization(self): - """Test TavilySearchTool initialization.""" - tool = TavilySearchTool(reasoning="Test", query="test query") - assert tool.tool_name == "tavilysearchtool" - assert tool._default_engine == "tavily" - - def test_brave_search_tool_initialization(self): - """Test BraveSearchTool initialization.""" - tool = BraveSearchTool(reasoning="Test", query="test query") - assert tool.tool_name == "bravesearchtool" - assert tool._default_engine == "brave" - - def test_perplexity_search_tool_initialization(self): - """Test PerplexitySearchTool initialization.""" - tool = PerplexitySearchTool(reasoning="Test", query="test query") - assert tool.tool_name == "perplexitysearchtool" - assert tool._default_engine == "perplexity" - class TestAnswerToolExecution: """Tests for AnswerTool execution.""" @@ -212,9 +191,10 @@ async def test_web_search_tool_uses_kwargs_over_config_search(self): context = AgentContext() config = MagicMock() config.search = SearchConfig(tavily_api_key="k", max_results=10) - with patch.object(TavilySearchTool, "_search", new_callable=AsyncMock, return_value=[]) as mock_search: - await tool(context, config, max_results=3) - assert mock_search.call_args.kwargs["config"].max_results == 3 + mock_handler = AsyncMock(return_value=[]) + with patch.dict("sgr_agent_core.tools.web_search_tool._ENGINE_HANDLERS", {"tavily": mock_handler}): + await tool(context, config, api_key="k", max_results=3) + assert mock_handler.call_args.kwargs["max_results"] == 3 @pytest.mark.asyncio async def test_web_search_tool_fallback_to_config_search(self): @@ -224,9 +204,10 @@ async def test_web_search_tool_fallback_to_config_search(self): context = AgentContext() config = MagicMock() config.search = SearchConfig(tavily_api_key="k", max_results=10) - with patch.object(TavilySearchTool, "_search", new_callable=AsyncMock, return_value=[]) as mock_search: - await tool(context, config) - assert mock_search.call_args.kwargs["config"].max_results == 10 + mock_handler = AsyncMock(return_value=[]) + with patch.dict("sgr_agent_core.tools.web_search_tool._ENGINE_HANDLERS", {"tavily": mock_handler}): + await tool(context, config, api_key="k") + assert mock_handler.call_args.kwargs["max_results"] == 5 @pytest.mark.asyncio async def test_web_search_tool_with_offset(self): @@ -237,14 +218,14 @@ async def test_web_search_tool_with_offset(self): config = MagicMock() config.search = SearchConfig(tavily_api_key="k", max_results=10) - # Provider returns already-offset results (3 items after skipping 2) mock_sources = [ SourceData(number=i, url=f"https://example.com/{i}", title=f"Result {i}", snippet=f"Snippet {i}") for i in range(2, 5) ] - with patch.object(TavilySearchTool, "_search", new_callable=AsyncMock, return_value=mock_sources): - result = await tool(context, config) + mock_handler = AsyncMock(return_value=mock_sources) + with patch.dict("sgr_agent_core.tools.web_search_tool._ENGINE_HANDLERS", {"tavily": mock_handler}): + result = await tool(context, config, api_key="k") assert len(context.searches) == 1 assert len(context.searches[0].citations) == 3 @@ -265,8 +246,9 @@ async def test_web_search_tool_offset_default_zero(self): for i in range(3) ] - with patch.object(TavilySearchTool, "_search", new_callable=AsyncMock, return_value=mock_sources): - await tool(context, config) + mock_handler = AsyncMock(return_value=mock_sources) + with patch.dict("sgr_agent_core.tools.web_search_tool._ENGINE_HANDLERS", {"tavily": mock_handler}): + await tool(context, config, api_key="k") assert len(context.searches[0].citations) == 3 @@ -279,24 +261,13 @@ async def test_web_search_tool_offset_exceeds_results(self): config = MagicMock() config.search = SearchConfig(tavily_api_key="k", max_results=20) - # Provider returns empty list (offset exceeded available results) - with patch.object(TavilySearchTool, "_search", new_callable=AsyncMock, return_value=[]): - result = await tool(context, config) + mock_handler = AsyncMock(return_value=[]) + with patch.dict("sgr_agent_core.tools.web_search_tool._ENGINE_HANDLERS", {"tavily": mock_handler}): + result = await tool(context, config, api_key="k") assert len(context.searches[0].citations) == 0 assert "Search Query: test" in result - @pytest.mark.asyncio - async def test_brave_search_tool_forces_engine(self): - """BraveSearchTool forces engine='brave' regardless of config.""" - tool = BraveSearchTool(reasoning="r", query="test", max_results=5) - context = AgentContext() - config = MagicMock() - config.search = SearchConfig(tavily_api_key="k", max_results=10, engine="tavily") - with patch.object(BraveSearchTool, "_search", new_callable=AsyncMock, return_value=[]) as mock_search: - await tool(context, config) - assert mock_search.call_args.kwargs["config"].engine == "brave" - @pytest.mark.asyncio async def test_extract_page_content_tool_uses_content_limit_from_kwargs(self): """ExtractPageContentTool uses content_limit from kwargs.""" From 358550a3644fdc6f2a02c47aca33506a7936182e Mon Sep 17 00:00:00 2001 From: Nikita Matsko Date: Mon, 16 Mar 2026 09:06:29 +0000 Subject: [PATCH 12/12] refactor(tools): add ExtractPageContentConfig and refactor extraction tool - extract SearchConfig into ExtractPageContentConfig in tools - update ExtractPageContentTool to use new config and validate tavily_api_key - adjust package exports and imports to expose new config class - update docs/comments and tests to reference ExtractPageContentConfig --- sgr_agent_core/__init__.py | 2 - sgr_agent_core/agent_definition.py | 8 --- sgr_agent_core/tools/__init__.py | 3 +- .../tools/extract_page_content_tool.py | 52 ++++++++++++------- sgr_agent_core/utils.py | 2 +- tests/test_base_agent.py | 2 +- tests/test_search_providers.py | 5 +- 7 files changed, 40 insertions(+), 34 deletions(-) diff --git a/sgr_agent_core/__init__.py b/sgr_agent_core/__init__.py index e398ccec..6778bd87 100644 --- a/sgr_agent_core/__init__.py +++ b/sgr_agent_core/__init__.py @@ -16,7 +16,6 @@ ExecutionConfig, LLMConfig, PromptsConfig, - SearchConfig, ) from sgr_agent_core.agent_factory import AgentFactory from sgr_agent_core.agents import * # noqa: F403 @@ -63,7 +62,6 @@ "AgentDefinition", "LLMConfig", "PromptsConfig", - "SearchConfig", "ExecutionConfig", "GlobalConfig", # Next step tools diff --git a/sgr_agent_core/agent_definition.py b/sgr_agent_core/agent_definition.py index 47670246..43971826 100644 --- a/sgr_agent_core/agent_definition.py +++ b/sgr_agent_core/agent_definition.py @@ -62,14 +62,6 @@ def to_openai_client_kwargs(self) -> dict[str, Any]: return self.model_dump(exclude={"api_key", "base_url", "proxy"}) -class SearchConfig(BaseModel, extra="allow"): - max_searches: int = Field(default=4, ge=0, description="Maximum number of searches") - max_results: int = Field(default=10, ge=1, description="Maximum number of search results") - content_limit: int = Field(default=3500, gt=0, description="Content character limit per source") - tavily_api_key: str | None = Field(default=None, description="Tavily API key") - tavily_api_base_url: str = Field(default="https://api.tavily.com", description="Tavily API base URL") - - class PromptsConfig(BaseModel, extra="allow"): system_prompt_file: FilePath | None = Field( default=os.path.join(os.path.dirname(__file__), "prompts/system_prompt.txt"), diff --git a/sgr_agent_core/tools/__init__.py b/sgr_agent_core/tools/__init__.py index 2e71cec4..258f63eb 100644 --- a/sgr_agent_core/tools/__init__.py +++ b/sgr_agent_core/tools/__init__.py @@ -8,7 +8,7 @@ from sgr_agent_core.tools.answer_tool import AnswerTool from sgr_agent_core.tools.clarification_tool import ClarificationTool from sgr_agent_core.tools.create_report_tool import CreateReportTool -from sgr_agent_core.tools.extract_page_content_tool import ExtractPageContentTool +from sgr_agent_core.tools.extract_page_content_tool import ExtractPageContentConfig, ExtractPageContentTool from sgr_agent_core.tools.final_answer_tool import FinalAnswerTool from sgr_agent_core.tools.generate_plan_tool import GeneratePlanTool from sgr_agent_core.tools.reasoning_tool import ReasoningTool @@ -28,6 +28,7 @@ "AnswerTool", "ClarificationTool", "CreateReportTool", + "ExtractPageContentConfig", "ExtractPageContentTool", "FinalAnswerTool", "GeneratePlanTool", diff --git a/sgr_agent_core/tools/extract_page_content_tool.py b/sgr_agent_core/tools/extract_page_content_tool.py index ff547ad9..1f883ee4 100644 --- a/sgr_agent_core/tools/extract_page_content_tool.py +++ b/sgr_agent_core/tools/extract_page_content_tool.py @@ -3,10 +3,9 @@ import logging from typing import TYPE_CHECKING, Any -from pydantic import Field +from pydantic import BaseModel, Field, model_validator from tavily import AsyncTavilyClient -from sgr_agent_core.agent_definition import SearchConfig from sgr_agent_core.base_tool import BaseTool from sgr_agent_core.models import SourceData @@ -18,6 +17,23 @@ logger.setLevel(logging.INFO) +class ExtractPageContentConfig(BaseModel, extra="allow"): + """Configuration for ExtractPageContentTool (Tavily Extract API).""" + + tavily_api_key: str | None = Field(default=None, description="Tavily API key") + tavily_api_base_url: str = Field(default="https://api.tavily.com", description="Tavily API base URL") + content_limit: int = Field(default=3500, gt=0, description="Content character limit per source") + + @model_validator(mode="after") + def validate_api_key(self): + if not self.tavily_api_key: + raise ValueError( + "tavily_api_key is required for ExtractPageContentTool." + " Tavily is the only provider that supports content extraction." + ) + return self + + class ExtractPageContentTool(BaseTool): """Extract full detailed content from specific web pages. @@ -35,13 +51,13 @@ class ExtractPageContentTool(BaseTool): - For date/number questions, cross-check extracted values with search snippets """ - config_model = SearchConfig + config_model = ExtractPageContentConfig reasoning: str = Field(description="Why extract these specific pages") urls: list[str] = Field(description="List of URLs to extract full content from", min_length=1, max_length=5) @staticmethod - async def _extract(config: SearchConfig, urls: list[str]) -> list[SourceData]: + async def _extract(config: ExtractPageContentConfig, urls: list[str]) -> list[SourceData]: """Extract full content from URLs via Tavily Extract API.""" logger.info(f"Tavily extract: {len(urls)} URLs") @@ -53,13 +69,15 @@ async def _extract(config: SearchConfig, urls: list[str]) -> list[SourceData]: if not result.get("url"): continue + url = result.get("url", "") + raw_content = result.get("raw_content", "") source = SourceData( number=i, - title=result.get("url", "").split("/")[-1] or "Extracted Content", - url=result.get("url", ""), + title=url.split("/")[-1] or "Extracted Content", + url=url, snippet="", - full_content=result.get("raw_content", ""), - char_count=len(result.get("raw_content", "")), + full_content=raw_content, + char_count=len(raw_content), ) sources.append(source) @@ -71,15 +89,13 @@ async def _extract(config: SearchConfig, urls: list[str]) -> list[SourceData]: async def __call__(self, context: AgentContext, config: AgentConfig, **kwargs: Any) -> str: """Extract full content from specified URLs.""" - search_config = SearchConfig(**kwargs) - if not search_config.tavily_api_key: - return ( - "Error: tavily_api_key is required for ExtractPageContentTool." - " Tavily is the only provider that supports content extraction." - ) + try: + extract_config = ExtractPageContentConfig(**kwargs) + except ValueError as e: + return f"Error: {e}" logger.info(f"Extracting content from {len(self.urls)} URLs") - sources = await self._extract(search_config, urls=self.urls) + sources = await self._extract(extract_config, urls=self.urls) # Update existing sources instead of overwriting for source in sources: @@ -97,10 +113,10 @@ async def __call__(self, context: AgentContext, config: AgentConfig, **kwargs: A # Format results using sources from context (to get correct numbers) for url in self.urls: - if url in context.sources: - source = context.sources[url] + source = context.sources.get(url) + if source is not None: if source.full_content: - content_preview = source.full_content[: search_config.content_limit] + content_preview = source.full_content[: extract_config.content_limit] formatted_result += ( f"{str(source)}\n\n**Full Content:**\n" f"{content_preview}\n\n" diff --git a/sgr_agent_core/utils.py b/sgr_agent_core/utils.py index e193fb28..1a7a8d8c 100644 --- a/sgr_agent_core/utils.py +++ b/sgr_agent_core/utils.py @@ -27,7 +27,7 @@ def config_from_kwargs(config_class: type[T], base: T | None, kwargs: dict[str, agent-level config with per-tool kwargs from the tools array (global or inline). Args: - config_class: Pydantic model class to instantiate (e.g. SearchConfig). + config_class: Pydantic model class to instantiate (e.g. WebSearchConfig). base: Existing config instance, or None to use only kwargs (with model defaults). kwargs: Overrides; keys present here override base. None values are skipped. diff --git a/tests/test_base_agent.py b/tests/test_base_agent.py index 7c249bde..227f628e 100644 --- a/tests/test_base_agent.py +++ b/tests/test_base_agent.py @@ -357,7 +357,7 @@ def test_get_tool_config_returns_model_when_tool_has_config_model(self): assert out.max_searches == 6 def test_get_tool_config_returns_model_from_tool_configs_only(self): - """get_tool_config builds SearchConfig exclusively from tool_configs + """get_tool_config builds config model exclusively from tool_configs (search settings are per-tool, not in AgentConfig).""" agent = create_test_agent( BaseAgent, diff --git a/tests/test_search_providers.py b/tests/test_search_providers.py index 43df7b11..49cdb464 100644 --- a/tests/test_search_providers.py +++ b/tests/test_search_providers.py @@ -4,7 +4,6 @@ import pytest -from sgr_agent_core.agent_definition import SearchConfig from sgr_agent_core.models import SourceData @@ -153,9 +152,9 @@ def test_convert_tavily_response(self): @pytest.mark.asyncio async def test_extract_calls_tavily_api(self): - from sgr_agent_core.tools.extract_page_content_tool import ExtractPageContentTool + from sgr_agent_core.tools.extract_page_content_tool import ExtractPageContentConfig, ExtractPageContentTool - config = SearchConfig(tavily_api_key="test-key") + config = ExtractPageContentConfig(tavily_api_key="test-key") mock_client = AsyncMock() mock_client.extract = AsyncMock(