Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ dependencies = [
"absolufy-imports>=0.3.1, <1.0.0", # For Agent Engine deployment.
"anyio>=4.9.0, <5.0.0;python_version>='3.10'", # For MCP Session Manager
"authlib>=1.5.1, <2.0.0", # For RestAPI Tool
"cachetools>=5.3.3, <6.0.0", # For LRU cache
"click>=8.1.8, <9.0.0", # For CLI tools
"fastapi>=0.115.0, <1.0.0", # FastAPI framework
"google-api-python-client>=2.157.0, <3.0.0", # Google API client discovery
Expand Down
61 changes: 46 additions & 15 deletions src/google/adk/tools/agent_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from typing import Any
from typing import TYPE_CHECKING

from cachetools import LRUCache
from google.genai import types
from pydantic import model_validator
from typing_extensions import override
Expand Down Expand Up @@ -46,10 +47,23 @@ class AgentTool(BaseTool):
agent: The agent to wrap.
skip_summarization: Whether to skip summarization of the agent output.
"""
_runners: Optional[dict[str, tuple['Runner', 'Session']]]
_runners: Optional[dict[str, tuple['Runner', 'Session']]]

def __init__(self, agent: BaseAgent, skip_summarization: bool = False):
def __init__(
self,
agent: BaseAgent,
skip_summarization: bool = False,
persist_memory: bool = False,
cache_size: int = 128,
):
self.agent = agent
self.skip_summarization: bool = skip_summarization
self.persist_memory: bool = persist_memory
if self.persist_memory:
self._runners = LRUCache(maxsize=cache_size)
else:
self._runners = None

super().__init__(name=agent.name, description=agent.description)

Expand Down Expand Up @@ -125,19 +139,27 @@ async def run_async(
role='user',
parts=[types.Part.from_text(text=args['request'])],
)
runner = Runner(
app_name=self.agent.name,
agent=self.agent,
artifact_service=ForwardingArtifactService(tool_context),
session_service=InMemorySessionService(),
memory_service=InMemoryMemoryService(),
credential_service=tool_context._invocation_context.credential_service,
)
session = await runner.session_service.create_session(
app_name=self.agent.name,
user_id=tool_context._invocation_context.user_id,
state=tool_context.state.to_dict(),
)

session_id = tool_context._invocation_context.session.id
if self.persist_memory and session_id in self._runners:
runner, session = self._runners[session_id]
else:
runner = Runner(
app_name=self.agent.name,
agent=self.agent,
artifact_service=ForwardingArtifactService(tool_context),
session_service=InMemorySessionService(),
memory_service=InMemoryMemoryService(),
credential_service=tool_context._invocation_context.credential_service,
)
session = await runner.session_service.create_session(
app_name=self.agent.name,
user_id=tool_context._invocation_context.user_id,
state=tool_context.state.to_dict(),
session_id=session_id,
)
if self.persist_memory:
self._runners[session_id] = (runner, session)

last_content = None
async with Aclosing(
Expand Down Expand Up @@ -176,7 +198,10 @@ def from_config(
agent_tool_config.agent, config_abs_path
)
return cls(
agent=agent, skip_summarization=agent_tool_config.skip_summarization
agent=agent,
skip_summarization=agent_tool_config.skip_summarization,
persist_memory=agent_tool_config.persist_memory,
cache_size=agent_tool_config.cache_size,
)


Expand All @@ -188,3 +213,9 @@ class AgentToolConfig(BaseToolConfig):

skip_summarization: bool = False
"""Whether to skip summarization of the agent output."""

persist_memory: bool = False
"""Whether to persist the agent's memory across tool calls."""

cache_size: int = 128
"""The maximum number of runners to cache."""
111 changes: 111 additions & 0 deletions tests/unittests/tools/test_agent_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from unittest.mock import MagicMock
from google.adk.agents.callback_context import CallbackContext
from google.adk.agents.llm_agent import Agent
from google.adk.agents.sequential_agent import SequentialAgent
from google.adk.auth.credential_service.in_memory_credential_service import InMemoryCredentialService
from google.adk.tools.agent_tool import AgentTool
from google.adk.tools.tool_context import ToolContext
from google.adk.utils.variant_utils import GoogleLLMVariant
from google.genai import types
from google.genai.types import Part
from pydantic import BaseModel
from pytest import fixture
from pytest import mark

from .. import testing_utils
Expand All @@ -41,6 +45,31 @@
)


