|
4 | 4 |
|
5 | 5 | from syrupy.assertion import SnapshotAssertion
|
6 | 6 |
|
| 7 | +from pydantic import BaseModel, Field |
7 | 8 | from langchain_core.language_models import BaseChatModel
|
8 | 9 | from langchain_core.language_models.chat_models import BaseChatModel
|
9 | 10 | from langchain_core.messages import (
|
|
28 | 29 | from langgraph.checkpoint.memory import InMemorySaver
|
29 | 30 | from langgraph.constants import END
|
30 | 31 | from langgraph.graph.message import REMOVE_ALL_MESSAGES
|
31 |
| -from langgraph.prebuilt.interrupt import ActionRequest |
| 32 | +from langchain.agents.structured_output import ToolStrategy |
32 | 33 |
|
33 | 34 | from .messages import _AnyIdHumanMessage, _AnyIdToolMessage
|
34 | 35 | from .model import FakeToolCallingModel
|
@@ -1275,3 +1276,118 @@ def modify_model_request(self, request: ModelRequest, state: AgentState) -> Mode
|
1275 | 1276 | assert (
|
1276 | 1277 | result["messages"][2].content == "You are a helpful assistant.-Hello-remember to be nice!"
|
1277 | 1278 | )
|
| 1279 | + |
| 1280 | + |
| 1281 | +def test_tools_to_model_edge_with_structured_and_regular_tool_calls(): |
| 1282 | + """Test that when there are both structured and regular tool calls, we execute regular and jump to END.""" |
| 1283 | + |
| 1284 | + class WeatherResponse(BaseModel): |
| 1285 | + """Weather response.""" |
| 1286 | + |
| 1287 | + temperature: float = Field(description="Temperature in fahrenheit") |
| 1288 | + condition: str = Field(description="Weather condition") |
| 1289 | + |
| 1290 | + @tool |
| 1291 | + def regular_tool(query: str) -> str: |
| 1292 | + """A regular tool that returns a string.""" |
| 1293 | + return f"Regular tool result for: {query}" |
| 1294 | + |
| 1295 | + # Create a fake model that returns both structured and regular tool calls |
| 1296 | + class FakeModelWithBothToolCalls(FakeToolCallingModel): |
| 1297 | + def __init__(self): |
| 1298 | + super().__init__() |
| 1299 | + self.tool_calls = [ |
| 1300 | + [ |
| 1301 | + ToolCall( |
| 1302 | + name="WeatherResponse", |
| 1303 | + args={"temperature": 72.0, "condition": "sunny"}, |
| 1304 | + id="structured_call_1", |
| 1305 | + ), |
| 1306 | + ToolCall( |
| 1307 | + name="regular_tool", args={"query": "test query"}, id="regular_call_1" |
| 1308 | + ), |
| 1309 | + ] |
| 1310 | + ] |
| 1311 | + |
| 1312 | + # Create agent with both structured output and regular tools |
| 1313 | + agent = create_agent( |
| 1314 | + model=FakeModelWithBothToolCalls(), |
| 1315 | + tools=[regular_tool], |
| 1316 | + response_format=ToolStrategy(schema=WeatherResponse), |
| 1317 | + ) |
| 1318 | + |
| 1319 | + # Compile and invoke the agent |
| 1320 | + compiled_agent = agent.compile() |
| 1321 | + result = compiled_agent.invoke( |
| 1322 | + {"messages": [HumanMessage("What's the weather and help me with a query?")]} |
| 1323 | + ) |
| 1324 | + |
| 1325 | + # Verify that we have the expected messages: |
| 1326 | + # 1. Human message |
| 1327 | + # 2. AI message with both tool calls |
| 1328 | + # 3. Tool message from structured tool call |
| 1329 | + # 4. Tool message from regular tool call |
| 1330 | + |
| 1331 | + messages = result["messages"] |
| 1332 | + assert len(messages) >= 4 |
| 1333 | + |
| 1334 | + # Check that we have the AI message with both tool calls |
| 1335 | + ai_message = messages[1] |
| 1336 | + assert isinstance(ai_message, AIMessage) |
| 1337 | + assert len(ai_message.tool_calls) == 2 |
| 1338 | + |
| 1339 | + # Check that we have a tool message from the regular tool |
| 1340 | + tool_messages = [m for m in messages if isinstance(m, ToolMessage)] |
| 1341 | + assert len(tool_messages) >= 1 |
| 1342 | + |
| 1343 | + # The regular tool should have been executed |
| 1344 | + regular_tool_message = next((m for m in tool_messages if m.name == "regular_tool"), None) |
| 1345 | + assert regular_tool_message is not None |
| 1346 | + assert "Regular tool result for: test query" in regular_tool_message.content |
| 1347 | + |
| 1348 | + # Verify that the structured response is available in the result |
| 1349 | + assert "response" in result |
| 1350 | + assert result["response"] is not None |
| 1351 | + assert hasattr(result["response"], "temperature") |
| 1352 | + assert result["response"].temperature == 72.0 |
| 1353 | + assert result["response"].condition == "sunny" |
| 1354 | + |
| 1355 | + |
| 1356 | +def test_human_in_the_loop_middleware_with_structured_response() -> None: |
| 1357 | + """Test that we can get structured response with human in the loop middleware.""" |
| 1358 | + |
| 1359 | + class WeatherBaseModel(BaseModel): |
| 1360 | + temperature: float = Field(description="Temperature in fahrenheit") |
| 1361 | + condition: str = Field(description="Weather condition") |
| 1362 | + |
| 1363 | + tool_calls = [ |
| 1364 | + [ |
| 1365 | + {"args": {"a": 1, "b": 2}, "id": "1", "name": "add_numbers"}, |
| 1366 | + { |
| 1367 | + "name": "WeatherBaseModel", |
| 1368 | + "id": "2", |
| 1369 | + "args": {"temperature": 72.0, "condition": "sunny"}, |
| 1370 | + }, |
| 1371 | + ], |
| 1372 | + ] |
| 1373 | + |
| 1374 | + @tool |
| 1375 | + def add_numbers(a: int, b: int) -> int: |
| 1376 | + """Add two numbers.""" |
| 1377 | + return a + b |
| 1378 | + |
| 1379 | + model = FakeToolCallingModel(tool_calls=tool_calls) |
| 1380 | + |
| 1381 | + agent = create_agent(model=model, tools=[add_numbers], response_format=WeatherBaseModel) |
| 1382 | + agent = agent.compile() |
| 1383 | + response = agent.invoke( |
| 1384 | + {"messages": [HumanMessage("Add 1 and 2, then return the weather forecast.")]} |
| 1385 | + ) |
| 1386 | + |
| 1387 | + assert response["response"] == WeatherBaseModel(temperature=72.0, condition="sunny") |
| 1388 | + messages = response["messages"] |
| 1389 | + assert len(messages) == 4 |
| 1390 | + assert messages[0].content == "Add 1 and 2, then return the weather forecast." |
| 1391 | + assert len(messages[1].tool_calls) == 2 |
| 1392 | + assert messages[2].name == "WeatherBaseModel" |
| 1393 | + assert messages[3].name == "add_numbers" |
0 commit comments