Skip to content

Commit 6f3b0bb

Browse files
author
Jean-Marc Le Roux
committed
Add the add_tool(), remove_tool() and remove_all_tools() methods for AssistantAgent
1 parent 7eaffa8 commit 6f3b0bb

File tree

2 files changed

+119
-28
lines changed

2 files changed

+119
-28
lines changed

python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py

Lines changed: 44 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -231,24 +231,10 @@ def __init__(
231231
else:
232232
self._system_messages = [SystemMessage(content=system_message)]
233233
self._tools: List[Tool] = []
234-
if tools is not None:
235-
if model_client.capabilities["function_calling"] is False:
236-
raise ValueError("The model does not support function calling.")
237-
for tool in tools:
238-
if isinstance(tool, Tool):
239-
self._tools.append(tool)
240-
elif callable(tool):
241-
if hasattr(tool, "__doc__") and tool.__doc__ is not None:
242-
description = tool.__doc__
243-
else:
244-
description = ""
245-
self._tools.append(FunctionTool(tool, description=description))
246-
else:
247-
raise ValueError(f"Unsupported tool type: {type(tool)}")
248-
# Check if tool names are unique.
249-
tool_names = [tool.name for tool in self._tools]
250-
if len(tool_names) != len(set(tool_names)):
251-
raise ValueError(f"Tool names must be unique: {tool_names}")
234+
self._model_context: List[LLMMessage] = []
235+
self._reflect_on_tool_use = reflect_on_tool_use
236+
self._tool_call_summary_format = tool_call_summary_format
237+
self._is_running = False
252238
# Handoff tools.
253239
self._handoff_tools: List[Tool] = []
254240
self._handoffs: Dict[str, HandoffBase] = {}
@@ -258,24 +244,54 @@ def __init__(
258244
for handoff in handoffs:
259245
if isinstance(handoff, str):
260246
handoff = HandoffBase(target=handoff)
247+
if handoff.name in self._handoffs:
248+
raise ValueError(f"Handoff name {handoff.name} already exists.")
261249
if isinstance(handoff, HandoffBase):
262250
self._handoff_tools.append(handoff.handoff_tool)
263251
self._handoffs[handoff.name] = handoff
264252
else:
265253
raise ValueError(f"Unsupported handoff type: {type(handoff)}")
266-
# Check if handoff tool names are unique.
267-
handoff_tool_names = [tool.name for tool in self._handoff_tools]
268-
if len(handoff_tool_names) != len(set(handoff_tool_names)):
269-
raise ValueError(f"Handoff names must be unique: {handoff_tool_names}")
254+
if tools is not None:
255+
for tool in tools:
256+
self.add_tool(tool)
257+
258+
def add_tool(self, tool: Tool | Callable[..., Any] | Callable[..., Awaitable[Any]]) -> None:
259+
new_tool = None
260+
if self._model_client.capabilities["function_calling"] is False:
261+
raise ValueError("The model does not support function calling.")
262+
if isinstance(tool, Tool):
263+
new_tool = tool
264+
elif callable(tool):
265+
if hasattr(tool, "__doc__") and tool.__doc__ is not None:
266+
description = tool.__doc__
267+
else:
268+
description = ""
269+
new_tool = FunctionTool(tool, description=description)
270+
else:
271+
raise ValueError(f"Unsupported tool type: {type(tool)}")
272+
# Check if tool names are unique.
273+
if any(tool.name == new_tool.name for tool in self._tools):
274+
raise ValueError(f"Tool names must be unique: {new_tool.name}")
270275
# Check if handoff tool names not in tool names.
271-
if any(name in tool_names for name in handoff_tool_names):
276+
handoff_tool_names = [handoff.name for handoff in self._handoffs.values()]
277+
if new_tool.name in handoff_tool_names:
272278
raise ValueError(
273-
f"Handoff names must be unique from tool names. Handoff names: {handoff_tool_names}; tool names: {tool_names}"
279+
f"Handoff names must be unique from tool names. Handoff names: {handoff_tool_names}; "
280+
f"tool names: {new_tool.name}"
274281
)
275-
self._model_context: List[LLMMessage] = []
276-
self._reflect_on_tool_use = reflect_on_tool_use
277-
self._tool_call_summary_format = tool_call_summary_format
278-
self._is_running = False
282+
self._tools.append(new_tool)
283+
284+
def remove_all_tools(self) -> None:
285+
"""Remove all tools."""
286+
self._tools.clear()
287+
288+
def remove_tool(self, tool_name: str) -> None:
289+
"""Remove tools by name."""
290+
for tool in self._tools:
291+
if tool.name == tool_name:
292+
self._tools.remove(tool)
293+
return
294+
raise ValueError(f"Tool {tool_name} not found.")
279295

280296
@property
281297
def produced_message_types(self) -> List[type[ChatMessage]]:

python/packages/autogen-agentchat/tests/test_assistant_agent.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -467,3 +467,78 @@ async def test_list_chat_messages(monkeypatch: pytest.MonkeyPatch) -> None:
467467
else:
468468
assert message == result.messages[index]
469469
index += 1
470+
471+
472+
def test_tool_management():
473+
model_client = OpenAIChatCompletionClient(model="gpt-4", api_key="")
474+
agent = AssistantAgent(name="test_assistant", model_client=model_client)
475+
476+
# Test function to be used as a tool
477+
def sample_tool() -> str:
478+
return "sample result"
479+
480+
# Test adding a tool
481+
tool = FunctionTool(sample_tool, description="Sample tool")
482+
agent.add_tool(tool)
483+
assert len(agent._tools) == 1
484+
485+
# Test adding duplicate tool
486+
with pytest.raises(ValueError, match="Tool names must be unique"):
487+
agent.add_tool(tool)
488+
489+
# Test tool collision with handoff
490+
agent_with_handoff = AssistantAgent(
491+
name="test_assistant", model_client=model_client, handoffs=[Handoff(target="other_agent")]
492+
)
493+
494+
conflicting_tool = FunctionTool(sample_tool, name="transfer_to_other_agent", description="Sample tool")
495+
with pytest.raises(ValueError, match="Handoff names must be unique from tool names"):
496+
agent_with_handoff.add_tool(conflicting_tool)
497+
498+
# Test removing a tool
499+
agent.remove_tool(tool.name)
500+
assert len(agent._tools) == 0
501+
502+
# Test removing non-existent tool
503+
with pytest.raises(ValueError, match="Tool non_existent_tool not found"):
504+
agent.remove_tool("non_existent_tool")
505+
506+
# Test removing all tools
507+
agent.add_tool(tool)
508+
assert len(agent._tools) == 1
509+
agent.remove_all_tools()
510+
assert len(agent._tools) == 0
511+
512+
# Test idempotency of remove_all_tools
513+
agent.remove_all_tools()
514+
assert len(agent._tools) == 0
515+
516+
517+
def test_callable_tool_addition():
518+
model_client = OpenAIChatCompletionClient(model="gpt-4", api_key="")
519+
agent = AssistantAgent(name="test_assistant", model_client=model_client)
520+
521+
# Test adding a callable directly
522+
def documented_tool() -> str:
523+
"""This is a documented tool"""
524+
return "result"
525+
526+
agent.add_tool(documented_tool)
527+
assert len(agent._tools) == 1
528+
assert agent._tools[0].description == "This is a documented tool"
529+
530+
# Test adding async callable
531+
async def async_tool() -> str:
532+
return "async result"
533+
534+
agent.add_tool(async_tool)
535+
assert len(agent._tools) == 2
536+
537+
538+
def test_invalid_tool_addition():
539+
model_client = OpenAIChatCompletionClient(model="gpt-4", api_key="")
540+
agent = AssistantAgent(name="test_assistant", model_client=model_client)
541+
542+
# Test adding invalid tool type
543+
with pytest.raises(ValueError, match="Unsupported tool type"):
544+
agent.add_tool("not a tool")

0 commit comments

Comments
 (0)