Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions src/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from .agent_output import AgentOutputSchemaBase
from .guardrail import InputGuardrail, OutputGuardrail
from .handoffs import Handoff
from .items import ItemHelpers
from .logger import logger
from .mcp import MCPUtil
from .model_settings import ModelSettings
Expand Down Expand Up @@ -417,7 +416,7 @@ def as_tool(
description_override=tool_description or "",
is_enabled=is_enabled,
)
async def run_agent(context: RunContextWrapper, input: str) -> str:
async def run_agent(context: RunContextWrapper, input: str) -> Any:
from .run import DEFAULT_MAX_TURNS, Runner

resolved_max_turns = max_turns if max_turns is not None else DEFAULT_MAX_TURNS
Expand All @@ -436,7 +435,7 @@ async def run_agent(context: RunContextWrapper, input: str) -> str:
if custom_output_extractor:
return await custom_output_extractor(output)

return ItemHelpers.text_message_outputs(output.new_items)
return output.final_output

return run_agent

Expand Down
21 changes: 3 additions & 18 deletions tests/test_agent_as_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,30 +225,15 @@ async def custom_extractor(result):


@pytest.mark.asyncio
async def test_agent_as_tool_returns_concatenated_text(monkeypatch: pytest.MonkeyPatch) -> None:
"""Agent tool should use default text aggregation when no custom extractor is provided."""
async def test_agent_as_tool_returns_final_output(monkeypatch: pytest.MonkeyPatch) -> None:
"""Agent tool should return final_output when no custom extractor is provided."""

agent = Agent(name="storyteller")

message = ResponseOutputMessage(
id="msg_1",
role="assistant",
status="completed",
type="message",
content=[
ResponseOutputText(
annotations=[],
text="Hello world",
type="output_text",
logprobs=[],
)
],
)

result = type(
"DummyResult",
(),
{"new_items": [MessageOutputItem(agent=agent, raw_item=message)]},
{"final_output": "Hello world"},
)()

async def fake_run(
Expand Down