Skip to content

Commit 7ec090c

Browse files
committed
add/document gemini 3.0 support and ReflectAndRetryToolPlugin
1 parent b0d1987 commit 7ec090c

File tree

3 files changed

+135
-46
lines changed

3 files changed

+135
-46
lines changed

python/agents/tau2-benchmark-agent/README.md

Lines changed: 53 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,18 @@ tenacity dependecy version may conflict with that of tau2 repo. Upgrading it bac
4848
pip install --upgrade tenacity
4949
```
5050

51+
**IMPORTANT:** Gemini 3 Pro model makes sending thought signatures mandatory. Tau2 bench relies on litellm for user simulation and non-adk agent simulation. Until https://github.com/BerriAI/litellm/pull/16812 is merged to litellm repository, the PR needs to be applied as shown below:
52+
53+
```bash
54+
git clone --filter=blob:none --quiet https://github.com/BerriAI/litellm.git /tmp/litellm-pr-16812
55+
cd /tmp/litellm-pr-16812
56+
git checkout -q pull/16812/head
57+
git fetch origin pull/16812/head:pr-16812
58+
git checkout pr-16812
59+
pip install .
60+
cd -
61+
```
62+
5163
## 3. Add env params
5264

5365
Create `.env` file at root with the following content.
@@ -130,11 +142,17 @@ def _create_agent(name: str, model: Union[str, BaseLlm], instruction: str, tools
130142
Here is an example command to run the agent on an airline domain task:
131143

132144
```bash
133-
tau2 run --domain airline --agent adk_agent --agent-llm vertex_ai/gemini-2.5-pro --user-llm vertex_ai/gemini-2.5-pro --num-trials 1 --num-tasks 1
145+
tau2 run --domain airline --agent adk_agent --agent-llm vertex_ai/gemini-3-pro-preview --user-llm vertex_ai/gemini-3-pro-preview --num-trials 1 --num-tasks 1 --user-llm-args '{"temperature": 1, "reasoning_effort": "high"}' --agent-llm-args '{"temperature": 1, "reasoning_effort": "high"}'
134146
```
135147

136148
Optionally, you can run specific example by using `--task-ids` instead of `--num-tasks`.
137149

150+
**temperature:** When adk_agent is used defaults to 1. The commands in this document sets them explicitly using llm_args for both user and agent models.
151+
152+
**reasoning_level** Only applies to Gemini 3 Pro model. It defaults to high for adk_agent while using this model. Otherwise, it will default to dynamic thinking. Again this document demonsrates setting it explicitly using llm_args.
153+
154+
**NOTE**: It is normal that you will be getting `This model isn't mapped yet` error logs. This is coming from litellm cost calculation workflow used by `--user-llm`. You can suppress is temporarily by swapping `--user-llm vertex_ai/gemini-3-pro-preview` with `--user-llm vertex_ai/gemini-2.5-pro`.
155+
138156
### Viewing trajectories
139157

140158
You can use the following command to view trajectories after following the default options:
@@ -149,18 +167,47 @@ Full run requires dropping the arg `--task-ids`.
149167

