Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions examples/deepscaler/train_deepscaler_8k.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ MODEL_PATH=deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B

python3 -m examples.deepscaler.train_deepscaler \
algorithm.adv_estimator=grpo \
data.train_batch_size=128 \
data.train_batch_size=16 \
data.val_batch_size=30 \
data.max_prompt_length=2048 \
data.max_response_length=8192 \
Expand All @@ -22,9 +22,9 @@ python3 -m examples.deepscaler.train_deepscaler \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.actor.loss_agg_mode=seq-mean-token-mean \
actor_rollout_ref.actor.ppo_mini_batch_size=64 \
actor_rollout_ref.actor.ppo_mini_batch_size=16 \
actor_rollout_ref.actor.use_dynamic_bsz=True \
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=30000 \
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=16384 \
actor_rollout_ref.actor.use_kl_loss=False \
actor_rollout_ref.actor.clip_ratio_high=0.28 \
actor_rollout_ref.actor.kl_loss_coef=0.001 \
Expand Down
79 changes: 79 additions & 0 deletions examples/tool_calling/prepare_apigen_mt_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import json

import datasets
from typing import Any

from rllm.data.dataset import DatasetRegistry

from rllm.parser import get_tool_parser
from rllm.parser.tool_parser.tool_parser_base import ToolParser



def remove_nulls_recursive(value: Any) -> Any:
if isinstance(value, dict):
return {k: remove_nulls_recursive(v) for k, v in value.items() if v is not None}
elif isinstance(value, list):
return [remove_nulls_recursive(x) for x in value]
return value

def prepare_apigen_mt_data(train_size: int = None, test_size: int = None, parser_name: str = "qwen"):
train_dataset = datasets.load_from_disk("/Users/tianyi/dataset/data/")
test_dataset = datasets.load_from_disk("/Users/tianyi/data/")

parser_class: type[ToolParser] = get_tool_parser(parser_name=parser_name)
tool_parser = parser_class()

def preprocess_fn(example, idx):
messages = example["messages"]
messages = remove_nulls_recursive(messages)
tools = example.get("tools", [])
tools = remove_nulls_recursive(tools)
last_turn_tool_calls = messages[-1]['tool_calls']

ground_truth = []
for tool_call in last_turn_tool_calls:
tool_call = tool_call["function"]
tool_call["arguments"] = json.loads(tool_call["arguments"])
if isinstance(tool_call["arguments"], str):
tool_call["arguments"] = json.loads(tool_call["arguments"])
ground_truth.append(tool_call)

# for tool_call in ground_truth:
# if isinstance(tool_call["arguments"], str):
# print("Something wrong")
# elif isinstance(tool_call["arguments"], dict):
# print("Something right")

tools_prompt = tool_parser.get_tool_prompt(json.dumps(tools))

possible_system_message = messages[0]
if possible_system_message["role"] == "system":
system_content = possible_system_message["content"]
messages[0]["content"] = system_content + tools_prompt
else:
system_message = {"role": "system", "content": system_content}
messages = [system_message] + messages

return {
"prompt": json.dumps(messages[:-1]),
"ground_truth": json.dumps(ground_truth),
"data_source": "apigen_mt",
}
if train_size:
train_dataset = train_dataset.select(range(min(train_size, len(train_dataset))))
if test_size:
test_dataset = test_dataset.select(range(min(test_size, len(test_dataset))))

train_dataset = train_dataset.map(preprocess_fn, with_indices=True, writer_batch_size=10, num_proc=16)
test_dataset = test_dataset.map(preprocess_fn, with_indices=True, writer_batch_size=10, num_proc=16)

train_dataset = DatasetRegistry.register_dataset("apigen_mt", train_dataset, "train")
test_dataset = DatasetRegistry.register_dataset("apigen_mt", test_dataset, "test")
return train_dataset, test_dataset

if __name__ == "__main__":
train_dataset, test_dataset = prepare_apigen_mt_data(test_size=100)
print(f" - Train dataset: {len(train_dataset.get_data())} examples")
print(f" - Test dataset: {len(test_dataset.get_data())} examples")
# print(train_dataset.get_data()[0])
31 changes: 31 additions & 0 deletions examples/tool_calling/train_apigen_mt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import hydra

from rllm.agents.tool_ast_agent import ToolASTAgent
from rllm.data.dataset import DatasetRegistry
from rllm.environments.base.single_turn_env import SingleTurnEnvironment
from rllm.rewards.reward_fn import tool_calling_ast_reward_fn
from rllm.trainer.agent_trainer import AgentTrainer


