diff --git a/python/agents/tau2-benchmark-agent/README.md b/python/agents/tau2-benchmark-agent/README.md new file mode 100644 index 000000000..67b749235 --- /dev/null +++ b/python/agents/tau2-benchmark-agent/README.md @@ -0,0 +1,235 @@ +# Using ADK Agent with τ-bench + +This guide provides instructions on how to set up and use a custom agent that integrates with the Google ADK (Agent Development Kit) within the τ-bench framework. + +## 1. Clone the Repository + +First, clone the specific version of the τ-bench repository that is compatible with these modifications. + +```bash +git clone https://github.com/sierra-research/tau2-bench.git +cd tau2-bench +git checkout cc97b34 +``` + +## 2. Setup and Installation + +Follow the standard installation instructions to set up the environment for the project. + + +Ensure that you have venv: + +```bash +sudo apt install python3.13-venv +``` + +```bash +python3 -m venv .venv +source .venv/bin/activate + +pip install -e . +``` + +You can check tau data: + +```bash +tau2 check-data +``` + +After the standard installation, you need to install the Google ADK package. + +```bash +pip install google-adk +``` + +tenacity dependecy version may conflict with that of tau2 repo. Upgrading it back in favor of tau2. + +```bash +pip install --upgrade tenacity +``` + +**IMPORTANT:** Gemini 3 Pro model makes sending thought signatures mandatory. Tau2 bench relies on litellm for user simulation and non-adk agent simulation. Until https://github.com/BerriAI/litellm/pull/16812 is merged to litellm repository, the PR needs to be applied as shown below: + +```bash +git clone --filter=blob:none --quiet https://github.com/BerriAI/litellm.git /tmp/litellm-pr-16812 +cd /tmp/litellm-pr-16812 +git checkout -q pull/16812/head +git fetch origin pull/16812/head:pr-16812 +git checkout pr-16812 +pip install . +cd - +``` + +## 3. Add env params + +Create `.env` file at root with the following content. + +```bash +GOOGLE_GENAI_USE_VERTEXAI=true +GOOGLE_CLOUD_PROJECT=your_project_id +GOOGLE_CLOUD_LOCATION=global +VERTEXAI_LOCATION=global +``` + +## 4. Copy Modified Files + +You will need to copy two files into your local `tau2-bench` directory. These files contain the implementation of the ADK agent and its unit tests. + +Assuming you have the following files available: + +* `tau2_agent/adk_agent.py` +* `tests/test_adk_agent.py` + +Copy them to the correct locations within the project: + +```bash +cp ../tau2_agent/adk_agent.py src/tau2/agent/adk_agent.py +cp ../tests/test_adk_agent.py tests/test_adk_agent.py +``` + +## 4.1. Registering the ADK Agent + +To enable the `adk_agent` within the `tau2-bench` framework, you need to manually modify the `src/tau2/registry.py` file in your `tau2-bench` repository. + +1. **Add the import statement** for `AdkAgent` at the top of `src/tau2/registry.py`: + ```python + from tau2.agent.adk_agent import AdkAgent + ``` +2. **Register the agent** within the `try` block where other default components are registered (look for `registry = Registry()`): + ```python + try: + registry = Registry() + logger.debug("Registering default components...") + # ... existing registrations ... + registry.register_agent(AdkAgent, "adk_agent") # Add this line + # ... more existing registrations ... + ``` + This allows the `adk_agent` to be selected via command-line arguments (e.g., `--agent adk_agent`). + +## 5. Running the ADK Agent + +Once the files are in place and dependencies are installed, you can run a simulation using the `adk_agent`. + +### Changing baseline ADK agent + +You can implement an improved agent and replace `AdkLlmAgent` implementation `_create_agent` returns as shown below. Return type is BaseAgent which accommodates workflow agents as well. + +```py +def _create_agent(name: str, model: Union[str, BaseLlm], instruction: str, tools: List[Tool]) -> BaseAgent: + adk_tools = [ + AdkTool( + types.FunctionDeclaration( + name=tool.openai_schema['function']['name'], + description=tool.openai_schema['function'].get('description', ''), + parameters_json_schema=tool.openai_schema['function']['parameters'], + ) + ) + for tool in tools + ] + return AdkLlmAgent( + model=model, + name=name, + instruction=instruction, + tools=adk_tools, + planner=built_in_planner.BuiltInPlanner( + thinking_config=types.ThinkingConfig(include_thoughts=True), + ), + ) +``` + +### Limited run + +Here is an example command to run the agent on an airline domain task: + +```bash +tau2 run --domain airline --agent adk_agent --agent-llm vertex_ai/gemini-3-pro-preview --user-llm vertex_ai/gemini-3-pro-preview --num-trials 1 --num-tasks 1 --user-llm-args '{"temperature": 1, "reasoning_effort": "high"}' --agent-llm-args '{"temperature": 1, "reasoning_effort": "high"}' +``` + +Optionally, you can run specific example by using `--task-ids` instead of `--num-tasks`. + +**temperature:** When adk_agent is used defaults to 1. The commands in this document sets them explicitly using llm_args for both user and agent models. + +**reasoning_level** Only applies to Gemini 3 Pro model. It defaults to high for adk_agent while using this model. Otherwise, it will default to dynamic thinking. Again this document demonsrates setting it explicitly using llm_args. + +**NOTE**: It is normal that you will be getting `This model isn't mapped yet` error logs. This is coming from litellm cost calculation workflow used by `--user-llm`. You can suppress is temporarily by swapping `--user-llm vertex_ai/gemini-3-pro-preview` with `--user-llm vertex_ai/gemini-2.5-pro`. + +### Viewing trajectories + +You can use the following command to view trajectories after following the default options: + +```bash +tau2 view +``` + +### Full run + +Full run requires dropping the arg `--task-ids`. + +```bash +# Example: Run complete evaluation for all domains +tau2 run \ + --domain retail \ + --agent adk_agent \ + --agent-llm vertex_ai/gemini-3-pro-preview \ + --user-llm vertex_ai/gemini-3-pro-preview \ + --num-trials 4 \ + --save-to gemini_3_pro_retail \ + --user-llm-args '{"temperature": 1, "reasoning_effort": "high"}' \ + --agent-llm-args '{"temperature": 1, "reasoning_effort": "high"}' + + +tau2 run \ + --domain airline \ + --agent adk_agent \ + --agent-llm vertex_ai/gemini-3-pro-preview \ + --user-llm vertex_ai/gemini-3-pro-preview \ + --num-trials 4 \ + --save-to gemini_3_pro_airline \ + --user-llm-args '{"temperature": 1, "reasoning_effort": "high"}' \ + --agent-llm-args '{"temperature": 1, "reasoning_effort": "high"}' + + +tau2 run \ + --domain telecom \ + --agent adk_agent \ + --agent-llm vertex_ai/gemini-3-pro-preview \ + --user-llm vertex_ai/gemini-3-pro-preview \ + --num-trials 4 \ + --save-to gemini_3_pro_telecom \ + --user-llm-args '{"temperature": 1, "reasoning_effort": "high"}' \ + --agent-llm-args '{"temperature": 1, "reasoning_effort": "high"}' +``` + +### Prepare Submission Package + +```bash +tau2 submit prepare data/tau2/simulations/gemini_3_pro_*.json --output ./gemini_3_pro_submission +``` + +This command will: + +- Verify all trajectory files are valid +- Check that submission requirements are met +- Compute performance metrics (Pass^k rates) +- Prompt for required metadata (model name, organization, contact email) -> you can pass dummy values here as we are not submitting yet. +- Create a structured submission directory with: + - `submission.json`: Metadata and metrics + - `trajectories/`: Your trajectory files + +## 6. Testing the Agent + +To verify that the agent is set up correctly, you can run its unit tests using `pytest`. + +```bash +pytest tests/test_adk_agent.py +``` + +To see coverage (optional): + +```bash +pip install pytest-cov +``` + +```bash +pytest --cov=tau2.agent.adk_agent --cov-report=html tests/test_adk_agent.py +```` diff --git a/python/agents/tau2-benchmark-agent/pyproject.toml b/python/agents/tau2-benchmark-agent/pyproject.toml new file mode 100644 index 000000000..d2787c754 --- /dev/null +++ b/python/agents/tau2-benchmark-agent/pyproject.toml @@ -0,0 +1,29 @@ +[project] +name = "tau2-benchmark-agent" +version = "0.1.0" +description = "ADK Agent for Tau-Bench" +authors = [{ name = "Google ADK Team", email = "no-reply@google.com" }] +license = "Apache-2.0" +readme = "README.md" +dependencies = [ + "google-adk==1.19.0", + "google-genai==1.51.0", + "loguru==0.7.3", + "tenacity==9.1.2", + "tau2 @ git+https://github.com/sierra-research/tau2-bench.git@cc97b34", +] +requires-python = ">=3.10" + +[dependency-groups] +dev = [ + "pytest>=8.0.0", + "pytest-asyncio>=0.23.0", +] + +[tool.setuptools.packages.find] +where = ["."] +include = ["tau2_agent*"] + +[tool.pytest.ini_options] +pythonpath = "." +asyncio_mode = "auto" \ No newline at end of file diff --git a/python/agents/tau2-benchmark-agent/tau2_agent/__init__.py b/python/agents/tau2-benchmark-agent/tau2_agent/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/python/agents/tau2-benchmark-agent/tau2_agent/adk_agent.py b/python/agents/tau2-benchmark-agent/tau2_agent/adk_agent.py new file mode 100644 index 000000000..7633fa546 --- /dev/null +++ b/python/agents/tau2-benchmark-agent/tau2_agent/adk_agent.py @@ -0,0 +1,364 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import asyncio +from typing import Any, Dict, List, Optional, Union + +from google.adk import Agent as AdkLlmAgent +from google.adk.agents import BaseAgent +from google.adk.models.base_llm import BaseLlm +from google.adk.planners import built_in_planner +from google.adk.plugins import ReflectAndRetryToolPlugin +from google.adk.runners import InMemoryRunner +from google.adk.tools import base_tool +from google.genai import types +from loguru import logger +from tau2.agent.llm_agent import LLMAgent, LLMAgentState +from tau2.data_model.message import ( + AssistantMessage, + MultiToolMessage, + ToolCall, + ToolMessage, + UserMessage, +) +from tau2.environment.tool import Tool + + +class AdkTool(base_tool.BaseTool): + """Long running tool that escalates and stops invocation for ADK Agent""" + + def __init__(self, function_declaration: types.FunctionDeclaration): + """Initialize the AdkTool with a function declaration. + + Args: + function_declaration: The function declaration for the tool. + """ + super().__init__( + name=function_declaration.name, + description=function_declaration.description, + is_long_running=True, + ) + self._function_declaration = function_declaration + + def _get_declaration(self): + """Get the function declaration for the tool.""" + + return self._function_declaration + + async def run_async(self, *, args, tool_context) -> Any: + """Run the tool asynchronously.""" + + tool_context.actions.escalate = True + return None + + +def _create_agent( + name: str, + model: Union[str, BaseLlm], + instruction: str, + tools: List[Tool], + llm_args: Dict[str, Any], +) -> BaseAgent: + """Create an ADK LLM Agent with the given parameters. + + Args: + name: The name of the agent. + model: The LLM model to use. + instruction: The system prompt/instruction for the agent. + tools: The list of tools available to the agent. + llm_args: Additional arguments for the LLM. + + Returns: + An instance of BaseAgent (which also allows workflow agents). + """ + adk_tools = [ + AdkTool( + types.FunctionDeclaration( + name=tool.openai_schema["function"]["name"], + description=tool.openai_schema["function"].get("description", ""), + parameters_json_schema=tool.openai_schema["function"]["parameters"], + ) + ) + for tool in tools + ] + + generate_content_config = types.GenerateContentConfig() + generate_content_config.temperature = llm_args.get( + "temperature", 1 + ) # default to recommended temperature for gemini models + + thinking_level = None + if ( + isinstance(model, str) + and model.startswith("gemini-3") + and "reasoning_effort" in llm_args + ): + thinking_level = llm_args["reasoning_effort"] + + thinking_config = types.ThinkingConfig( + include_thoughts=True, thinking_level=thinking_level, thinking_budget=None + ) + + return AdkLlmAgent( + model=model, + name=name, + instruction=instruction, + tools=adk_tools, + planner=built_in_planner.BuiltInPlanner( + thinking_config=thinking_config, + ), + generate_content_config=generate_content_config, + ) + + +class AdkAgent(LLMAgent): + """Agent that uses ADK to interact with LLMs and tools.""" + + def __init__( + self, + tools: List[Tool], + domain_policy: str, + llm: Optional[str] = None, + llm_args: Optional[dict] = None, + ): + """Initialize the AdkAgent with the given parameters. + + Args: + tools: The list of tools available to the agent. + domain_policy: The domain policy for the agent. + llm: The LLM model to use. + llm_args: Additional arguments for the LLM. + """ + super().__init__( + tools=tools, domain_policy=domain_policy, llm=llm, llm_args=llm_args + ) + model_name = llm or "gemini-2.5-pro" + assert ( + "gemini" in model_name + ), "AdkAgent only supports gemini models for this benchmark." + if model_name.startswith("vertex_ai/"): + model_name = model_name.replace("vertex_ai/", "") + if model_name.startswith("gemini/"): + model_name = model_name.replace("gemini/", "") + self._adk_root_agent = _create_agent( + name="customer_service_agent", + model=self.llm_args.get("model_obj", model_name), + instruction=self.system_prompt, + tools=tools, + llm_args=llm_args, + ) + + error_handling_plugin = ReflectAndRetryToolPlugin( + max_retries=3, throw_exception_if_retry_exceeded=False + ) + + self._runner = InMemoryRunner( + agent=self._adk_root_agent, + app_name="tau2_adk_app", + plugins=[error_handling_plugin], + ) + self._app_name = "tau2_adk_app" + self._user_id = "tau2_user" + try: + self.session = asyncio.run( + self._runner.session_service.create_session( + app_name=self._app_name, + user_id=self._user_id, + ) + ) + except RuntimeError: + self.session = None + + async def async_setup(self) -> None: + """Asynchronous setup for the AdkAgent.""" + + if self.session is None: + self.session = await self._runner.session_service.create_session( + app_name=self._app_name, user_id=self._user_id + ) + + async def _run_prompt_async( + self, + new_message: Optional[str], + function_responses: Optional[list[types.FunctionResponse]] = None, + ) -> AssistantMessage: + """Run the prompt asynchronously and return the assistant message. + + Args: + new_message: The new message from the user. + function_responses: The list of function responses from tools. + + Returns: + An AssistantMessage containing the response from the agent. + """ + if new_message is not None: + content = types.Content( + role="user", parts=[types.Part.from_text(text=new_message)] + ) + else: + content = types.Content( + role="user", + parts=[types.Part(function_response=fr) for fr in function_responses], + ) + + logger.info(f"** User says: {content.model_dump(exclude_none=True)}") + text_content = "" + tool_calls: list[ToolCall] = [] + async for event in self._runner.run_async( + user_id=self._user_id, session_id=self.session.id, new_message=content + ): + if event is None or event.content is None: + continue + + logger.info(f"** Event received: {event.content.parts}") + for part in event.content.parts: + if part.function_call: + logger.info( + f"** Tool call: {part.function_call.name} with arguments" + f" {part.function_call.args}" + ) + self.add_long_running_call_info( + (part.function_call.id, part.function_call.name) + ) + tool_calls.append( + ToolCall( + id=part.function_call.id, + name=part.function_call.name, + arguments=part.function_call.args, + requestor="assistant", + ) + ) + elif part.text: + if not part.thought: + text_content += part.text + else: + logger.info(f"** Other part type received: {part}") + + return AssistantMessage( + role="assistant", + content=text_content or None, + tool_calls=tool_calls or None, + ) + + def generate_next_message( + self, message: Any, state: LLMAgentState + ) -> tuple[AssistantMessage, LLMAgentState]: + """Generate the next message from the agent based on the input message. + + Args: + message: The input message from the user or tool. + state: The current state of the agent. + + Returns: + A tuple containing the assistant message and the updated agent state. + """ + if isinstance(message, MultiToolMessage): + state.messages.extend(message.tool_messages) + else: + state.messages.append(message) + + if getattr(self, "session", None) is None: + try: + asyncio.run(self.async_setup()) + except RuntimeError: + raise RuntimeError( + "Cannot create ADK session: an event loop is already running." + ) + + if isinstance(message, UserMessage): + assistant_message = asyncio.run( + self._run_prompt_async(new_message=message.content) + ) + elif isinstance(message, ToolMessage): + call_id, call_name = self.pop_long_running_call_info_with_id(message.id) + if not call_id or not call_name: + call_id, call_name = self.pop_long_running_call_info() + + json_response = {"result": message.content} + function_response = types.FunctionResponse( + id=call_id, + name=call_name, + response=json_response, + ) + assistant_message = asyncio.run( + self._run_prompt_async( + new_message=None, function_responses=[function_response] + ) + ) + elif isinstance(message, MultiToolMessage): + function_responses = [] + for tm in message.tool_messages: + call_id, call_name = self.pop_long_running_call_info_with_id(tm.id) + if not call_id or not call_name: + call_id, call_name = self.pop_long_running_call_info() + + json_response = {"result": tm.content} + function_response = types.FunctionResponse( + id=call_id, + name=call_name, + response=json_response, + ) + function_responses.append(function_response) + assistant_message = asyncio.run( + self._run_prompt_async( + new_message=None, function_responses=function_responses + ) + ) + else: + assistant_message = asyncio.run(self._run_prompt_async(new_message="")) + + state.messages.append(assistant_message) + return assistant_message, state + + def add_long_running_call_info(self, call_info: tuple[str, str]): + """Add information about a long-running call. + + Args: + call_info: A tuple containing the call ID and call name. + """ + if not hasattr(self, "long_running_call_infos"): + self.long_running_call_infos = [] + self.long_running_call_infos.append(call_info) + + def pop_long_running_call_info(self): + """Pop the oldest long-running call information. + + Returns: + A tuple containing the call ID and call name, or None if no information + is available. + """ + if hasattr(self, "long_running_call_infos") and self.long_running_call_infos: + return self.long_running_call_infos.pop(0) + return None + + def pop_long_running_call_info_with_id( + self, call_id: str + ) -> Optional[tuple[str, str]]: + """Pop long-running call information by call ID. + + Args: + call_id: The ID of the long-running call to pop. + + Returns: + A tuple containing the call ID and call name, or None if no information + is available. + """ + if hasattr(self, "long_running_call_infos") and self.long_running_call_infos: + for i, (stored_call_id, call_name) in enumerate( + self.long_running_call_infos + ): + if stored_call_id == call_id: + return self.long_running_call_infos.pop(i) + return None diff --git a/python/agents/tau2-benchmark-agent/tests/conftest.py b/python/agents/tau2-benchmark-agent/tests/conftest.py new file mode 100644 index 000000000..712002f82 --- /dev/null +++ b/python/agents/tau2-benchmark-agent/tests/conftest.py @@ -0,0 +1,50 @@ +import sys + +import pytest +import tau2.agent + +try: + from tau2_agent import adk_agent +except ImportError: + # Fallback: try to import from relative path if installed as editable but path issues + import os + + # Assuming this conftest is in tests/ and tau2_agent is in ../tau2_agent/ + project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) + if project_root not in sys.path: + sys.path.insert(0, project_root) + + from tau2_agent import adk_agent + +# Inject the local adk_agent module into the tau2.agent namespace +# This allows 'from tau2.agent.adk_agent import AdkAgent' to work +# even though adk_agent.py is not physically in the installed tau2 package. +tau2.agent.adk_agent = adk_agent +sys.modules["tau2.agent.adk_agent"] = adk_agent + + +@pytest.fixture +def get_environment(): + """Fixture to provide a mock environment with tools and policy.""" + + class MockTool: + def __init__(self, name="mock_tool"): + self.openai_schema = { + "function": { + "name": name, + "description": f"Description for {name}", + "parameters": { + "type": "object", + "properties": {"arg1": {"type": "string"}}, + }, + } + } + + class MockEnv: + def get_tools(self): + return [MockTool("create_task"), MockTool("get_users")] + + def get_policy(self): + return "You are a helpful assistant." + + return lambda: MockEnv() diff --git a/python/agents/tau2-benchmark-agent/tests/test_adk_agent.py b/python/agents/tau2-benchmark-agent/tests/test_adk_agent.py new file mode 100644 index 000000000..988cf83e7 --- /dev/null +++ b/python/agents/tau2-benchmark-agent/tests/test_adk_agent.py @@ -0,0 +1,233 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import AsyncGenerator + +import pytest +from google.adk.models.base_llm import BaseLlm +from google.adk.models.llm_request import LlmRequest +from google.adk.models.llm_response import LlmResponse +from google.genai import types +from tau2.agent.adk_agent import AdkAgent +from tau2.data_model.message import ( + AssistantMessage, + MultiToolMessage, + ToolMessage, + UserMessage, +) + + +class MockLlm(BaseLlm): + model: str = "mock-llm" + response_type: str = "text" + + async def generate_content_async( + self, llm_request: LlmRequest, stream: bool = False + ) -> AsyncGenerator[LlmResponse, None]: + if self.response_type == "tool_call": + tool_call = types.FunctionCall( + name="create_task", args={"user_id": "123", "title": "Test Task"} + ) + llm_response = LlmResponse( + content=types.Content( + parts=[types.Part(function_call=tool_call)], + role="model", + ) + ) + elif self.response_type == "multi_tool_call": + tool_call_1 = types.FunctionCall( + name="create_task", args={"user_id": "123", "title": "Test Task 1"} + ) + tool_call_2 = types.FunctionCall(name="get_users", args={}) + llm_response = LlmResponse( + content=types.Content( + parts=[ + types.Part(function_call=tool_call_1), + types.Part(function_call=tool_call_2), + ], + role="model", + ) + ) + else: + response_text = "Mock response" + llm_response = LlmResponse( + content=types.Content( + parts=[types.Part(text=response_text)], + role="model", + ) + ) + yield llm_response + + +@pytest.fixture(params=["text", "tool_call", "multi_tool_call"]) +def adk_agent(get_environment, request) -> AdkAgent: + """Fixture for AdkAgent with a mocked LLM.""" + mock_llm = MockLlm(response_type=request.param) + return AdkAgent( + llm="gemini-2.5-pro", + tools=get_environment().get_tools(), + domain_policy=get_environment().get_policy(), + llm_args={"model_obj": mock_llm}, + ) + + +@pytest.fixture +def first_user_message(): + """Fixture for the first user message.""" + return UserMessage(content="Hello can you help me create a task?", role="user") + + +def test_adk_agent(adk_agent: AdkAgent, first_user_message: UserMessage): + """Test case for AdkAgent.""" + agent_state = adk_agent.get_init_state() + assert agent_state is not None + agent_msg, agent_state = adk_agent.generate_next_message( + first_user_message, agent_state + ) + # Check the response is an assistant message + assert isinstance(agent_msg, AssistantMessage) + + response_type = adk_agent._adk_root_agent.model.response_type + if response_type == "text": + assert agent_msg.content == "Mock response" + elif response_type == "tool_call": + assert agent_msg.tool_calls is not None + assert len(agent_msg.tool_calls) == 1 + assert agent_msg.tool_calls[0].name == "create_task" + elif response_type == "multi_tool_call": + assert agent_msg.tool_calls is not None + assert len(agent_msg.tool_calls) == 2 + + # Check the state is updated + assert agent_state is not None + assert len(agent_state.messages) == 2 + # Check the messages are of the correct type + assert isinstance(agent_state.messages[0], UserMessage) + assert isinstance(agent_state.messages[1], AssistantMessage) + assert agent_state.messages[0].content == first_user_message.content + assert agent_state.messages[1].content == agent_msg.content + + +def test_adk_agent_with_tool_call(get_environment, first_user_message: UserMessage): + """Test case for AdkAgent with a tool call and response.""" + # Setup agent to respond with a tool call first + mock_llm_tool_call = MockLlm(response_type="tool_call") + agent = AdkAgent( + llm="gemini-2.5-pro", + tools=get_environment().get_tools(), + domain_policy=get_environment().get_policy(), + llm_args={"model_obj": mock_llm_tool_call}, + ) + + # 1. First interaction: User message -> Agent tool call + agent_state = agent.get_init_state() + agent_msg, agent_state = agent.generate_next_message( + first_user_message, agent_state + ) + + assert isinstance(agent_msg, AssistantMessage) + assert agent_msg.tool_calls is not None + assert len(agent_msg.tool_calls) == 1 + tool_call = agent_msg.tool_calls[0] + assert tool_call.name == "create_task" + + # 2. Second interaction: Tool response -> Agent final message + tool_message = ToolMessage( + id=tool_call.id, + name=tool_call.name, + role="tool", + requestor="assistant", + content="Task created successfully", + error=False, + ) + + # Switch mock to respond with text + mock_llm_text = MockLlm(response_type="text") + agent._adk_root_agent.model = mock_llm_text # pytype: disable=attribute-error + + agent_msg_final, agent_state = agent.generate_next_message( + tool_message, agent_state + ) + + assert isinstance(agent_msg_final, AssistantMessage) + assert agent_msg_final.content == "Mock response" + assert agent_msg_final.tool_calls is None + assert ( + len(agent_state.messages) == 4 + ) # User, Assistant (tool call), ToolMessage, Assistant (text) + + +def test_adk_agent_with_multi_tool_call( + get_environment, first_user_message: UserMessage +): + """Test case for AdkAgent with multiple tool calls.""" + # Setup agent to respond with multiple tool calls + mock_llm_multi_tool_call = MockLlm(response_type="multi_tool_call") + agent = AdkAgent( + llm="gemini-2.5-pro", + tools=get_environment().get_tools(), + domain_policy=get_environment().get_policy(), + llm_args={"model_obj": mock_llm_multi_tool_call}, + ) + + # 1. First interaction: User message -> Agent multi tool call + agent_state = agent.get_init_state() + agent_msg, agent_state = agent.generate_next_message( + first_user_message, agent_state + ) + + assert isinstance(agent_msg, AssistantMessage) + assert agent_msg.tool_calls is not None + assert len(agent_msg.tool_calls) == 2 + tool_call_1 = agent_msg.tool_calls[0] + tool_call_2 = agent_msg.tool_calls[1] + assert tool_call_1.name == "create_task" + assert tool_call_2.name == "get_users" + + # 2. Second interaction: MultiToolMessage response -> Agent final message + tool_message_1 = ToolMessage( + id=tool_call_1.id, + name=tool_call_1.name, + role="tool", + content="Task created", + requestor="assistant", + error=False, + ) + tool_message_2 = ToolMessage( + id=tool_call_2.id, + name=tool_call_2.name, + role="tool", + content="['user1', 'user2']", + requestor="assistant", + error=False, + ) + multi_tool_message = MultiToolMessage( + role="tool", tool_messages=[tool_message_1, tool_message_2] + ) + + # Switch mock to respond with text + mock_llm_text = MockLlm(response_type="text") + agent._adk_root_agent.model = mock_llm_text # pytype: disable=attribute-error + + agent_msg_final, agent_state = agent.generate_next_message( + multi_tool_message, agent_state + ) + + assert isinstance(agent_msg_final, AssistantMessage) + assert agent_msg_final.content == "Mock response" + assert agent_msg_final.tool_calls is None + assert ( + len(agent_state.messages) == 5 + ) # System, User, Assistant (tool calls), MultiToolMessage, Assistant (text)