-
Notifications
You must be signed in to change notification settings - Fork 485
Expand file tree
/
Copy pathtool_manager.py
More file actions
260 lines (206 loc) · 8.67 KB
/
tool_manager.py
File metadata and controls
260 lines (206 loc) · 8.67 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
import re
from areal.utils import logging
from tools import ( # isort: skip
BaseTool,
CalculatorTool,
PythonTool,
ToolCallStatus,
ToolType,
)
logger = logging.getLogger("Tool Manager")
class ToolRegistry:
"""Tool registry that manages all available tools"""
TOOL_NAMES = {
"python": ToolType.PYTHON,
"calculator": ToolType.CALCULATOR,
}
def __init__(
self,
timeout: int = 30,
enabled_tools: str = "python;calculator",
debug_mode: bool = False,
):
# All available tools
self.all_tools = {
ToolType.PYTHON: PythonTool(timeout, debug_mode),
ToolType.CALCULATOR: CalculatorTool(timeout, debug_mode),
}
# Set enabled tools
if enabled_tools is None:
# Default: enable all tools
self.enabled_tools = list(self.TOOL_NAMES.values())
else:
# Validate enabled tools
self.enabled_tools = []
for tool_type in enabled_tools.split(";"):
if tool_type in self.TOOL_NAMES:
self.enabled_tools.append(self.TOOL_NAMES[tool_type])
else:
logger.warning(f"Unknown tool type: {tool_type}, skipping")
# Only keep enabled tools
self.tools = {
tool_type: self.all_tools[tool_type] for tool_type in self.enabled_tools
}
logger.info(
f"ToolRegistry initialized with enabled tools: {[t.value for t in self.enabled_tools]}"
)
def get_tool(self, tool_type: ToolType) -> BaseTool | None:
"""Get tool instance"""
return self.tools.get(tool_type)
def get_all_tools(self) -> dict[ToolType, BaseTool]:
"""Get all tool instances"""
return self.tools
def get_tool_markers(self) -> dict[ToolType, tuple[list[str], list[str]]]:
"""Get marker information for enabled tools only
Returns:
Dict[ToolType, Tuple[List[str], List[str]]]: Tool type -> (start markers list, end markers list)
"""
return {
tool_type: (tool.markers.start_markers, tool.markers.end_markers)
for tool_type, tool in self.tools.items()
}
def get_all_start_markers(self) -> list[str]:
"""Get all start markers for enabled tools only
Returns:
List[str]: List of all start markers
"""
start_markers = []
for tool in self.tools.values():
start_markers.extend(tool.markers.start_markers)
return start_markers
def get_all_end_markers(self) -> list[str]:
"""Get all end markers for enabled tools only
Returns:
List[str]: List of all end markers
"""
end_markers = []
for tool in self.tools.values():
end_markers.extend(tool.markers.end_markers)
return end_markers
def get_all_markers(self) -> list[str]:
"""Get all markers (start and end) for enabled tools only
Returns:
List[str]: List of all markers
"""
all_markers = []
all_markers.extend(self.get_all_start_markers())
all_markers.extend(self.get_all_end_markers())
return all_markers
def get_tool_descriptions_prompt(self) -> str:
"""Generate tool description prompt text for external calls (enabled tools only)"""
prompt_parts = ["Tools List:\n"]
for tool_type, tool in self.tools.items():
desc = tool.description
prompt_parts.append(f"Tool Name: {desc.name}")
prompt_parts.append(f"Description: {desc.description}")
prompt_parts.append(f"Parameter Description: {desc.parameter_prompt}")
prompt_parts.append(f"Usage Example: {desc.example}")
prompt_parts.append("---")
return "\n".join(prompt_parts)
def get_enabled_tools(self) -> list[ToolType]:
"""Get list of enabled tools
Returns:
List[ToolType]: List of enabled tool types
"""
return self.enabled_tools.copy()
class ToolRouter:
"""Tool router that determines which tool to call based on markers"""
def __init__(self, registry: ToolRegistry):
self.registry = registry
# Build tool markers dynamically based on enabled tools
self.tool_markers = self._build_tool_markers()
def _build_tool_markers(self) -> list[tuple[ToolType, str]]:
"""Build tool markers based on enabled tools"""
markers = []
for tool_type, tool in self.registry.tools.items():
# Build regex patterns for each tool's markers
for start_marker in tool.markers.start_markers:
for end_marker in tool.markers.end_markers:
# Escape special regex characters in markers
escaped_start = re.escape(start_marker)
escaped_end = re.escape(end_marker)
# Create pattern that matches content between markers
pattern = f"{escaped_start}(.*?){escaped_end}"
markers.append((tool_type, pattern))
return markers
def route(self, text: str) -> ToolType | None:
"""Determine tool type to call based on markers (enabled tools only)"""
text = text.strip()
# Check markers for each enabled tool
for tool_type, pattern in self.tool_markers:
if re.search(pattern, text, re.DOTALL | re.IGNORECASE):
return tool_type
return None
class ToolManager:
"""General tool manager responsible for coordinating tool calls"""
def __init__(
self,
timeout: int = 30,
enabled_tools: str = "python;calculator",
debug_mode: bool = False,
):
self.timeout = timeout
self.debug_mode = debug_mode
self.registry = ToolRegistry(timeout, enabled_tools, debug_mode)
self.router = ToolRouter(self.registry)
logger.info(
f"Initialized ToolManager (debug_mode={debug_mode}, enabled_tools={[t.value for t in self.registry.get_enabled_tools()]})"
)
def get_tool_descriptions_prompt(self) -> str:
"""Get tool description prompt text for external calls"""
return self.registry.get_tool_descriptions_prompt()
def get_tool_markers(self) -> dict[ToolType, tuple[list[str], list[str]]]:
"""Get marker information for all tools
Returns:
Dict[ToolType, Tuple[List[str], List[str]]]: Tool type -> (start markers list, end markers list)
"""
return self.registry.get_tool_markers()
def get_all_start_markers(self) -> list[str]:
"""Get all start markers for setting stop tokens
Returns:
List[str]: List of all start markers, e.g. ['```python\n', '<calculator>']
"""
return self.registry.get_all_start_markers()
def get_all_end_markers(self) -> list[str]:
"""Get all end markers for setting stop tokens
Returns:
List[str]: List of all end markers, e.g. ['\n```', '</calculator>']
"""
return self.registry.get_all_end_markers()
def get_all_markers(self) -> list[str]:
"""Get all markers (start and end) for setting stop tokens
Returns:
List[str]: List of all markers, e.g. ['```python\n', '\n```', '<calculator>', '</calculator>']
"""
return self.registry.get_all_markers()
def execute_tool_call(self, text: str) -> tuple[str, ToolCallStatus]:
"""Unified tool call interface
Returns:
Tuple[str, ToolCallStatus]: (result, status)
"""
# 1. Routing: determine which tool to call
tool_type = self.router.route(text)
if not tool_type:
return (
"Error: No suitable tool found for the given text",
ToolCallStatus.NOT_FOUND,
)
# 2. Get tool instance
tool = self.registry.get_tool(tool_type)
if not tool:
return f"Error: Tool {tool_type.value} not found", ToolCallStatus.NOT_FOUND
# 3. Parse parameters
try:
parameters = tool.parse_parameters(text)
logger.debug(f"Parsed parameters: {parameters}")
except Exception as e:
logger.error(f"Parameter parsing error: {e}")
return f"Error: Failed to parse parameters - {str(e)}", ToolCallStatus.ERROR
# 4. Execute tool
result, status = tool.execute(parameters)
if status == ToolCallStatus.SUCCESS:
logger.debug(f"Tool execution completed: {result}")
return result, status
else:
logger.error(f"Tool execution error: {result}")
return f"Error: Tool execution failed - {result}", status