diff --git a/hud/agents/__init__.py b/hud/agents/__init__.py index 7470adb3..f69d4ed5 100644 --- a/hud/agents/__init__.py +++ b/hud/agents/__init__.py @@ -2,11 +2,13 @@ from .base import MCPAgent from .claude import ClaudeAgent +from .gemini import GeminiAgent from .openai import OperatorAgent from .openai_chat_generic import GenericOpenAIChatAgent __all__ = [ "ClaudeAgent", + "GeminiAgent", "GenericOpenAIChatAgent", "MCPAgent", "OperatorAgent", diff --git a/hud/agents/gemini.py b/hud/agents/gemini.py new file mode 100644 index 00000000..bd5dc2c5 --- /dev/null +++ b/hud/agents/gemini.py @@ -0,0 +1,489 @@ +"""Gemini MCP Agent implementation.""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Any, ClassVar, cast + +from google import genai +from google.genai import types as genai_types + +import hud + +if TYPE_CHECKING: + from hud.datasets import Task + +import mcp.types as types + +from hud.settings import settings +from hud.tools.computer.settings import computer_settings +from hud.types import AgentResponse, MCPToolCall, MCPToolResult +from hud.utils.hud_console import HUDConsole + +from .base import MCPAgent + +logger = logging.getLogger(__name__) + +# Maximum number of recent turns to keep screenshots for +MAX_RECENT_TURN_WITH_SCREENSHOTS = 3 + +# Predefined Gemini computer use functions +PREDEFINED_COMPUTER_USE_FUNCTIONS = [ + "open_web_browser", + "click_at", + "hover_at", + "type_text_at", + "scroll_document", + "scroll_at", + "wait_5_seconds", + "go_back", + "go_forward", + "search", + "navigate", + "key_combination", + "drag_and_drop", +] + + +class GeminiAgent(MCPAgent): + """ + Gemini agent that uses MCP servers for tool execution. + + This agent uses Gemini's native computer use capabilities but executes + tools through MCP servers instead of direct implementation. + """ + + metadata: ClassVar[dict[str, Any]] = { + "display_width": computer_settings.GEMINI_COMPUTER_WIDTH, + "display_height": computer_settings.GEMINI_COMPUTER_HEIGHT, + } + + def __init__( + self, + model_client: genai.Client | None = None, + model: str = "gemini-2.5-computer-use-preview-10-2025", + temperature: float = 1.0, + top_p: float = 0.95, + top_k: int = 40, + max_output_tokens: int = 8192, + validate_api_key: bool = True, + excluded_predefined_functions: list[str] | None = None, + **kwargs: Any, + ) -> None: + """ + Initialize Gemini MCP agent. + + Args: + model_client: Gemini client (created if not provided) + model: Gemini model to use + temperature: Temperature for response generation + top_p: Top-p sampling parameter + top_k: Top-k sampling parameter + max_output_tokens: Maximum tokens for response + validate_api_key: Whether to validate API key on initialization + excluded_predefined_functions: List of predefined functions to exclude + **kwargs: Additional arguments passed to BaseMCPAgent (including mcp_client) + """ + super().__init__(**kwargs) + + # Initialize client if not provided + if model_client is None: + api_key = settings.gemini_api_key + if not api_key: + raise ValueError("Gemini API key not found. Set GEMINI_API_KEY.") + model_client = genai.Client(api_key=api_key) + + # Validate API key if requested + if validate_api_key: + try: + # Simple validation - try to list models + list(model_client.models.list(config=genai_types.ListModelsConfig(page_size=1))) + except Exception as e: + raise ValueError(f"Gemini API key is invalid: {e}") from e + + self.gemini_client = model_client + self.model = model + self.temperature = temperature + self.top_p = top_p + self.top_k = top_k + self.max_output_tokens = max_output_tokens + self.excluded_predefined_functions = excluded_predefined_functions or [] + self.hud_console = HUDConsole(logger=logger) + + self.model_name = self.model + + # Track mapping from Gemini tool names to MCP tool names + self._gemini_to_mcp_tool_map: dict[str, str] = {} + self.gemini_tools: list[genai_types.Tool] = [] + + # Append Gemini-specific instructions to the base system prompt + gemini_instructions = "\n".join( + [ + "You are Gemini, a helpful AI assistant created by Google.", + "You can interact with computer interfaces.", + "", + "When working on tasks:", + "1. Be thorough and systematic in your approach", + "2. Complete tasks autonomously without asking for confirmation", + "3. Use available tools efficiently to accomplish your goals", + "4. Verify your actions and ensure task completion", + "5. Be precise and accurate in all operations", + "6. Adapt to the environment and the task at hand", + "", + "Remember: You are expected to complete tasks autonomously.", + "The user trusts you to accomplish what they asked.", + ] + ) + + # Append Gemini instructions to any base system prompt + if self.system_prompt: + self.system_prompt = f"{self.system_prompt}\n\n{gemini_instructions}" + else: + self.system_prompt = gemini_instructions + + async def initialize(self, task: str | Task | None = None) -> None: + """Initialize the agent and build tool mappings.""" + await super().initialize(task) + # Build tool mappings after tools are discovered + self._convert_tools_for_gemini() + + async def get_system_messages(self) -> list[Any]: + """No system messages for Gemini because applied in get_response""" + return [] + + async def format_blocks(self, blocks: list[types.ContentBlock]) -> list[genai_types.Content]: + """Format messages for Gemini.""" + # Convert MCP content types to Gemini content types + gemini_parts: list[genai_types.Part] = [] + + for block in blocks: + if isinstance(block, types.TextContent): + gemini_parts.append(genai_types.Part(text=block.text)) + elif isinstance(block, types.ImageContent): + # Convert MCP ImageContent to Gemini format + # Need to decode base64 string to bytes + import base64 + + image_bytes = base64.b64decode(block.data) + gemini_parts.append( + genai_types.Part.from_bytes(data=image_bytes, mime_type=block.mimeType) + ) + else: + # For other types, try to handle but log a warning + self.hud_console.log(f"Unknown content block type: {type(block)}", level="warning") + + return [genai_types.Content(role="user", parts=gemini_parts)] + + @hud.instrument( + span_type="agent", + record_args=False, # Messages can be large + record_result=True, + ) + async def get_response(self, messages: list[genai_types.Content]) -> AgentResponse: + """Get response from Gemini including any tool calls.""" + + # Build generate content config + generate_config = genai_types.GenerateContentConfig( + temperature=self.temperature, + top_p=self.top_p, + top_k=self.top_k, + max_output_tokens=self.max_output_tokens, + tools=cast("Any", self.gemini_tools), + system_instruction=self.system_prompt, + ) + + # Trim screenshots from older turns to manage context growth + self._remove_old_screenshots(messages) + + # Make API call - using a simpler call pattern + response = self.gemini_client.models.generate_content( + model=self.model, + contents=cast("Any", messages), + config=generate_config, + ) + + # Append assistant response (including any function_call) so that + # subsequent FunctionResponse messages correspond to a prior FunctionCall + if response.candidates and len(response.candidates) > 0 and response.candidates[0].content: + cast("list[genai_types.Content]", messages).append(response.candidates[0].content) + + # Process response + result = AgentResponse(content="", tool_calls=[], done=True) + collected_tool_calls: list[MCPToolCall] = [] + + if not response.candidates: + self.hud_console.warning("Response has no candidates") + return result + + candidate = response.candidates[0] + + # Extract text content and function calls + text_content = "" + thinking_content = "" + + if candidate.content and candidate.content.parts: + for part in candidate.content.parts: + if part.function_call: + # Map Gemini tool name back to MCP tool name + func_name = part.function_call.name or "" + mcp_tool_name = self._gemini_to_mcp_tool_map.get(func_name, func_name) + + # Create MCPToolCall object with Gemini metadata + raw_args = dict(part.function_call.args) if part.function_call.args else {} + + # Normalize Gemini Computer Use calls to MCP tool schema + if part.function_call.name in PREDEFINED_COMPUTER_USE_FUNCTIONS: + # Ensure 'action' is present and equals the Gemini function name + normalized_args: dict[str, Any] = {"action": part.function_call.name} + + # Map common argument shapes used by Gemini Computer Use + # 1) Coordinate arrays → x/y + coord = raw_args.get("coordinate") or raw_args.get("coordinates") + if isinstance(coord, (list, tuple)) and len(coord) >= 2: + try: + normalized_args["x"] = int(coord[0]) + normalized_args["y"] = int(coord[1]) + except (TypeError, ValueError): + # Fall back to raw if casting fails + pass + + # Destination coordinate arrays → destination_x/destination_y + dest = ( + raw_args.get("destination") + or raw_args.get("destination_coordinate") + or raw_args.get("destinationCoordinate") + ) + if isinstance(dest, (list, tuple)) and len(dest) >= 2: + try: + normalized_args["destination_x"] = int(dest[0]) + normalized_args["destination_y"] = int(dest[1]) + except (TypeError, ValueError): + pass + + # Pass through supported fields if present (including direct coords) + for key in ( + "text", + "press_enter", + "clear_before_typing", + "safety_decision", + "direction", + "magnitude", + "url", + "keys", + "x", + "y", + "destination_x", + "destination_y", + ): + if key in raw_args: + normalized_args[key] = raw_args[key] + + # Use normalized args for computer tool calls + final_args = normalized_args + else: + # Non-computer tools: pass args as-is + final_args = raw_args + + tool_call = MCPToolCall( + name=mcp_tool_name, + arguments=final_args, + gemini_name=func_name, # type: ignore[arg-type] + ) + collected_tool_calls.append(tool_call) + elif part.text: + text_content += part.text + elif hasattr(part, "thought") and part.thought: + thinking_content += f"Thinking: {part.thought}\n" + + # Assign collected tool calls and mark done status + if collected_tool_calls: + result.tool_calls = collected_tool_calls + result.done = False + + # Combine text and thinking for final content + if thinking_content: + result.content = thinking_content + text_content + else: + result.content = text_content + + return result + + async def format_tool_results( + self, tool_calls: list[MCPToolCall], tool_results: list[MCPToolResult] + ) -> list[genai_types.Content]: + """Format tool results into Gemini messages.""" + # Process each tool result + function_responses = [] + + for tool_call, result in zip(tool_calls, tool_results, strict=True): + # Get the Gemini function name from metadata + gemini_name = getattr(tool_call, "gemini_name", tool_call.name) + + # Convert MCP tool results to Gemini format + response_dict: dict[str, Any] = {} + url = None + + if result.isError: + # Extract error message from content + error_msg = "Tool execution failed" + for content in result.content: + if isinstance(content, types.TextContent): + # Check if this is a URL metadata block + if content.text.startswith("__URL__:"): + url = content.text.replace("__URL__:", "") + else: + error_msg = content.text + break + response_dict["error"] = error_msg + else: + # Process success content + response_dict["success"] = True + + # Extract URL and screenshot from content + screenshot_parts = [] + for content in result.content: + if isinstance(content, types.TextContent): + # Check if this is a URL metadata block + if content.text.startswith("__URL__:"): + url = content.text.replace("__URL__:", "") + elif isinstance(content, types.ImageContent): + # Decode base64 string to bytes for FunctionResponseBlob + import base64 + + image_bytes = base64.b64decode(content.data) + screenshot_parts.append( + genai_types.FunctionResponsePart( + inline_data=genai_types.FunctionResponseBlob( + mime_type=content.mimeType or "image/png", + data=image_bytes, + ) + ) + ) + + # Add URL to response dict (required by Gemini Computer Use model) + # URL must ALWAYS be present per Gemini API requirements + response_dict["url"] = url if url else "about:blank" + + # For Gemini Computer Use actions, always acknowledge safety decisions + requires_ack = False + if tool_call.arguments: + requires_ack = bool(tool_call.arguments.get("safety_decision")) + if gemini_name in PREDEFINED_COMPUTER_USE_FUNCTIONS and requires_ack: + response_dict["safety_acknowledgement"] = True + + # Create function response + function_response = genai_types.FunctionResponse( + name=gemini_name, + response=response_dict, + parts=screenshot_parts if screenshot_parts else None, + ) + function_responses.append(function_response) + + # Return as a user message containing all function responses + return [ + genai_types.Content( + role="user", + parts=[genai_types.Part(function_response=fr) for fr in function_responses], + ) + ] + + async def create_user_message(self, text: str) -> genai_types.Content: + """Create a user message in Gemini's format.""" + return genai_types.Content(role="user", parts=[genai_types.Part(text=text)]) + + def _convert_tools_for_gemini(self) -> list[genai_types.Tool]: + """Convert MCP tools to Gemini tool format.""" + gemini_tools = [] + self._gemini_to_mcp_tool_map = {} # Reset mapping + + # Find computer tool by priority + computer_tool_priority = ["gemini_computer", "computer_gemini", "computer"] + selected_computer_tool = None + + for priority_name in computer_tool_priority: + for tool in self.get_available_tools(): + # Check both exact match and suffix match (for prefixed tools) + if tool.name == priority_name or tool.name.endswith(f"_{priority_name}"): + selected_computer_tool = tool + break + if selected_computer_tool: + break + + # Add the selected computer tool if found + if selected_computer_tool: + gemini_tool = genai_types.Tool( + computer_use=genai_types.ComputerUse( + environment=genai_types.Environment.ENVIRONMENT_BROWSER, + excluded_predefined_functions=self.excluded_predefined_functions, + ) + ) + # Map Gemini's computer use functions back to the actual MCP tool name + for func_name in PREDEFINED_COMPUTER_USE_FUNCTIONS: + if func_name not in self.excluded_predefined_functions: + self._gemini_to_mcp_tool_map[func_name] = selected_computer_tool.name + + gemini_tools.append(gemini_tool) + self.hud_console.debug( + f"Using {selected_computer_tool.name} as computer tool for Gemini" + ) + + # Add other non-computer tools as custom functions + for tool in self.get_available_tools(): + # Skip computer tools (already handled) + if any( + tool.name == priority_name or tool.name.endswith(f"_{priority_name}") + for priority_name in computer_tool_priority + ): + continue + + # Convert MCP tool schema to Gemini function declaration + try: + # Ensure parameters have proper Schema format + params = tool.inputSchema or {"type": "object", "properties": {}} + function_decl = genai_types.FunctionDeclaration( + name=tool.name, + description=tool.description or f"Execute {tool.name}", + parameters=genai_types.Schema(**params) if isinstance(params, dict) else params, # type: ignore + ) + custom_tool = genai_types.Tool(function_declarations=[function_decl]) + gemini_tools.append(custom_tool) + # Direct mapping for non-computer tools + self._gemini_to_mcp_tool_map[tool.name] = tool.name + except Exception: + self.hud_console.warning(f"Failed to convert tool {tool.name} to Gemini format") + + self.gemini_tools = gemini_tools + return gemini_tools + + def _remove_old_screenshots(self, messages: list[genai_types.Content]) -> None: + """ + Remove screenshots from old turns to manage context length. + Keeps only the last MAX_RECENT_TURN_WITH_SCREENSHOTS turns with screenshots. + """ + turn_with_screenshots_found = 0 + + for content in reversed(messages): + if content.role == "user" and content.parts: + # Check if content has screenshots (function responses with images) + has_screenshot = False + for part in content.parts: + if ( + part.function_response + and part.function_response.parts + and part.function_response.name in PREDEFINED_COMPUTER_USE_FUNCTIONS + ): + has_screenshot = True + break + + if has_screenshot: + turn_with_screenshots_found += 1 + # Remove the screenshot image if the number of screenshots exceeds the limit + if turn_with_screenshots_found > MAX_RECENT_TURN_WITH_SCREENSHOTS: + for part in content.parts: + if ( + part.function_response + and part.function_response.parts + and part.function_response.name in PREDEFINED_COMPUTER_USE_FUNCTIONS + ): + # Clear the parts (screenshots) + part.function_response.parts = None diff --git a/hud/agents/tests/test_gemini.py b/hud/agents/tests/test_gemini.py new file mode 100644 index 00000000..0164f568 --- /dev/null +++ b/hud/agents/tests/test_gemini.py @@ -0,0 +1,372 @@ +"""Tests for Gemini MCP Agent implementation.""" + +from __future__ import annotations + +import base64 +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from google.genai import types as genai_types +from mcp import types + +from hud.agents.gemini import GeminiAgent +from hud.types import MCPToolCall, MCPToolResult + + +class TestGeminiAgent: + """Test GeminiAgent class.""" + + @pytest.fixture + def mock_mcp_client(self): + """Create a mock MCP client.""" + mcp_client = AsyncMock() + # Set up the mcp_config attribute as a regular dict, not a coroutine + mcp_client.mcp_config = {"test_server": {"url": "http://test"}} + # Mock list_tools to return gemini_computer tool + mcp_client.list_tools = AsyncMock( + return_value=[ + types.Tool( + name="gemini_computer", + description="Gemini computer use tool", + inputSchema={}, + ) + ] + ) + mcp_client.initialize = AsyncMock() + return mcp_client + + @pytest.fixture + def mock_gemini_client(self): + """Create a mock Gemini client.""" + client = MagicMock() + client.api_key = "test_key" + # Mock models.list for validation + client.models = MagicMock() + client.models.list = MagicMock(return_value=iter([])) + return client + + @pytest.mark.asyncio + async def test_init(self, mock_mcp_client, mock_gemini_client): + """Test agent initialization.""" + agent = GeminiAgent( + mcp_client=mock_mcp_client, + model_client=mock_gemini_client, + model="gemini-2.5-computer-use-preview-10-2025", + validate_api_key=False, # Skip validation in tests + ) + + assert agent.model_name == "gemini-2.5-computer-use-preview-10-2025" + assert agent.model == "gemini-2.5-computer-use-preview-10-2025" + assert agent.gemini_client == mock_gemini_client + + @pytest.mark.asyncio + async def test_init_without_model_client(self, mock_mcp_client): + """Test agent initialization without model client.""" + with ( + patch("hud.settings.settings.gemini_api_key", "test_key"), + patch("hud.agents.gemini.genai.Client") as mock_client_class, + ): + mock_client = MagicMock() + mock_client.api_key = "test_key" + mock_client.models = MagicMock() + mock_client.models.list = MagicMock(return_value=iter([])) + mock_client_class.return_value = mock_client + + agent = GeminiAgent( + mcp_client=mock_mcp_client, + model="gemini-2.5-computer-use-preview-10-2025", + validate_api_key=False, + ) + + assert agent.model_name == "gemini-2.5-computer-use-preview-10-2025" + assert agent.gemini_client is not None + + @pytest.mark.asyncio + async def test_format_blocks(self, mock_mcp_client, mock_gemini_client): + """Test formatting content blocks into Gemini messages.""" + agent = GeminiAgent( + mcp_client=mock_mcp_client, + model_client=mock_gemini_client, + validate_api_key=False, + ) + + # Test with text only + text_blocks: list[types.ContentBlock] = [ + types.TextContent(type="text", text="Hello, Gemini!") + ] + messages = await agent.format_blocks(text_blocks) + assert len(messages) == 1 + assert messages[0].role == "user" + parts = messages[0].parts + assert parts is not None + assert len(parts) == 1 + assert parts[0].text == "Hello, Gemini!" + + # Test with screenshot + image_blocks: list[types.ContentBlock] = [ + types.TextContent(type="text", text="Look at this"), + types.ImageContent( + type="image", + data=base64.b64encode(b"fakeimage").decode("utf-8"), + mimeType="image/png", + ), + ] + messages = await agent.format_blocks(image_blocks) + assert len(messages) == 1 + assert messages[0].role == "user" + parts = messages[0].parts + assert parts is not None + assert len(parts) == 2 + # First part is text + assert parts[0].text == "Look at this" + # Second part is image - check that it was created from bytes + assert parts[1].inline_data is not None + + @pytest.mark.asyncio + async def test_format_tool_results(self, mock_mcp_client, mock_gemini_client): + """Test the agent's format_tool_results method.""" + agent = GeminiAgent( + mcp_client=mock_mcp_client, + model_client=mock_gemini_client, + validate_api_key=False, + ) + + tool_calls = [ + MCPToolCall( + name="gemini_computer", + arguments={"action": "click_at", "x": 100, "y": 200}, + id="call_1", # type: ignore + gemini_name="click_at", # type: ignore + ), + ] + + tool_results = [ + MCPToolResult( + content=[ + types.TextContent(type="text", text="Clicked successfully"), + types.ImageContent( + type="image", + data=base64.b64encode(b"screenshot").decode("utf-8"), + mimeType="image/png", + ), + ], + isError=False, + ), + ] + + messages = await agent.format_tool_results(tool_calls, tool_results) + + # format_tool_results returns a single user message with function responses + assert len(messages) == 1 + assert messages[0].role == "user" + # The content contains function response parts + parts = messages[0].parts + assert parts is not None + assert len(parts) == 1 + function_response = parts[0].function_response + assert function_response is not None + assert function_response.name == "click_at" + response_payload = function_response.response or {} + assert response_payload.get("success") is True + + @pytest.mark.asyncio + async def test_format_tool_results_with_error(self, mock_mcp_client, mock_gemini_client): + """Test formatting tool results with errors.""" + agent = GeminiAgent( + mcp_client=mock_mcp_client, + model_client=mock_gemini_client, + validate_api_key=False, + ) + + tool_calls = [ + MCPToolCall( + name="gemini_computer", + arguments={"action": "invalid"}, + id="call_error", # type: ignore + gemini_name="invalid_action", # type: ignore + ), + ] + + tool_results = [ + MCPToolResult( + content=[types.TextContent(type="text", text="Action failed: invalid action")], + isError=True, + ), + ] + + messages = await agent.format_tool_results(tool_calls, tool_results) + + # Check that error is in the response + assert len(messages) == 1 + assert messages[0].role == "user" + parts = messages[0].parts + assert parts is not None + function_response = parts[0].function_response + assert function_response is not None + response_payload = function_response.response or {} + assert "error" in response_payload + + @pytest.mark.asyncio + async def test_get_response(self, mock_mcp_client, mock_gemini_client): + """Test getting model response from Gemini API.""" + # Disable telemetry for this test + with patch("hud.settings.settings.telemetry_enabled", False): + agent = GeminiAgent( + mcp_client=mock_mcp_client, + model_client=mock_gemini_client, + validate_api_key=False, + ) + + # Set up available tools + agent._available_tools = [ + types.Tool(name="gemini_computer", description="Computer tool", inputSchema={}) + ] + + # Mock the API response + mock_response = MagicMock() + mock_candidate = MagicMock() + + # Create text part + text_part = MagicMock() + text_part.text = "I will click at coordinates" + text_part.function_call = None + + # Create function call part + function_call_part = MagicMock() + function_call_part.text = None + function_call_part.function_call = MagicMock() + function_call_part.function_call.name = "click_at" + function_call_part.function_call.args = {"x": 100, "y": 200} + + mock_candidate.content = MagicMock() + mock_candidate.content.parts = [text_part, function_call_part] + + mock_response.candidates = [mock_candidate] + + mock_gemini_client.models = MagicMock() + mock_gemini_client.models.generate_content = MagicMock(return_value=mock_response) + + messages = [genai_types.Content(role="user", parts=[genai_types.Part(text="Click")])] + response = await agent.get_response(messages) + + assert response.content == "I will click at coordinates" + assert len(response.tool_calls) == 1 + assert response.tool_calls[0].arguments == {"action": "click_at", "x": 100, "y": 200} + assert response.done is False + + @pytest.mark.asyncio + async def test_get_response_text_only(self, mock_mcp_client, mock_gemini_client): + """Test getting text-only response.""" + # Disable telemetry for this test + with patch("hud.settings.settings.telemetry_enabled", False): + agent = GeminiAgent( + mcp_client=mock_mcp_client, + model_client=mock_gemini_client, + validate_api_key=False, + ) + + # Mock the API response with text only + mock_response = MagicMock() + mock_candidate = MagicMock() + + text_part = MagicMock() + text_part.text = "Task completed successfully" + text_part.function_call = None + + mock_candidate.content = MagicMock() + mock_candidate.content.parts = [text_part] + + mock_response.candidates = [mock_candidate] + + mock_gemini_client.models = MagicMock() + mock_gemini_client.models.generate_content = MagicMock(return_value=mock_response) + + messages = [genai_types.Content(role="user", parts=[genai_types.Part(text="Status?")])] + response = await agent.get_response(messages) + + assert response.content == "Task completed successfully" + assert response.tool_calls == [] + assert response.done is True + + @pytest.mark.asyncio + async def test_convert_tools_for_gemini(self, mock_mcp_client, mock_gemini_client): + """Test converting MCP tools to Gemini format.""" + agent = GeminiAgent( + mcp_client=mock_mcp_client, + model_client=mock_gemini_client, + validate_api_key=False, + ) + + # Set up available tools + agent._available_tools = [ + types.Tool( + name="gemini_computer", + description="Computer tool", + inputSchema={"type": "object"}, + ), + types.Tool( + name="calculator", + description="Calculator tool", + inputSchema={ + "type": "object", + "properties": {"operation": {"type": "string"}}, + }, + ), + ] + + gemini_tools = agent._convert_tools_for_gemini() + + # Should have 2 tools: computer_use and calculator + assert len(gemini_tools) == 2 + + # First should be computer use tool + assert gemini_tools[0].computer_use is not None + assert ( + gemini_tools[0].computer_use.environment == genai_types.Environment.ENVIRONMENT_BROWSER + ) + + # Second should be calculator as function declaration + assert gemini_tools[1].function_declarations is not None + assert len(gemini_tools[1].function_declarations) == 1 + assert gemini_tools[1].function_declarations[0].name == "calculator" + + @pytest.mark.asyncio + async def test_create_user_message(self, mock_mcp_client, mock_gemini_client): + """Test creating a user message.""" + agent = GeminiAgent( + mcp_client=mock_mcp_client, + model_client=mock_gemini_client, + validate_api_key=False, + ) + + message = await agent.create_user_message("Hello Gemini") + + assert message.role == "user" + parts = message.parts + assert parts is not None + assert len(parts) == 1 + assert parts[0].text == "Hello Gemini" + + @pytest.mark.asyncio + async def test_handle_empty_response(self, mock_mcp_client, mock_gemini_client): + """Test handling empty response from API.""" + with patch("hud.settings.settings.telemetry_enabled", False): + agent = GeminiAgent( + mcp_client=mock_mcp_client, + model_client=mock_gemini_client, + validate_api_key=False, + ) + + # Mock empty response + mock_response = MagicMock() + mock_response.candidates = [] + + mock_gemini_client.models = MagicMock() + mock_gemini_client.models.generate_content = MagicMock(return_value=mock_response) + + messages = [genai_types.Content(role="user", parts=[genai_types.Part(text="Hi")])] + response = await agent.get_response(messages) + + assert response.content == "" + assert response.tool_calls == [] + assert response.done is True diff --git a/hud/cli/__init__.py b/hud/cli/__init__.py index 8148e450..3428afbe 100644 --- a/hud/cli/__init__.py +++ b/hud/cli/__init__.py @@ -917,6 +917,7 @@ def eval( [ {"name": "Claude 4 Sonnet", "value": AgentType.CLAUDE}, {"name": "OpenAI Computer Use", "value": AgentType.OPENAI}, + {"name": "Gemini Computer Use", "value": AgentType.GEMINI}, {"name": "vLLM (Local Server)", "value": AgentType.VLLM}, {"name": "LiteLLM (Multi-provider)", "value": AgentType.LITELLM}, ] diff --git a/hud/cli/eval.py b/hud/cli/eval.py index d7d70034..642fbaa3 100644 --- a/hud/cli/eval.py +++ b/hud/cli/eval.py @@ -188,6 +188,24 @@ def build_agent( else: return OperatorAgent(verbose=verbose) + elif agent_type == AgentType.GEMINI: + try: + from hud.agents import GeminiAgent + except ImportError as e: + hud_console.error( + "Gemini agent dependencies are not installed. " + "Please install with: pip install 'hud-python[agent]'" + ) + raise typer.Exit(1) from e + + gemini_kwargs: dict[str, Any] = { + "model": model or "gemini-2.5-computer-use-preview-10-2025", + "verbose": verbose, + } + if allowed_tools: + gemini_kwargs["allowed_tools"] = allowed_tools + return GeminiAgent(**gemini_kwargs) + elif agent_type == AgentType.LITELLM: try: from hud.agents.lite_llm import LiteAgent @@ -344,6 +362,17 @@ async def run_single_task( agent_config = {"verbose": verbose} if allowed_tools: agent_config["allowed_tools"] = allowed_tools + elif agent_type == AgentType.GEMINI: + from hud.agents import GeminiAgent + + agent_class = GeminiAgent + agent_config = { + "model": model or "gemini-2.5-computer-use-preview-10-2025", + "verbose": verbose, + "validate_api_key": False, + } + if allowed_tools: + agent_config["allowed_tools"] = allowed_tools elif agent_type == AgentType.LITELLM: from hud.agents.lite_llm import LiteAgent @@ -534,6 +563,26 @@ async def run_full_dataset( if allowed_tools: agent_config["allowed_tools"] = allowed_tools + elif agent_type == AgentType.GEMINI: + try: + from hud.agents import GeminiAgent + + agent_class = GeminiAgent + except ImportError as e: + hud_console.error( + "Gemini agent dependencies are not installed. " + "Please install with: pip install 'hud-python[agent]'" + ) + raise typer.Exit(1) from e + + agent_config = { + "model": model or "gemini-2.5-computer-use-preview-10-2025", + "verbose": verbose, + "validate_api_key": False, + } + if allowed_tools: + agent_config["allowed_tools"] = allowed_tools + elif agent_type == AgentType.LITELLM: try: from hud.agents.lite_llm import LiteAgent @@ -641,7 +690,7 @@ def eval_command( agent: AgentType = typer.Option( # noqa: B008 AgentType.CLAUDE, "--agent", - help="Agent backend to use (claude, openai, vllm for local server, or litellm)", + help="Agent backend to use (claude, gemini, openai, vllm for local servers, or litellm)", ), model: str | None = typer.Option( None, @@ -757,6 +806,13 @@ def eval_command( "Set it in your environment or run: hud set ANTHROPIC_API_KEY=your-key-here" ) raise typer.Exit(1) + elif agent == AgentType.GEMINI: + if not settings.gemini_api_key: + hud_console.error("GEMINI_API_KEY is required for Gemini agent") + hud_console.info( + "Set it in your environment or run: hud set GEMINI_API_KEY=your-key-here" + ) + raise typer.Exit(1) elif agent == AgentType.OPENAI and not settings.openai_api_key: hud_console.error("OPENAI_API_KEY is required for OpenAI agent") hud_console.info("Set it in your environment or run: hud set OPENAI_API_KEY=your-key-here") diff --git a/hud/cli/tests/test_eval.py b/hud/cli/tests/test_eval.py index 4b4a18af..6f7e313b 100644 --- a/hud/cli/tests/test_eval.py +++ b/hud/cli/tests/test_eval.py @@ -68,6 +68,26 @@ def test_builds_claude_agent_with_custom_model_and_allowed_tools(self) -> None: ) assert result == mock_instance + def test_builds_gemini_agent(self) -> None: + """Test building a Gemini agent.""" + with patch("hud.agents.GeminiAgent") as mock_runner: + mock_instance = Mock() + mock_runner.return_value = mock_instance + + result = build_agent( + AgentType.GEMINI, + model="gemini-test", + allowed_tools=["gemini_computer"], + verbose=True, + ) + + mock_runner.assert_called_once_with( + model="gemini-test", + verbose=True, + allowed_tools=["gemini_computer"], + ) + assert result == mock_instance + class TestRunSingleTask: """Test the run_single_task function.""" diff --git a/hud/settings.py b/hud/settings.py index 4e8787c9..bb2013a1 100644 --- a/hud/settings.py +++ b/hud/settings.py @@ -94,6 +94,12 @@ def settings_customise_sources( validation_alias="OPENAI_API_KEY", ) + gemini_api_key: str | None = Field( + default=None, + description="API key for Google Gemini models", + validation_alias="GEMINI_API_KEY", + ) + openrouter_api_key: str | None = Field( default=None, description="API key for OpenRouter models", diff --git a/hud/tools/__init__.py b/hud/tools/__init__.py index c3b6bb96..57f70e99 100644 --- a/hud/tools/__init__.py +++ b/hud/tools/__init__.py @@ -12,7 +12,12 @@ from .submit import SubmitTool if TYPE_CHECKING: - from .computer import AnthropicComputerTool, HudComputerTool, OpenAIComputerTool + from .computer import ( + AnthropicComputerTool, + GeminiComputerTool, + HudComputerTool, + OpenAIComputerTool, + ) __all__ = [ "AnthropicComputerTool", @@ -20,6 +25,7 @@ "BaseTool", "BashTool", "EditTool", + "GeminiComputerTool", "HudComputerTool", "OpenAIComputerTool", "PlaywrightTool", @@ -30,7 +36,12 @@ def __getattr__(name: str) -> Any: """Lazy import computer tools to avoid importing pyautogui unless needed.""" - if name in ("AnthropicComputerTool", "HudComputerTool", "OpenAIComputerTool"): + if name in ( + "AnthropicComputerTool", + "HudComputerTool", + "OpenAIComputerTool", + "GeminiComputerTool", + ): from . import computer return getattr(computer, name) diff --git a/hud/tools/computer/__init__.py b/hud/tools/computer/__init__.py index 67814548..9697488a 100644 --- a/hud/tools/computer/__init__.py +++ b/hud/tools/computer/__init__.py @@ -3,6 +3,7 @@ from __future__ import annotations from .anthropic import AnthropicComputerTool +from .gemini import GeminiComputerTool from .hud import HudComputerTool from .openai import OpenAIComputerTool from .qwen import QwenComputerTool @@ -10,6 +11,7 @@ __all__ = [ "AnthropicComputerTool", + "GeminiComputerTool", "HudComputerTool", "OpenAIComputerTool", "QwenComputerTool", diff --git a/hud/tools/computer/gemini.py b/hud/tools/computer/gemini.py new file mode 100644 index 00000000..0b62d3a9 --- /dev/null +++ b/hud/tools/computer/gemini.py @@ -0,0 +1,385 @@ +from __future__ import annotations + +import logging +import platform +from typing import TYPE_CHECKING, Any, Literal + +from mcp import ErrorData, McpError +from mcp.types import INVALID_PARAMS, ContentBlock +from pydantic import Field + +from hud.tools.types import ContentResult + +from .hud import HudComputerTool +from .settings import computer_settings + +if TYPE_CHECKING: + from hud.tools.executors.base import BaseExecutor + +logger = logging.getLogger(__name__) + + +ACTION_FIELD = Field(..., description="Gemini Computer Use action to perform") +X_FIELD = Field(None, description="X coordinate (pixels in agent space)") +Y_FIELD = Field(None, description="Y coordinate (pixels in agent space)") +TEXT_FIELD = Field(None, description="Text to type") +PRESS_ENTER_FIELD = Field(None, description="Whether to press Enter after typing (type_text_at)") +CLEAR_BEFORE_TYPING_FIELD = Field( + None, description="Whether to select-all before typing (type_text_at)" +) +DIRECTION_FIELD = Field(None, description="Scroll direction for scroll_document/scroll_at") +MAGNITUDE_FIELD = Field(None, description="Scroll magnitude (pixels in agent space)") +URL_FIELD = Field(None, description="Target URL for navigate") +KEYS_FIELD = Field(None, description="Keys for key_combination") +DESTINATION_X_FIELD = Field(None, description="Destination X for drag_and_drop (agent space)") +DESTINATION_Y_FIELD = Field(None, description="Destination Y for drag_and_drop (agent space)") +TAKE_SCREENSHOT_ON_CLICK_FIELD = Field( + True, description="Whether to include a screenshot for interactive actions" +) + + +class GeminiComputerTool(HudComputerTool): + """ + Gemini Computer Use tool for interacting with a computer via MCP. + + Maps Gemini's predefined function names (open_web_browser, click_at, hover_at, + type_text_at, scroll_document, scroll_at, wait_5_seconds, go_back, go_forward, + search, navigate, key_combination, drag_and_drop) to executor actions. + """ + + def __init__( + self, + # Define within environment based on platform + executor: BaseExecutor | None = None, + platform_type: Literal["auto", "xdo", "pyautogui"] = "auto", + display_num: int | None = None, + # Overrides for what dimensions the agent thinks it operates in + width: int = computer_settings.GEMINI_COMPUTER_WIDTH, + height: int = computer_settings.GEMINI_COMPUTER_HEIGHT, + rescale_images: bool = computer_settings.GEMINI_RESCALE_IMAGES, + # What the agent sees as the tool's name, title, and description + name: str | None = None, + title: str | None = None, + description: str | None = None, + **kwargs: Any, + ) -> None: + """ + Initialize with Gemini's default dimensions. + """ + super().__init__( + executor=executor, + platform_type=platform_type, + display_num=display_num, + width=width, + height=height, + rescale_images=rescale_images, + name=name or "gemini_computer", + title=title or "Gemini Computer Tool", + description=description or "Control computer with mouse, keyboard, and screenshots", + **kwargs, + ) + + async def __call__( + self, + action: str = ACTION_FIELD, + # Common coordinates + x: int | None = X_FIELD, + y: int | None = Y_FIELD, + # Text input + text: str | None = TEXT_FIELD, + press_enter: bool | None = PRESS_ENTER_FIELD, + clear_before_typing: bool | None = CLEAR_BEFORE_TYPING_FIELD, + # Scroll parameters + direction: Literal["up", "down", "left", "right"] | None = DIRECTION_FIELD, + magnitude: int | None = MAGNITUDE_FIELD, + # Navigation + url: str | None = URL_FIELD, + # Key combos + keys: list[str] | str | None = KEYS_FIELD, + # Drag parameters + destination_x: int | None = DESTINATION_X_FIELD, + destination_y: int | None = DESTINATION_Y_FIELD, + # Behavior + take_screenshot_on_click: bool = TAKE_SCREENSHOT_ON_CLICK_FIELD, + ) -> list[ContentBlock]: + """ + Handle Gemini Computer Use API calls by mapping to executor actions. + + Returns: + List of MCP content blocks + """ + logger.info("GeminiComputerTool received action: %s", action) + + # Helper to finalize ContentResult: rescale if requested and ensure URL metadata + async def _finalize( + result: ContentResult, requested_url: str | None = None + ) -> list[ContentBlock]: + if result.base64_image and self.rescale_images: + try: + result.base64_image = await self._rescale_screenshot(result.base64_image) + except Exception as e: + logger.warning("Failed to rescale screenshot: %s", e) + # Always include URL metadata if provided; otherwise default to about:blank + result.url = requested_url or result.url or "about:blank" + return result.to_content_blocks() + + # Scale coordinates helper + def _scale(xv: int | None, yv: int | None) -> tuple[int | None, int | None]: + return self._scale_coordinates(xv, yv) + + # Gemini emits coordinates/magnitudes in a 0-1000 normalized space. + def _denormalize(value: float | None, axis: Literal["x", "y"]) -> int | None: + if value is None: + return None + try: + numeric = float(value) + except (TypeError, ValueError): + try: + return int(value) # type: ignore[arg-type] + except (TypeError, ValueError): + return None + + # Treat values within the normalized range (including defaults like 800). + if 0 <= numeric <= 1000: + target = self.width if axis == "x" else self.height + numeric = numeric / 1000 * target + + return round(numeric) + + def _scale_distance(value: int | None, axis: Literal["x", "y"]) -> int | None: + if value is None: + return None + scale = self.scale_x if axis == "x" else self.scale_y + if scale != 1.0: + return round(value / scale) + return value + + # Map actions + if action == "open_web_browser": + screenshot = await self.executor.screenshot() + if screenshot: + result = ContentResult(base64_image=screenshot, url="about:blank") + else: + result = ContentResult(error="Failed to take screenshot", url="about:blank") + return await _finalize(result) + + elif action == "click_at": + if x is None or y is None: + raise McpError(ErrorData(code=INVALID_PARAMS, message="x and y are required")) + dx = _denormalize(x, "x") + dy = _denormalize(y, "y") + sx, sy = _scale(dx, dy) + result = await self.executor.click(x=sx, y=sy) + return await _finalize(result) + + elif action == "hover_at": + if x is None or y is None: + raise McpError(ErrorData(code=INVALID_PARAMS, message="x and y are required")) + dx = _denormalize(x, "x") + dy = _denormalize(y, "y") + sx, sy = _scale(dx, dy) + result = await self.executor.move(x=sx, y=sy) + return await _finalize(result) + + elif action == "type_text_at": + if x is None or y is None: + raise McpError(ErrorData(code=INVALID_PARAMS, message="x and y are required")) + if text is None: + raise McpError(ErrorData(code=INVALID_PARAMS, message="text is required")) + + dx = _denormalize(x, "x") + dy = _denormalize(y, "y") + sx, sy = _scale(dx, dy) + + # Focus the field + await self.executor.move(x=sx, y=sy, take_screenshot=False) + await self.executor.click(x=sx, y=sy, take_screenshot=False) + + # Clear existing text if requested + if clear_before_typing is None or clear_before_typing: + is_mac = platform.system().lower() == "darwin" + combo = ["cmd", "a"] if is_mac else ["ctrl", "a"] + await self.executor.press(keys=combo, take_screenshot=False) + delete_key = "backspace" if is_mac else "delete" + await self.executor.press(keys=[delete_key], take_screenshot=False) + + # Type (optionally press enter after) + result = await self.executor.write(text=text, enter_after=bool(press_enter)) + return await _finalize(result) + + elif action == "scroll_document": + if direction is None: + raise McpError(ErrorData(code=INVALID_PARAMS, message="direction is required")) + # Default magnitude similar to reference implementation + mag = magnitude if magnitude is not None else 800 + # Convert to environment units while preserving sign + if direction in ("down", "up"): + distance = _denormalize(mag, "y") + if distance is None: + raise McpError( + ErrorData( + code=INVALID_PARAMS, message="Unable to determine scroll magnitude" + ) + ) + distance = _scale_distance(distance, "y") + if distance is None: + raise McpError( + ErrorData( + code=INVALID_PARAMS, + message="Unable to determine scroll magnitude", + ) + ) + scroll_y = distance if direction == "down" else -distance + scroll_x = None + elif direction in ("right", "left"): + distance = _denormalize(mag, "x") + if distance is None: + raise McpError( + ErrorData( + code=INVALID_PARAMS, message="Unable to determine scroll magnitude" + ) + ) + distance = _scale_distance(distance, "x") + if distance is None: + raise McpError( + ErrorData( + code=INVALID_PARAMS, + message="Unable to determine scroll magnitude", + ) + ) + scroll_x = distance if direction == "right" else -distance + scroll_y = None + else: + raise McpError( + ErrorData(code=INVALID_PARAMS, message=f"Invalid direction: {direction}") + ) + result = await self.executor.scroll(scroll_x=scroll_x, scroll_y=scroll_y) + return await _finalize(result) + + elif action == "scroll_at": + if direction is None: + raise McpError(ErrorData(code=INVALID_PARAMS, message="direction is required")) + if x is None or y is None: + raise McpError(ErrorData(code=INVALID_PARAMS, message="x and y are required")) + mag = magnitude if magnitude is not None else 800 + dx = _denormalize(x, "x") + dy = _denormalize(y, "y") + sx, sy = _scale(dx, dy) + if direction in ("down", "up"): + distance = _denormalize(mag, "y") + if distance is None: + raise McpError( + ErrorData( + code=INVALID_PARAMS, message="Unable to determine scroll magnitude" + ) + ) + distance = _scale_distance(distance, "y") + if distance is None: + raise McpError( + ErrorData( + code=INVALID_PARAMS, + message="Unable to determine scroll magnitude", + ) + ) + scroll_y = distance if direction == "down" else -distance + scroll_x = None + elif direction in ("right", "left"): + distance = _denormalize(mag, "x") + if distance is None: + raise McpError( + ErrorData( + code=INVALID_PARAMS, message="Unable to determine scroll magnitude" + ) + ) + distance = _scale_distance(distance, "x") + if distance is None: + raise McpError( + ErrorData( + code=INVALID_PARAMS, + message="Unable to determine scroll magnitude", + ) + ) + scroll_x = distance if direction == "right" else -distance + scroll_y = None + else: + raise McpError( + ErrorData(code=INVALID_PARAMS, message=f"Invalid direction: {direction}") + ) + result = await self.executor.scroll(x=sx, y=sy, scroll_x=scroll_x, scroll_y=scroll_y) + return await _finalize(result) + + elif action == "wait_5_seconds": + result = await self.executor.wait(time=5000) + return await _finalize(result) + + elif action == "go_back": + is_mac = platform.system().lower() == "darwin" + combo = ["cmd", "["] if is_mac else ["alt", "left"] + result = await self.executor.press(keys=combo) + return await _finalize(result) + + elif action == "go_forward": + is_mac = platform.system().lower() == "darwin" + combo = ["cmd", "]"] if is_mac else ["alt", "right"] + result = await self.executor.press(keys=combo) + return await _finalize(result) + + elif action == "search": + # Best-effort navigate to a default search page + target = url or "https://www.google.com" + is_mac = platform.system().lower() == "darwin" + await self.executor.press( + keys=["cmd", "l"] if is_mac else ["ctrl", "l"], take_screenshot=False + ) + result = await self.executor.write(text=target, enter_after=True) + return await _finalize(result, requested_url=target) + + elif action == "navigate": + if not url: + raise McpError(ErrorData(code=INVALID_PARAMS, message="url is required")) + is_mac = platform.system().lower() == "darwin" + await self.executor.press( + keys=["cmd", "l"] if is_mac else ["ctrl", "l"], take_screenshot=False + ) + result = await self.executor.write(text=url, enter_after=True) + return await _finalize(result, requested_url=url) + + elif action == "key_combination": + if keys is None: + raise McpError(ErrorData(code=INVALID_PARAMS, message="keys is required")) + if isinstance(keys, str): + # Accept formats like "ctrl+c" or "ctrl+shift+t" + key_list = [k.strip() for k in keys.split("+") if k.strip()] + else: + key_list = keys + result = await self.executor.press(keys=key_list) + return await _finalize(result) + + elif action == "drag_and_drop": + if x is None or y is None or destination_x is None or destination_y is None: + raise McpError( + ErrorData( + code=INVALID_PARAMS, + message="x, y, destination_x, and destination_y are required", + ) + ) + sx_norm = _denormalize(x, "x") + sy_norm = _denormalize(y, "y") + dx_norm = _denormalize(destination_x, "x") + dy_norm = _denormalize(destination_y, "y") + sx, sy = _scale(sx_norm, sy_norm) + dx_scaled, dy_scaled = _scale(dx_norm, dy_norm) + # Build a two-point path + path = [] # type: list[tuple[int, int]] + if ( + sx is not None + and sy is not None + and dx_scaled is not None + and dy_scaled is not None + ): + path = [(sx, sy), (dx_scaled, dy_scaled)] + result = await self.executor.drag(path=path) + return await _finalize(result) + + else: + raise McpError(ErrorData(code=INVALID_PARAMS, message=f"Unknown action: {action}")) diff --git a/hud/tools/computer/settings.py b/hud/tools/computer/settings.py index 9a122af7..94dfdfc0 100644 --- a/hud/tools/computer/settings.py +++ b/hud/tools/computer/settings.py @@ -94,5 +94,21 @@ class ComputerSettings(BaseSettings): validation_alias="QWEN_RESCALE_IMAGES", ) + GEMINI_COMPUTER_WIDTH: int = Field( + default=1440, + description="Width of the display to use for the Gemini computer tools", + validation_alias="GEMINI_COMPUTER_WIDTH", + ) + GEMINI_COMPUTER_HEIGHT: int = Field( + default=900, + description="Height of the display to use for the Gemini computer tools", + validation_alias="GEMINI_COMPUTER_HEIGHT", + ) + GEMINI_RESCALE_IMAGES: bool = Field( + default=True, + description="Whether to rescale images to the agent width and height", + validation_alias="GEMINI_RESCALE_IMAGES", + ) + computer_settings = ComputerSettings() diff --git a/hud/tools/playwright.py b/hud/tools/playwright.py index e170f3af..5d85405a 100644 --- a/hud/tools/playwright.py +++ b/hud/tools/playwright.py @@ -84,6 +84,9 @@ async def __call__( code=INVALID_PARAMS, message="url parameter is required for navigate" ) ) + # Guard against pydantic FieldInfo default leaking through + if not isinstance(wait_for_load_state, str): + wait_for_load_state = None result = await self.navigate(url, wait_for_load_state or "networkidle") elif action == "screenshot": @@ -179,11 +182,16 @@ async def _ensure_browser(self) -> None: if self._browser is None: raise RuntimeError("Failed to connect to remote browser") - # Use existing context or create new one + # Reuse existing context and page where possible to avoid spawning new windows contexts = self._browser.contexts if contexts: self._browser_context = contexts[0] + # Prefer the first existing page to keep using the already visible window/tab + existing_pages = self._browser_context.pages + if existing_pages: + self.page = existing_pages[0] else: + # As a fallback, create a new context self._browser_context = await self._browser.new_context( viewport={"width": 1920, "height": 1080}, ignore_https_errors=True, diff --git a/hud/tools/types.py b/hud/tools/types.py index faf92efc..f3285258 100644 --- a/hud/tools/types.py +++ b/hud/tools/types.py @@ -28,6 +28,7 @@ class ContentResult(BaseModel): error: str | None = Field(default=None, description="Error message") base64_image: str | None = Field(default=None, description="Base64-encoded image") system: str | None = Field(default=None, description="System message") + url: str | None = Field(default=None, description="Current page URL (for browser automation)") def __add__(self, other: ContentResult) -> ContentResult: def combine_fields( @@ -44,6 +45,7 @@ def combine_fields( error=combine_fields(self.error, other.error), base64_image=combine_fields(self.base64_image, other.base64_image, False), system=combine_fields(self.system, other.system), + url=combine_fields(self.url, other.url, False), ) def to_content_blocks(self) -> list[ContentBlock]: @@ -55,7 +57,7 @@ def to_content_blocks(self) -> list[ContentBlock]: result: ContentResult to convert Returns: - List of ContentBlock + List of ContentBlock with URL embedded as metadata if available """ blocks: list[ContentBlock] = [] @@ -65,6 +67,12 @@ def to_content_blocks(self) -> list[ContentBlock]: blocks.append(TextContent(text=self.error, type="text")) if self.base64_image: blocks.append(ImageContent(data=self.base64_image, mimeType="image/png", type="image")) + + # Add URL as a special metadata text block (for Gemini Computer Use) + # Always include URL if set, even if it's a placeholder like "about:blank" + if self.url: + blocks.append(TextContent(text=f"__URL__:{self.url}", type="text")) + return blocks diff --git a/hud/types.py b/hud/types.py index 08cef5b7..3919408a 100644 --- a/hud/types.py +++ b/hud/types.py @@ -25,6 +25,7 @@ class AgentType(str, Enum): CLAUDE = "claude" OPENAI = "openai" + GEMINI = "gemini" VLLM = "vllm" LITELLM = "litellm" INTEGRATION_TEST = "integration_test" diff --git a/pyproject.toml b/pyproject.toml index 693e4bb1..c500ddb9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,6 +42,7 @@ dependencies = [ # AI providers "anthropic", "openai", + "google-genai", ] classifiers = [ "Development Status :: 4 - Beta",