Skip to content

Commit af921ce

Browse files
committed
squash
trainer.experiment_name='rllm-apigen-mt-16k-stage2' batch_size=128 comment to the back back to full finetune comment double parse update ground_truth ground_truth update ground_truth = [tool_call["function"] for tool_call in ground_truth] \n instead of \\n lora remove nulls in tools tool_calls = [tool_call.to_dict() for tool_call in tool_calls] tool_call_str question = json.loads(question) agent_args = {} tool calling AST environment observation = json.loads(observation) Update dataset_registry.json prepare dataset with _timer( only necessary dependencies 16 revert env bring them back clip_advantages mask_truncated_samples remove chat_scheduler cleanup _target_ critic: enable: False revert trainer.device=cuda Timer Timer timer only import ToolASTAgent from rllm.agents.tool_ast_agent import ToolASTAgent examples.tool_calling.train_apigen_mt
1 parent b75aa4b commit af921ce

File tree

20 files changed

+405
-90
lines changed

20 files changed

+405
-90
lines changed

examples/deepscaler/train_deepscaler_8k.sh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ MODEL_PATH=deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B
1313

1414
python3 -m examples.deepscaler.train_deepscaler \
1515
algorithm.adv_estimator=grpo \
16-
data.train_batch_size=128 \
16+
data.train_batch_size=16 \
1717
data.val_batch_size=30 \
1818
data.max_prompt_length=2048 \
1919
data.max_response_length=8192 \
@@ -22,9 +22,9 @@ python3 -m examples.deepscaler.train_deepscaler \
2222
actor_rollout_ref.actor.optim.lr=1e-6 \
2323
actor_rollout_ref.model.use_remove_padding=True \
2424
actor_rollout_ref.actor.loss_agg_mode=seq-mean-token-mean \
25-
actor_rollout_ref.actor.ppo_mini_batch_size=64 \
25+
actor_rollout_ref.actor.ppo_mini_batch_size=16 \
2626
actor_rollout_ref.actor.use_dynamic_bsz=True \
27-
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=30000 \
27+
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=16384 \
2828
actor_rollout_ref.actor.use_kl_loss=False \
2929
actor_rollout_ref.actor.clip_ratio_high=0.28 \
3030
actor_rollout_ref.actor.kl_loss_coef=0.001 \
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
import json
2+
3+
import datasets
4+
from typing import Any
5+
6+
from rllm.data.dataset import DatasetRegistry
7+
8+
from rllm.parser import get_tool_parser
9+
from rllm.parser.tool_parser.tool_parser_base import ToolParser
10+
11+
12+
13+
def remove_nulls_recursive(value: Any) -> Any:
14+
if isinstance(value, dict):
15+
return {k: remove_nulls_recursive(v) for k, v in value.items() if v is not None}
16+
elif isinstance(value, list):
17+
return [remove_nulls_recursive(x) for x in value]
18+
return value
19+
20+
def prepare_apigen_mt_data(train_size: int = None, test_size: int = None, parser_name: str = "qwen"):
21+
train_dataset = datasets.load_from_disk("/Users/tianyi/dataset/data/")
22+
test_dataset = datasets.load_from_disk("/Users/tianyi/data/")
23+
24+
parser_class: type[ToolParser] = get_tool_parser(parser_name=parser_name)
25+
tool_parser = parser_class()
26+
27+
def preprocess_fn(example, idx):
28+
messages = example["messages"]
29+
messages = remove_nulls_recursive(messages)
30+
tools = example.get("tools", [])
31+
tools = remove_nulls_recursive(tools)
32+
last_turn_tool_calls = messages[-1]['tool_calls']
33+
34+
ground_truth = []
35+
for tool_call in last_turn_tool_calls:
36+
tool_call = tool_call["function"]
37+
tool_call["arguments"] = json.loads(tool_call["arguments"])
38+
if isinstance(tool_call["arguments"], str):
39+
tool_call["arguments"] = json.loads(tool_call["arguments"])
40+
ground_truth.append(tool_call)
41+
42+
# for tool_call in ground_truth:
43+
# if isinstance(tool_call["arguments"], str):
44+
# print("Something wrong")
45+
# elif isinstance(tool_call["arguments"], dict):
46+
# print("Something right")
47+
48+
tools_prompt = tool_parser.get_tool_prompt(json.dumps(tools))
49+
50+
possible_system_message = messages[0]
51+
if possible_system_message["role"] == "system":
52+
system_content = possible_system_message["content"]
53+
messages[0]["content"] = system_content + tools_prompt
54+
else:
55+
system_message = {"role": "system", "content": system_content}
56+
messages = [system_message] + messages
57+
58+
return {
59+
"prompt": json.dumps(messages[:-1]),
60+
"ground_truth": json.dumps(ground_truth),
61+
"data_source": "apigen_mt",
62+
}
63+
if train_size:
64+
train_dataset = train_dataset.select(range(min(train_size, len(train_dataset))))
65+
if test_size:
66+
test_dataset = test_dataset.select(range(min(test_size, len(test_dataset))))
67+
68+
train_dataset = train_dataset.map(preprocess_fn, with_indices=True, writer_batch_size=10, num_proc=16)
69+
test_dataset = test_dataset.map(preprocess_fn, with_indices=True, writer_batch_size=10, num_proc=16)
70+
71+
train_dataset = DatasetRegistry.register_dataset("apigen_mt", train_dataset, "train")
72+
test_dataset = DatasetRegistry.register_dataset("apigen_mt", test_dataset, "test")
73+
return train_dataset, test_dataset
74+
75+
if __name__ == "__main__":
76+
train_dataset, test_dataset = prepare_apigen_mt_data(test_size=100)
77+
print(f" - Train dataset: {len(train_dataset.get_data())} examples")
78+
print(f" - Test dataset: {len(test_dataset.get_data())} examples")
79+
# print(train_dataset.get_data()[0])
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import hydra
2+
3+
from rllm.agents.tool_ast_agent import ToolASTAgent
4+
from rllm.data.dataset import DatasetRegistry
5+
from rllm.environments.base.single_turn_env import SingleTurnEnvironment
6+
from rllm.rewards.reward_fn import tool_calling_ast_reward_fn
7+
from rllm.trainer.agent_trainer import AgentTrainer
8+
9+
10+
@hydra.main(config_path="pkg://rllm.trainer.config", config_name="ppo_trainer", version_base=None)
11+
def main(config):
12+
train_dataset = DatasetRegistry.load_dataset("apigen_mt", "train")
13+
test_dataset = DatasetRegistry.load_dataset("apigen_mt", "test")
14+
15+
env_args = {"reward_fn": tool_calling_ast_reward_fn}
16+
agent_args = {}
17+
18+
trainer = AgentTrainer(
19+
agent_class=ToolASTAgent,
20+
agent_args=agent_args,
21+
env_args=env_args,
22+
env_class=SingleTurnEnvironment,
23+
config=config,
24+
train_dataset=train_dataset,
25+
val_dataset=test_dataset,
26+
)
27+
trainer.train()
28+
29+
30+
if __name__ == "__main__":
31+
main()
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
set -x
2+
3+
ulimit -n 1048576
4+
export VLLM_ATTENTION_BACKEND=FLASH_ATTN
5+
export PYTORCH_CUDA_ALLOC_CONF="expandable_segments:False"
6+
export VLLM_USE_V1=1
7+
export VLLM_ALLOW_LONG_MAX_MODEL_LEN=1
8+
export VLLM_ENGINE_ITERATION_TIMEOUT_S=1000000000
9+
10+
# Find the directory where rllm package is located
11+
RLLM_DIR=$(python3 -c "import rllm; import os; print(os.path.dirname(os.path.dirname(rllm.__file__)))")
12+
13+
MODEL_PATH=Qwen/Qwen3-4B-Instruct-2507
14+
15+
python3 -m examples.tool_calling.train_apigen_mt \
16+
algorithm.adv_estimator=grpo \
17+
data.train_batch_size=128 \
18+
data.val_batch_size=512 \
19+
data.max_prompt_length=16384 \
20+
data.max_response_length=4096 \
21+
actor_rollout_ref.model.path=$MODEL_PATH \
22+
actor_rollout_ref.hybrid_engine=True \
23+
actor_rollout_ref.actor.optim.lr=1e-6 \
24+
actor_rollout_ref.model.use_remove_padding=True \
25+
actor_rollout_ref.model.enable_gradient_checkpointing=True \
26+
actor_rollout_ref.actor.loss_agg_mode=seq-mean-token-mean \
27+
actor_rollout_ref.actor.ppo_mini_batch_size=128 \
28+
actor_rollout_ref.actor.ppo_micro_batch_size=128 \
29+
actor_rollout_ref.actor.ppo_epochs=1 \
30+
actor_rollout_ref.actor.use_dynamic_bsz=True \
31+
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=20480 \
32+
actor_rollout_ref.actor.use_kl_loss=True \
33+
actor_rollout_ref.actor.kl_loss_coef=0.0001 \
34+
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
35+
actor_rollout_ref.actor.ulysses_sequence_parallel_size=2 \
36+
actor_rollout_ref.actor.entropy_coeff=0 \
37+
actor_rollout_ref.actor.grad_clip=1.0 \
38+
actor_rollout_ref.actor.clip_ratio_low=0.2 \
39+
actor_rollout_ref.actor.clip_ratio_high=0.28 \
40+
actor_rollout_ref.actor.fsdp_config.param_offload=False \
41+
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
42+
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
43+
actor_rollout_ref.rollout.name=vllm \
44+
actor_rollout_ref.rollout.mode="async" \
45+
actor_rollout_ref.rollout.chat_scheduler=verl.schedulers.completions_scheduler.CompletionsScheduler \
46+
actor_rollout_ref.rollout.enforce_eager=False \
47+
actor_rollout_ref.rollout.temperature=1 \
48+
actor_rollout_ref.rollout.top_p=0.95 \
49+
actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \
50+
actor_rollout_ref.rollout.n=4 \
51+
actor_rollout_ref.rollout.val_kwargs.n=2 \
52+
actor_rollout_ref.rollout.val_kwargs.temperature=0.6 \
53+
actor_rollout_ref.rollout.val_kwargs.top_p=0.95 \
54+
actor_rollout_ref.ref.fsdp_config.param_offload=True \
55+
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=1 \
56+
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=1 \
57+
algorithm.kl_ctrl.kl_coef=0.001 \
58+
algorithm.mask_truncated_samples=True \
59+
algorithm.clip_advantages=False \
60+
trainer.critic_warmup=0 \
61+
trainer.logger=['console','wandb'] \
62+
trainer.project_name='rllm-apigen_mt_16k' \
63+
trainer.experiment_name='rllm-apigen-mt-16k-stage2' \
64+
trainer.val_before_train=True \
65+
trainer.n_gpus_per_node=8 \
66+
trainer.nnodes=1 \
67+
trainer.save_freq=10 \
68+
trainer.test_freq=10 \
69+
trainer.default_hdfs_dir=null \
70+
agent.max_steps=1 \
71+
agent.use_stepwise_advantage=False \
72+
trainer.total_epochs=100
73+
74+
# actor_rollout_ref.model.lora_rank=32 \
75+
# actor_rollout_ref.model.lora_alpha=32 \
76+
# actor_rollout_ref.model.target_modules=all-linear \
77+
# actor_rollout_ref.actor.optim.lr=3e-5 \

pyproject.toml

Lines changed: 1 addition & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -18,74 +18,10 @@ classifiers = [
1818
]
1919
license = {file = "LICENSE"}
2020
dependencies = [
21-
# Core ML/AI packages
22-
"torch>=2.7",
23-
"transformers",
24-
"accelerate",
25-
"flash-attn>=2.8.0.post2",
2621
"sentence-transformers",
27-
"torchmetrics",
28-
29-
# Training and inference
30-
"deepspeed",
31-
"vllm>=0.8.3",
32-
"sgl-kernel>=0.2.0",
33-
"sglang>=0.4.8.post1",
34-
"sglang-router",
35-
"peft",
36-
"torchao",
37-
38-
# Data processing
39-
"datasets",
4022
"polars",
41-
"dm-tree",
42-
43-
# Cloud and infrastructure
44-
"google-cloud-aiplatform",
4523
"vertexai",
46-
"kubernetes",
47-
"ray",
48-
49-
# Web and automation
50-
"gradio",
51-
"selenium",
52-
"browsergym",
5324
"firecrawl",
54-
55-
# Math and science
56-
"latex2sympy2",
57-
"pylatexenc",
58-
"nltk",
59-
"scipy",
60-
"scikit-learn",
61-
62-
# Code evaluation
63-
"swebench",
64-
"e2b_code_interpreter",
65-
66-
# Utilities
67-
"fire",
68-
"gdown",
69-
"tabulate",
70-
"sortedcontainers",
71-
"PyMuPDF",
72-
"together",
73-
"wandb",
74-
"pybind11",
75-
"gym",
76-
77-
# Development and testing
78-
"pytest",
79-
"pre-commit",
80-
"ruff",
81-
"mypy",
82-
83-
# Documentation
84-
"mkdocs>=1.5.0",
85-
"mkdocs-material>=9.0.0",
86-
"mkdocstrings[python]>=0.24.0",
87-
"mkdocs-autorefs>=0.5.0",
88-
"pymdown-extensions>=10.0.0",
8925
]
9026

9127
[tool.ruff]
@@ -119,7 +55,7 @@ ignore = [
11955
# `.log()` statement uses f-string
12056
"G004",
12157
# equality check using x == True
122-
"E712",
58+
"E712",
12359
]
12460

12561
[tool.ruff.lint.per-file-ignores]

rllm/agents/__init__.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
1-
from rllm.agents.math_agent import MathAgent
2-
from rllm.agents.tool_agent import ToolAgent
1+
# from rllm.agents.math_agent import MathAgent
2+
# from rllm.agents.tool_agent import ToolAgent
3+
from rllm.agents.tool_ast_agent import ToolASTAgent
34

4-
__all__ = ["MathAgent", "ToolAgent"]
5+
__all__ = [
6+
# "MathAgent",
7+
# "ToolAgent",
8+
"ToolASTAgent",
9+
]
510

611

712
def safe_import(module_path, class_name):

rllm/agents/tool_ast_agent.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
import copy
2+
from typing import Any
3+
import json
4+
from rllm.agents.agent import Action, BaseAgent, Step, Trajectory
5+
6+
7+
class ToolASTAgent(BaseAgent):
8+
"""
9+
A tool agent that only parses AST tree to check correctness, following the BaseAgent interface.
10+
Always single turn.
11+
"""
12+
13+
def __init__(self, accumulate_thinking=True):
14+
"""
15+
Initialize the MathAgent.
16+
"""
17+
self._trajectory = Trajectory()
18+
self.messages = []
19+
self.accumulate_thinking = accumulate_thinking
20+
21+
def update_from_env(self, observation: Any, reward: float, done: bool, info: dict, **kwargs):
22+
if not self.trajectory.steps:
23+
# Initial problem presentation
24+
assert isinstance(observation, dict) and "prompt" in observation
25+
question = observation["prompt"]
26+
question = json.loads(question)
27+
assert isinstance(question, list)
28+
self.messages = question
29+
else:
30+
# Place Holder as it's always single turn.
31+
self.messages.append({"role": "user", "content": "Hi!"})
32+
33+
def update_from_model(self, response: str, **kwargs) -> Action:
34+
"""
35+
Updates the agent's internal state based on the model's response.
36+
"""
37+
self.messages.append({"role": "assistant", "content": response})
38+
new_step = Step(chat_completions=copy.deepcopy(self.chat_completions))
39+
self.trajectory.steps.append(new_step)
40+
41+
return Action(action=response)
42+
43+
def reset(self):
44+
"""Reset agent state for new episode."""
45+
self._trajectory = Trajectory()
46+
self.messages = []
47+
48+
@property
49+
def chat_completions(self) -> list[dict[str, str]]:
50+
"""Return conversation history for model interaction."""
51+
# remove thinking from assistant messages if not accumulate_thinking except the last one
52+
messages = copy.deepcopy(self.messages)
53+
if not self.accumulate_thinking:
54+
for msg in messages[:-1]:
55+
if msg["role"] == "assistant":
56+
_, sep, after = msg["content"].partition("</think>")
57+
if sep:
58+
msg["content"] = after
59+
return messages
60+
61+
@property
62+
def trajectory(self) -> Trajectory:
63+
"""Return complete interaction trajectory."""
64+
return self._trajectory
65+
66+
def get_current_state(self) -> Step:
67+
"""Returns the current step/state of the agent."""
68+
assert self._trajectory.steps, "Trajectory should not be empty when get_current_state is called."
69+
return self._trajectory.steps[-1]
115 KB
Binary file not shown.
117 KB
Binary file not shown.
4.08 MB
Binary file not shown.

0 commit comments

Comments
 (0)