150168
```bash
151169
# Example: Run complete evaluation for all domains
152-
tau2 run --domain retail --agent adk_agent --agent-llm vertex_ai/gemini-2.5-pro --user-llm vertex_ai/gemini-2.5-pro --num-trials 4 --save-to my_model_retail
153-
tau2 run --domain airline --agent adk_agent --agent-llm vertex_ai/gemini-2.5-pro --user-llm vertex_ai/gemini-2.5-pro --num-trials 4 --save-to my_model_airline
154-
tau2 run --domain telecom --agent adk_agent --agent-llm vertex_ai/gemini-2.5-pro --user-llm vertex_ai/gemini-2.5-pro --num-trials 4 --save-to my_model_telecom
170+
tau2 run \
171+
--domain retail \
172+
--agent adk_agent \
173+
--agent-llm vertex_ai/gemini-3-pro-preview \
174+
--user-llm vertex_ai/gemini-3-pro-preview \
175+
--num-trials 4 \
176+
--save-to gemini_3_pro_retail \
177+
--user-llm-args '{"temperature": 1, "reasoning_effort": "high"}' \
178+
--agent-llm-args '{"temperature": 1, "reasoning_effort": "high"}'
179+
180+
181+
tau2 run \
182+
--domain airline \
183+
--agent adk_agent \
184+
--agent-llm vertex_ai/gemini-3-pro-preview \
185+
--user-llm vertex_ai/gemini-3-pro-preview \
186+
--num-trials 4 \
187+
--save-to gemini_3_pro_airline \
188+
--user-llm-args '{"temperature": 1, "reasoning_effort": "high"}' \
189+
--agent-llm-args '{"temperature": 1, "reasoning_effort": "high"}'
190+
191+
192+
tau2 run \
193+
--domain telecom \
194+
--agent adk_agent \
195+
--agent-llm vertex_ai/gemini-3-pro-preview \
196+
--user-llm vertex_ai/gemini-3-pro-preview \
197+
--num-trials 4 \
198+
--save-to gemini_3_pro_telecom \
199+
--user-llm-args '{"temperature": 1, "reasoning_effort": "high"}' \
200+
--agent-llm-args '{"temperature": 1, "reasoning_effort": "high"}'
155201
```
156202

157203
### Prepare Submission Package
158204

159205
```bash
160-
tau2 submit prepare data/tau2/simulations/my_model_*.json --output ./my_submission
206+
tau2 submit prepare data/tau2/simulations/gemini_3_pro_*.json --output ./gemini_3_pro_submission
161207
```
162208

163209
This command will:
210+
164211
- Verify all trajectory files are valid
165212
- Check that submission requirements are met
166213
- Compute performance metrics (Pass^k rates)
@@ -185,4 +232,4 @@ pip install pytest-cov
185232

