Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions hud/agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@

from .base import MCPAgent
from .claude import ClaudeAgent
from .gemini import GeminiAgent
from .openai import OperatorAgent
from .openai_chat_generic import GenericOpenAIChatAgent

__all__ = [
"ClaudeAgent",
"GeminiAgent",
"GenericOpenAIChatAgent",
"MCPAgent",
"OperatorAgent",
Expand Down
489 changes: 489 additions & 0 deletions hud/agents/gemini.py

Large diffs are not rendered by default.

372 changes: 372 additions & 0 deletions hud/agents/tests/test_gemini.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,372 @@
"""Tests for Gemini MCP Agent implementation."""

from __future__ import annotations

import base64
from unittest.mock import AsyncMock, MagicMock, patch

import pytest
from google.genai import types as genai_types
from mcp import types

from hud.agents.gemini import GeminiAgent
from hud.types import MCPToolCall, MCPToolResult


class TestGeminiAgent:
"""Test GeminiAgent class."""

@pytest.fixture
def mock_mcp_client(self):
"""Create a mock MCP client."""
mcp_client = AsyncMock()
# Set up the mcp_config attribute as a regular dict, not a coroutine
mcp_client.mcp_config = {"test_server": {"url": "http://test"}}
# Mock list_tools to return gemini_computer tool
mcp_client.list_tools = AsyncMock(
return_value=[
types.Tool(
name="gemini_computer",
description="Gemini computer use tool",
inputSchema={},
)
]
)
mcp_client.initialize = AsyncMock()
return mcp_client

@pytest.fixture
def mock_gemini_client(self):
"""Create a mock Gemini client."""
client = MagicMock()
client.api_key = "test_key"
# Mock models.list for validation
client.models = MagicMock()
client.models.list = MagicMock(return_value=iter([]))
return client

@pytest.mark.asyncio
async def test_init(self, mock_mcp_client, mock_gemini_client):
"""Test agent initialization."""
agent = GeminiAgent(
mcp_client=mock_mcp_client,
model_client=mock_gemini_client,
model="gemini-2.5-computer-use-preview-10-2025",
validate_api_key=False, # Skip validation in tests
)

assert agent.model_name == "gemini-2.5-computer-use-preview-10-2025"
assert agent.model == "gemini-2.5-computer-use-preview-10-2025"
assert agent.gemini_client == mock_gemini_client

@pytest.mark.asyncio
async def test_init_without_model_client(self, mock_mcp_client):
"""Test agent initialization without model client."""
with (
patch("hud.settings.settings.gemini_api_key", "test_key"),
patch("hud.agents.gemini.genai.Client") as mock_client_class,
):
mock_client = MagicMock()
mock_client.api_key = "test_key"
mock_client.models = MagicMock()
mock_client.models.list = MagicMock(return_value=iter([]))
mock_client_class.return_value = mock_client

agent = GeminiAgent(
mcp_client=mock_mcp_client,
model="gemini-2.5-computer-use-preview-10-2025",
validate_api_key=False,
)

assert agent.model_name == "gemini-2.5-computer-use-preview-10-2025"
assert agent.gemini_client is not None

@pytest.mark.asyncio
async def test_format_blocks(self, mock_mcp_client, mock_gemini_client):
"""Test formatting content blocks into Gemini messages."""
agent = GeminiAgent(
mcp_client=mock_mcp_client,
model_client=mock_gemini_client,
validate_api_key=False,
)

# Test with text only
text_blocks: list[types.ContentBlock] = [
types.TextContent(type="text", text="Hello, Gemini!")
]
messages = await agent.format_blocks(text_blocks)
assert len(messages) == 1
assert messages[0].role == "user"
parts = messages[0].parts
assert parts is not None
assert len(parts) == 1
assert parts[0].text == "Hello, Gemini!"

# Test with screenshot
image_blocks: list[types.ContentBlock] = [
types.TextContent(type="text", text="Look at this"),
types.ImageContent(
type="image",
data=base64.b64encode(b"fakeimage").decode("utf-8"),
mimeType="image/png",
),
]
messages = await agent.format_blocks(image_blocks)
assert len(messages) == 1
assert messages[0].role == "user"
parts = messages[0].parts
assert parts is not None
assert len(parts) == 2
# First part is text
assert parts[0].text == "Look at this"
# Second part is image - check that it was created from bytes
assert parts[1].inline_data is not None

@pytest.mark.asyncio
async def test_format_tool_results(self, mock_mcp_client, mock_gemini_client):
"""Test the agent's format_tool_results method."""
agent = GeminiAgent(
mcp_client=mock_mcp_client,
model_client=mock_gemini_client,
validate_api_key=False,
)

tool_calls = [
MCPToolCall(
name="gemini_computer",
arguments={"action": "click_at", "x": 100, "y": 200},
id="call_1", # type: ignore
gemini_name="click_at", # type: ignore
),
]

tool_results = [
MCPToolResult(
content=[
types.TextContent(type="text", text="Clicked successfully"),
types.ImageContent(
type="image",
data=base64.b64encode(b"screenshot").decode("utf-8"),
mimeType="image/png",
),
],
isError=False,
),
]

messages = await agent.format_tool_results(tool_calls, tool_results)

# format_tool_results returns a single user message with function responses
assert len(messages) == 1
assert messages[0].role == "user"
# The content contains function response parts
parts = messages[0].parts
assert parts is not None
assert len(parts) == 1
function_response = parts[0].function_response
assert function_response is not None
assert function_response.name == "click_at"
response_payload = function_response.response or {}
assert response_payload.get("success") is True

@pytest.mark.asyncio
async def test_format_tool_results_with_error(self, mock_mcp_client, mock_gemini_client):
"""Test formatting tool results with errors."""
agent = GeminiAgent(
mcp_client=mock_mcp_client,
model_client=mock_gemini_client,
validate_api_key=False,
)

tool_calls = [
MCPToolCall(
name="gemini_computer",
arguments={"action": "invalid"},
id="call_error", # type: ignore
gemini_name="invalid_action", # type: ignore
),
]

tool_results = [
MCPToolResult(
content=[types.TextContent(type="text", text="Action failed: invalid action")],
isError=True,
),
]

messages = await agent.format_tool_results(tool_calls, tool_results)

# Check that error is in the response
assert len(messages) == 1
assert messages[0].role == "user"
parts = messages[0].parts
assert parts is not None
function_response = parts[0].function_response
assert function_response is not None
response_payload = function_response.response or {}
assert "error" in response_payload

@pytest.mark.asyncio
async def test_get_response(self, mock_mcp_client, mock_gemini_client):
"""Test getting model response from Gemini API."""
# Disable telemetry for this test
with patch("hud.settings.settings.telemetry_enabled", False):
agent = GeminiAgent(
mcp_client=mock_mcp_client,
model_client=mock_gemini_client,
validate_api_key=False,
)

# Set up available tools
agent._available_tools = [
types.Tool(name="gemini_computer", description="Computer tool", inputSchema={})
]

# Mock the API response
mock_response = MagicMock()
mock_candidate = MagicMock()

# Create text part
text_part = MagicMock()
text_part.text = "I will click at coordinates"
text_part.function_call = None

# Create function call part
function_call_part = MagicMock()
function_call_part.text = None
function_call_part.function_call = MagicMock()
function_call_part.function_call.name = "click_at"
function_call_part.function_call.args = {"x": 100, "y": 200}

mock_candidate.content = MagicMock()
mock_candidate.content.parts = [text_part, function_call_part]

mock_response.candidates = [mock_candidate]

mock_gemini_client.models = MagicMock()
mock_gemini_client.models.generate_content = MagicMock(return_value=mock_response)

messages = [genai_types.Content(role="user", parts=[genai_types.Part(text="Click")])]
response = await agent.get_response(messages)

assert response.content == "I will click at coordinates"
assert len(response.tool_calls) == 1
assert response.tool_calls[0].arguments == {"action": "click_at", "x": 100, "y": 200}
assert response.done is False

@pytest.mark.asyncio
async def test_get_response_text_only(self, mock_mcp_client, mock_gemini_client):
"""Test getting text-only response."""
# Disable telemetry for this test
with patch("hud.settings.settings.telemetry_enabled", False):
agent = GeminiAgent(
mcp_client=mock_mcp_client,
model_client=mock_gemini_client,
validate_api_key=False,
)

# Mock the API response with text only
mock_response = MagicMock()
mock_candidate = MagicMock()

text_part = MagicMock()
text_part.text = "Task completed successfully"
text_part.function_call = None

mock_candidate.content = MagicMock()
mock_candidate.content.parts = [text_part]

mock_response.candidates = [mock_candidate]

mock_gemini_client.models = MagicMock()
mock_gemini_client.models.generate_content = MagicMock(return_value=mock_response)

messages = [genai_types.Content(role="user", parts=[genai_types.Part(text="Status?")])]
response = await agent.get_response(messages)

assert response.content == "Task completed successfully"
assert response.tool_calls == []
assert response.done is True

@pytest.mark.asyncio
async def test_convert_tools_for_gemini(self, mock_mcp_client, mock_gemini_client):
"""Test converting MCP tools to Gemini format."""
agent = GeminiAgent(
mcp_client=mock_mcp_client,
model_client=mock_gemini_client,
validate_api_key=False,
)

# Set up available tools
agent._available_tools = [
types.Tool(
name="gemini_computer",
description="Computer tool",
inputSchema={"type": "object"},
),
types.Tool(
name="calculator",
description="Calculator tool",
inputSchema={
"type": "object",
"properties": {"operation": {"type": "string"}},
},
),
]

gemini_tools = agent._convert_tools_for_gemini()

# Should have 2 tools: computer_use and calculator
assert len(gemini_tools) == 2

# First should be computer use tool
assert gemini_tools[0].computer_use is not None
assert (
gemini_tools[0].computer_use.environment == genai_types.Environment.ENVIRONMENT_BROWSER
)

# Second should be calculator as function declaration
assert gemini_tools[1].function_declarations is not None
assert len(gemini_tools[1].function_declarations) == 1
assert gemini_tools[1].function_declarations[0].name == "calculator"

@pytest.mark.asyncio
async def test_create_user_message(self, mock_mcp_client, mock_gemini_client):
"""Test creating a user message."""
agent = GeminiAgent(
mcp_client=mock_mcp_client,
model_client=mock_gemini_client,
validate_api_key=False,
)

message = await agent.create_user_message("Hello Gemini")

assert message.role == "user"
parts = message.parts
assert parts is not None
assert len(parts) == 1
assert parts[0].text == "Hello Gemini"

@pytest.mark.asyncio
async def test_handle_empty_response(self, mock_mcp_client, mock_gemini_client):
"""Test handling empty response from API."""
with patch("hud.settings.settings.telemetry_enabled", False):
agent = GeminiAgent(
mcp_client=mock_mcp_client,
model_client=mock_gemini_client,
validate_api_key=False,
)

# Mock empty response
mock_response = MagicMock()
mock_response.candidates = []

mock_gemini_client.models = MagicMock()
mock_gemini_client.models.generate_content = MagicMock(return_value=mock_response)

messages = [genai_types.Content(role="user", parts=[genai_types.Part(text="Hi")])]
response = await agent.get_response(messages)

assert response.content == ""
assert response.tool_calls == []
assert response.done is True
1 change: 1 addition & 0 deletions hud/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -917,6 +917,7 @@ def eval(
[
{"name": "Claude 4 Sonnet", "value": AgentType.CLAUDE},
{"name": "OpenAI Computer Use", "value": AgentType.OPENAI},
{"name": "Gemini Computer Use", "value": AgentType.GEMINI},
{"name": "vLLM (Local Server)", "value": AgentType.VLLM},
{"name": "LiteLLM (Multi-provider)", "value": AgentType.LITELLM},
]
Expand Down
Loading
Loading