Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] RL goes brrr #533

Merged
merged 9 commits into from
Mar 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 3 additions & 17 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -160,26 +160,12 @@ ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_con
```

### GRPO

To train via the GRPO trainer, we use one GPU to run vLLM for faster generation and the remaining GPUs for training. For example, one a node with 8 GPUs, set `--num_processes` to override the default value in the `accelerate` configs:

We use TRL's new distributed vLLM server and GRPOTraining in order to scale to larger >7B models. We provide an example slurm script:
```shell
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/zero2.yaml \
--num_processes=7 src/open_r1/grpo.py \
--config recipes/DeepSeek-R1-Distill-Qwen-1.5B/grpo/config_demo.yaml
sbatch --job-name=trl-Qwen2.5-Math-7B-config_simple_rl --nodes=2 slurm/train.slurm Qwen2.5-Math-7B grpo config_simple_rl zero3
```

> [!WARNING]
> The chat template used in the distilled DeepSeek models omits the contents of the reasoning block within the `<think>` and `</think>` tags. It also prefills the assistant response with `<think>` which interferes with the format reward function. To handle that, it is important to override the chat template as done in e.g. [recipes/DeepSeek-R1-Distill-Qwen-1.5B/grpo/config_demo.yaml](./recipes/DeepSeek-R1-Distill-Qwen-1.5B/grpo/config_demo.yaml).


We provide a minimal reproducible experiment using GRPO for mathematical reasoning, referencing the approach from [SimpleRL-Reason](https://hkust-nlp.notion.site/simplerl-reason) which uses a 7B model trained on 8K examples. Running this on 8 H100 80G GPU takes about 3 hours:

```shell
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file recipes/accelerate_configs/zero2.yaml \
--num_processes=7 src/open_r1/grpo.py \
--config recipes/Qwen2.5-Math-7B/grpo/config_simple_rl.yaml
```
You will need to adapt the `slurm/train.slurm` script to match your cluster.

Our final [model](https://huggingface.co/Dongwei/Qwen-2.5-7B_Base_Math_smalllr), while using different learning rates, loss functions and reward structures, achieves 69.4% accuracy on MATH-500, demonstrating a 17%+ improvement over the base model.

Expand Down
2 changes: 0 additions & 2 deletions recipes/DeepSeek-R1-Distill-Qwen-1.5B/grpo/config_demo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@ system_prompt: "You are a helpful AI Assistant that provides well-reasoned and d
# GRPO trainer config
bf16: true
use_vllm: true
vllm_device: auto
vllm_gpu_memory_utilization: 0.7
do_eval: false
gradient_accumulation_steps: 4
gradient_checkpointing: true
Expand Down
2 changes: 0 additions & 2 deletions recipes/Qwen2.5-1.5B-Instruct/grpo/config_demo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@ system_prompt: "You are a helpful AI Assistant that provides well-reasoned and d
# GRPO trainer config
bf16: true
use_vllm: true
vllm_device: auto
vllm_gpu_memory_utilization: 0.7
do_eval: false
gradient_accumulation_steps: 4
gradient_checkpointing: true
Expand Down
2 changes: 0 additions & 2 deletions recipes/Qwen2.5-1.5B-Instruct/grpo/config_demo_code.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@ system_prompt: "You are a helpful AI Assistant that provides well-reasoned and d
beta: 0.01
bf16: true
use_vllm: true
vllm_device: auto
vllm_gpu_memory_utilization: 0.9
do_eval: false
gradient_accumulation_steps: 4
gradient_checkpointing: true
Expand Down
2 changes: 0 additions & 2 deletions recipes/Qwen2.5-1.5B-Instruct/grpo/config_demo_code_ioi.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@ system_prompt: "You are a helpful AI Assistant that provides well-reasoned and d
beta: 0.01
bf16: true
use_vllm: true
vllm_device: auto
vllm_gpu_memory_utilization: 0.9
do_eval: false
gradient_accumulation_steps: 4
gradient_checkpointing: true
Expand Down
54 changes: 54 additions & 0 deletions recipes/Qwen2.5-7B-Instruct/grpo/config_demo.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Model arguments
model_name_or_path: Qwen/Qwen2.5-7B-Instruct
model_revision: main
torch_dtype: bfloat16
attn_implementation: flash_attention_2

# Data training arguments
dataset_name: open-r1/OpenR1-Math-cn_k12-86k
system_prompt: "You are a helpful AI Assistant that provides well-reasoned and detailed responses. You first think about the reasoning process as an internal monologue and then provide the user with the answer. Respond in the following format: <think>\n...\n</think>\n<answer>\n...\n</answer>"

# GRPO trainer config
beta: 0.001
bf16: true
do_eval: false
eval_strategy: "no"
use_vllm: true
do_eval: false
gradient_accumulation_steps: 16
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
hub_model_id: Qwen2.5-7B-Instruct-GRPO
hub_strategy: every_save
learning_rate: 1.0e-06
log_completions: true
log_level: info
logging_first_step: true
logging_steps: 1
logging_strategy: steps
lr_scheduler_type: constant_with_warmup
max_grad_norm: 0.2
max_prompt_length: 1024
max_completion_length: 4096
max_steps: -1
num_generations: 16
num_train_epochs: 1
output_dir: data/Qwen2.5-7B-Instruct-GRPO
overwrite_output_dir: true
per_device_train_batch_size: 4
push_to_hub: true
report_to:
- wandb
reward_funcs:
- accuracy
- format
reward_weights:
- 1.0
- 0.2
save_strategy: "steps"
save_steps: 0.1
save_total_limit: 1
seed: 42
temperature: 0.7
warmup_ratio: 0.1
2 changes: 0 additions & 2 deletions recipes/Qwen2.5-Math-7B/grpo/config_simple_rl.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@ system_prompt: "You are a helpful AI Assistant, designed to provided well-reason
# GRPO trainer config
bf16: true
use_vllm: true
vllm_device: auto
vllm_gpu_memory_utilization: 0.7
do_eval: true
eval_strategy: steps
eval_steps: 100
Expand Down
28 changes: 28 additions & 0 deletions scripts/get_tensor_parallel_size.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import argparse
from transformers import AutoConfig
from math import gcd

def get_tensor_parallel_size(model_name: str, revision: str = None, default_tp: int = 8) -> int:
try:
config = AutoConfig.from_pretrained(model_name, revision=revision, trust_remote_code=True)
num_heads = getattr(config, 'num_attention_heads', None)

if num_heads is not None and num_heads % default_tp != 0:
tp = gcd(num_heads, default_tp)
return max(tp, 1)
else:
return default_tp
except Exception as e:
print(f"Warning: Failed to fetch config for {model_name}@{revision}: {e}")
return default_tp

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model_name", type=str, required=True, help="Hugging Face model name or path")
parser.add_argument("--revision", type=str, default=None, help="Model revision if applicable")
parser.add_argument("--default_tp", type=int, default=8, help="Default TP size (usually GPUs per node)")

args = parser.parse_args()

tp = get_tensor_parallel_size(args.model_name, args.revision, args.default_tp)
print(tp)
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@
"safetensors>=0.3.3",
"sentencepiece>=0.1.99",
"torch==2.5.1",
"transformers==4.49.0",
"trl @ git+https://github.com/huggingface/trl.git@69ad852e5654a77f1695eb4c608906fe0c7e8624",
"transformers==4.50.0",
"trl==0.16.0",
"vllm==0.7.2",
"wandb>=0.19.1",
]
Expand Down
73 changes: 32 additions & 41 deletions slurm/train.slurm
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,7 @@
#SBATCH --requeue

# Specific configuration optimized for the Hugging Face Compute Cluster
# Be ye warned this may not work on other clusters!
module load cuda/12.4


set -x -e

source ~/.bashrc
Expand All @@ -24,42 +21,43 @@ TASK=$2
CONFIG_SUFFIX=$3
ACCELERATOR=$4
OPTIONAL_ARGS=$5
CONFIG_FILE=recipes/$MODEL/$TASK/config_$CONFIG_SUFFIX.yaml
GRAD_ACC_STEPS=$(grep 'gradient_accumulation_steps' $CONFIG_FILE | awk '{print $2}')
MODEL=$(grep 'model_name_or_path:' $CONFIG_FILE | awk '{print $2}')
REVISION=$(grep 'model_revision:' $CONFIG_FILE | head -n 1 | awk '{print $2}')

# Training setup
# Distributed configuration
NUM_NODES=$SLURM_NNODES
GPUS_PER_NODE=8
WORLD_SIZE=$(($NUM_NODES*$GPUS_PER_NODE))
# Due to conflicts between Accelerate's DeepSpeed configs and Transformers' TrainingArguments, we need to parse the gradient accumulation steps from the config file to ensure they match
CONFIG_FILE=recipes/$MODEL/$TASK/config_$CONFIG_SUFFIX.yaml
GRAD_ACC_STEPS=$(grep 'gradient_accumulation_steps' $CONFIG_FILE | awk '{print $2}')
NODELIST=($(scontrol show hostnames $SLURM_JOB_NODELIST))
MASTER_ADDR=${NODELIST[0]} # First node for main process
MASTER_PORT=6000
TRAIN_NODES=("${NODELIST[@]}")

# Check if we are running vLLM during training to adjust the world size
if grep -q 'use_vllm:\s*true' "$CONFIG_FILE"; then
USE_VLLM="false"
if [[ -f "$CONFIG_FILE" ]] && grep -qE '^\s*use_vllm:\s*true' "$CONFIG_FILE"; then
USE_VLLM="true"
else
USE_VLLM="false"
fi

# if using vllm
if [[ "$USE_VLLM" == "true" ]]; then
WORLD_SIZE=$(($WORLD_SIZE - 1))
TRAIN_NODES=("${NODELIST[@]:0:$((NUM_NODES - 1))}")
VLLM_NODE=${NODELIST[-1]} # Last node
TP=$(python scripts/get_tensor_parallel_size.py --model_name $MODEL --revision $REVISION --default_tp $GPUS_PER_NODE)
WORLD_SIZE=$((WORLD_SIZE - GPUS_PER_NODE))
NUM_NODES=$((NUM_NODES - 1))
srun --nodes=1 --ntasks=1 --nodelist=$VLLM_NODE trl vllm-serve --model $MODEL --revision $REVISION --tensor_parallel_size $TP &

OPTIONAL_ARGS="$OPTIONAL_ARGS --vllm_server_host=$VLLM_NODE"
fi

# Split the string into individual arguments
IFS=' ' read -ra ARGS <<< "$OPTIONAL_ARGS"

# Loop through the arguments and find the one with "--gradient_accumulation_steps"
for arg in "${ARGS[@]}"; do
if [[ "$arg" == "--gradient_accumulation_steps="* ]]; then
# Extract the value after the equals sign
GRAD_ACC_STEPS="${arg#*=}"
break # Exit the loop once we find the desired argument
fi
done

echo "Gradient accumulation steps: $GRAD_ACC_STEPS"
# so processes know who to talk to
MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1)
MASTER_PORT=6000
# force crashing on nccl issues like hanging broadcast
export NCCL_ASYNC_ERROR_HANDLING=1
# export NCCL_DEBUG=INFO
# export NCCL_DEBUG_SUBSYS=COLL
# export NCCL_SOCKET_NTHREADS=1
# export NCCL_NSOCKS_PERTHREAD=1
# export CUDA_LAUNCH_BLOCKING=1

export CMD=" \
src/open_r1/$TASK.py --config $CONFIG_FILE $OPTIONAL_ARGS
Expand All @@ -72,29 +70,22 @@ export LAUNCHER="HF_HUB_ENABLE_HF_TRANSFER=1 ACCELERATE_LOG_LEVEL=info TRANSFORM
--num_processes $WORLD_SIZE \
--main_process_ip $MASTER_ADDR \
--main_process_port $MASTER_PORT \
--machine_rank \$SLURM_PROCID \
--rdzv_conf "rdzv_backend=c10d,rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT" \
--machine_rank $SLURM_PROCID \
--rdzv_backend=c10d \
--max_restarts 1 \
--role \$(hostname -s): \
--tee 3 \
"

# force crashing on nccl issues like hanging broadcast
export NCCL_ASYNC_ERROR_HANDLING=1
# export NCCL_DEBUG=INFO
# export NCCL_DEBUG_SUBSYS=COLL
# export NCCL_SOCKET_NTHREADS=1
# export NCCL_NSOCKS_PERTHREAD=1
# export CUDA_LAUNCH_BLOCKING=1

# srun error handling:
# --wait=60: wait 60 sec after the first task terminates before terminating all remaining tasks
# --kill-on-bad-exit=1: terminate a step if any task exits with a non-zero exit code
SRUN_ARGS=" \
--wait=60 \
--kill-on-bad-exit=1 \
--nodes=$NUM_NODES \
--ntasks=$NUM_NODES \
--nodelist=$TRAIN_NODES
"

clear; srun $SRUN_ARGS --jobid $SLURM_JOB_ID bash -c "$LAUNCHER --role \$SLURMD_NODENAME: $CMD" 2>&1

echo "END TIME: $(date)"