Skip to content

Commit 225eb9d

Browse files
authored
feat: introduce ModelClientStreamingChunkEvent for streaming model output and update handling in agents and console (microsoft#5208)
Resolves microsoft#3983 * introduce `model_client_stream` parameter in `AssistantAgent` to enable token-level streaming output. * introduce `ModelClientStreamingChunkEvent` as a type of `AgentEvent` to pass the streaming chunks to the application via `run_stream` and `on_messages_stream`. Although this will not affect the inner messages list in the final `Response` or `TaskResult`. * handle this new message type in `Console`.
1 parent 8a0daf8 commit 225eb9d

File tree

13 files changed

+330
-32
lines changed

13 files changed

+330
-32
lines changed

python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py

+61-9
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,14 @@
2222
from autogen_core.models import (
2323
AssistantMessage,
2424
ChatCompletionClient,
25+
CreateResult,
2526
FunctionExecutionResult,
2627
FunctionExecutionResultMessage,
2728
LLMMessage,
2829
SystemMessage,
2930
UserMessage,
3031
)
31-
from autogen_core.tools import FunctionTool, BaseTool
32+
from autogen_core.tools import BaseTool, FunctionTool
3233
from pydantic import BaseModel
3334
from typing_extensions import Self
3435

@@ -40,6 +41,7 @@
4041
ChatMessage,
4142
HandoffMessage,
4243
MemoryQueryEvent,
44+
ModelClientStreamingChunkEvent,
4345
MultiModalMessage,
4446
TextMessage,
4547
ToolCallExecutionEvent,
@@ -62,6 +64,7 @@ class AssistantAgentConfig(BaseModel):
6264
model_context: ComponentModel | None = None
6365
description: str
6466
system_message: str | None = None
67+
model_client_stream: bool
6568
reflect_on_tool_use: bool
6669
tool_call_summary_format: str
6770

@@ -126,6 +129,14 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
126129
This will limit the number of recent messages sent to the model and can be useful
127130
when the model has a limit on the number of tokens it can process.
128131
132+
Streaming mode:
133+
134+
The assistant agent can be used in streaming mode by setting `model_client_stream=True`.
135+
In this mode, the :meth:`on_messages_stream` and :meth:`BaseChatAgent.run_stream` methods will also yield
136+
:class:`~autogen_agentchat.messages.ModelClientStreamingChunkEvent`
137+
messages as the model client produces chunks of response.
138+
The chunk messages will not be included in the final response's inner messages.
139+
129140
130141
Args:
131142
name (str): The name of the agent.
@@ -138,6 +149,9 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
138149
model_context (ChatCompletionContext | None, optional): The model context for storing and retrieving :class:`~autogen_core.models.LLMMessage`. It can be preloaded with initial messages. The initial messages will be cleared when the agent is reset.
139150
description (str, optional): The description of the agent.
140151
system_message (str, optional): The system message for the model. If provided, it will be prepended to the messages in the model context when making an inference. Set to `None` to disable.
152+
model_client_stream (bool, optional): If `True`, the model client will be used in streaming mode.
153+
:meth:`on_messages_stream` and :meth:`BaseChatAgent.run_stream` methods will also yield :class:`~autogen_agentchat.messages.ModelClientStreamingChunkEvent`
154+
messages as the model client produces chunks of response. Defaults to `False`.
141155
reflect_on_tool_use (bool, optional): If `True`, the agent will make another model inference using the tool call and result
142156
to generate a response. If `False`, the tool call result will be returned as the response. Defaults to `False`.
143157
tool_call_summary_format (str, optional): The format string used to create a tool call summary for every tool call result.
@@ -268,12 +282,14 @@ def __init__(
268282
system_message: (
269283
str | None
270284
) = "You are a helpful AI assistant. Solve tasks using your tools. Reply with TERMINATE when the task has been completed.",
285+
model_client_stream: bool = False,
271286
reflect_on_tool_use: bool = False,
272287
tool_call_summary_format: str = "{result}",
273288
memory: Sequence[Memory] | None = None,
274289
):
275290
super().__init__(name=name, description=description)
276291
self._model_client = model_client
292+
self._model_client_stream = model_client_stream
277293
self._memory = None
278294
if memory is not None:
279295
if isinstance(memory, list):
@@ -340,7 +356,7 @@ def __init__(
340356

341357
@property
342358
def produced_message_types(self) -> Sequence[type[ChatMessage]]:
343-
"""The types of messages that the assistant agent produces."""
359+
"""The types of final response messages that the assistant agent produces."""
344360
message_types: List[type[ChatMessage]] = [TextMessage]
345361
if self._handoffs:
346362
message_types.append(HandoffMessage)
@@ -383,9 +399,23 @@ async def on_messages_stream(
383399

384400
# Generate an inference result based on the current model context.
385401
llm_messages = self._system_messages + await self._model_context.get_messages()
386-
model_result = await self._model_client.create(
387-
llm_messages, tools=self._tools + self._handoff_tools, cancellation_token=cancellation_token
388-
)
402+
model_result: CreateResult | None = None
403+
if self._model_client_stream:
404+
# Stream the model client.
405+
async for chunk in self._model_client.create_stream(
406+
llm_messages, tools=self._tools + self._handoff_tools, cancellation_token=cancellation_token
407+
):
408+
if isinstance(chunk, CreateResult):
409+
model_result = chunk
410+
elif isinstance(chunk, str):
411+
yield ModelClientStreamingChunkEvent(content=chunk, source=self.name)
412+
else:
413+
raise RuntimeError(f"Invalid chunk type: {type(chunk)}")
414+
assert isinstance(model_result, CreateResult)
415+
else:
416+
model_result = await self._model_client.create(
417+
llm_messages, tools=self._tools + self._handoff_tools, cancellation_token=cancellation_token
418+
)
389419

390420
# Add the response to the model context.
391421
await self._model_context.add_message(AssistantMessage(content=model_result.content, source=self.name))
@@ -465,14 +495,34 @@ async def on_messages_stream(
465495
if self._reflect_on_tool_use:
466496
# Generate another inference result based on the tool call and result.
467497
llm_messages = self._system_messages + await self._model_context.get_messages()
468-
model_result = await self._model_client.create(llm_messages, cancellation_token=cancellation_token)
469-
assert isinstance(model_result.content, str)
498+
reflection_model_result: CreateResult | None = None
499+
if self._model_client_stream:
500+
# Stream the model client.
501+
async for chunk in self._model_client.create_stream(
502+
llm_messages, cancellation_token=cancellation_token
503+
):
504+
if isinstance(chunk, CreateResult):
505+
reflection_model_result = chunk
506+
elif isinstance(chunk, str):
507+
yield ModelClientStreamingChunkEvent(content=chunk, source=self.name)
508+
else:
509+
raise RuntimeError(f"Invalid chunk type: {type(chunk)}")
510+
assert isinstance(reflection_model_result, CreateResult)
511+
else:
512+
reflection_model_result = await self._model_client.create(
513+
llm_messages, cancellation_token=cancellation_token
514+
)
515+
assert isinstance(reflection_model_result.content, str)
470516
# Add the response to the model context.
471-
await self._model_context.add_message(AssistantMessage(content=model_result.content, source=self.name))
517+
await self._model_context.add_message(
518+
AssistantMessage(content=reflection_model_result.content, source=self.name)
519+
)
472520
# Yield the response.
473521
yield Response(
474522
chat_message=TextMessage(
475-
content=model_result.content, source=self.name, models_usage=model_result.usage
523+
content=reflection_model_result.content,
524+
source=self.name,
525+
models_usage=reflection_model_result.usage,
476526
),
477527
inner_messages=inner_messages,
478528
)
@@ -538,6 +588,7 @@ def _to_config(self) -> AssistantAgentConfig:
538588
system_message=self._system_messages[0].content
539589
if self._system_messages and isinstance(self._system_messages[0].content, str)
540590
else None,
591+
model_client_stream=self._model_client_stream,
541592
reflect_on_tool_use=self._reflect_on_tool_use,
542593
tool_call_summary_format=self._tool_call_summary_format,
543594
)
@@ -553,6 +604,7 @@ def _from_config(cls, config: AssistantAgentConfig) -> Self:
553604
model_context=None,
554605
description=config.description,
555606
system_message=config.system_message,
607+
model_client_stream=config.model_client_stream,
556608
reflect_on_tool_use=config.reflect_on_tool_use,
557609
tool_call_summary_format=config.tool_call_summary_format,
558610
)

python/packages/autogen-agentchat/src/autogen_agentchat/agents/_base_chat_agent.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
AgentEvent,
1010
BaseChatMessage,
1111
ChatMessage,
12+
ModelClientStreamingChunkEvent,
1213
TextMessage,
1314
)
1415
from ..state import BaseState
@@ -178,8 +179,11 @@ async def run_stream(
178179
output_messages.append(message.chat_message)
179180
yield TaskResult(messages=output_messages)
180181
else:
181-
output_messages.append(message)
182182
yield message
183+
if isinstance(message, ModelClientStreamingChunkEvent):
184+
# Skip the model client streaming chunk events.
185+
continue
186+
output_messages.append(message)
183187

184188
@abstractmethod
185189
async def on_reset(self, cancellation_token: CancellationToken) -> None:

python/packages/autogen-agentchat/src/autogen_agentchat/agents/_society_of_mind_agent.py

+4
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
AgentEvent,
1414
BaseChatMessage,
1515
ChatMessage,
16+
ModelClientStreamingChunkEvent,
1617
TextMessage,
1718
)
1819
from ._base_chat_agent import BaseChatAgent
@@ -150,6 +151,9 @@ async def on_messages_stream(
150151
# Skip the task messages.
151152
continue
152153
yield inner_msg
154+
if isinstance(inner_msg, ModelClientStreamingChunkEvent):
155+
# Skip the model client streaming chunk events.
156+
continue
153157
inner_messages.append(inner_msg)
154158
assert result is not None
155159

python/packages/autogen-agentchat/src/autogen_agentchat/base/_handoff.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import logging
22
from typing import Any, Dict
33

4-
from autogen_core.tools import FunctionTool, BaseTool
4+
from autogen_core.tools import BaseTool, FunctionTool
55
from pydantic import BaseModel, Field, model_validator
66

77
from .. import EVENT_LOGGER_NAME

python/packages/autogen-agentchat/src/autogen_agentchat/messages.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -128,14 +128,27 @@ class MemoryQueryEvent(BaseAgentEvent):
128128
type: Literal["MemoryQueryEvent"] = "MemoryQueryEvent"
129129

130130

131+
class ModelClientStreamingChunkEvent(BaseAgentEvent):
132+
"""An event signaling a text output chunk from a model client in streaming mode."""
133+
134+
content: str
135+
"""The partial text chunk."""
136+
137+
type: Literal["ModelClientStreamingChunkEvent"] = "ModelClientStreamingChunkEvent"
138+
139+
131140
ChatMessage = Annotated[
132141
TextMessage | MultiModalMessage | StopMessage | ToolCallSummaryMessage | HandoffMessage, Field(discriminator="type")
133142
]
134143
"""Messages for agent-to-agent communication only."""
135144

136145

137146
AgentEvent = Annotated[
138-
ToolCallRequestEvent | ToolCallExecutionEvent | MemoryQueryEvent | UserInputRequestedEvent,
147+
ToolCallRequestEvent
148+
| ToolCallExecutionEvent
149+
| MemoryQueryEvent
150+
| UserInputRequestedEvent
151+
| ModelClientStreamingChunkEvent,
139152
Field(discriminator="type"),
140153
]
141154
"""Events emitted by agents and teams when they work, not used for agent-to-agent communication."""
@@ -154,4 +167,5 @@ class MemoryQueryEvent(BaseAgentEvent):
154167
"ToolCallSummaryMessage",
155168
"MemoryQueryEvent",
156169
"UserInputRequestedEvent",
170+
"ModelClientStreamingChunkEvent",
157171
]

python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py

+17-2
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
from ... import EVENT_LOGGER_NAME
2323
from ...base import ChatAgent, TaskResult, Team, TerminationCondition
24-
from ...messages import AgentEvent, BaseChatMessage, ChatMessage, TextMessage
24+
from ...messages import AgentEvent, BaseChatMessage, ChatMessage, ModelClientStreamingChunkEvent, TextMessage
2525
from ...state import TeamState
2626
from ._chat_agent_container import ChatAgentContainer
2727
from ._events import GroupChatMessage, GroupChatReset, GroupChatStart, GroupChatTermination
@@ -190,6 +190,9 @@ async def run(
190190
and it may not reset the termination condition.
191191
To gracefully stop the team, use :class:`~autogen_agentchat.conditions.ExternalTermination` instead.
192192
193+
Returns:
194+
result: The result of the task as :class:`~autogen_agentchat.base.TaskResult`. The result contains the messages produced by the team and the stop reason.
195+
193196
Example using the :class:`~autogen_agentchat.teams.RoundRobinGroupChat` team:
194197
195198
@@ -279,16 +282,25 @@ async def run_stream(
279282
cancellation_token: CancellationToken | None = None,
280283
) -> AsyncGenerator[AgentEvent | ChatMessage | TaskResult, None]:
281284
"""Run the team and produces a stream of messages and the final result
282-
of the type :class:`TaskResult` as the last item in the stream. Once the
285+
of the type :class:`~autogen_agentchat.base.TaskResult` as the last item in the stream. Once the
283286
team is stopped, the termination condition is reset.
284287
288+
.. note::
289+
290+
If an agent produces :class:`~autogen_agentchat.messages.ModelClientStreamingChunkEvent`,
291+
the message will be yielded in the stream but it will not be included in the
292+
:attr:`~autogen_agentchat.base.TaskResult.messages`.
293+
285294
Args:
286295
task (str | ChatMessage | Sequence[ChatMessage] | None): The task to run the team with. Can be a string, a single :class:`ChatMessage` , or a list of :class:`ChatMessage`.
287296
cancellation_token (CancellationToken | None): The cancellation token to kill the task immediately.
288297
Setting the cancellation token potentially put the team in an inconsistent state,
289298
and it may not reset the termination condition.
290299
To gracefully stop the team, use :class:`~autogen_agentchat.conditions.ExternalTermination` instead.
291300
301+
Returns:
302+
stream: an :class:`~collections.abc.AsyncGenerator` that yields :class:`~autogen_agentchat.messages.AgentEvent`, :class:`~autogen_agentchat.messages.ChatMessage`, and the final result :class:`~autogen_agentchat.base.TaskResult` as the last item in the stream.
303+
292304
Example using the :class:`~autogen_agentchat.teams.RoundRobinGroupChat` team:
293305
294306
.. code-block:: python
@@ -422,6 +434,9 @@ async def stop_runtime() -> None:
422434
if message is None:
423435
break
424436
yield message
437+
if isinstance(message, ModelClientStreamingChunkEvent):
438+
# Skip the model client streaming chunk events.
439+
continue
425440
output_messages.append(message)
426441

427442
# Yield the final result.

python/packages/autogen-agentchat/src/autogen_agentchat/ui/_console.py

+31-8
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,13 @@
1010

1111
from autogen_agentchat.agents import UserProxyAgent
1212
from autogen_agentchat.base import Response, TaskResult
13-
from autogen_agentchat.messages import AgentEvent, ChatMessage, MultiModalMessage, UserInputRequestedEvent
13+
from autogen_agentchat.messages import (
14+
AgentEvent,
15+
ChatMessage,
16+
ModelClientStreamingChunkEvent,
17+
MultiModalMessage,
18+
UserInputRequestedEvent,
19+
)
1420

1521

1622
def _is_running_in_iterm() -> bool:
@@ -106,6 +112,8 @@ async def Console(
106112

107113
last_processed: Optional[T] = None
108114

115+
streaming_chunks: List[str] = []
116+
109117
async for message in stream:
110118
if isinstance(message, TaskResult):
111119
duration = time.time() - start_time
@@ -159,13 +167,28 @@ async def Console(
159167
else:
160168
# Cast required for mypy to be happy
161169
message = cast(AgentEvent | ChatMessage, message) # type: ignore
162-
output = f"{'-' * 10} {message.source} {'-' * 10}\n{_message_to_str(message, render_image_iterm=render_image_iterm)}\n"
163-
if message.models_usage:
164-
if output_stats:
165-
output += f"[Prompt tokens: {message.models_usage.prompt_tokens}, Completion tokens: {message.models_usage.completion_tokens}]\n"
166-
total_usage.completion_tokens += message.models_usage.completion_tokens
167-
total_usage.prompt_tokens += message.models_usage.prompt_tokens
168-
await aprint(output, end="")
170+
if not streaming_chunks:
171+
# Print message sender.
172+
await aprint(f"{'-' * 10} {message.source} {'-' * 10}", end="\n")
173+
if isinstance(message, ModelClientStreamingChunkEvent):
174+
await aprint(message.content, end="")
175+
streaming_chunks.append(message.content)
176+
else:
177+
if streaming_chunks:
178+
streaming_chunks.clear()
179+
# Chunked messages are already printed, so we just print a newline.
180+
await aprint("", end="\n")
181+
else:
182+
# Print message content.
183+
await aprint(_message_to_str(message, render_image_iterm=render_image_iterm), end="\n")
184+
if message.models_usage:
185+
if output_stats:
186+
await aprint(
187+
f"[Prompt tokens: {message.models_usage.prompt_tokens}, Completion tokens: {message.models_usage.completion_tokens}]",
188+
end="\n",
189+
)
190+
total_usage.completion_tokens += message.models_usage.completion_tokens
191+
total_usage.prompt_tokens += message.models_usage.prompt_tokens
169192

170193
if last_processed is None:
171194
raise ValueError("No TaskResult or Response was processed.")

0 commit comments

Comments
 (0)