diff --git a/environments/remote_browser/src/hud_controller/server.py b/environments/remote_browser/src/hud_controller/server.py index af7af53d..87b07bcb 100644 --- a/environments/remote_browser/src/hud_controller/server.py +++ b/environments/remote_browser/src/hud_controller/server.py @@ -25,6 +25,7 @@ AnthropicComputerTool, OpenAIComputerTool, HudComputerTool, + GeminiComputerTool, ) # Import setup and evaluate hubs @@ -283,6 +284,7 @@ async def send_progress(progress: int, message: str): mcp.add_tool(HudComputerTool(executor=browser_executor)) mcp.add_tool(AnthropicComputerTool(executor=browser_executor)) mcp.add_tool(OpenAIComputerTool(executor=browser_executor)) + mcp.add_tool(GeminiComputerTool(executor=browser_executor)) await send_progress(80, "Registered hud computer tools") diff --git a/environments/remote_browser/src/hud_controller/tools/executor.py b/environments/remote_browser/src/hud_controller/tools/executor.py index 94196832..1a54569a 100644 --- a/environments/remote_browser/src/hud_controller/tools/executor.py +++ b/environments/remote_browser/src/hud_controller/tools/executor.py @@ -94,6 +94,14 @@ def __init__(self, playwright_tool, display_num: int | None = None): self.playwright_tool = playwright_tool logger.info("BrowserExecutor initialized with Playwright backend") + async def _current_url(self) -> str | None: + """Return current page URL if available.""" + try: + page = await self._ensure_page() + return page.url + except Exception: + return None + def _map_key(self, key: str) -> str: """Map a key name to Playwright format.""" key = key.strip() @@ -172,6 +180,9 @@ async def click( logger.debug(f"Clicked at ({x}, {y}) with button {button}") result = ContentResult(output=f"Clicked at ({x}, {y})") + current = await self._current_url() + if current: + result.url = current if take_screenshot: result = result + ContentResult(base64_image=await self.screenshot()) @@ -213,6 +224,9 @@ async def write( logger.debug(f"Typed text: {text[:50]}...") result = ContentResult(output=f"Typed: {text}") + current = await self._current_url() + if current: + result.url = current if take_screenshot: result = result + ContentResult(base64_image=await self.screenshot()) @@ -253,6 +267,9 @@ async def press( logger.debug(f"Pressed keys: {keys} (mapped to: {mapped_keys})") result = ContentResult(output=f"Pressed: {key_combination}") + current = await self._current_url() + if current: + result.url = current if take_screenshot: result = result + ContentResult(base64_image=await self.screenshot()) @@ -292,6 +309,9 @@ async def scroll( logger.debug(f"Scrolled at ({x}, {y}) by ({delta_x}, {delta_y})") result = ContentResult(output=f"Scrolled by ({delta_x}, {delta_y})") + current = await self._current_url() + if current: + result.url = current if take_screenshot: result = result + ContentResult(base64_image=await self.screenshot()) @@ -319,6 +339,9 @@ async def move( logger.debug(f"Moved mouse to ({x}, {y})") result = ContentResult(output=f"Moved to ({x}, {y})") + current = await self._current_url() + if current: + result.url = current if take_screenshot: result = result + ContentResult(base64_image=await self.screenshot()) @@ -369,6 +392,9 @@ async def drag( logger.debug(f"Dragged from {path[0]} through {len(path)} points") result = ContentResult(output=f"Dragged through {len(path)} points") + current = await self._current_url() + if current: + result.url = current if take_screenshot: result = result + ContentResult(base64_image=await self.screenshot()) diff --git a/examples/gemini_agent.py b/examples/gemini_agent.py new file mode 100644 index 00000000..3d9583d5 --- /dev/null +++ b/examples/gemini_agent.py @@ -0,0 +1,104 @@ +#!/usr/bin/env python3 +""" +Gemini Agent Example (Remote Browser) + +This example showcases Gemini-specific features against a remote browser environment: +- Computer use capabilities with normalized coordinates +- Browser automation +- Multi-step reasoning tasks + +Gemini uses a normalized coordinate system (0-999) that is automatically +scaled to actual screen dimensions. +""" + +import asyncio + +import hud +import os +from hud.agents import GeminiAgent +from hud.clients import MCPClient +from hud.settings import settings + + +async def main(): + with hud.trace("Gemini Agent Demo"): + # Remote HUD MCP server using your custom remote-browser image + # Built via environments/remote_browser/Dockerfile + # Build headers with required environment for remote browser + provider = os.getenv("BROWSER_PROVIDER", "anchorbrowser") + headers = { + "Authorization": f"Bearer {settings.api_key}", + "Mcp-Image": "alberthu233/hud-remote-browser:gemini-dev-2", + "Env-Browser-Provider": provider, + } + + # Optionally pass provider-specific API key if available + provider_key_map = { + "anchorbrowser": "ANCHOR_API_KEY", + "steel": "STEEL_API_KEY", + "browserbase": "BROWSERBASE_API_KEY", + "hyperbrowser": "HYPERBROWSER_API_KEY", + "kernel": "KERNEL_API_KEY", + } + if provider in provider_key_map: + key_var = provider_key_map[provider] + key_val = os.getenv(key_var) + if key_val: + header_key = f"Env-{'-'.join(part.capitalize() for part in key_var.split('_'))}" + headers[header_key] = key_val + + mcp_config = {"hud": {"url": "https://mcp.hud.so/v3/mcp", "headers": headers}} + + # Create Gemini-specific agent + client = MCPClient(mcp_config=mcp_config) + agent = GeminiAgent( + mcp_client=client, + model="gemini-2.5-computer-use-preview-10-2025", + allowed_tools=["gemini_computer"], + initial_screenshot=True, + temperature=1.0, + max_output_tokens=8192, + ) + + await client.initialize() + + try: + initial_url = "https://httpbin.org/forms/post" + + prompt = f""" + Please help me fill out a web form one step at a time: + 1. Navigate to {initial_url} + 2. Fill in the customer name as "Gemini Test" + 3. Enter the telephone as "555-0456" + 4. Enter the email as "gemini@test.com" + 5. Type "Submission with Gemini" in the comments + 6. Select a medium pizza size + 7. Choose "mushroom" as a topping + 8. Set delivery time to "16:00" + 9. Submit the form + 10. Verify the submission was successful + """ + + print("šŸ“‹ Task: Multi-step form interaction (Remote Browser)") + print("šŸš€ Running Gemini agent...\n") + + # Setup: navigate to initial URL via setup tool + await client.call_tool( + name="setup", + arguments={"name": "navigate_to_url", "arguments": {"url": initial_url}}, + ) + + # Run the prompt + result = await agent.run(prompt, max_steps=50) + + print(result) + + finally: + await client.shutdown() + + print("\n✨ Gemini agent demo complete!") + + +if __name__ == "__main__": + asyncio.run(main()) + diff --git a/hud/agents/__init__.py b/hud/agents/__init__.py index 7470adb3..f69d4ed5 100644 --- a/hud/agents/__init__.py +++ b/hud/agents/__init__.py @@ -2,11 +2,13 @@ from .base import MCPAgent from .claude import ClaudeAgent +from .gemini import GeminiAgent from .openai import OperatorAgent from .openai_chat_generic import GenericOpenAIChatAgent __all__ = [ "ClaudeAgent", + "GeminiAgent", "GenericOpenAIChatAgent", "MCPAgent", "OperatorAgent", diff --git a/hud/agents/gemini.py b/hud/agents/gemini.py new file mode 100644 index 00000000..8c6f3cf7 --- /dev/null +++ b/hud/agents/gemini.py @@ -0,0 +1,495 @@ +"""Gemini MCP Agent implementation.""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Any, ClassVar, cast + +from google import genai +from google.genai import types as genai_types + +import hud + +if TYPE_CHECKING: + from hud.datasets import Task + +import mcp.types as types + +from hud.settings import settings +from hud.tools.computer.settings import computer_settings +from hud.types import AgentResponse, MCPToolCall, MCPToolResult +from hud.utils.hud_console import HUDConsole + +from .base import MCPAgent + +logger = logging.getLogger(__name__) + +# Maximum number of recent turns to keep screenshots for +MAX_RECENT_TURN_WITH_SCREENSHOTS = 3 + +# Predefined Gemini computer use functions +PREDEFINED_COMPUTER_USE_FUNCTIONS = [ + "open_web_browser", + "click_at", + "hover_at", + "type_text_at", + "scroll_document", + "scroll_at", + "wait_5_seconds", + "go_back", + "go_forward", + "search", + "navigate", + "key_combination", + "drag_and_drop", +] + + +class GeminiAgent(MCPAgent): + """ + Gemini agent that uses MCP servers for tool execution. + + This agent uses Gemini's native computer use capabilities but executes + tools through MCP servers instead of direct implementation. + """ + + metadata: ClassVar[dict[str, Any]] = { + "display_width": computer_settings.GEMINI_COMPUTER_WIDTH, + "display_height": computer_settings.GEMINI_COMPUTER_HEIGHT, + } + + def __init__( + self, + model_client: genai.Client | None = None, + model: str = "gemini-2.5-computer-use-preview-10-2025", + temperature: float = 1.0, + top_p: float = 0.95, + top_k: int = 40, + max_output_tokens: int = 8192, + validate_api_key: bool = True, + excluded_predefined_functions: list[str] | None = None, + **kwargs: Any, + ) -> None: + """ + Initialize Gemini MCP agent. + + Args: + model_client: Gemini client (created if not provided) + model: Gemini model to use + temperature: Temperature for response generation + top_p: Top-p sampling parameter + top_k: Top-k sampling parameter + max_output_tokens: Maximum tokens for response + validate_api_key: Whether to validate API key on initialization + excluded_predefined_functions: List of predefined functions to exclude + **kwargs: Additional arguments passed to BaseMCPAgent (including mcp_client) + """ + super().__init__(**kwargs) + + # Initialize client if not provided + if model_client is None: + api_key = settings.gemini_api_key + if not api_key: + raise ValueError("Gemini API key not found. Set GEMINI_API_KEY.") + model_client = genai.Client(api_key=api_key) + + # Validate API key if requested + if validate_api_key: + try: + # Simple validation - try to list models + list(model_client.models.list(config=genai_types.ListModelsConfig(page_size=1))) + except Exception as e: + raise ValueError(f"Gemini API key is invalid: {e}") from e + + self.gemini_client = model_client + self.model = model + self.temperature = temperature + self.top_p = top_p + self.top_k = top_k + self.max_output_tokens = max_output_tokens + self.excluded_predefined_functions = excluded_predefined_functions or [] + self.hud_console = HUDConsole(logger=logger) + + self.model_name = self.model + + # Track mapping from Gemini tool names to MCP tool names + self._gemini_to_mcp_tool_map: dict[str, str] = {} + self.gemini_tools: list[genai_types.Tool] = [] + + # Append Gemini-specific instructions to the base system prompt + gemini_instructions = """ + You are Gemini, a helpful AI assistant created by Google. You are capable of interacting with computer interfaces. + + When working on tasks: + 1. Be thorough and systematic in your approach + 2. Complete tasks autonomously without asking for confirmation + 3. Use available tools efficiently to accomplish your goals + 4. Verify your actions and ensure task completion + 5. Be precise and accurate in all operations + 6. Adapt to the environment and the task at hand + + Remember: You are expected to complete tasks autonomously. The user trusts you to accomplish what they asked. + """.strip() + + # Append Gemini instructions to any base system prompt + if self.system_prompt: + self.system_prompt = f"{self.system_prompt}\n\n{gemini_instructions}" + else: + self.system_prompt = gemini_instructions + + async def initialize(self, task: str | Task | None = None) -> None: + """Initialize the agent and build tool mappings.""" + await super().initialize(task) + # Build tool mappings after tools are discovered + self._convert_tools_for_gemini() + + async def get_system_messages(self) -> list[Any]: + """No system messages for Gemini because applied in get_response""" + return [] + + async def format_blocks( + self, blocks: list[types.ContentBlock] + ) -> list[genai_types.Content]: + """Format messages for Gemini.""" + # Convert MCP content types to Gemini content types + gemini_parts: list[genai_types.Part] = [] + + for block in blocks: + if isinstance(block, types.TextContent): + gemini_parts.append(genai_types.Part(text=block.text)) + elif isinstance(block, types.ImageContent): + # Convert MCP ImageContent to Gemini format + # Need to decode base64 string to bytes + import base64 + image_bytes = base64.b64decode(block.data) + gemini_parts.append( + genai_types.Part.from_bytes(data=image_bytes, mime_type=block.mimeType) + ) + else: + # For other types, try to handle but log a warning + self.hud_console.log(f"Unknown content block type: {type(block)}", level="warning") + + return [genai_types.Content(role="user", parts=gemini_parts)] + + @hud.instrument( + span_type="agent", + record_args=False, # Messages can be large + record_result=True, + ) + async def get_response( + self, messages: list[genai_types.Content] + ) -> AgentResponse: + """Get response from Gemini including any tool calls.""" + + # Build generate content config + generate_config = genai_types.GenerateContentConfig( + temperature=self.temperature, + top_p=self.top_p, + top_k=self.top_k, + max_output_tokens=self.max_output_tokens, + tools=cast("Any", self.gemini_tools), + system_instruction=self.system_prompt, + ) + + # Trim screenshots from older turns to manage context growth + self._remove_old_screenshots(messages) + + # Make API call - using a simpler call pattern + response = self.gemini_client.models.generate_content( + model=self.model, + contents=cast("Any", messages), + config=generate_config, + ) + + # Append assistant response (including any function_call) so that + # subsequent FunctionResponse messages correspond to a prior FunctionCall + if response.candidates and len(response.candidates) > 0 and response.candidates[0].content: + cast("list[genai_types.Content]", messages).append(response.candidates[0].content) + + # Process response + result = AgentResponse(content="", tool_calls=[], done=True) + collected_tool_calls: list[MCPToolCall] = [] + + if not response.candidates: + self.hud_console.warning("Response has no candidates") + return result + + candidate = response.candidates[0] + + # Extract text content and function calls + text_content = "" + thinking_content = "" + + if candidate.content and candidate.content.parts: + for part in candidate.content.parts: + if part.function_call: + # Map Gemini tool name back to MCP tool name + func_name = part.function_call.name or "" + mcp_tool_name = self._gemini_to_mcp_tool_map.get(func_name, func_name) + + # Create MCPToolCall object with Gemini metadata + raw_args = dict(part.function_call.args) if part.function_call.args else {} + + # Normalize Gemini Computer Use calls to MCP tool schema + if part.function_call.name in PREDEFINED_COMPUTER_USE_FUNCTIONS: + # Ensure 'action' is present and equals the Gemini function name + normalized_args: dict[str, Any] = {"action": part.function_call.name} + + # Map common argument shapes used by Gemini Computer Use + # 1) Coordinate arrays → x/y + coord = raw_args.get("coordinate") or raw_args.get("coordinates") + if isinstance(coord, (list, tuple)) and len(coord) >= 2: + try: + normalized_args["x"] = int(coord[0]) + normalized_args["y"] = int(coord[1]) + except (TypeError, ValueError): + # Fall back to raw if casting fails + pass + + # Destination coordinate arrays → destination_x/destination_y + dest = ( + raw_args.get("destination") + or raw_args.get("destination_coordinate") + or raw_args.get("destinationCoordinate") + ) + if isinstance(dest, (list, tuple)) and len(dest) >= 2: + try: + normalized_args["destination_x"] = int(dest[0]) + normalized_args["destination_y"] = int(dest[1]) + except (TypeError, ValueError): + pass + + # Pass through supported fields if present (including direct coords) + for key in ( + "text", + "press_enter", + "clear_before_typing", + "safety_decision", + "safetyDecision", + "direction", + "magnitude", + "url", + "keys", + "x", + "y", + "destination_x", + "destination_y", + ): + if key in raw_args: + normalized_args[key] = raw_args[key] + + # Use normalized args for computer tool calls + final_args = normalized_args + else: + # Non-computer tools: pass args as-is + final_args = raw_args + + tool_call = MCPToolCall( + name=mcp_tool_name, + arguments=final_args, + gemini_name=func_name, # type: ignore[arg-type] + ) + collected_tool_calls.append(tool_call) + elif part.text: + text_content += part.text + elif hasattr(part, "thought") and part.thought: + thinking_content += f"Thinking: {part.thought}\n" + + # Assign collected tool calls and mark done status + if collected_tool_calls: + result.tool_calls = collected_tool_calls + result.done = False + + # Combine text and thinking for final content + if thinking_content: + result.content = thinking_content + text_content + else: + result.content = text_content + + return result + + async def format_tool_results( + self, tool_calls: list[MCPToolCall], tool_results: list[MCPToolResult] + ) -> list[genai_types.Content]: + """Format tool results into Gemini messages.""" + # Process each tool result + function_responses = [] + + for tool_call, result in zip(tool_calls, tool_results, strict=True): + # Get the Gemini function name from metadata + gemini_name = getattr(tool_call, "gemini_name", tool_call.name) + + # Convert MCP tool results to Gemini format + response_dict: dict[str, Any] = {} + url = None + + if result.isError: + # Extract error message from content + error_msg = "Tool execution failed" + for content in result.content: + if isinstance(content, types.TextContent): + # Check if this is a URL metadata block + if content.text.startswith("__URL__:"): + url = content.text.replace("__URL__:", "") + else: + error_msg = content.text + break + response_dict["error"] = error_msg + else: + # Process success content + response_dict["success"] = True + + # Extract URL and screenshot from content + screenshot_parts = [] + for content in result.content: + if isinstance(content, types.TextContent): + # Check if this is a URL metadata block + if content.text.startswith("__URL__:"): + url = content.text.replace("__URL__:", "") + elif isinstance(content, types.ImageContent): + # Decode base64 string to bytes for FunctionResponseBlob + import base64 + image_bytes = base64.b64decode(content.data) + screenshot_parts.append( + genai_types.FunctionResponsePart( + inline_data=genai_types.FunctionResponseBlob( + mime_type=content.mimeType or "image/png", + data=image_bytes, + ) + ) + ) + + # Add URL to response dict (required by Gemini Computer Use model) + # URL must ALWAYS be present per Gemini API requirements + response_dict["url"] = url if url else "about:blank" + + # For Gemini Computer Use actions, always acknowledge safety decisions + requires_ack = False + if tool_call.arguments: + requires_ack = bool( + tool_call.arguments.get("safety_decision") + or tool_call.arguments.get("safetyDecision") + ) + if gemini_name in PREDEFINED_COMPUTER_USE_FUNCTIONS and requires_ack: + # Provide common acknowledgement flags expected by the API + # (include multiple canonical spellings to maximize compatibility) + response_dict["acknowledged"] = True + response_dict["acknowledged_safety"] = True + response_dict["acknowledgedSafetyDecision"] = True + + # Create function response + function_response = genai_types.FunctionResponse( + name=gemini_name, + response=response_dict, + parts=screenshot_parts if screenshot_parts else None, + ) + function_responses.append(function_response) + + # Return as a user message containing all function responses + return [ + genai_types.Content( + role="user", + parts=[genai_types.Part(function_response=fr) for fr in function_responses], + ) + ] + + async def create_user_message(self, text: str) -> genai_types.Content: + """Create a user message in Gemini's format.""" + return genai_types.Content(role="user", parts=[genai_types.Part(text=text)]) + + def _convert_tools_for_gemini(self) -> list[genai_types.Tool]: + """Convert MCP tools to Gemini tool format.""" + gemini_tools = [] + self._gemini_to_mcp_tool_map = {} # Reset mapping + + # Find computer tool by priority + computer_tool_priority = ["gemini_computer", "computer_gemini", "computer"] + selected_computer_tool = None + + for priority_name in computer_tool_priority: + for tool in self.get_available_tools(): + # Check both exact match and suffix match (for prefixed tools) + if tool.name == priority_name or tool.name.endswith(f"_{priority_name}"): + selected_computer_tool = tool + break + if selected_computer_tool: + break + + # Add the selected computer tool if found + if selected_computer_tool: + gemini_tool = genai_types.Tool( + computer_use=genai_types.ComputerUse( + environment=genai_types.Environment.ENVIRONMENT_BROWSER, + excluded_predefined_functions=self.excluded_predefined_functions, + ) + ) + # Map Gemini's computer use functions back to the actual MCP tool name + for func_name in PREDEFINED_COMPUTER_USE_FUNCTIONS: + if func_name not in self.excluded_predefined_functions: + self._gemini_to_mcp_tool_map[func_name] = selected_computer_tool.name + + gemini_tools.append(gemini_tool) + self.hud_console.debug( + f"Using {selected_computer_tool.name} as computer tool for Gemini" + ) + + # Add other non-computer tools as custom functions + for tool in self.get_available_tools(): + # Skip computer tools (already handled) + if any( + tool.name == priority_name or tool.name.endswith(f"_{priority_name}") + for priority_name in computer_tool_priority + ): + continue + + # Convert MCP tool schema to Gemini function declaration + try: + # Ensure parameters have proper Schema format + params = tool.inputSchema or {"type": "object", "properties": {}} + function_decl = genai_types.FunctionDeclaration( + name=tool.name, + description=tool.description or f"Execute {tool.name}", + parameters=genai_types.Schema(**params) if isinstance(params, dict) else params, # type: ignore + ) + custom_tool = genai_types.Tool(function_declarations=[function_decl]) + gemini_tools.append(custom_tool) + # Direct mapping for non-computer tools + self._gemini_to_mcp_tool_map[tool.name] = tool.name + except Exception: + self.hud_console.warning(f"Failed to convert tool {tool.name} to Gemini format") + + self.gemini_tools = gemini_tools + return gemini_tools + + def _remove_old_screenshots(self, messages: list[genai_types.Content]) -> None: + """ + Remove screenshots from old turns to manage context length. + Keeps only the last MAX_RECENT_TURN_WITH_SCREENSHOTS turns with screenshots. + """ + turn_with_screenshots_found = 0 + + for content in reversed(messages): + if content.role == "user" and content.parts: + # Check if content has screenshots (function responses with images) + has_screenshot = False + for part in content.parts: + if ( + part.function_response + and part.function_response.parts + and part.function_response.name in PREDEFINED_COMPUTER_USE_FUNCTIONS + ): + has_screenshot = True + break + + if has_screenshot: + turn_with_screenshots_found += 1 + # Remove the screenshot image if the number of screenshots exceeds the limit + if turn_with_screenshots_found > MAX_RECENT_TURN_WITH_SCREENSHOTS: + for part in content.parts: + if ( + part.function_response + and part.function_response.parts + and part.function_response.name in PREDEFINED_COMPUTER_USE_FUNCTIONS + ): + # Clear the parts (screenshots) + part.function_response.parts = None diff --git a/hud/agents/tests/test_gemini.py b/hud/agents/tests/test_gemini.py new file mode 100644 index 00000000..423f74be --- /dev/null +++ b/hud/agents/tests/test_gemini.py @@ -0,0 +1,358 @@ +"""Tests for Gemini MCP Agent implementation.""" + +from __future__ import annotations + +import base64 +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from google.genai import types as genai_types +from mcp import types + +from hud.agents.gemini import GeminiAgent +from hud.types import MCPToolCall, MCPToolResult + + +class TestGeminiAgent: + """Test GeminiAgent class.""" + + @pytest.fixture + def mock_mcp_client(self): + """Create a mock MCP client.""" + mcp_client = AsyncMock() + # Set up the mcp_config attribute as a regular dict, not a coroutine + mcp_client.mcp_config = {"test_server": {"url": "http://test"}} + # Mock list_tools to return gemini_computer tool + mcp_client.list_tools = AsyncMock( + return_value=[ + types.Tool( + name="gemini_computer", + description="Gemini computer use tool", + inputSchema={}, + ) + ] + ) + mcp_client.initialize = AsyncMock() + return mcp_client + + @pytest.fixture + def mock_gemini_client(self): + """Create a mock Gemini client.""" + client = MagicMock() + client.api_key = "test_key" + # Mock models.list for validation + client.models = MagicMock() + client.models.list = MagicMock(return_value=iter([])) + return client + + @pytest.mark.asyncio + async def test_init(self, mock_mcp_client, mock_gemini_client): + """Test agent initialization.""" + agent = GeminiAgent( + mcp_client=mock_mcp_client, + model_client=mock_gemini_client, + model="gemini-2.5-computer-use-preview-10-2025", + validate_api_key=False, # Skip validation in tests + ) + + assert agent.model_name == "gemini-2.5-computer-use-preview-10-2025" + assert agent.model == "gemini-2.5-computer-use-preview-10-2025" + assert agent.gemini_client == mock_gemini_client + + @pytest.mark.asyncio + async def test_init_without_model_client(self, mock_mcp_client): + """Test agent initialization without model client.""" + with patch("hud.settings.settings.gemini_api_key", "test_key"): + with patch("hud.agents.gemini.genai.Client") as mock_client_class: + mock_client = MagicMock() + mock_client.api_key = "test_key" + mock_client.models = MagicMock() + mock_client.models.list = MagicMock(return_value=iter([])) + mock_client_class.return_value = mock_client + + agent = GeminiAgent( + mcp_client=mock_mcp_client, + model="gemini-2.5-computer-use-preview-10-2025", + validate_api_key=False, + ) + + assert agent.model_name == "gemini-2.5-computer-use-preview-10-2025" + assert agent.gemini_client is not None + + @pytest.mark.asyncio + async def test_format_blocks(self, mock_mcp_client, mock_gemini_client): + """Test formatting content blocks into Gemini messages.""" + agent = GeminiAgent( + mcp_client=mock_mcp_client, + model_client=mock_gemini_client, + validate_api_key=False, + ) + + # Test with text only + text_blocks: list[types.ContentBlock] = [ + types.TextContent(type="text", text="Hello, Gemini!") + ] + messages = await agent.format_blocks(text_blocks) + assert len(messages) == 1 + assert messages[0].role == "user" + assert len(messages[0].parts) == 1 + assert messages[0].parts[0].text == "Hello, Gemini!" + + # Test with screenshot + image_blocks: list[types.ContentBlock] = [ + types.TextContent(type="text", text="Look at this"), + types.ImageContent( + type="image", + data=base64.b64encode(b"fakeimage").decode("utf-8"), + mimeType="image/png", + ), + ] + messages = await agent.format_blocks(image_blocks) + assert len(messages) == 1 + assert messages[0].role == "user" + assert len(messages[0].parts) == 2 + # First part is text + assert messages[0].parts[0].text == "Look at this" + # Second part is image - check that it was created from bytes + assert messages[0].parts[1].inline_data is not None + + @pytest.mark.asyncio + async def test_format_tool_results(self, mock_mcp_client, mock_gemini_client): + """Test the agent's format_tool_results method.""" + agent = GeminiAgent( + mcp_client=mock_mcp_client, + model_client=mock_gemini_client, + validate_api_key=False, + ) + + tool_calls = [ + MCPToolCall( + name="gemini_computer", + arguments={"action": "click_at", "x": 100, "y": 200}, + id="call_1", # type: ignore + gemini_name="click_at", # type: ignore + ), + ] + + tool_results = [ + MCPToolResult( + content=[ + types.TextContent(type="text", text="Clicked successfully"), + types.ImageContent( + type="image", + data=base64.b64encode(b"screenshot").decode("utf-8"), + mimeType="image/png", + ), + ], + isError=False, + ), + ] + + messages = await agent.format_tool_results(tool_calls, tool_results) + + # format_tool_results returns a single user message with function responses + assert len(messages) == 1 + assert messages[0].role == "user" + # The content contains function response parts + assert len(messages[0].parts) == 1 + assert messages[0].parts[0].function_response is not None + assert messages[0].parts[0].function_response.name == "click_at" + assert messages[0].parts[0].function_response.response.get("success") is True + + @pytest.mark.asyncio + async def test_format_tool_results_with_error(self, mock_mcp_client, mock_gemini_client): + """Test formatting tool results with errors.""" + agent = GeminiAgent( + mcp_client=mock_mcp_client, + model_client=mock_gemini_client, + validate_api_key=False, + ) + + tool_calls = [ + MCPToolCall( + name="gemini_computer", + arguments={"action": "invalid"}, + id="call_error", # type: ignore + gemini_name="invalid_action", # type: ignore + ), + ] + + tool_results = [ + MCPToolResult( + content=[types.TextContent(type="text", text="Action failed: invalid action")], + isError=True, + ), + ] + + messages = await agent.format_tool_results(tool_calls, tool_results) + + # Check that error is in the response + assert len(messages) == 1 + assert messages[0].role == "user" + assert messages[0].parts[0].function_response is not None + assert "error" in messages[0].parts[0].function_response.response + + @pytest.mark.asyncio + async def test_get_response(self, mock_mcp_client, mock_gemini_client): + """Test getting model response from Gemini API.""" + # Disable telemetry for this test + with patch("hud.settings.settings.telemetry_enabled", False): + agent = GeminiAgent( + mcp_client=mock_mcp_client, + model_client=mock_gemini_client, + validate_api_key=False, + ) + + # Set up available tools + agent._available_tools = [ + types.Tool( + name="gemini_computer", description="Computer tool", inputSchema={} + ) + ] + + # Mock the API response + mock_response = MagicMock() + mock_candidate = MagicMock() + + # Create text part + text_part = MagicMock() + text_part.text = "I will click at coordinates" + text_part.function_call = None + + # Create function call part + function_call_part = MagicMock() + function_call_part.text = None + function_call_part.function_call = MagicMock() + function_call_part.function_call.name = "click_at" + function_call_part.function_call.args = {"x": 100, "y": 200} + + mock_candidate.content = MagicMock() + mock_candidate.content.parts = [text_part, function_call_part] + + mock_response.candidates = [mock_candidate] + + mock_gemini_client.models = MagicMock() + mock_gemini_client.models.generate_content = MagicMock(return_value=mock_response) + + messages = [genai_types.Content(role="user", parts=[genai_types.Part(text="Click")])] + response = await agent.get_response(messages) + + assert response.content == "I will click at coordinates" + assert len(response.tool_calls) == 1 + assert response.tool_calls[0].arguments == {"action": "click_at", "x": 100, "y": 200} + assert response.done is False + + @pytest.mark.asyncio + async def test_get_response_text_only(self, mock_mcp_client, mock_gemini_client): + """Test getting text-only response.""" + # Disable telemetry for this test + with patch("hud.settings.settings.telemetry_enabled", False): + agent = GeminiAgent( + mcp_client=mock_mcp_client, + model_client=mock_gemini_client, + validate_api_key=False, + ) + + # Mock the API response with text only + mock_response = MagicMock() + mock_candidate = MagicMock() + + text_part = MagicMock() + text_part.text = "Task completed successfully" + text_part.function_call = None + + mock_candidate.content = MagicMock() + mock_candidate.content.parts = [text_part] + + mock_response.candidates = [mock_candidate] + + mock_gemini_client.models = MagicMock() + mock_gemini_client.models.generate_content = MagicMock(return_value=mock_response) + + messages = [ + genai_types.Content(role="user", parts=[genai_types.Part(text="Status?")]) + ] + response = await agent.get_response(messages) + + assert response.content == "Task completed successfully" + assert response.tool_calls == [] + assert response.done is True + + @pytest.mark.asyncio + async def test_convert_tools_for_gemini(self, mock_mcp_client, mock_gemini_client): + """Test converting MCP tools to Gemini format.""" + agent = GeminiAgent( + mcp_client=mock_mcp_client, + model_client=mock_gemini_client, + validate_api_key=False, + ) + + # Set up available tools + agent._available_tools = [ + types.Tool( + name="gemini_computer", + description="Computer tool", + inputSchema={"type": "object"}, + ), + types.Tool( + name="calculator", + description="Calculator tool", + inputSchema={ + "type": "object", + "properties": {"operation": {"type": "string"}}, + }, + ), + ] + + gemini_tools = agent._convert_tools_for_gemini() + + # Should have 2 tools: computer_use and calculator + assert len(gemini_tools) == 2 + + # First should be computer use tool + assert gemini_tools[0].computer_use is not None + assert gemini_tools[0].computer_use.environment == genai_types.Environment.ENVIRONMENT_BROWSER + + # Second should be calculator as function declaration + assert gemini_tools[1].function_declarations is not None + assert len(gemini_tools[1].function_declarations) == 1 + assert gemini_tools[1].function_declarations[0].name == "calculator" + + @pytest.mark.asyncio + async def test_create_user_message(self, mock_mcp_client, mock_gemini_client): + """Test creating a user message.""" + agent = GeminiAgent( + mcp_client=mock_mcp_client, + model_client=mock_gemini_client, + validate_api_key=False, + ) + + message = await agent.create_user_message("Hello Gemini") + + assert message.role == "user" + assert len(message.parts) == 1 + assert message.parts[0].text == "Hello Gemini" + + @pytest.mark.asyncio + async def test_handle_empty_response(self, mock_mcp_client, mock_gemini_client): + """Test handling empty response from API.""" + with patch("hud.settings.settings.telemetry_enabled", False): + agent = GeminiAgent( + mcp_client=mock_mcp_client, + model_client=mock_gemini_client, + validate_api_key=False, + ) + + # Mock empty response + mock_response = MagicMock() + mock_response.candidates = [] + + mock_gemini_client.models = MagicMock() + mock_gemini_client.models.generate_content = MagicMock(return_value=mock_response) + + messages = [genai_types.Content(role="user", parts=[genai_types.Part(text="Hi")])] + response = await agent.get_response(messages) + + assert response.content == "" + assert response.tool_calls == [] + assert response.done is True diff --git a/hud/settings.py b/hud/settings.py index 4e8787c9..bb2013a1 100644 --- a/hud/settings.py +++ b/hud/settings.py @@ -94,6 +94,12 @@ def settings_customise_sources( validation_alias="OPENAI_API_KEY", ) + gemini_api_key: str | None = Field( + default=None, + description="API key for Google Gemini models", + validation_alias="GEMINI_API_KEY", + ) + openrouter_api_key: str | None = Field( default=None, description="API key for OpenRouter models", diff --git a/hud/tools/__init__.py b/hud/tools/__init__.py index c3b6bb96..38d06bb6 100644 --- a/hud/tools/__init__.py +++ b/hud/tools/__init__.py @@ -12,7 +12,12 @@ from .submit import SubmitTool if TYPE_CHECKING: - from .computer import AnthropicComputerTool, HudComputerTool, OpenAIComputerTool + from .computer import ( + AnthropicComputerTool, + GeminiComputerTool, + HudComputerTool, + OpenAIComputerTool, + ) __all__ = [ "AnthropicComputerTool", @@ -20,6 +25,7 @@ "BaseTool", "BashTool", "EditTool", + "GeminiComputerTool", "HudComputerTool", "OpenAIComputerTool", "PlaywrightTool", @@ -30,7 +36,7 @@ def __getattr__(name: str) -> Any: """Lazy import computer tools to avoid importing pyautogui unless needed.""" - if name in ("AnthropicComputerTool", "HudComputerTool", "OpenAIComputerTool"): + if name in ("AnthropicComputerTool", "HudComputerTool", "OpenAIComputerTool", "GeminiComputerTool"): from . import computer return getattr(computer, name) diff --git a/hud/tools/computer/__init__.py b/hud/tools/computer/__init__.py index 67814548..9697488a 100644 --- a/hud/tools/computer/__init__.py +++ b/hud/tools/computer/__init__.py @@ -3,6 +3,7 @@ from __future__ import annotations from .anthropic import AnthropicComputerTool +from .gemini import GeminiComputerTool from .hud import HudComputerTool from .openai import OpenAIComputerTool from .qwen import QwenComputerTool @@ -10,6 +11,7 @@ __all__ = [ "AnthropicComputerTool", + "GeminiComputerTool", "HudComputerTool", "OpenAIComputerTool", "QwenComputerTool", diff --git a/hud/tools/computer/gemini.py b/hud/tools/computer/gemini.py new file mode 100644 index 00000000..4de680df --- /dev/null +++ b/hud/tools/computer/gemini.py @@ -0,0 +1,334 @@ +from __future__ import annotations + +import logging +import platform +from typing import TYPE_CHECKING, Any, Literal + +from mcp import ErrorData, McpError +from mcp.types import INVALID_PARAMS, ContentBlock +from pydantic import Field + +from hud.tools.types import ContentResult + +from .hud import HudComputerTool +from .settings import computer_settings + +if TYPE_CHECKING: + from hud.tools.executors.base import BaseExecutor + +logger = logging.getLogger(__name__) + + +class GeminiComputerTool(HudComputerTool): + """ + Gemini Computer Use tool for interacting with a computer via MCP. + + Maps Gemini's predefined function names (open_web_browser, click_at, hover_at, + type_text_at, scroll_document, scroll_at, wait_5_seconds, go_back, go_forward, + search, navigate, key_combination, drag_and_drop) to executor actions. + """ + + def __init__( + self, + # Define within environment based on platform + executor: BaseExecutor | None = None, + platform_type: Literal["auto", "xdo", "pyautogui"] = "auto", + display_num: int | None = None, + # Overrides for what dimensions the agent thinks it operates in + width: int = computer_settings.GEMINI_COMPUTER_WIDTH, + height: int = computer_settings.GEMINI_COMPUTER_HEIGHT, + rescale_images: bool = computer_settings.GEMINI_RESCALE_IMAGES, + # What the agent sees as the tool's name, title, and description + name: str | None = None, + title: str | None = None, + description: str | None = None, + **kwargs: Any, + ) -> None: + """ + Initialize with Gemini's default dimensions. + """ + super().__init__( + executor=executor, + platform_type=platform_type, + display_num=display_num, + width=width, + height=height, + rescale_images=rescale_images, + name=name or "gemini_computer", + title=title or "Gemini Computer Tool", + description=description or "Control computer with mouse, keyboard, and screenshots", + **kwargs, + ) + + async def __call__( + self, + action: str = Field(..., description="Gemini Computer Use action to perform"), + # Common coordinates + x: int | None = Field(None, description="X coordinate (pixels in agent space)"), + y: int | None = Field(None, description="Y coordinate (pixels in agent space)"), + # Text input + text: str | None = Field(None, description="Text to type"), + press_enter: bool | None = Field( + None, description="Whether to press Enter after typing (type_text_at)" + ), + clear_before_typing: bool | None = Field( + None, description="Whether to select-all before typing (type_text_at)" + ), + # Scroll parameters + direction: Literal["up", "down", "left", "right"] | None = Field( + None, description="Scroll direction for scroll_document/scroll_at" + ), + magnitude: int | None = Field( + None, description="Scroll magnitude (pixels in agent space)" + ), + # Navigation + url: str | None = Field(None, description="Target URL for navigate"), + # Key combos + keys: list[str] | str | None = Field(None, description="Keys for key_combination"), + # Drag parameters + destination_x: int | None = Field( + None, description="Destination X for drag_and_drop (agent space)" + ), + destination_y: int | None = Field( + None, description="Destination Y for drag_and_drop (agent space)" + ), + # Behavior + take_screenshot_on_click: bool = Field( + True, description="Whether to include a screenshot for interactive actions" + ), + ) -> list[ContentBlock]: + """ + Handle Gemini Computer Use API calls by mapping to executor actions. + + Returns: + List of MCP content blocks + """ + logger.info("GeminiComputerTool received action: %s", action) + + # Helper to finalize ContentResult: rescale if requested and ensure URL metadata + async def _finalize(result: ContentResult, requested_url: str | None = None) -> list[ContentBlock]: + if result.base64_image and self.rescale_images: + try: + result.base64_image = await self._rescale_screenshot(result.base64_image) + except Exception as e: + logger.warning("Failed to rescale screenshot: %s", e) + # Always include URL metadata if provided; otherwise default to about:blank + result.url = requested_url or result.url or "about:blank" + return result.to_content_blocks() + + # Scale coordinates helper + def _scale(xv: int | None, yv: int | None) -> tuple[int | None, int | None]: + return self._scale_coordinates(xv, yv) + + # Gemini emits coordinates/magnitudes in a 0-1000 normalized space. + def _denormalize(value: int | float | None, axis: Literal["x", "y"]) -> int | None: + if value is None: + return None + try: + numeric = float(value) + except (TypeError, ValueError): + try: + return int(value) # type: ignore[arg-type] + except (TypeError, ValueError): + return None + + # Treat values within the normalized range (including defaults like 800). + if 0 <= numeric <= 1000: + target = self.width if axis == "x" else self.height + numeric = numeric / 1000 * target + + return int(round(numeric)) + + def _scale_distance(value: int | None, axis: Literal["x", "y"]) -> int | None: + if value is None: + return None + scale = self.scale_x if axis == "x" else self.scale_y + if scale != 1.0: + return int(round(value / scale)) + return value + + # Map actions + if action == "open_web_browser": + screenshot = await self.executor.screenshot() + if screenshot: + result = ContentResult(base64_image=screenshot, url="about:blank") + else: + result = ContentResult(error="Failed to take screenshot", url="about:blank") + return await _finalize(result) + + elif action == "click_at": + if x is None or y is None: + raise McpError(ErrorData(code=INVALID_PARAMS, message="x and y are required")) + dx = _denormalize(x, "x") + dy = _denormalize(y, "y") + sx, sy = _scale(dx, dy) + result = await self.executor.click(x=sx, y=sy) + return await _finalize(result) + + elif action == "hover_at": + if x is None or y is None: + raise McpError(ErrorData(code=INVALID_PARAMS, message="x and y are required")) + dx = _denormalize(x, "x") + dy = _denormalize(y, "y") + sx, sy = _scale(dx, dy) + result = await self.executor.move(x=sx, y=sy) + return await _finalize(result) + + elif action == "type_text_at": + if x is None or y is None: + raise McpError(ErrorData(code=INVALID_PARAMS, message="x and y are required")) + if text is None: + raise McpError(ErrorData(code=INVALID_PARAMS, message="text is required")) + + dx = _denormalize(x, "x") + dy = _denormalize(y, "y") + sx, sy = _scale(dx, dy) + + # Focus the field + await self.executor.move(x=sx, y=sy, take_screenshot=False) + await self.executor.click(x=sx, y=sy, take_screenshot=False) + + # Clear existing text if requested + if clear_before_typing is None or clear_before_typing: + is_mac = platform.system().lower() == "darwin" + combo = ["cmd", "a"] if is_mac else ["ctrl", "a"] + await self.executor.press(keys=combo, take_screenshot=False) + delete_key = "backspace" if is_mac else "delete" + await self.executor.press(keys=[delete_key], take_screenshot=False) + + # Type (optionally press enter after) + result = await self.executor.write(text=text, enter_after=bool(press_enter)) + return await _finalize(result) + + elif action == "scroll_document": + if direction is None: + raise McpError(ErrorData(code=INVALID_PARAMS, message="direction is required")) + # Default magnitude similar to reference implementation + mag = magnitude if magnitude is not None else 800 + # Convert to environment units while preserving sign + if direction in ("down", "up"): + distance = _denormalize(mag, "y") + if distance is None: + raise McpError( + ErrorData(code=INVALID_PARAMS, message="Unable to determine scroll magnitude") + ) + distance = _scale_distance(distance, "y") + scroll_y = distance if direction == "down" else -distance + scroll_x = None + elif direction in ("right", "left"): + distance = _denormalize(mag, "x") + if distance is None: + raise McpError( + ErrorData(code=INVALID_PARAMS, message="Unable to determine scroll magnitude") + ) + distance = _scale_distance(distance, "x") + scroll_x = distance if direction == "right" else -distance + scroll_y = None + else: + raise McpError(ErrorData(code=INVALID_PARAMS, message=f"Invalid direction: {direction}")) + result = await self.executor.scroll(scroll_x=scroll_x, scroll_y=scroll_y) + return await _finalize(result) + + elif action == "scroll_at": + if direction is None: + raise McpError(ErrorData(code=INVALID_PARAMS, message="direction is required")) + if x is None or y is None: + raise McpError(ErrorData(code=INVALID_PARAMS, message="x and y are required")) + mag = magnitude if magnitude is not None else 800 + dx = _denormalize(x, "x") + dy = _denormalize(y, "y") + sx, sy = _scale(dx, dy) + if direction in ("down", "up"): + distance = _denormalize(mag, "y") + if distance is None: + raise McpError( + ErrorData(code=INVALID_PARAMS, message="Unable to determine scroll magnitude") + ) + distance = _scale_distance(distance, "y") + scroll_y = distance if direction == "down" else -distance + scroll_x = None + elif direction in ("right", "left"): + distance = _denormalize(mag, "x") + if distance is None: + raise McpError( + ErrorData(code=INVALID_PARAMS, message="Unable to determine scroll magnitude") + ) + distance = _scale_distance(distance, "x") + scroll_x = distance if direction == "right" else -distance + scroll_y = None + else: + raise McpError(ErrorData(code=INVALID_PARAMS, message=f"Invalid direction: {direction}")) + result = await self.executor.scroll(x=sx, y=sy, scroll_x=scroll_x, scroll_y=scroll_y) + return await _finalize(result) + + elif action == "wait_5_seconds": + result = await self.executor.wait(time=5000) + return await _finalize(result) + + elif action == "go_back": + is_mac = platform.system().lower() == "darwin" + combo = ["cmd", "["] if is_mac else ["alt", "left"] + result = await self.executor.press(keys=combo) + return await _finalize(result) + + elif action == "go_forward": + is_mac = platform.system().lower() == "darwin" + combo = ["cmd", "]"] if is_mac else ["alt", "right"] + result = await self.executor.press(keys=combo) + return await _finalize(result) + + elif action == "search": + # Best-effort navigate to a default search page + target = url or "https://www.google.com" + is_mac = platform.system().lower() == "darwin" + await self.executor.press(keys=["cmd", "l"] if is_mac else ["ctrl", "l"], take_screenshot=False) + result = await self.executor.write(text=target, enter_after=True) + return await _finalize(result, requested_url=target) + + elif action == "navigate": + if not url: + raise McpError(ErrorData(code=INVALID_PARAMS, message="url is required")) + is_mac = platform.system().lower() == "darwin" + await self.executor.press(keys=["cmd", "l"] if is_mac else ["ctrl", "l"], take_screenshot=False) + result = await self.executor.write(text=url, enter_after=True) + return await _finalize(result, requested_url=url) + + elif action == "key_combination": + if keys is None: + raise McpError(ErrorData(code=INVALID_PARAMS, message="keys is required")) + if isinstance(keys, str): + # Accept formats like "ctrl+c" or "ctrl+shift+t" + key_list = [k.strip() for k in keys.split("+") if k.strip()] + else: + key_list = keys + result = await self.executor.press(keys=key_list) + return await _finalize(result) + + elif action == "drag_and_drop": + if x is None or y is None or destination_x is None or destination_y is None: + raise McpError( + ErrorData( + code=INVALID_PARAMS, + message="x, y, destination_x, and destination_y are required", + ) + ) + sx_norm = _denormalize(x, "x") + sy_norm = _denormalize(y, "y") + dx_norm = _denormalize(destination_x, "x") + dy_norm = _denormalize(destination_y, "y") + sx, sy = _scale(sx_norm, sy_norm) + dx_scaled, dy_scaled = _scale(dx_norm, dy_norm) + # Build a two-point path + path = [] # type: list[tuple[int, int]] + if ( + sx is not None + and sy is not None + and dx_scaled is not None + and dy_scaled is not None + ): + path = [(sx, sy), (dx_scaled, dy_scaled)] + result = await self.executor.drag(path=path) + return await _finalize(result) + + else: + raise McpError(ErrorData(code=INVALID_PARAMS, message=f"Unknown action: {action}")) diff --git a/hud/tools/computer/settings.py b/hud/tools/computer/settings.py index 9a122af7..94dfdfc0 100644 --- a/hud/tools/computer/settings.py +++ b/hud/tools/computer/settings.py @@ -94,5 +94,21 @@ class ComputerSettings(BaseSettings): validation_alias="QWEN_RESCALE_IMAGES", ) + GEMINI_COMPUTER_WIDTH: int = Field( + default=1440, + description="Width of the display to use for the Gemini computer tools", + validation_alias="GEMINI_COMPUTER_WIDTH", + ) + GEMINI_COMPUTER_HEIGHT: int = Field( + default=900, + description="Height of the display to use for the Gemini computer tools", + validation_alias="GEMINI_COMPUTER_HEIGHT", + ) + GEMINI_RESCALE_IMAGES: bool = Field( + default=True, + description="Whether to rescale images to the agent width and height", + validation_alias="GEMINI_RESCALE_IMAGES", + ) + computer_settings = ComputerSettings() diff --git a/hud/tools/computer/tests/test_gemini_tool.py b/hud/tools/computer/tests/test_gemini_tool.py new file mode 100644 index 00000000..3dbb0bf2 --- /dev/null +++ b/hud/tools/computer/tests/test_gemini_tool.py @@ -0,0 +1,125 @@ +from __future__ import annotations + +import platform +from typing import Any + +import pytest + +from hud.tools.computer.gemini import GeminiComputerTool +from hud.tools.types import ContentResult + + +class DummyExecutor: + """Minimal async executor used for tooling tests.""" + + def __init__(self) -> None: + self.calls: list[tuple[str, dict[str, Any]]] = [] + + async def click(self, *, x: int | None, y: int | None, **kwargs: Any) -> ContentResult: + self.calls.append(("click", {"x": x, "y": y, **kwargs})) + return ContentResult(url="about:blank") + + async def move(self, *, x: int | None, y: int | None, **kwargs: Any) -> ContentResult: + self.calls.append(("move", {"x": x, "y": y, **kwargs})) + return ContentResult(url="about:blank") + + async def press(self, *, keys: list[str], **kwargs: Any) -> ContentResult: + self.calls.append(("press", {"keys": keys, **kwargs})) + return ContentResult(url="about:blank") + + async def write(self, *, text: str, enter_after: bool, **kwargs: Any) -> ContentResult: + self.calls.append(("write", {"text": text, "enter_after": enter_after, **kwargs})) + return ContentResult(url="about:blank") + + async def scroll( + self, + *, + x: int | None = None, + y: int | None = None, + scroll_x: int | None = None, + scroll_y: int | None = None, + **kwargs: Any, + ) -> ContentResult: + self.calls.append( + ( + "scroll", + {"x": x, "y": y, "scroll_x": scroll_x, "scroll_y": scroll_y, **kwargs}, + ) + ) + return ContentResult(url="about:blank") + + async def drag(self, *, path: list[tuple[int, int]], **kwargs: Any) -> ContentResult: + self.calls.append(("drag", {"path": path, **kwargs})) + return ContentResult(url="about:blank") + + async def screenshot(self) -> str | None: # pragma: no cover - not used in tests + return None + + async def wait(self, **kwargs: Any) -> ContentResult: # pragma: no cover - unused + return ContentResult(url="about:blank") + + +@pytest.mark.asyncio +async def test_coordinates_denormalize_to_screen_space(): + executor = DummyExecutor() + tool = GeminiComputerTool( + executor=executor, + width=1920, + height=1080, + rescale_images=False, + ) + + await tool(action="click_at", x=500, y=250) + + assert executor.calls[0][0] == "click" + # 500/1000 * 1920 = 960, 250/1000 * 1080 = 270 + assert executor.calls[0][1]["x"] == 960 + assert executor.calls[0][1]["y"] == 270 + + +@pytest.mark.asyncio +async def test_type_text_clears_before_typing(monkeypatch: pytest.MonkeyPatch): + executor = DummyExecutor() + tool = GeminiComputerTool( + executor=executor, + width=1920, + height=1080, + rescale_images=False, + ) + + monkeypatch.setattr(platform, "system", lambda: "Linux") + + await tool(action="type_text_at", x=100, y=100, text="hello", press_enter=True) + + # Expect move + click for focus + assert executor.calls[0][0] == "move" + assert executor.calls[1][0] == "click" + + # Clearing sequence should press ctrl+a then delete + press_calls = [call for call in executor.calls if call[0] == "press"] + assert press_calls[0][1]["keys"] == ["ctrl", "a"] + assert press_calls[1][1]["keys"] == ["delete"] + + # Final write call should include the text and Enter flag + write_call = next(call for call in executor.calls if call[0] == "write") + assert write_call[1]["text"] == "hello" + assert write_call[1]["enter_after"] is True + + +@pytest.mark.asyncio +async def test_drag_and_drop_denormalizes_path(): + executor = DummyExecutor() + tool = GeminiComputerTool( + executor=executor, + width=1920, + height=1080, + rescale_images=False, + ) + + await tool(action="drag_and_drop", x=100, y=200, destination_x=900, destination_y=800) + + drag_call = next(call for call in executor.calls if call[0] == "drag") + path = drag_call[1]["path"] + # Ensure both start and end points are denormalized from 0-1000 range + assert path[0] == (192, 216) # 100/1000 of width/height + assert path[1] == (1728, 864) # 900/1000 of width/height diff --git a/hud/tools/playwright.py b/hud/tools/playwright.py index e170f3af..d728a648 100644 --- a/hud/tools/playwright.py +++ b/hud/tools/playwright.py @@ -84,6 +84,9 @@ async def __call__( code=INVALID_PARAMS, message="url parameter is required for navigate" ) ) + # Guard against pydantic FieldInfo default leaking through + if not isinstance(wait_for_load_state, str): + wait_for_load_state = None result = await self.navigate(url, wait_for_load_state or "networkidle") elif action == "screenshot": @@ -179,11 +182,16 @@ async def _ensure_browser(self) -> None: if self._browser is None: raise RuntimeError("Failed to connect to remote browser") - # Use existing context or create new one + # Reuse existing context and page where possible to avoid spawning new windows contexts = self._browser.contexts if contexts: self._browser_context = contexts[0] + # Prefer the first existing page to keep using the already visible window/tab + existing_pages = self._browser_context.pages + if existing_pages: + self.page = existing_pages[0] else: + # As a fallback, create a new context self._browser_context = await self._browser.new_context( viewport={"width": 1920, "height": 1080}, ignore_https_errors=True, @@ -225,6 +233,9 @@ async def _ensure_browser(self) -> None: if self._browser_context is None: raise RuntimeError("Browser context failed to initialize") + # Only create a new page if we didn't already reuse one above + if self.page is None: + self.page = await self._browser_context.new_page() # Reuse existing page if available (for CDP connections), otherwise create new one pages = self._browser_context.pages if pages: diff --git a/hud/tools/types.py b/hud/tools/types.py index faf92efc..7ca7a593 100644 --- a/hud/tools/types.py +++ b/hud/tools/types.py @@ -28,6 +28,7 @@ class ContentResult(BaseModel): error: str | None = Field(default=None, description="Error message") base64_image: str | None = Field(default=None, description="Base64-encoded image") system: str | None = Field(default=None, description="System message") + url: str | None = Field(default=None, description="Current page URL (for browser automation)") def __add__(self, other: ContentResult) -> ContentResult: def combine_fields( @@ -44,6 +45,7 @@ def combine_fields( error=combine_fields(self.error, other.error), base64_image=combine_fields(self.base64_image, other.base64_image, False), system=combine_fields(self.system, other.system), + url=combine_fields(self.url, other.url, False), ) def to_content_blocks(self) -> list[ContentBlock]: @@ -55,7 +57,7 @@ def to_content_blocks(self) -> list[ContentBlock]: result: ContentResult to convert Returns: - List of ContentBlock + List of ContentBlock with URL embedded as metadata if available """ blocks: list[ContentBlock] = [] @@ -65,6 +67,12 @@ def to_content_blocks(self) -> list[ContentBlock]: blocks.append(TextContent(text=self.error, type="text")) if self.base64_image: blocks.append(ImageContent(data=self.base64_image, mimeType="image/png", type="image")) + + # Add URL as a special metadata text block (for Gemini Computer Use) + # Always include URL if set, even if it's a placeholder like "about:blank" + if self.url: + blocks.append(TextContent(text=f"__URL__:{self.url}", type="text")) + return blocks diff --git a/pyproject.toml b/pyproject.toml index 693e4bb1..c500ddb9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,6 +42,7 @@ dependencies = [ # AI providers "anthropic", "openai", + "google-genai", ] classifiers = [ "Development Status :: 4 - Beta",