186233
```bash
187234
pytest --cov=tau2.agent.adk_agent --cov-report=html tests/test_adk_agent.py
188-
````
235+
````

python/agents/tau2-benchmark-agent/tau2_agent/adk_agent.py

Lines changed: 74 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,13 @@
1414

1515

1616
import asyncio
17-
from typing import Any, List, Optional, Union
17+
from typing import Any, Dict, List, Optional, Union
1818

1919
from google.adk import Agent as AdkLlmAgent
2020
from google.adk.agents import BaseAgent
2121
from google.adk.models.base_llm import BaseLlm
2222
from google.adk.planners import built_in_planner
23+
from google.adk.plugins import ReflectAndRetryToolPlugin
2324
from google.adk.runners import InMemoryRunner
2425
from google.adk.tools import base_tool
2526
from google.genai import types
@@ -42,7 +43,7 @@ def __init__(self, function_declaration: types.FunctionDeclaration):
4243
"""Initialize the AdkTool with a function declaration.
4344
4445
Args:
45-
function_declaration: The function declaration for the tool.
46+
function_declaration: The function declaration for the tool.
4647
"""
4748
super().__init__(
4849
name=function_declaration.name,
@@ -64,20 +65,24 @@ async def run_async(self, *, args, tool_context) -> Any:
6465

6566

6667
def _create_agent(
67-
name: str, model: Union[str, BaseLlm], instruction: str, tools: List[Tool]
68+
name: str,
69+
model: Union[str, BaseLlm],
70+
instruction: str,
71+
tools: List[Tool],
72+
llm_args: Dict[str, Any],
6873
) -> BaseAgent:
6974
"""Create an ADK LLM Agent with the given parameters.
7075
7176
Args:
72-
name: The name of the agent.
73-
model: The LLM model to use.
74-
instruction: The system prompt/instruction for the agent.
75-
tools: The list of tools available to the agent.
77+
name: The name of the agent.
78+
model: The LLM model to use.
79+
instruction: The system prompt/instruction for the agent.
80+
tools: The list of tools available to the agent.
81+
llm_args: Additional arguments for the LLM.
7682
7783
Returns:
78-
An instance of BaseAgent (which also allows workflow agents).
84+
An instance of BaseAgent (which also allows workflow agents).
7985
"""
80-
8186
adk_tools = [
8287
AdkTool(
8388
types.FunctionDeclaration(
@@ -88,14 +93,33 @@ def _create_agent(
8893
)
8994
for tool in tools
9095
]
96+
97+
generate_content_config = types.GenerateContentConfig()
98+
generate_content_config.temperature = llm_args.get(
99+
"temperature", 1
100+
) # default to recommended temperature for gemini models
101+
102+
thinking_level = None
103+
if (
104+
isinstance(model, str)
105+
and model.startswith("gemini-3")
106+
and "reasoning_effort" in llm_args
107+
):
108+
thinking_level = llm_args["reasoning_effort"]
109+
110+
thinking_config = types.ThinkingConfig(
111+
include_thoughts=True, thinking_level=thinking_level, thinking_budget=None
112+
)
113+
91114
return AdkLlmAgent(
92115
model=model,
93116
name=name,
94117
instruction=instruction,
95118
tools=adk_tools,
96119
planner=built_in_planner.BuiltInPlanner(
97-
thinking_config=types.ThinkingConfig(include_thoughts=True),
120+
thinking_config=thinking_config,
98121
),
122+
generate_content_config=generate_content_config,
99123
)
100124

101125

@@ -112,12 +136,11 @@ def __init__(
112136
"""Initialize the AdkAgent with the given parameters.
113137
114138
Args:
115-
tools: The list of tools available to the agent.
116-
domain_policy: The domain policy for the agent.
117-
llm: The LLM model to use.
118-
llm_args: Additional arguments for the LLM.
139+
tools: The list of tools available to the agent.
140+
domain_policy: The domain policy for the agent.
141+
llm: The LLM model to use.
142+
llm_args: Additional arguments for the LLM.
119143
"""
120-
121144
super().__init__(
122145
tools=tools, domain_policy=domain_policy, llm=llm, llm_args=llm_args
123146
)
@@ -127,15 +150,24 @@ def __init__(
127150
), "AdkAgent only supports gemini models for this benchmark."
128151
if model_name.startswith("vertex_ai/"):
129152
model_name = model_name.replace("vertex_ai/", "")
153+
if model_name.startswith("gemini/"):
154+
model_name = model_name.replace("gemini/", "")
130155
self._adk_root_agent = _create_agent(
131156
name="customer_service_agent",
132157
model=self.llm_args.get("model_obj", model_name),
133158
instruction=self.system_prompt,
134159
tools=tools,
160+
llm_args=llm_args,
161+
)
162+
163+
error_handling_plugin = ReflectAndRetryToolPlugin(
164+
max_retries=3, throw_exception_if_retry_exceeded=False
135165
)
136-
self.long_running_call_infos = []
166+
137167
self._runner = InMemoryRunner(
138-
agent=self._adk_root_agent, app_name="tau2_adk_app"
168+
agent=self._adk_root_agent,
169+
app_name="tau2_adk_app",
170+
plugins=[error_handling_plugin],
139171
)
140172
self._app_name = "tau2_adk_app"
141173
self._user_id = "tau2_user"
@@ -165,10 +197,11 @@ async def _run_prompt_async(
165197
"""Run the prompt asynchronously and return the assistant message.
166198
167199
Args:
168-
new_message: The new message from the user.
169-
function_responses: The list of function responses from tools.
200+
new_message: The new message from the user.
201+
function_responses: The list of function responses from tools.
202+
170203
Returns:
171-
An AssistantMessage containing the response from the agent.
204+
An AssistantMessage containing the response from the agent.
172205
"""
173206
if new_message is not None:
174207
content = types.Content(
@@ -186,6 +219,9 @@ async def _run_prompt_async(
186219
async for event in self._runner.run_async(
187220
user_id=self._user_id, session_id=self.session.id, new_message=content
188221
):
222+
if event is None or event.content is None:
223+
continue
224+
189225
logger.info(f"** Event received: {event.content.parts}")
190226
for part in event.content.parts:
191227
if part.function_call:
@@ -206,7 +242,6 @@ async def _run_prompt_async(
206242
)
207243
elif part.text:
208244
if not part.thought:
209-
text_content += "\n" if text_content else ""
210245
text_content += part.text
211246
else:
212247
logger.info(f"** Other part type received: {part}")
@@ -223,13 +258,12 @@ def generate_next_message(
223258
"""Generate the next message from the agent based on the input message.
224259
225260
Args:
226-
message: The input message from the user or tool.
227-
state: The current state of the agent.
261+
message: The input message from the user or tool.
262+
state: The current state of the agent.
228263
229264
Returns:
230-
A tuple containing the assistant message and the updated agent state.
265+
A tuple containing the assistant message and the updated agent state.
231266
"""
232-
233267
if isinstance(message, MultiToolMessage):
234268
state.messages.extend(message.tool_messages)
235269
else:
@@ -292,17 +326,20 @@ def add_long_running_call_info(self, call_info: tuple[str, str]):
292326
"""Add information about a long-running call.
293327
294328
Args:
295-
call_info: A tuple containing the call ID and call name.
329+
call_info: A tuple containing the call ID and call name.
296330
"""
331+
if not hasattr(self, "long_running_call_infos"):
332+
self.long_running_call_infos = []
297333
self.long_running_call_infos.append(call_info)
298334

299335
def pop_long_running_call_info(self):
300336
"""Pop the oldest long-running call information.
301337
302338
Returns:
303-
A tuple containing the call ID and call name, or None if no information is available.
339+
A tuple containing the call ID and call name, or None if no information
340+
is available.
304341
"""
305-
if self.long_running_call_infos:
342+
if hasattr(self, "long_running_call_infos") and self.long_running_call_infos:
306343
return self.long_running_call_infos.pop(0)
307344
return None
308345

@@ -312,12 +349,16 @@ def pop_long_running_call_info_with_id(
312349
"""Pop long-running call information by call ID.
313350
314351
Args:
315-
call_id: The ID of the long-running call to pop.
352+
call_id: The ID of the long-running call to pop.
316353
317354
Returns:
318-
A tuple containing the call ID and call name, or None if no information is available.
355+
A tuple containing the call ID and call name, or None if no information
356+
is available.
319357
"""
320-
for i, (stored_call_id, call_name) in enumerate(self.long_running_call_infos):
321-
if stored_call_id == call_id:
322-
return self.long_running_call_infos.pop(i)
358+
if hasattr(self, "long_running_call_infos") and self.long_running_call_infos:
359+
for i, (stored_call_id, call_name) in enumerate(
360+
self.long_running_call_infos
361+
):
362+
if stored_call_id == call_id:
363+
return self.long_running_call_infos.pop(i)
323364
return None

python/agents/tau2-benchmark-agent/tests/conftest.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,20 @@
11
import sys
2+
23
import pytest
34
import tau2.agent
45

56
try:
67
from tau2_agent import adk_agent
78
except ImportError:
89
# Fallback: try to import from relative path if installed as editable but path issues
9-
import os
1010
import importlib.util
11-
11+
import os
12+
1213
# Assuming this conftest is in tests/ and tau2_agent is in ../tau2_agent/
1314
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
1415
if project_root not in sys.path:
1516
sys.path.insert(0, project_root)
16-
17+
1718
from tau2_agent import adk_agent
1819

1920
# Inject the local adk_agent module into the tau2.agent namespace
@@ -22,9 +23,11 @@
2223
tau2.agent.adk_agent = adk_agent
2324
sys.modules["tau2.agent.adk_agent"] = adk_agent
2425

26+
2527
@pytest.fixture
2628
def get_environment():
2729
"""Fixture to provide a mock environment with tools and policy."""
30+
2831
class MockTool:
2932
def __init__(self, name="mock_tool"):
3033
self.openai_schema = {
@@ -33,10 +36,8 @@ def __init__(self, name="mock_tool"):
3336
"description": f"Description for {name}",
3437
"parameters": {
3538
"type": "object",
36-
"properties": {
37-
"arg1": {"type": "string"}
38-
}
39-
}
39+
"properties": {"arg1": {"type": "string"}},
40+
},
4041
}
4142
}
4243

0 commit comments

Comments
 (0)