@fixture
def agent_tool_test_setup():
"""Sets up common objects for agent tool tests."""
tool_mock_model = testing_utils.MockModel.create(
responses=["Sub-agent response 1", "Sub-agent response 2", "Sub-agent response 3"]
)
tool_agent = Agent(
name="tool_agent",
model=tool_mock_model,
)

mock_context = MagicMock(spec=ToolContext)
mock_invocation_context = MagicMock()
mock_session = MagicMock()
mock_session.id = "test_session"
mock_invocation_context.session = mock_session
mock_invocation_context.user_id = "test_user"
mock_invocation_context.credential_service = InMemoryCredentialService()
mock_context._invocation_context = mock_invocation_context
mock_context.state = MagicMock()
mock_context.state.to_dict.return_value = {}

return tool_agent, tool_mock_model, mock_context


def change_state_callback(callback_context: CallbackContext):
callback_context.state['state_1'] = 'changed_value'
print('change_state_callback: ', callback_context.state)
Expand Down Expand Up @@ -121,6 +150,7 @@ async def before_tool_agent(callback_context: CallbackContext):

tool_agent = SequentialAgent(
name='tool_agent',
sub_agents= [],
before_agent_callback=before_tool_agent,
)

Expand Down Expand Up @@ -355,3 +385,84 @@ class CustomInput(BaseModel):
# Should have string response schema for VERTEX_AI when no output_schema
assert declaration.response is not None
assert declaration.response.type == types.Type.STRING

@mark.asyncio
async def test_persist_memory(agent_tool_test_setup):
"""Tests that the agent tool can persist memory across tool calls."""
tool_agent, tool_mock_model, mock_context = agent_tool_test_setup
agent_tool = AgentTool(agent=tool_agent, persist_memory=True)

# First call to the tool
await agent_tool.run_async(
args={"request": "test1"}, tool_context=mock_context
)

# Second call to the tool
await agent_tool.run_async(
args={"request": "test2"}, tool_context=mock_context
)

# Check the history of the sub-agent's model
# The second request should contain the history of the first call.
assert len(tool_mock_model.requests) == 2
second_request_contents = tool_mock_model.requests[1].contents
assert len(second_request_contents) == 3
assert second_request_contents[0].role == "user"
assert second_request_contents[0].parts[0].text == "test1"
assert second_request_contents[1].role == "model"
assert second_request_contents[1].parts[0].text == "Sub-agent response 1"
assert second_request_contents[2].role == "user"
assert second_request_contents[2].parts[0].text == "test2"


@mark.asyncio
async def test_no_persist_memory(agent_tool_test_setup):
"""Tests that the agent tool does not persist memory across tool calls."""
tool_agent, tool_mock_model, mock_context = agent_tool_test_setup
agent_tool = AgentTool(agent=tool_agent, persist_memory=False)

# First call to the tool
await agent_tool.run_async(
args={"request": "test1"}, tool_context=mock_context
)

# Second call to the tool
await agent_tool.run_async(
args={"request": "test2"}, tool_context=mock_context
)

# Check the history of the sub-agent's model
# The second request should not contain the history of the first call.
assert len(tool_mock_model.requests) == 2
second_request_contents = tool_mock_model.requests[1].contents
assert len(second_request_contents) == 1
assert second_request_contents[0].role == "user"
assert second_request_contents[0].parts[0].text == "test2"

@mark.asyncio
async def test_lru_cache(agent_tool_test_setup):
"""Tests that the LRU cache evicts runners."""
tool_agent, _, mock_context = agent_tool_test_setup
agent_tool = AgentTool(agent=tool_agent, persist_memory=True, cache_size=2)

# First call to the tool
mock_context._invocation_context.session.id = "session1"
await agent_tool.run_async(
args={"request": "test1"}, tool_context=mock_context
)
assert len(agent_tool._runners) == 1

# Second call to the tool
mock_context._invocation_context.session.id = "session2"
await agent_tool.run_async(
args={"request": "test2"}, tool_context=mock_context
)
assert len(agent_tool._runners) == 2

# Third call to the tool, should evict the first runner
mock_context._invocation_context.session.id = "session3"
await agent_tool.run_async(
args={"request": "test3"}, tool_context=mock_context
)
assert len(agent_tool._runners) == 2
assert "session1" not in agent_tool._runners