|
| 1 | +from datetime import timedelta |
1 | 2 | from unittest.mock import MagicMock |
2 | 3 |
|
3 | 4 | import pytest |
@@ -88,5 +89,31 @@ async def test_stream(mcp_agent_tool, mock_mcp_client, alist): |
88 | 89 |
|
89 | 90 | assert tru_events == exp_events |
90 | 91 | mock_mcp_client.call_tool_async.assert_called_once_with( |
91 | | - tool_use_id="test-123", name="test_tool", arguments={"param": "value"} |
| 92 | + tool_use_id="test-123", name="test_tool", arguments={"param": "value"}, read_timeout_seconds=None |
| 93 | + ) |
| 94 | + |
| 95 | + |
| 96 | +def test_timeout_initialization(mock_mcp_tool, mock_mcp_client): |
| 97 | + timeout = timedelta(seconds=30) |
| 98 | + agent_tool = MCPAgentTool(mock_mcp_tool, mock_mcp_client, timeout=timeout) |
| 99 | + assert agent_tool.timeout == timeout |
| 100 | + |
| 101 | + |
| 102 | +def test_timeout_default_none(mock_mcp_tool, mock_mcp_client): |
| 103 | + agent_tool = MCPAgentTool(mock_mcp_tool, mock_mcp_client) |
| 104 | + assert agent_tool.timeout is None |
| 105 | + |
| 106 | + |
| 107 | +@pytest.mark.asyncio |
| 108 | +async def test_stream_with_timeout(mock_mcp_tool, mock_mcp_client, alist): |
| 109 | + timeout = timedelta(seconds=45) |
| 110 | + agent_tool = MCPAgentTool(mock_mcp_tool, mock_mcp_client, timeout=timeout) |
| 111 | + tool_use = {"toolUseId": "test-456", "name": "test_tool", "input": {"param": "value"}} |
| 112 | + |
| 113 | + tru_events = await alist(agent_tool.stream(tool_use, {})) |
| 114 | + exp_events = [ToolResultEvent(mock_mcp_client.call_tool_async.return_value)] |
| 115 | + |
| 116 | + assert tru_events == exp_events |
| 117 | + mock_mcp_client.call_tool_async.assert_called_once_with( |
| 118 | + tool_use_id="test-456", name="test_tool", arguments={"param": "value"}, read_timeout_seconds=timeout |
92 | 119 | ) |
0 commit comments