Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
HandoffMessage,
InnerMessage,
ResetMessage,
StopMessage,
TextMessage,
ToolCallMessage,
ToolCallResultMessage,
Expand Down Expand Up @@ -232,8 +231,8 @@ def __init__(
def produced_message_types(self) -> List[type[ChatMessage]]:
"""The types of messages that the assistant agent produces."""
if self._handoffs:
return [TextMessage, HandoffMessage, StopMessage]
return [TextMessage, StopMessage]
return [TextMessage, HandoffMessage]
return [TextMessage]

async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
async for message in self.on_messages_stream(messages, cancellation_token):
Expand Down Expand Up @@ -303,16 +302,9 @@ async def on_messages_stream(
self._model_context.append(AssistantMessage(content=result.content, source=self.name))

assert isinstance(result.content, str)
# Detect stop request.
request_stop = "terminate" in result.content.strip().lower()
if request_stop:
yield Response(
chat_message=StopMessage(content=result.content, source=self.name), inner_messages=inner_messages
)
else:
yield Response(
chat_message=TextMessage(content=result.content, source=self.name), inner_messages=inner_messages
)
yield Response(
chat_message=TextMessage(content=result.content, source=self.name), inner_messages=inner_messages
)

async def _execute_tool_call(
self, tool_call: FunctionCall, cancellation_token: CancellationToken
Expand Down
28 changes: 14 additions & 14 deletions python/packages/autogen-agentchat/tests/test_group_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
ToolCallMessage,
ToolCallResultMessage,
)
from autogen_agentchat.task import MaxMessageTermination, StopMessageTermination
from autogen_agentchat.task import MaxMessageTermination, TextMentionTermination
from autogen_agentchat.teams import (
RoundRobinGroupChat,
SelectorGroupChat,
Expand Down Expand Up @@ -151,7 +151,7 @@ async def test_round_robin_group_chat(monkeypatch: pytest.MonkeyPatch) -> None:
team = RoundRobinGroupChat(participants=[coding_assistant_agent, code_executor_agent])
result = await team.run(
"Write a program that prints 'Hello, world!'",
termination_condition=StopMessageTermination(),
termination_condition=TextMentionTermination("TERMINATE"),
)
expected_messages = [
"Write a program that prints 'Hello, world!'",
Expand All @@ -172,7 +172,7 @@ async def test_round_robin_group_chat(monkeypatch: pytest.MonkeyPatch) -> None:
mock.reset()
index = 0
async for message in team.run_stream(
"Write a program that prints 'Hello, world!'", termination_condition=StopMessageTermination()
"Write a program that prints 'Hello, world!'", termination_condition=TextMentionTermination("TERMINATE")
):
if isinstance(message, TaskResult):
assert message == result
Expand Down Expand Up @@ -247,7 +247,7 @@ async def test_round_robin_group_chat_with_tools(monkeypatch: pytest.MonkeyPatch
team = RoundRobinGroupChat(participants=[tool_use_agent, echo_agent])
result = await team.run(
"Write a program that prints 'Hello, world!'",
termination_condition=StopMessageTermination(),
termination_condition=TextMentionTermination("TERMINATE"),
)

assert len(result.messages) == 6
Expand All @@ -256,7 +256,7 @@ async def test_round_robin_group_chat_with_tools(monkeypatch: pytest.MonkeyPatch
assert isinstance(result.messages[2], ToolCallResultMessage) # tool call result
assert isinstance(result.messages[3], TextMessage) # tool use agent response
assert isinstance(result.messages[4], TextMessage) # echo agent response
assert isinstance(result.messages[5], StopMessage) # tool use agent response
assert isinstance(result.messages[5], TextMessage) # tool use agent response

context = tool_use_agent._model_context # pyright: ignore
assert context[0].content == "Write a program that prints 'Hello, world!'"
Expand All @@ -275,7 +275,7 @@ async def test_round_robin_group_chat_with_tools(monkeypatch: pytest.MonkeyPatch
mock.reset()
index = 0
async for message in team.run_stream(
"Write a program that prints 'Hello, world!'", termination_condition=StopMessageTermination()
"Write a program that prints 'Hello, world!'", termination_condition=TextMentionTermination("TERMINATE")
):
if isinstance(message, TaskResult):
assert message == result
Expand Down Expand Up @@ -351,7 +351,7 @@ async def test_selector_group_chat(monkeypatch: pytest.MonkeyPatch) -> None:
)
result = await team.run(
"Write a program that prints 'Hello, world!'",
termination_condition=StopMessageTermination(),
termination_condition=TextMentionTermination("TERMINATE"),
)
assert len(result.messages) == 6
assert result.messages[0].content == "Write a program that prints 'Hello, world!'"
Expand All @@ -366,7 +366,7 @@ async def test_selector_group_chat(monkeypatch: pytest.MonkeyPatch) -> None:
agent1._count = 0 # pyright: ignore
index = 0
async for message in team.run_stream(
"Write a program that prints 'Hello, world!'", termination_condition=StopMessageTermination()
"Write a program that prints 'Hello, world!'", termination_condition=TextMentionTermination("TERMINATE")
):
if isinstance(message, TaskResult):
assert message == result
Expand Down Expand Up @@ -401,7 +401,7 @@ async def test_selector_group_chat_two_speakers(monkeypatch: pytest.MonkeyPatch)
)
result = await team.run(
"Write a program that prints 'Hello, world!'",
termination_condition=StopMessageTermination(),
termination_condition=TextMentionTermination("TERMINATE"),
)
assert len(result.messages) == 5
assert result.messages[0].content == "Write a program that prints 'Hello, world!'"
Expand All @@ -417,7 +417,7 @@ async def test_selector_group_chat_two_speakers(monkeypatch: pytest.MonkeyPatch)
agent1._count = 0 # pyright: ignore
index = 0
async for message in team.run_stream(
"Write a program that prints 'Hello, world!'", termination_condition=StopMessageTermination()
"Write a program that prints 'Hello, world!'", termination_condition=TextMentionTermination("TERMINATE")
):
if isinstance(message, TaskResult):
assert message == result
Expand Down Expand Up @@ -472,7 +472,7 @@ async def test_selector_group_chat_two_speakers_allow_repeated(monkeypatch: pyte
allow_repeated_speaker=True,
)
result = await team.run(
"Write a program that prints 'Hello, world!'", termination_condition=StopMessageTermination()
"Write a program that prints 'Hello, world!'", termination_condition=TextMentionTermination("TERMINATE")
)
assert len(result.messages) == 4
assert result.messages[0].content == "Write a program that prints 'Hello, world!'"
Expand All @@ -484,7 +484,7 @@ async def test_selector_group_chat_two_speakers_allow_repeated(monkeypatch: pyte
mock.reset()
index = 0
async for message in team.run_stream(
"Write a program that prints 'Hello, world!'", termination_condition=StopMessageTermination()
"Write a program that prints 'Hello, world!'", termination_condition=TextMentionTermination("TERMINATE")
):
if isinstance(message, TaskResult):
assert message == result
Expand Down Expand Up @@ -649,7 +649,7 @@ async def test_swarm_handoff_using_tool_calls(monkeypatch: pytest.MonkeyPatch) -
)
agent2 = _HandOffAgent("agent2", description="agent 2", next_agent="agent1")
team = Swarm([agent1, agent2])
result = await team.run("task", termination_condition=StopMessageTermination())
result = await team.run("task", termination_condition=TextMentionTermination("TERMINATE"))
assert len(result.messages) == 7
assert result.messages[0].content == "task"
assert isinstance(result.messages[1], ToolCallMessage)
Expand All @@ -663,7 +663,7 @@ async def test_swarm_handoff_using_tool_calls(monkeypatch: pytest.MonkeyPatch) -
agent1._model_context.clear() # pyright: ignore
mock.reset()
index = 0
stream = team.run_stream("task", termination_condition=StopMessageTermination())
stream = team.run_stream("task", termination_condition=TextMentionTermination("TERMINATE"))
async for message in stream:
if isinstance(message, TaskResult):
assert message == result
Expand Down