Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
efa3b1d
fixes for tool use
hamishivi Oct 16, 2025
e95e9c7
whoops, fix
hamishivi Oct 16, 2025
14690cd
simpler approach
hamishivi Oct 16, 2025
c9ea292
lint
hamishivi Oct 16, 2025
d1ada94
Refactor tool architecture to use Ray actors
hamishivi Oct 16, 2025
402dfb1
some changes away from mcp
hamishivi Oct 16, 2025
70ccf31
Merge branch 'main' into tool-refactor
hamishivi Oct 16, 2025
c5bfff9
fix test
hamishivi Oct 16, 2025
ef89152
align tool vllm and llm ray actor
hamishivi Oct 16, 2025
4bfd702
tool vllm demo
hamishivi Oct 16, 2025
eeead3d
lint
hamishivi Oct 17, 2025
07023fe
fix maybe
hamishivi Oct 17, 2025
7041e23
update tool vllm
hamishivi Oct 17, 2025
185cbd3
forgot callable
hamishivi Oct 17, 2025
c2237ad
tool vllm now working
hamishivi Oct 17, 2025
2725eef
lint
hamishivi Oct 17, 2025
5264183
cleanly shut down vllm
hamishivi Oct 17, 2025
005d697
Refactor: Reorganize tools code into cleaner folder structure
hamishivi Oct 17, 2025
25041fe
cleaning up imports, slightly moving stuff around
hamishivi Oct 17, 2025
f8e1acb
More cleaning up + readme
hamishivi Oct 17, 2025
2aed0fa
Minor fixes
hamishivi Oct 17, 2025
d312de0
Don't overlap query strings
hamishivi Oct 17, 2025
d734adc
Add system prompt and example system prompt.
hamishivi Oct 17, 2025
8c99b8f
lint
hamishivi Oct 17, 2025
1ecfac9
new debug system prompt + debug script
hamishivi Oct 17, 2025
c5ce620
exclude git from ray
hamishivi Oct 17, 2025
1c3e17c
script fix
hamishivi Oct 17, 2025
4f5e4d0
Update accelerate to unpin deepspeed
hamishivi Oct 17, 2025
64c8caa
lint
hamishivi Oct 17, 2025
335c7cb
upper bound deepspeed
Oct 17, 2025
acc9165
Merge branch 'main' into tool-refactor
hamishivi Oct 20, 2025
93502be
fix import
hamishivi Oct 20, 2025
4d917fa
Merge branch 'main' into tool-refactor
hamishivi Oct 20, 2025
33b5c36
merge branch 'main' into tool-refactor
hamishivi Oct 20, 2025
44a4570
fix
hamishivi Oct 20, 2025
3058e22
more tooooools
hamishivi Oct 20, 2025
9de4e7c
fix system prompt
hamishivi Oct 21, 2025
ffd059c
yay fix tools
hamishivi Oct 21, 2025
f44df1b
Truncate webpage
hamishivi Oct 21, 2025
5941e13
more normal outputs
hamishivi Oct 21, 2025
873a6fd
Restructure crawl.
hamishivi Oct 21, 2025
660c062
make sure browse tool is reasonable
hamishivi Oct 22, 2025
ea8127f
fix serper tool
hamishivi Oct 22, 2025
7fdfe2d
optional api endpoint
hamishivi Oct 22, 2025
50a33f9
fix str behaviour
hamishivi Oct 22, 2025
f74c090
more fixes...
hamishivi Oct 22, 2025
bddf9f8
add some logging
hamishivi Oct 22, 2025
7a38e69
Merge branch 'main' into tool-refactor
hamishivi Oct 24, 2025
d4fce78
duplicated sys prompt override
hamishivi Oct 24, 2025
dd04f47
clean
hamishivi Oct 24, 2025
c09aeaf
deepspeed broke on me :(
hamishivi Oct 24, 2025
b23c83c
Merge branch 'main' into tool-refactor
hamishivi Oct 31, 2025
c7c3941
vibe cobe some restructuring
hamishivi Nov 3, 2025
822ede3
unsafe serialization
hamishivi Nov 3, 2025
a1d4ae2
fix
hamishivi Nov 3, 2025
6b4b6aa
tool
hamishivi Nov 3, 2025
9d5245b
fix
hamishivi Nov 3, 2025
d1b20d5
fix
hamishivi Nov 3, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 46 additions & 28 deletions open_instruct/grpo_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@
push_folder_to_hub,
)
from open_instruct.queue_types import GenerationResult, PromptRequest, RequestInfo, TokenStatistics
from open_instruct.tools.tool_actor import TOOL_CLASS_REGISTRY, ToolActor
from open_instruct.tools.utils.tool_proxy import ToolProxy
from open_instruct.rl_utils import Timer, pack_sequences
from open_instruct.utils import (
ArgumentParserPlus,
Expand Down Expand Up @@ -420,16 +422,19 @@ class Args:
"""Whether to mask the tool output. By default on."""
only_reward_good_outputs: bool = False
"""Whether to only reward good outputs. By default off. Useful to force the model to use the tool(s)."""
tool_max_concurrency: int = 512
"""The maximum number of concurrent tool calls allowed across all rollouts per tool."""

# rl-rag specific settngs
# code-tool specific settings
code_tool_api_endpoint: str | None = None

# search-tool specific settings
# rl-rag tool settings. These are shared across different tools.
number_documents_to_search: int = 3
"""The maximum number of documents to retrieve for each query."""
search_api_endpoint: str | None = None
"""The API endpoint for the search engine."""

# code-tool specific settings
code_tool_api_endpoint: str | None = None

def __post_init__(self):
if os.environ.get("VLLM_USE_V1") == "0":
logger.warning("When using the v0 version of vLLM, caching is broken and will never be invalidated.")
Expand Down Expand Up @@ -492,8 +497,10 @@ def __post_init__(self):
calibrate_checkpoint_state_dir(self.checkpoint_state_dir)
if self.tools is not None and len(self.tools) > 0:
for tool in self.tools:
if tool not in ["search", "code"]:
raise ValueError(f"Tool {tool} is not supported. Supported tools are: search, code")
if tool not in TOOL_CLASS_REGISTRY:
raise ValueError(
f"Tool {tool} is not supported. Supported tools are: {', '.join(TOOL_CLASS_REGISTRY.keys())}"
)
assert len(self.tools) == len(set(self.tools)), "Duplicate tools are not allowed"
if self.use_vllm_logprobs or self.truncated_importance_sampling_ratio_cap > 0.0:
assert self.mask_tool_use, (
Expand Down Expand Up @@ -2195,29 +2202,40 @@ def create_model_and_optimizer(
# Set up tools
max_len = args.max_prompt_token_length + args.response_length
tool_objects = {}
tool_max_conc = args.tool_max_concurrency

def _register_actor_backed_tool(tool_name: str, class_path: str, init_kwargs: dict):
actor = ToolActor.options(max_concurrency=tool_max_conc).remote(
tool_name=tool_name, class_path=class_path, init_kwargs=init_kwargs
)
tool_name_from_actor = ray.get(actor.get_name.remote())
# Ensure tool name matches registry name
if tool_name_from_actor != tool_name:
logger.warning(
f"Tool name mismatch: registry name '{tool_name}' vs tool.get_name() '{tool_name_from_actor}'. "
f"Using registry name '{tool_name}' for consistency."
)
tool_name_from_actor = tool_name
start = ray.get(actor.get_start_str.remote())
stop_strings = ray.get(actor.get_stop_strings.remote())
# Tools dict is keyed by end_str for stop string checking during generation
# But tracking (max_tool_calls, num_calls) uses tool name (registry name)
for end_str in stop_strings:
tool_proxy = ToolProxy(actor_handle=actor, start_str=start, end_str=end_str, name=tool_name_from_actor)
# Store by end_str for stop string checking (this is what vllm_utils expects)
tool_objects[end_str] = tool_proxy
# Add tool end string to stop_strings
args.stop_strings.append(end_str)

# Register tools via actors
if args.tools:
for tool in args.tools:
if tool.lower() == "search":
from open_instruct.search_utils.search_tool import SearchTool

tool = SearchTool(
start_str="<query>",
end_str="</query>",
api_endpoint=args.search_api_endpoint,
number_documents_to_search=args.number_documents_to_search,
)
tool_objects[tool.end_str] = tool
# Add tool end string to stop_strings
args.stop_strings.append(tool.end_str)
elif tool.lower() == "code":
from open_instruct.tool_utils.tools import PythonCodeTool

tool = PythonCodeTool(start_str="<code>", end_str="</code>", api_endpoint=args.code_tool_api_endpoint)
tool_objects[tool.end_str] = tool
# Add tool end string to stop_strings
args.stop_strings.append(tool.end_str)
else:
raise ValueError(f"Unknown tool: {tool}")
for tool_registry_name in args.tools:
registry_key = tool_registry_name.lower()
class_path = TOOL_CLASS_REGISTRY.get(registry_key, None)
if class_path is None:
raise ValueError(f"Unknown tool: {tool_registry_name}")
# Pass the registry name so the tool is created with the correct name
_register_actor_backed_tool(tool_name=registry_key, class_path=class_path, init_kwargs=vars(args))

queues_to_monitor = {
"Inference Results Queue": inference_results_Q,
Expand Down
4 changes: 2 additions & 2 deletions open_instruct/ppo_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -1588,7 +1588,7 @@ def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig, reward_fn:
if args.tools:
for tool in args.tools:
if tool.lower() == "search":
from open_instruct.search_utils.search_tool import SearchTool
from open_instruct.tools.search_tool.search_tool import SearchTool

tool = SearchTool(
start_str="<query>",
Expand All @@ -1598,7 +1598,7 @@ def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig, reward_fn:
)
tool_objects[tool.end_str] = tool
elif tool.lower() == "code":
from open_instruct.tool_utils.tools import PythonCodeTool
from open_instruct.tools.python_tool.tool import PythonCodeTool

tool = PythonCodeTool(start_str="<code>", end_str="</code>", api_endpoint=args.code_tool_api_endpoint)
tool_objects[tool.end_str] = tool
Expand Down
66 changes: 0 additions & 66 deletions open_instruct/search_utils/search_tool.py

This file was deleted.

7 changes: 4 additions & 3 deletions open_instruct/tool_utils/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import time
import unittest

from open_instruct.tool_utils.tools import MaxCallsExceededTool, PythonCodeTool, Tool, ToolOutput
from open_instruct.tools.python_tool.tool import PythonCodeTool
from open_instruct.tools.utils.tool_classes import MaxCallsExceededTool, Tool, ToolOutput


class TestToolOutput(unittest.TestCase):
Expand Down Expand Up @@ -51,7 +52,7 @@ def test_max_calls_exceeded_output(self):
self.assertIsInstance(result, ToolOutput)
self.assertEqual(result.output, "Max tool calls exceeded.")
self.assertFalse(result.called)
self.assertEqual(result.error, "")
self.assertEqual(result.error, "Max tool calls exceeded")
self.assertFalse(result.timeout)
self.assertEqual(result.runtime, 0)

Expand All @@ -63,7 +64,7 @@ def setUpClass(cls):
# Start the server in a subprocess
cls.server_process = subprocess.Popen(
["uv", "run", "uvicorn", "tool_server:app", "--host", "0.0.0.0", "--port", "1212"],
cwd="open_instruct/tool_utils",
cwd="open_instruct/tools/python_tool/python_server",
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
start_new_session=True, # Create new process group
Expand Down
Loading