diff --git a/examples/deepscaler/train_deepscaler_8k.sh b/examples/deepscaler/train_deepscaler_8k.sh index acc8c075..48078acd 100755 --- a/examples/deepscaler/train_deepscaler_8k.sh +++ b/examples/deepscaler/train_deepscaler_8k.sh @@ -13,7 +13,7 @@ MODEL_PATH=deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B python3 -m examples.deepscaler.train_deepscaler \ algorithm.adv_estimator=grpo \ - data.train_batch_size=128 \ + data.train_batch_size=16 \ data.val_batch_size=30 \ data.max_prompt_length=2048 \ data.max_response_length=8192 \ @@ -22,9 +22,9 @@ python3 -m examples.deepscaler.train_deepscaler \ actor_rollout_ref.actor.optim.lr=1e-6 \ actor_rollout_ref.model.use_remove_padding=True \ actor_rollout_ref.actor.loss_agg_mode=seq-mean-token-mean \ - actor_rollout_ref.actor.ppo_mini_batch_size=64 \ + actor_rollout_ref.actor.ppo_mini_batch_size=16 \ actor_rollout_ref.actor.use_dynamic_bsz=True \ - actor_rollout_ref.actor.ppo_max_token_len_per_gpu=30000 \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=16384 \ actor_rollout_ref.actor.use_kl_loss=False \ actor_rollout_ref.actor.clip_ratio_high=0.28 \ actor_rollout_ref.actor.kl_loss_coef=0.001 \ diff --git a/examples/tool_calling/prepare_apigen_mt_data.py b/examples/tool_calling/prepare_apigen_mt_data.py new file mode 100644 index 00000000..b5bf16dc --- /dev/null +++ b/examples/tool_calling/prepare_apigen_mt_data.py @@ -0,0 +1,79 @@ +import json + +import datasets +from typing import Any + +from rllm.data.dataset import DatasetRegistry + +from rllm.parser import get_tool_parser +from rllm.parser.tool_parser.tool_parser_base import ToolParser + + + +def remove_nulls_recursive(value: Any) -> Any: + if isinstance(value, dict): + return {k: remove_nulls_recursive(v) for k, v in value.items() if v is not None} + elif isinstance(value, list): + return [remove_nulls_recursive(x) for x in value] + return value + +def prepare_apigen_mt_data(train_size: int = None, test_size: int = None, parser_name: str = "qwen"): + train_dataset = datasets.load_from_disk("/Users/tianyi/dataset/data/") + test_dataset = datasets.load_from_disk("/Users/tianyi/data/") + + parser_class: type[ToolParser] = get_tool_parser(parser_name=parser_name) + tool_parser = parser_class() + + def preprocess_fn(example, idx): + messages = example["messages"] + messages = remove_nulls_recursive(messages) + tools = example.get("tools", []) + tools = remove_nulls_recursive(tools) + last_turn_tool_calls = messages[-1]['tool_calls'] + + ground_truth = [] + for tool_call in last_turn_tool_calls: + tool_call = tool_call["function"] + tool_call["arguments"] = json.loads(tool_call["arguments"]) + if isinstance(tool_call["arguments"], str): + tool_call["arguments"] = json.loads(tool_call["arguments"]) + ground_truth.append(tool_call) + + # for tool_call in ground_truth: + # if isinstance(tool_call["arguments"], str): + # print("Something wrong") + # elif isinstance(tool_call["arguments"], dict): + # print("Something right") + + tools_prompt = tool_parser.get_tool_prompt(json.dumps(tools)) + + possible_system_message = messages[0] + if possible_system_message["role"] == "system": + system_content = possible_system_message["content"] + messages[0]["content"] = system_content + tools_prompt + else: + system_message = {"role": "system", "content": system_content} + messages = [system_message] + messages + + return { + "prompt": json.dumps(messages[:-1]), + "ground_truth": json.dumps(ground_truth), + "data_source": "apigen_mt", + } + if train_size: + train_dataset = train_dataset.select(range(min(train_size, len(train_dataset)))) + if test_size: + test_dataset = test_dataset.select(range(min(test_size, len(test_dataset)))) + + train_dataset = train_dataset.map(preprocess_fn, with_indices=True, writer_batch_size=10, num_proc=16) + test_dataset = test_dataset.map(preprocess_fn, with_indices=True, writer_batch_size=10, num_proc=16) + + train_dataset = DatasetRegistry.register_dataset("apigen_mt", train_dataset, "train") + test_dataset = DatasetRegistry.register_dataset("apigen_mt", test_dataset, "test") + return train_dataset, test_dataset + +if __name__ == "__main__": + train_dataset, test_dataset = prepare_apigen_mt_data(test_size=100) + print(f" - Train dataset: {len(train_dataset.get_data())} examples") + print(f" - Test dataset: {len(test_dataset.get_data())} examples") + # print(train_dataset.get_data()[0]) diff --git a/examples/tool_calling/train_apigen_mt.py b/examples/tool_calling/train_apigen_mt.py new file mode 100644 index 00000000..a47b3b1e --- /dev/null +++ b/examples/tool_calling/train_apigen_mt.py @@ -0,0 +1,31 @@ +import hydra + +from rllm.agents.tool_ast_agent import ToolASTAgent +from rllm.data.dataset import DatasetRegistry +from rllm.environments.base.single_turn_env import SingleTurnEnvironment +from rllm.rewards.reward_fn import tool_calling_ast_reward_fn +from rllm.trainer.agent_trainer import AgentTrainer + + +@hydra.main(config_path="pkg://rllm.trainer.config", config_name="ppo_trainer", version_base=None) +def main(config): + train_dataset = DatasetRegistry.load_dataset("apigen_mt", "train") + test_dataset = DatasetRegistry.load_dataset("apigen_mt", "test") + + env_args = {"reward_fn": tool_calling_ast_reward_fn} + agent_args = {} + + trainer = AgentTrainer( + agent_class=ToolASTAgent, + agent_args=agent_args, + env_args=env_args, + env_class=SingleTurnEnvironment, + config=config, + train_dataset=train_dataset, + val_dataset=test_dataset, + ) + trainer.train() + + +if __name__ == "__main__": + main() diff --git a/examples/tool_calling/train_apigen_mt_16k.sh b/examples/tool_calling/train_apigen_mt_16k.sh new file mode 100644 index 00000000..bca84724 --- /dev/null +++ b/examples/tool_calling/train_apigen_mt_16k.sh @@ -0,0 +1,78 @@ +set -x + +ulimit -n 1048576 +export VLLM_ATTENTION_BACKEND=FLASH_ATTN +export PYTORCH_CUDA_ALLOC_CONF="expandable_segments:False" +export VLLM_USE_V1=1 +export VLLM_ALLOW_LONG_MAX_MODEL_LEN=1 +export VLLM_ENGINE_ITERATION_TIMEOUT_S=1000000000 + +# Find the directory where rllm package is located +RLLM_DIR=$(python3 -c "import rllm; import os; print(os.path.dirname(os.path.dirname(rllm.__file__)))") + +MODEL_PATH=Qwen/Qwen3-8B +TRAIN_BATCH_SIZE=32 + +python3 -m examples.tool_calling.train_apigen_mt \ + algorithm.adv_estimator=grpo \ + data.train_batch_size=$TRAIN_BATCH_SIZE \ + data.val_batch_size=512 \ + data.max_prompt_length=16384 \ + data.max_response_length=8192 \ + actor_rollout_ref.model.path=$MODEL_PATH \ + actor_rollout_ref.hybrid_engine=True \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.loss_agg_mode=seq-mean-token-mean \ + actor_rollout_ref.actor.ppo_mini_batch_size=$TRAIN_BATCH_SIZE \ + actor_rollout_ref.actor.ppo_micro_batch_size=$TRAIN_BATCH_SIZE \ + actor_rollout_ref.actor.ppo_epochs=1 \ + actor_rollout_ref.actor.use_dynamic_bsz=True \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=10240 \ + actor_rollout_ref.actor.use_kl_loss=False \ + actor_rollout_ref.actor.kl_loss_coef=0.0001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=8 \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.grad_clip=1.0 \ + actor_rollout_ref.actor.clip_ratio_low=0.2 \ + actor_rollout_ref.actor.clip_ratio_high=0.28 \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.tensor_model_parallel_size=4 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.mode="async" \ + actor_rollout_ref.rollout.chat_scheduler=verl.schedulers.completions_scheduler.CompletionsScheduler \ + actor_rollout_ref.rollout.enforce_eager=False \ + actor_rollout_ref.rollout.temperature=1 \ + actor_rollout_ref.rollout.top_p=0.95 \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \ + actor_rollout_ref.rollout.n=4 \ + actor_rollout_ref.rollout.val_kwargs.n=2 \ + actor_rollout_ref.rollout.val_kwargs.temperature=0.6 \ + actor_rollout_ref.rollout.val_kwargs.top_p=0.95 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=1 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=1 \ + algorithm.kl_ctrl.kl_coef=0.001 \ + algorithm.mask_truncated_samples=True \ + algorithm.clip_advantages=False \ + trainer.critic_warmup=0 \ + trainer.logger=['console','wandb'] \ + trainer.project_name='rllm-apigen_mt_16k' \ + trainer.experiment_name='rllm-apigen-mt-16k-8b-stage1' \ + trainer.val_before_train=True \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=1 \ + trainer.save_freq=10 \ + trainer.test_freq=10 \ + trainer.default_hdfs_dir=null \ + agent.max_steps=1 \ + agent.use_stepwise_advantage=False \ + trainer.total_epochs=100 + + # actor_rollout_ref.model.lora_rank=32 \ + # actor_rollout_ref.model.lora_alpha=32 \ + # actor_rollout_ref.model.target_modules=all-linear \ + # actor_rollout_ref.actor.optim.lr=3e-5 \ \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 35f3d445..a1979b22 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,74 +18,10 @@ classifiers = [ ] license = {file = "LICENSE"} dependencies = [ - # Core ML/AI packages - "torch>=2.7", - "transformers", - "accelerate", - "flash-attn>=2.8.0.post2", "sentence-transformers", - "torchmetrics", - - # Training and inference - "deepspeed", - "vllm>=0.8.3", - "sgl-kernel>=0.2.0", - "sglang>=0.4.8.post1", - "sglang-router", - "peft", - "torchao", - - # Data processing - "datasets", "polars", - "dm-tree", - - # Cloud and infrastructure - "google-cloud-aiplatform", "vertexai", - "kubernetes", - "ray", - - # Web and automation - "gradio", - "selenium", - "browsergym", "firecrawl", - - # Math and science - "latex2sympy2", - "pylatexenc", - "nltk", - "scipy", - "scikit-learn", - - # Code evaluation - "swebench", - "e2b_code_interpreter", - - # Utilities - "fire", - "gdown", - "tabulate", - "sortedcontainers", - "PyMuPDF", - "together", - "wandb", - "pybind11", - "gym", - - # Development and testing - "pytest", - "pre-commit", - "ruff", - "mypy", - - # Documentation - "mkdocs>=1.5.0", - "mkdocs-material>=9.0.0", - "mkdocstrings[python]>=0.24.0", - "mkdocs-autorefs>=0.5.0", - "pymdown-extensions>=10.0.0", ] [tool.ruff] @@ -119,7 +55,7 @@ ignore = [ # `.log()` statement uses f-string "G004", # equality check using x == True - "E712", + "E712", ] [tool.ruff.lint.per-file-ignores] diff --git a/rllm/agents/__init__.py b/rllm/agents/__init__.py index 50b04c3c..08ad8753 100644 --- a/rllm/agents/__init__.py +++ b/rllm/agents/__init__.py @@ -1,7 +1,12 @@ -from rllm.agents.math_agent import MathAgent -from rllm.agents.tool_agent import ToolAgent +# from rllm.agents.math_agent import MathAgent +# from rllm.agents.tool_agent import ToolAgent +from rllm.agents.tool_ast_agent import ToolASTAgent -__all__ = ["MathAgent", "ToolAgent"] +__all__ = [ + # "MathAgent", + # "ToolAgent", + "ToolASTAgent", +] def safe_import(module_path, class_name): diff --git a/rllm/agents/tool_ast_agent.py b/rllm/agents/tool_ast_agent.py new file mode 100644 index 00000000..f0d3d17a --- /dev/null +++ b/rllm/agents/tool_ast_agent.py @@ -0,0 +1,69 @@ +import copy +from typing import Any +import json +from rllm.agents.agent import Action, BaseAgent, Step, Trajectory + + +class ToolASTAgent(BaseAgent): + """ + A tool agent that only parses AST tree to check correctness, following the BaseAgent interface. + Always single turn. + """ + + def __init__(self, accumulate_thinking=True): + """ + Initialize the MathAgent. + """ + self._trajectory = Trajectory() + self.messages = [] + self.accumulate_thinking = accumulate_thinking + + def update_from_env(self, observation: Any, reward: float, done: bool, info: dict, **kwargs): + if not self.trajectory.steps: + # Initial problem presentation + assert isinstance(observation, dict) and "prompt" in observation + question = observation["prompt"] + question = json.loads(question) + assert isinstance(question, list) + self.messages = question + else: + # Place Holder as it's always single turn. + self.messages.append({"role": "user", "content": "Hi!"}) + + def update_from_model(self, response: str, **kwargs) -> Action: + """ + Updates the agent's internal state based on the model's response. + """ + self.messages.append({"role": "assistant", "content": response}) + new_step = Step(chat_completions=copy.deepcopy(self.chat_completions)) + self.trajectory.steps.append(new_step) + + return Action(action=response) + + def reset(self): + """Reset agent state for new episode.""" + self._trajectory = Trajectory() + self.messages = [] + + @property + def chat_completions(self) -> list[dict[str, str]]: + """Return conversation history for model interaction.""" + # remove thinking from assistant messages if not accumulate_thinking except the last one + messages = copy.deepcopy(self.messages) + if not self.accumulate_thinking: + for msg in messages[:-1]: + if msg["role"] == "assistant": + _, sep, after = msg["content"].partition("") + if sep: + msg["content"] = after + return messages + + @property + def trajectory(self) -> Trajectory: + """Return complete interaction trajectory.""" + return self._trajectory + + def get_current_state(self) -> Step: + """Returns the current step/state of the agent.""" + assert self._trajectory.steps, "Trajectory should not be empty when get_current_state is called." + return self._trajectory.steps[-1] diff --git a/rllm/data/datasets/apigen_mt/test.parquet b/rllm/data/datasets/apigen_mt/test.parquet new file mode 100644 index 00000000..7a039cc1 Binary files /dev/null and b/rllm/data/datasets/apigen_mt/test.parquet differ diff --git a/rllm/data/datasets/apigen_mt/test_verl.parquet b/rllm/data/datasets/apigen_mt/test_verl.parquet new file mode 100644 index 00000000..ac4dd3a7 Binary files /dev/null and b/rllm/data/datasets/apigen_mt/test_verl.parquet differ diff --git a/rllm/data/datasets/apigen_mt/train.parquet b/rllm/data/datasets/apigen_mt/train.parquet new file mode 100644 index 00000000..288500f6 Binary files /dev/null and b/rllm/data/datasets/apigen_mt/train.parquet differ diff --git a/rllm/data/datasets/apigen_mt/train_verl.parquet b/rllm/data/datasets/apigen_mt/train_verl.parquet new file mode 100644 index 00000000..b721b0f5 Binary files /dev/null and b/rllm/data/datasets/apigen_mt/train_verl.parquet differ diff --git a/rllm/parser/chat_template/parser.py b/rllm/parser/chat_template/parser.py index 97d0cf47..93fc3926 100644 --- a/rllm/parser/chat_template/parser.py +++ b/rllm/parser/chat_template/parser.py @@ -171,7 +171,18 @@ def parse_user(self, message): return self.user_token + message["content"] + self.eot_token def parse_assistant(self, message): - result = self.assistant_token + message["content"] + self.eot_token + tool_call_str = "" + if "tool_calls" in message: + # if there are tool calls, we need to insert them into the assistant message + for tool_call in message["tool_calls"]: + tool_call = tool_call["function"] + tool_call_name = tool_call["name"] + tool_call_arguments = tool_call["arguments"] + tool_call_prefix = '\n{\"name\": \"' + tool_call_mid = '\", \"arguments\": ' + tool_call_suffix = '}\n' + tool_call_str += tool_call_prefix + tool_call_name + tool_call_mid + tool_call_arguments + tool_call_suffix + result = self.assistant_token + message.get("content", "") + tool_call_str + self.eot_token return result def parse_tool(self, message): diff --git a/rllm/registry/dataset_registry.json b/rllm/registry/dataset_registry.json new file mode 100644 index 00000000..931e1eed --- /dev/null +++ b/rllm/registry/dataset_registry.json @@ -0,0 +1,6 @@ +{ + "apigen_mt": { + "train": "/workspace/rllm/rllm/data/datasets/apigen_mt/train.parquet", + "test": "/workspace/rllm/rllm/data/datasets/apigen_mt/test.parquet" + } +} \ No newline at end of file diff --git a/rllm/rewards/__init__.py b/rllm/rewards/__init__.py index ea95e2c2..517fe634 100644 --- a/rllm/rewards/__init__.py +++ b/rllm/rewards/__init__.py @@ -3,4 +3,4 @@ from .reward_fn import RewardFunction, zero_reward from .reward_types import RewardConfig, RewardInput, RewardOutput, RewardType -__all__ = ["RewardInput", "RewardOutput", "RewardType", "RewardConfig", "RewardFunction", "zero_reward"] +__all__ = ["RewardInput", "RewardOutput", "RewardType", "RewardConfig", "RewardFunction", "zero_reward", "RewardToolCallingASTFn"] diff --git a/rllm/rewards/reward_fn.py b/rllm/rewards/reward_fn.py index 96a057f8..f401fc26 100644 --- a/rllm/rewards/reward_fn.py +++ b/rllm/rewards/reward_fn.py @@ -1,6 +1,7 @@ from typing import Protocol, runtime_checkable from rllm.rewards.code_reward import RewardCodeFn +from rllm.rewards.tool_calling_ast_reward import RewardToolCallingASTFn from rllm.rewards.math_reward import RewardMathFn from rllm.rewards.reward_types import RewardConfig, RewardInput, RewardOutput from rllm.rewards.search_reward import RewardSearchFn @@ -90,3 +91,20 @@ def code_reward_fn(task_info: dict, action: str) -> RewardOutput: reward_config = RewardConfig() reward_fn = RewardCodeFn(reward_config) return reward_fn(task_info, action) + + +def tool_calling_ast_reward_fn(task_info: dict, action: str) -> RewardOutput: + """ + A reward function for tool calling AST matching tasks that implements the RewardFunction protocol. + + Args: + task: The task dictionary containing data_source, ground_truth and other metadata + action: The agent's response/solution + + Returns: + float: The calculated reward value based on AST matching + """ + reward_config = RewardConfig() + reward_fn = RewardToolCallingASTFn(reward_config) + return reward_fn(task_info, action) + diff --git a/rllm/rewards/tool_calling_ast_reward.py b/rllm/rewards/tool_calling_ast_reward.py new file mode 100644 index 00000000..a62a9f1f --- /dev/null +++ b/rllm/rewards/tool_calling_ast_reward.py @@ -0,0 +1,76 @@ +""" +This module contains the RewardToolCallingASTFn class, which parses tool calls +from model outputs and evaluates them against ground truth using AST matching. +""" +import json +from collections import Counter + +from rllm.parser.tool_parser.tool_parser_base import ToolParser +from rllm.parser import get_tool_parser + +from rllm.rewards.reward_types import RewardConfig, RewardOutput + +class RewardToolCallingASTFn: + """ + Reward function for evaluating mathematical answers. + + This class implements the RewardFunction protocol to process the input and determine + the reward based on the correctness of the provided answer compared to the ground truth. + """ + + def __init__(self, config: RewardConfig, parser_name: str = "qwen") -> None: + self.config = config + parser_class: type[ToolParser] = get_tool_parser(parser_name=parser_name) + self.tool_parser = parser_class() + + def __call__(self, task_info: dict, action: str) -> RewardOutput: + """ + Calculate the reward for a math task based on the agent's action. + + Args: + task_info: Dictionary containing problem, data_source, problem_type, and ground_truth + action: The agent's response/solution + + Returns: + RewardOutput: The calculated reward with correctness information + """ + # Extract information from task_info + model_response = action + + # Handle None or empty response + if model_response is None or model_response == "": + print("DEBUG: Empty or None response") + return RewardOutput(reward=self.config.format_error_reward, is_correct=False) + + # Extract solution. + try: + tool_calls = self.tool_parser.parse(model_response) + except Exception as e: + return RewardOutput(reward=self.config.format_error_reward, is_correct=False) + tool_calls = [tool_call.to_dict() for tool_call in tool_calls] + # Process the ground truth(s) + ground_truths = task_info.get("ground_truth", None) + if ground_truths is None: + return RewardOutput(reward=self.config.unk_error_reward, is_correct=False) + ground_truths = json.loads(ground_truths) + + if compare_tool_calls(tool_calls, ground_truths): + return RewardOutput(reward=self.config.correct_reward, is_correct=True) + else: + return RewardOutput(reward=self.config.incorrect_reward, is_correct=False) + + +def compare_tool_calls(generated_tool_calls: list, gt_tool_calls: list) -> bool: + if len(generated_tool_calls) != len(gt_tool_calls): + return False + + generated_tool_calls_serialized = [json.dumps(item, sort_keys=True) for item in generated_tool_calls] + gt_tool_calls_serialized = [json.dumps(item, sort_keys=True) for item in gt_tool_calls] + + result = Counter(generated_tool_calls_serialized) == Counter(gt_tool_calls_serialized) + if not result: + print("Tool calls mismatch") + print("Generation: ", generated_tool_calls) + print("Ground Truth: ", gt_tool_calls) + return result + diff --git a/rllm/tools/__init__.py b/rllm/tools/__init__.py index 4de7fa84..58a04532 100644 --- a/rllm/tools/__init__.py +++ b/rllm/tools/__init__.py @@ -1,6 +1,6 @@ -from rllm.tools.code_tools import ( - PythonInterpreter, -) +# from rllm.tools.code_tools import ( +# PythonInterpreter, +# ) from rllm.tools.registry import ToolRegistry from rllm.tools.web_tools import ( FirecrawlTool, @@ -11,7 +11,7 @@ # Define default tools dict DEFAULT_TOOLS = { - "python": PythonInterpreter, + # "python": PythonInterpreter, "google_search": GoogleSearchTool, "firecrawl": FirecrawlTool, "tavily_extract": TavilyExtractTool, diff --git a/rllm/tools/code_tools/__init__.py b/rllm/tools/code_tools/__init__.py index b45a44e5..6f704b8c 100644 --- a/rllm/tools/code_tools/__init__.py +++ b/rllm/tools/code_tools/__init__.py @@ -1,12 +1,12 @@ -from rllm.tools.code_tools.e2b_tool import E2BPythonInterpreter -from rllm.tools.code_tools.lcb_tool import LCBPythonInterpreter -from rllm.tools.code_tools.python_interpreter import PythonInterpreter -from rllm.tools.code_tools.together_tool import TogetherCodeTool +# from rllm.tools.code_tools.e2b_tool import E2BPythonInterpreter +# from rllm.tools.code_tools.lcb_tool import LCBPythonInterpreter +# from rllm.tools.code_tools.python_interpreter import PythonInterpreter +# from rllm.tools.code_tools.together_tool import TogetherCodeTool __all__ = [ - "PythonInterpreter", # New unified interpreter - "E2BPythonInterpreter", # Legacy interpreters for backward compatibility - "LocalPythonInterpreter", - "LCBPythonInterpreter", - "TogetherCodeTool", + # "PythonInterpreter", # New unified interpreter + # "E2BPythonInterpreter", # Legacy interpreters for backward compatibility + # "LocalPythonInterpreter", + # "LCBPythonInterpreter", + # "TogetherCodeTool", ] diff --git a/rllm/trainer/config/ppo_trainer.yaml b/rllm/trainer/config/ppo_trainer.yaml index 43e5c112..ff95e0ba 100644 --- a/rllm/trainer/config/ppo_trainer.yaml +++ b/rllm/trainer/config/ppo_trainer.yaml @@ -144,7 +144,7 @@ actor_rollout_ref: temperature: 0 n: 1 do_sample: False # default eager for validation - multi_turn: + multi_turn: enable: False # should set rollout.name to sglang_async if True max_turns: null # null for no limit (default max_length // 3) tool_config_path: null # null for no tool @@ -301,6 +301,6 @@ agent: overlong_filter: False normalize_step_advantage: False use_stepwise_advantage: False - stepwise_advantage_mode: "broadcast" + stepwise_advantage_mode: "broadcast" agent_args: {} - engine_args: {} + engine_args: {} \ No newline at end of file diff --git a/rllm/trainer/verl/agent_ppo_trainer.py b/rllm/trainer/verl/agent_ppo_trainer.py index a7d15b6c..84c74b8c 100644 --- a/rllm/trainer/verl/agent_ppo_trainer.py +++ b/rllm/trainer/verl/agent_ppo_trainer.py @@ -12,6 +12,7 @@ import numpy as np import torch from omegaconf import OmegaConf +from contextlib import contextmanager from rllm.engine.agent_execution_engine import AsyncAgentExecutionEngine from verl import DataProto @@ -22,13 +23,19 @@ ResourcePoolManager, Role, WorkerType, - _timer, compute_advantage, compute_data_metrics, compute_response_mask, compute_timing_metrics, reduce_metrics, ) +from codetiming import Timer + +@contextmanager +def _timer(name: str, timing_raw: dict[str, float]): + with Timer(name=name, logger=None) as timer: + yield + timing_raw[name] = timer.last class AgentPPOTrainer(RayPPOTrainer):