@hydra.main(config_path="pkg://rllm.trainer.config", config_name="ppo_trainer", version_base=None)
def main(config):
train_dataset = DatasetRegistry.load_dataset("apigen_mt", "train")
test_dataset = DatasetRegistry.load_dataset("apigen_mt", "test")

env_args = {"reward_fn": tool_calling_ast_reward_fn}
agent_args = {}

trainer = AgentTrainer(
agent_class=ToolASTAgent,
agent_args=agent_args,
env_args=env_args,
env_class=SingleTurnEnvironment,
config=config,
train_dataset=train_dataset,
val_dataset=test_dataset,
)
trainer.train()


if __name__ == "__main__":
main()
78 changes: 78 additions & 0 deletions examples/tool_calling/train_apigen_mt_16k.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
set -x

ulimit -n 1048576
export VLLM_ATTENTION_BACKEND=FLASH_ATTN
export PYTORCH_CUDA_ALLOC_CONF="expandable_segments:False"
export VLLM_USE_V1=1
export VLLM_ALLOW_LONG_MAX_MODEL_LEN=1
export VLLM_ENGINE_ITERATION_TIMEOUT_S=1000000000

# Find the directory where rllm package is located
RLLM_DIR=$(python3 -c "import rllm; import os; print(os.path.dirname(os.path.dirname(rllm.__file__)))")

MODEL_PATH=Qwen/Qwen3-8B
TRAIN_BATCH_SIZE=32

python3 -m examples.tool_calling.train_apigen_mt \
algorithm.adv_estimator=grpo \
data.train_batch_size=$TRAIN_BATCH_SIZE \
data.val_batch_size=512 \
data.max_prompt_length=16384 \
data.max_response_length=8192 \
actor_rollout_ref.model.path=$MODEL_PATH \
actor_rollout_ref.hybrid_engine=True \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
actor_rollout_ref.actor.loss_agg_mode=seq-mean-token-mean \
actor_rollout_ref.actor.ppo_mini_batch_size=$TRAIN_BATCH_SIZE \
actor_rollout_ref.actor.ppo_micro_batch_size=$TRAIN_BATCH_SIZE \
actor_rollout_ref.actor.ppo_epochs=1 \
actor_rollout_ref.actor.use_dynamic_bsz=True \
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=10240 \
actor_rollout_ref.actor.use_kl_loss=False \
actor_rollout_ref.actor.kl_loss_coef=0.0001 \
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
actor_rollout_ref.actor.ulysses_sequence_parallel_size=8 \
actor_rollout_ref.actor.entropy_coeff=0 \
actor_rollout_ref.actor.grad_clip=1.0 \
actor_rollout_ref.actor.clip_ratio_low=0.2 \
actor_rollout_ref.actor.clip_ratio_high=0.28 \
actor_rollout_ref.actor.fsdp_config.param_offload=False \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
actor_rollout_ref.rollout.tensor_model_parallel_size=4 \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.rollout.mode="async" \
actor_rollout_ref.rollout.chat_scheduler=verl.schedulers.completions_scheduler.CompletionsScheduler \
actor_rollout_ref.rollout.enforce_eager=False \
actor_rollout_ref.rollout.temperature=1 \
actor_rollout_ref.rollout.top_p=0.95 \
actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \
actor_rollout_ref.rollout.n=4 \
actor_rollout_ref.rollout.val_kwargs.n=2 \
actor_rollout_ref.rollout.val_kwargs.temperature=0.6 \
actor_rollout_ref.rollout.val_kwargs.top_p=0.95 \
actor_rollout_ref.ref.fsdp_config.param_offload=True \
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=1 \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=1 \
algorithm.kl_ctrl.kl_coef=0.001 \
algorithm.mask_truncated_samples=True \
algorithm.clip_advantages=False \
trainer.critic_warmup=0 \
trainer.logger=['console','wandb'] \
trainer.project_name='rllm-apigen_mt_16k' \
trainer.experiment_name='rllm-apigen-mt-16k-8b-stage1' \
trainer.val_before_train=True \
trainer.n_gpus_per_node=8 \
trainer.nnodes=1 \
trainer.save_freq=10 \
trainer.test_freq=10 \
trainer.default_hdfs_dir=null \
agent.max_steps=1 \
agent.use_stepwise_advantage=False \
trainer.total_epochs=100

