Skip to content

Commit a3128ce

Browse files
authored
feat(agents): Add on_llm_start and on_llm_end Lifecycle Hooks (#987)
1 parent cb72933 commit a3128ce

File tree

3 files changed

+192
-1
lines changed

3 files changed

+192
-1
lines changed

src/agents/lifecycle.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
from typing import Any, Generic
1+
from typing import Any, Generic, Optional
22

33
from typing_extensions import TypeVar
44

55
from .agent import Agent, AgentBase
6+
from .items import ModelResponse, TResponseInputItem
67
from .run_context import RunContextWrapper, TContext
78
from .tool import Tool
89

@@ -14,6 +15,25 @@ class RunHooksBase(Generic[TContext, TAgent]):
1415
override the methods you need.
1516
"""
1617

18+
async def on_llm_start(
19+
self,
20+
context: RunContextWrapper[TContext],
21+
agent: Agent[TContext],
22+
system_prompt: Optional[str],
23+
input_items: list[TResponseInputItem],
24+
) -> None:
25+
"""Called just before invoking the LLM for this agent."""
26+
pass
27+
28+
async def on_llm_end(
29+
self,
30+
context: RunContextWrapper[TContext],
31+
agent: Agent[TContext],
32+
response: ModelResponse,
33+
) -> None:
34+
"""Called immediately after the LLM call returns for this agent."""
35+
pass
36+
1737
async def on_agent_start(self, context: RunContextWrapper[TContext], agent: TAgent) -> None:
1838
"""Called before the agent is invoked. Called each time the current agent changes."""
1939
pass
@@ -106,6 +126,25 @@ async def on_tool_end(
106126
"""Called after a tool is invoked."""
107127
pass
108128

129+
async def on_llm_start(
130+
self,
131+
context: RunContextWrapper[TContext],
132+
agent: Agent[TContext],
133+
system_prompt: Optional[str],
134+
input_items: list[TResponseInputItem],
135+
) -> None:
136+
"""Called immediately before the agent issues an LLM call."""
137+
pass
138+
139+
async def on_llm_end(
140+
self,
141+
context: RunContextWrapper[TContext],
142+
agent: Agent[TContext],
143+
response: ModelResponse,
144+
) -> None:
145+
"""Called immediately after the agent receives the LLM response."""
146+
pass
147+
109148

110149
RunHooks = RunHooksBase[TContext, Agent]
111150
"""Run hooks when using `Agent`."""

src/agents/run.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -935,6 +935,7 @@ async def _run_single_turn_streamed(
935935
input = ItemHelpers.input_to_new_input_list(streamed_result.input)
936936
input.extend([item.to_input_item() for item in streamed_result.new_items])
937937

938+
# THIS IS THE RESOLVED CONFLICT BLOCK
938939
filtered = await cls._maybe_filter_model_input(
939940
agent=agent,
940941
run_config=run_config,
@@ -943,6 +944,12 @@ async def _run_single_turn_streamed(
943944
system_instructions=system_prompt,
944945
)
945946

947+
# Call hook just before the model is invoked, with the correct system_prompt.
948+
if agent.hooks:
949+
await agent.hooks.on_llm_start(
950+
context_wrapper, agent, filtered.instructions, filtered.input
951+
)
952+
946953
# 1. Stream the output events
947954
async for event in model.stream_response(
948955
filtered.instructions,
@@ -979,6 +986,10 @@ async def _run_single_turn_streamed(
979986

980987
streamed_result._event_queue.put_nowait(RawResponsesStreamEvent(data=event))
981988

989+
# Call hook just after the model response is finalized.
990+
if agent.hooks and final_response is not None:
991+
await agent.hooks.on_llm_end(context_wrapper, agent, final_response)
992+
982993
# 2. At this point, the streaming is complete for this turn of the agent loop.
983994
if not final_response:
984995
raise ModelBehaviorError("Model did not produce a final response!")
@@ -1252,6 +1263,14 @@ async def _get_new_response(
12521263
model = cls._get_model(agent, run_config)
12531264
model_settings = agent.model_settings.resolve(run_config.model_settings)
12541265
model_settings = RunImpl.maybe_reset_tool_choice(agent, tool_use_tracker, model_settings)
1266+
# If the agent has hooks, we need to call them before and after the LLM call
1267+
if agent.hooks:
1268+
await agent.hooks.on_llm_start(
1269+
context_wrapper,
1270+
agent,
1271+
filtered.instructions, # Use filtered instructions
1272+
filtered.input, # Use filtered input
1273+
)
12551274

12561275
new_response = await model.get_response(
12571276
system_instructions=filtered.instructions,
@@ -1266,6 +1285,9 @@ async def _get_new_response(
12661285
previous_response_id=previous_response_id,
12671286
prompt=prompt_config,
12681287
)
1288+
# If the agent has hooks, we need to call them after the LLM call
1289+
if agent.hooks:
1290+
await agent.hooks.on_llm_end(context_wrapper, agent, new_response)
12691291

12701292
context_wrapper.usage.add(new_response.usage)
12711293

tests/test_agent_llm_hooks.py

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
from collections import defaultdict
2+
from typing import Any, Optional
3+
4+
import pytest
5+
6+
from agents.agent import Agent
7+
from agents.items import ItemHelpers, ModelResponse, TResponseInputItem
8+
from agents.lifecycle import AgentHooks
9+
from agents.run import Runner
10+
from agents.run_context import RunContextWrapper, TContext
11+
from agents.tool import Tool
12+
13+
from .fake_model import FakeModel
14+
from .test_responses import (
15+
get_function_tool,
16+
get_text_message,
17+
)
18+
19+
20+
class AgentHooksForTests(AgentHooks):
21+
def __init__(self):
22+
self.events: dict[str, int] = defaultdict(int)
23+
24+
def reset(self):
25+
self.events.clear()
26+
27+
async def on_start(self, context: RunContextWrapper[TContext], agent: Agent[TContext]) -> None:
28+
self.events["on_start"] += 1
29+
30+
async def on_end(
31+
self, context: RunContextWrapper[TContext], agent: Agent[TContext], output: Any
32+
) -> None:
33+
self.events["on_end"] += 1
34+
35+
async def on_handoff(
36+
self, context: RunContextWrapper[TContext], agent: Agent[TContext], source: Agent[TContext]
37+
) -> None:
38+
self.events["on_handoff"] += 1
39+
40+
async def on_tool_start(
41+
self, context: RunContextWrapper[TContext], agent: Agent[TContext], tool: Tool
42+
) -> None:
43+
self.events["on_tool_start"] += 1
44+
45+
async def on_tool_end(
46+
self,
47+
context: RunContextWrapper[TContext],
48+
agent: Agent[TContext],
49+
tool: Tool,
50+
result: str,
51+
) -> None:
52+
self.events["on_tool_end"] += 1
53+
54+
# NEW: LLM hooks
55+
async def on_llm_start(
56+
self,
57+
context: RunContextWrapper[TContext],
58+
agent: Agent[TContext],
59+
system_prompt: Optional[str],
60+
input_items: list[TResponseInputItem],
61+
) -> None:
62+
self.events["on_llm_start"] += 1
63+
64+
async def on_llm_end(
65+
self,
66+
context: RunContextWrapper[TContext],
67+
agent: Agent[TContext],
68+
response: ModelResponse,
69+
) -> None:
70+
self.events["on_llm_end"] += 1
71+
72+
73+
# Example test using the above hooks:
74+
@pytest.mark.asyncio
75+
async def test_async_agent_hooks_with_llm():
76+
hooks = AgentHooksForTests()
77+
model = FakeModel()
78+
agent = Agent(
79+
name="A", model=model, tools=[get_function_tool("f", "res")], handoffs=[], hooks=hooks
80+
)
81+
# Simulate a single LLM call producing an output:
82+
model.set_next_output([get_text_message("hello")])
83+
await Runner.run(agent, input="hello")
84+
# Expect one on_start, one on_llm_start, one on_llm_end, and one on_end
85+
assert hooks.events == {"on_start": 1, "on_llm_start": 1, "on_llm_end": 1, "on_end": 1}
86+
87+
88+
# test_sync_agent_hook_with_llm()
89+
def test_sync_agent_hook_with_llm():
90+
hooks = AgentHooksForTests()
91+
model = FakeModel()
92+
agent = Agent(
93+
name="A", model=model, tools=[get_function_tool("f", "res")], handoffs=[], hooks=hooks
94+
)
95+
# Simulate a single LLM call producing an output:
96+
model.set_next_output([get_text_message("hello")])
97+
Runner.run_sync(agent, input="hello")
98+
# Expect one on_start, one on_llm_start, one on_llm_end, and one on_end
99+
assert hooks.events == {"on_start": 1, "on_llm_start": 1, "on_llm_end": 1, "on_end": 1}
100+
101+
102+
# test_streamed_agent_hooks_with_llm():
103+
@pytest.mark.asyncio
104+
async def test_streamed_agent_hooks_with_llm():
105+
hooks = AgentHooksForTests()
106+
model = FakeModel()
107+
agent = Agent(
108+
name="A", model=model, tools=[get_function_tool("f", "res")], handoffs=[], hooks=hooks
109+
)
110+
# Simulate a single LLM call producing an output:
111+
model.set_next_output([get_text_message("hello")])
112+
stream = Runner.run_streamed(agent, input="hello")
113+
114+
async for event in stream.stream_events():
115+
if event.type == "raw_response_event":
116+
continue
117+
if event.type == "agent_updated_stream_event":
118+
print(f"[EVENT] agent_updated → {event.new_agent.name}")
119+
elif event.type == "run_item_stream_event":
120+
item = event.item
121+
if item.type == "tool_call_item":
122+
print("[EVENT] tool_call_item")
123+
elif item.type == "tool_call_output_item":
124+
print(f"[EVENT] tool_call_output_item → {item.output}")
125+
elif item.type == "message_output_item":
126+
text = ItemHelpers.text_message_output(item)
127+
print(f"[EVENT] message_output_item → {text}")
128+
129+
# Expect one on_start, one on_llm_start, one on_llm_end, and one on_end
130+
assert hooks.events == {"on_start": 1, "on_llm_start": 1, "on_llm_end": 1, "on_end": 1}

0 commit comments

Comments
 (0)