Skip to content

Commit 66a370c

Browse files
committed
structured output w/ multiple pending tool calls
1 parent 6745b79 commit 66a370c

File tree

2 files changed

+143
-33
lines changed

2 files changed

+143
-33
lines changed

libs/langchain_v1/langchain/agents/middleware_agent.py

Lines changed: 26 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from typing import Any, cast
66

77
from langchain_core.language_models.chat_models import BaseChatModel
8-
from langchain_core.messages import AIMessage, SystemMessage, ToolMessage
8+
from langchain_core.messages import AIMessage, AnyMessage, SystemMessage, ToolMessage
99
from langchain_core.runnables import Runnable
1010
from langchain_core.tools import BaseTool
1111
from langgraph.constants import END, START
@@ -219,8 +219,6 @@ def _handle_model_output(state: dict[str, Any], output: AIMessage) -> dict[str,
219219
if not output.tool_calls and native_output_binding:
220220
structured_response = native_output_binding.parse(output)
221221
return {"messages": [output], "response": structured_response}
222-
if state.get("response") is not None:
223-
return {"messages": [output], "response": None}
224222
return {"messages": [output]}
225223

226224
# Handle structured output with tools strategy
@@ -418,7 +416,7 @@ async def amodel_request(state: dict[str, Any]) -> dict[str, Any]:
418416
if tool_node is not None:
419417
graph.add_conditional_edges(
420418
"tools",
421-
_make_tools_to_model_edge(tool_node, first_node),
419+
_make_tools_to_model_edge(tool_node, first_node, structured_output_tools),
422420
[first_node, END],
423421
)
424422
graph.add_conditional_edges(
@@ -482,38 +480,30 @@ def _resolve_jump(jump_to: JumpTo | None, first_node: str) -> str | None:
482480
return None
483481

484482

483+
def _fetch_last_ai_and_tool_messages(
484+
messages: list[AnyMessage],
485+
) -> tuple[AIMessage, list[ToolMessage]]:
486+
last_ai_index: int
487+
last_ai_message: AIMessage
488+
489+
for i in range(len(messages) - 1, -1, -1):
490+
if isinstance(messages[i], AIMessage):
491+
last_ai_index = i
492+
last_ai_message = cast("AIMessage", messages[i])
493+
break
494+
495+
tool_messages = [m for m in messages[last_ai_index + 1 :] if isinstance(m, ToolMessage)]
496+
return last_ai_message, tool_messages
497+
498+
485499
def _make_model_to_tools_edge(
486500
first_node: str, structured_output_tools: dict[str, OutputToolBinding], tool_node: ToolNode
487501
) -> Callable[[AgentState], str | list[Send] | None]:
488502
def model_to_tools(state: AgentState) -> str | list[Send] | None:
489503
if jump_to := state.get("jump_to"):
490504
return _resolve_jump(jump_to, first_node)
491505

492-
last_message = state["messages"][-1]
493-
494-
# Check if this is a ToolMessage from structured output - if so, end
495-
# interesting, should we be auto ending here? should we execute other tools?
496-
if isinstance(last_message, ToolMessage) and last_message.name in structured_output_tools:
497-
return END
498-
499-
# Find the last AI message and all tool messages since said AI message
500-
last_ai_index = None
501-
last_ai_message: AIMessage
502-
for i in range(len(state["messages"]) - 1, -1, -1):
503-
if isinstance(state["messages"][i], AIMessage):
504-
last_ai_index = i
505-
last_ai_message = cast("AIMessage", state["messages"][i])
506-
break
507-
508-
tool_messages = (
509-
[
510-
m.tool_call_id
511-
for m in state["messages"][last_ai_index + 1 :]
512-
if isinstance(m, ToolMessage)
513-
]
514-
if last_ai_index is not None
515-
else []
516-
)
506+
last_ai_message, tool_messages = _fetch_last_ai_and_tool_messages(state["messages"])
517507

518508
pending_tool_calls = [
519509
c
@@ -538,17 +528,21 @@ def model_to_tools(state: AgentState) -> str | list[Send] | None:
538528

539529

540530
def _make_tools_to_model_edge(
541-
tool_node: ToolNode, next_node: str
531+
tool_node: ToolNode, next_node: str, structured_output_tools: dict[str, OutputToolBinding]
542532
) -> Callable[[AgentState], str | None]:
543533
def tools_to_model(state: AgentState) -> str | None:
544-
ai_message = [m for m in state["messages"] if isinstance(m, AIMessage)][-1]
534+
last_ai_message, tool_messages = _fetch_last_ai_and_tool_messages(state["messages"])
535+
545536
if all(
546537
tool_node.tools_by_name[c["name"]].return_direct
547-
for c in ai_message.tool_calls
538+
for c in last_ai_message.tool_calls
548539
if c["name"] in tool_node.tools_by_name
549540
):
550541
return END
551542

543+
if any(t.name in structured_output_tools for t in tool_messages):
544+
return END
545+
552546
return next_node
553547

554548
return tools_to_model

libs/langchain_v1/tests/unit_tests/agents/test_middleware_agent.py

Lines changed: 117 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from syrupy.assertion import SnapshotAssertion
66

7+
from pydantic import BaseModel, Field
78
from langchain_core.language_models import BaseChatModel
89
from langchain_core.language_models.chat_models import BaseChatModel
910
from langchain_core.messages import (
@@ -28,7 +29,7 @@
2829
from langgraph.checkpoint.memory import InMemorySaver
2930
from langgraph.constants import END
3031
from langgraph.graph.message import REMOVE_ALL_MESSAGES
31-
from langgraph.prebuilt.interrupt import ActionRequest
32+
from langchain.agents.structured_output import ToolStrategy
3233

3334
from .messages import _AnyIdHumanMessage, _AnyIdToolMessage
3435
from .model import FakeToolCallingModel
@@ -1275,3 +1276,118 @@ def modify_model_request(self, request: ModelRequest, state: AgentState) -> Mode
12751276
assert (
12761277
result["messages"][2].content == "You are a helpful assistant.-Hello-remember to be nice!"
12771278
)
1279+
1280+
1281+
def test_tools_to_model_edge_with_structured_and_regular_tool_calls():
1282+
"""Test that when there are both structured and regular tool calls, we execute regular and jump to END."""
1283+
1284+
class WeatherResponse(BaseModel):
1285+
"""Weather response."""
1286+
1287+
temperature: float = Field(description="Temperature in fahrenheit")
1288+
condition: str = Field(description="Weather condition")
1289+
1290+
@tool
1291+
def regular_tool(query: str) -> str:
1292+
"""A regular tool that returns a string."""
1293+
return f"Regular tool result for: {query}"
1294+
1295+
# Create a fake model that returns both structured and regular tool calls
1296+
class FakeModelWithBothToolCalls(FakeToolCallingModel):
1297+
def __init__(self):
1298+
super().__init__()
1299+
self.tool_calls = [
1300+
[
1301+
ToolCall(
1302+
name="WeatherResponse",
1303+
args={"temperature": 72.0, "condition": "sunny"},
1304+
id="structured_call_1",
1305+
),
1306+
ToolCall(
1307+
name="regular_tool", args={"query": "test query"}, id="regular_call_1"
1308+
),
1309+
]
1310+
]
1311+
1312+
# Create agent with both structured output and regular tools
1313+
agent = create_agent(
1314+
model=FakeModelWithBothToolCalls(),
1315+
tools=[regular_tool],
1316+
response_format=ToolStrategy(schema=WeatherResponse),
1317+
)
1318+
1319+
# Compile and invoke the agent
1320+
compiled_agent = agent.compile()
1321+
result = compiled_agent.invoke(
1322+
{"messages": [HumanMessage("What's the weather and help me with a query?")]}
1323+
)
1324+
1325+
# Verify that we have the expected messages:
1326+
# 1. Human message
1327+
# 2. AI message with both tool calls
1328+
# 3. Tool message from structured tool call
1329+
# 4. Tool message from regular tool call
1330+
1331+
messages = result["messages"]
1332+
assert len(messages) >= 4
1333+
1334+
# Check that we have the AI message with both tool calls
1335+
ai_message = messages[1]
1336+
assert isinstance(ai_message, AIMessage)
1337+
assert len(ai_message.tool_calls) == 2
1338+
1339+
# Check that we have a tool message from the regular tool
1340+
tool_messages = [m for m in messages if isinstance(m, ToolMessage)]
1341+
assert len(tool_messages) >= 1
1342+
1343+
# The regular tool should have been executed
1344+
regular_tool_message = next((m for m in tool_messages if m.name == "regular_tool"), None)
1345+
assert regular_tool_message is not None
1346+
assert "Regular tool result for: test query" in regular_tool_message.content
1347+
1348+
# Verify that the structured response is available in the result
1349+
assert "response" in result
1350+
assert result["response"] is not None
1351+
assert hasattr(result["response"], "temperature")
1352+
assert result["response"].temperature == 72.0
1353+
assert result["response"].condition == "sunny"
1354+
1355+
1356+
def test_human_in_the_loop_middleware_with_structured_response() -> None:
1357+
"""Test that we can get structured response with human in the loop middleware."""
1358+
1359+
class WeatherBaseModel(BaseModel):
1360+
temperature: float = Field(description="Temperature in fahrenheit")
1361+
condition: str = Field(description="Weather condition")
1362+
1363+
tool_calls = [
1364+
[
1365+
{"args": {"a": 1, "b": 2}, "id": "1", "name": "add_numbers"},
1366+
{
1367+
"name": "WeatherBaseModel",
1368+
"id": "2",
1369+
"args": {"temperature": 72.0, "condition": "sunny"},
1370+
},
1371+
],
1372+
]
1373+
1374+
@tool
1375+
def add_numbers(a: int, b: int) -> int:
1376+
"""Add two numbers."""
1377+
return a + b
1378+
1379+
model = FakeToolCallingModel(tool_calls=tool_calls)
1380+
1381+
agent = create_agent(model=model, tools=[add_numbers], response_format=WeatherBaseModel)
1382+
agent = agent.compile()
1383+
response = agent.invoke(
1384+
{"messages": [HumanMessage("Add 1 and 2, then return the weather forecast.")]}
1385+
)
1386+
1387+
assert response["response"] == WeatherBaseModel(temperature=72.0, condition="sunny")
1388+
messages = response["messages"]
1389+
assert len(messages) == 4
1390+
assert messages[0].content == "Add 1 and 2, then return the weather forecast."
1391+
assert len(messages[1].tool_calls) == 2
1392+
assert messages[2].name == "WeatherBaseModel"
1393+
assert messages[3].name == "add_numbers"

0 commit comments

Comments
 (0)