# actor_rollout_ref.model.lora_rank=32 \
# actor_rollout_ref.model.lora_alpha=32 \
# actor_rollout_ref.model.target_modules=all-linear \
# actor_rollout_ref.actor.optim.lr=3e-5 \
66 changes: 1 addition & 65 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,74 +18,10 @@ classifiers = [
]
license = {file = "LICENSE"}
dependencies = [
# Core ML/AI packages
"torch>=2.7",
"transformers",
"accelerate",
"flash-attn>=2.8.0.post2",
"sentence-transformers",
"torchmetrics",

# Training and inference
"deepspeed",
"vllm>=0.8.3",
"sgl-kernel>=0.2.0",
"sglang>=0.4.8.post1",
"sglang-router",
"peft",
"torchao",

# Data processing
"datasets",
"polars",
"dm-tree",

# Cloud and infrastructure
"google-cloud-aiplatform",
"vertexai",
"kubernetes",
"ray",

# Web and automation
"gradio",
"selenium",
"browsergym",
"firecrawl",

# Math and science
"latex2sympy2",
"pylatexenc",
"nltk",
"scipy",
"scikit-learn",

# Code evaluation
"swebench",
"e2b_code_interpreter",

# Utilities
"fire",
"gdown",
"tabulate",
"sortedcontainers",
"PyMuPDF",
"together",
"wandb",
"pybind11",
"gym",

# Development and testing
"pytest",
"pre-commit",
"ruff",
"mypy",

# Documentation
"mkdocs>=1.5.0",
"mkdocs-material>=9.0.0",
"mkdocstrings[python]>=0.24.0",
"mkdocs-autorefs>=0.5.0",
"pymdown-extensions>=10.0.0",
]

[tool.ruff]
Expand Down Expand Up @@ -119,7 +55,7 @@ ignore = [
# `.log()` statement uses f-string
"G004",
# equality check using x == True
"E712",
"E712",
]

[tool.ruff.lint.per-file-ignores]
Expand Down
11 changes: 8 additions & 3 deletions rllm/agents/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
from rllm.agents.math_agent import MathAgent
from rllm.agents.tool_agent import ToolAgent
# from rllm.agents.math_agent import MathAgent
# from rllm.agents.tool_agent import ToolAgent
from rllm.agents.tool_ast_agent import ToolASTAgent

__all__ = ["MathAgent", "ToolAgent"]
__all__ = [
# "MathAgent",
# "ToolAgent",
"ToolASTAgent",
]


def safe_import(module_path, class_name):
Expand Down
69 changes: 69 additions & 0 deletions rllm/agents/tool_ast_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import copy
from typing import Any
import json
from rllm.agents.agent import Action, BaseAgent, Step, Trajectory


class ToolASTAgent(BaseAgent):
"""
A tool agent that only parses AST tree to check correctness, following the BaseAgent interface.
Always single turn.
"""

def __init__(self, accumulate_thinking=True):
"""
Initialize the MathAgent.
"""
self._trajectory = Trajectory()
self.messages = []
self.accumulate_thinking = accumulate_thinking

def update_from_env(self, observation: Any, reward: float, done: bool, info: dict, **kwargs):
if not self.trajectory.steps:
# Initial problem presentation
assert isinstance(observation, dict) and "prompt" in observation
question = observation["prompt"]
question = json.loads(question)
assert isinstance(question, list)
self.messages = question
else:
# Place Holder as it's always single turn.
self.messages.append({"role": "user", "content": "Hi!"})

def update_from_model(self, response: str, **kwargs) -> Action:
"""
Updates the agent's internal state based on the model's response.
"""
self.messages.append({"role": "assistant", "content": response})
new_step = Step(chat_completions=copy.deepcopy(self.chat_completions))
self.trajectory.steps.append(new_step)

return Action(action=response)

def reset(self):
"""Reset agent state for new episode."""
self._trajectory = Trajectory()
self.messages = []

@property
def chat_completions(self) -> list[dict[str, str]]:
"""Return conversation history for model interaction."""
# remove thinking from assistant messages if not accumulate_thinking except the last one
messages = copy.deepcopy(self.messages)
if not self.accumulate_thinking:
for msg in messages[:-1]:
if msg["role"] == "assistant":
_, sep, after = msg["content"].partition("</think>")
if sep:
msg["content"] = after
return messages

@property
def trajectory(self) -> Trajectory:
"""Return complete interaction trajectory."""
return self._trajectory

def get_current_state(self) -> Step:
"""Returns the current step/state of the agent."""
assert self._trajectory.steps, "Trajectory should not be empty when get_current_state is called."
return self._trajectory.steps[-1]
Binary file added rllm/data/datasets/apigen_mt/test.parquet
Binary file not shown.
Binary file added rllm/data/datasets/apigen_mt/test_verl.parquet
Binary file not shown.
Binary file added rllm/data/datasets/apigen_mt/train.parquet
Binary file not shown.
Binary file added rllm/data/datasets/apigen_mt/train_verl.parquet
Binary file not shown.
Loading