Skip to content

Files

Latest commit

526f8a4 · Apr 22, 2025

History

History
404 lines (342 loc) · 19.8 KB

GRPO.md

File metadata and controls

404 lines (342 loc) · 19.8 KB

GRPO

Paper Links

DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models DeepSeek-R1: Incentivizing Reasoning Capability in LLMs via Reinforcement Learning

environments

pip install math_verify # reward function
pip install -U trl

FAQ

  1. It is normal for the loss to approach zero during training. Refer to this issue for more details.
  2. How to calculate the training steps? Refer to this issue for more details.
  3. Why is the clip_ratio always 1? Refer to this issue for more details.

Cluster Support

The GRPO training framework supports the integration of high-performance inference engines (such as vLLM) to accelerate the sampling process, offering the following two deployment modes:

1. Internal Integration Mode

  • Launch the inference service directly within the Trainer.
  • Provides two resource allocation strategies:
    • Colocate Mode: Training and inference share GPU resources.
    • Async Mode: Training and inference use separate GPU resources.

GRPO Training Resource Allocation Scheme

Configuration Scenario NPROC_PER_NODE num_infer_workers Resource Allocation Description
Colocate = Total GPUs = Total GPUs Training and inference share all GPU resources.
Async = Training GPUs = Inference GPUs Must satisfy: Training GPUs + Inference GPUs = Total GPUs.

Note:

  1. In Colocate mode, it is recommended to set sleep_level=1 to release the GPU memory occupied by vLLM during model training.
  2. Total GPUs refers to the total number of visible GPU devices.

2. External Service Mode

Connect to an external vLLM inference server. When using this mode, configure the external vLLM server with the following parameters:

--vllm_server_host <Server IP> \
--vllm_server_port <Server Port> \
--vllm_server_timeout <Timeout> \

Reward Functions

Custom Reward Functions

A reward function takes the text completions generated by a model and other columns from the dataset as parameters(kwargs), and scores the model's generated text. Below is an example that demonstrates how to implement a simple length-based reward function. This function will give a reward signal of 1.0 if the length of the generated text exceeds 1024; otherwise, the reward signal will be 0.0.

from swift.plugin import ORM, orms

class DummyLengthRewardFunction(ORM):
    def __call__(self, completions, **kwargs):
        return [1.0 if len(completion) > 1024 else 0.0 for completion in completions]

orms['dummy']= DummyLengthRewardFunction

You can add this reward function in swift/examples/train/grpo/plugin/plugin.py and register it using the parameter --external_plugins examples/train/grpo/plugin/plugin.py, then specify it using the reward_funcs parameter.

For an example of how to execute the script, refer to here.

Built-in Reward Functions

Swift provides five rule-based reward functions built into the system(The code can be found in swift/plugin/orm.py.)

Reward Function Paper
accuracy DeepSeek-R1: Incentivizing Reasoning Capability in LLMs via RL
format Same as above
cosine Demystifying Long Chain-of-Thought Reasoning in LLMs
repetition Same as above
soft_overlong Decoupled Clip and Dynamic sAmpling Policy Optimization (DAPO)

1. accuracy

This function compares the model's generated result with the solution column in the dataset to calculate an accuracy score. If the generated result matches the standard answer, the score is 1.0; otherwise, it is 0.0.

Note: This reward function uses the math_verify library to parse the generated results and the answers in the solution, and it may only be applicable to specific mathematical datasets.

2. format

The paper uses the following system prompt to enforce a fixed format for model responses:

A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think>

This function checks whether the model generates text in the format <think>think content</think><answer>answer content</answer>. If the generated text adheres to the format requirements, the score is 1.0; otherwise, it is 0.0.

3. cosine

The paper found that training with only the accuracy reward function could lead to overly long generated sequences, affecting training performance. The cosine reward function optimizes the training process by controlling the length of the generated sequences:

  • For text that generates the correct answer, the reward value decreases as the length increases, encouraging concise responses.
  • For text that generates incorrect answers, the reward value increases as the length increases, encouraging deeper reasoning.

A cosine function is used to smoothly adjust the reward value, ensuring that the changes are within a reasonable range. The parameters for the cosine function include the length of the generated text, the maximum length limit, and the minimum and maximum reward values.

Parameters:

  • cosine_min_len_value_wrong (default: -0.5): Reward value corresponding to the minimum length when the answer is incorrect.
  • cosine_max_len_value_wrong (default: 0.0): Reward value corresponding to the maximum length when the answer is incorrect.
  • cosine_min_len_value_correct (default: 1.0): Reward value corresponding to the minimum length when the answer is correct.
  • cosine_max_len_value_correct (default: 0.5): Reward value corresponding to the maximum length when the answer is correct.
  • cosine_max_len (default value equal to the model's maximum generation capacity): Maximum length limit for generated text.

4. repetition

This function penalizes repetition in generated text by detecting repeated n-gram patterns and assigning penalties based on the level of repetition.

The function splits the generated text into words and extracts n-grams of a specified size (default is 3-gram). It calculates the repetition ratio based on the proportion of unique n-grams to the total number of n-grams. If the proportion of repeated n-grams is high, a significant negative reward (penalty) is applied. The penalty value is computed based on the repetition ratio and a maximum penalty value (default: -1.0).

Parameters:

  • repetition_n_grams (default: 3): Size of the n-gram used to detect repetition.
  • repetition_max_penalty (default: -1.0): Maximum penalty value, which controls the intensity of the penalty.

5. soft overlong punishment

Define the length penalty interval. Within this interval, a linear penalty of [-1, 0] is applied.

Parameters:

  • soft_max_length: L_max in the paper, the maximum generation length of the model, default is equal to max_completion_length.
  • soft_cache_length: L_cache in the paper, controls the length penalty interval, which is defined as [soft_max_length - soft_cache_length, soft_max_length].

Original text from the paper:

a length-aware penalty mechanism designed to shape the reward for truncated samples. Specifically, when the response length exceeds the predefined maximum value, we define a punishment interval. Within this interval, the longer the response, the greater the punishment it receives. This penalty is added to the original rule-based correctness reward, thereby signaling to the model to avoid excessively long responses.

6. Reward Models

In addition to rule-based reward functions, this framework also supports using reward models as reward functions. When using a reward model, you need to specify the reward_model parameter, similar to the model parameter, which is used to specify the path or name of the reward model. Note that either reward_model or reward_funcs needs to be specified.

Arguments and Execution Script

Arguments

  • num_generations: The number of samples for each prompt, referred to as the G value in the paper, needs to be divisible by per_device_batch_size * - nproc_per_node.
  • max_completion_length: The maximum length for sampling generation, default is 512.
  • ds3_gather_for_generation: This parameter applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation, improving generation speed. However, disabling this option allows training models that exceed the VRAM capacity of a single GPU, albeit at the cost of slower generation. Disabling this option is not compatible with vLLM generation. The default is True.
  • reward_funcs: Reward functions to score the results generated by the model. Includes built-in accuracy, format , cosine and repetition rule-based functions, detailed in the swift/plugin/orm.py file.
  • reward_weights: Weights for each reward function. Must match the number of reward functions. If None, all rewards are weighted equally with weight 1.0.
    • Note: If --reward_model is included in GRPO training, it is added to the end of the reward functions.
  • log_completions: Whether to log the model-generated content during training, to be used in conjunction with --report_to wandb, default is False.
    • Note: If --report_to wandb is not set, a completions.jsonl will be created in the checkpoint to store the generated content.
  • use_vllm: Whether to use vLLM as the back-end for sampling generation; default is False, using it is recommended to speed up training.
  • vllm_device: Device for deploying vLLM, default is auto, meaning the first unused GPU. Use cuda:x to specify a particular card.
  • vllm_gpu_memory_utilization: vLLM passthrough parameter, default is 0.9.
  • vllm_max_model_len: vLLM passthrough parameter, default is None.
  • vllm_max_num_seqs: vLLM passthrough parameter, default is 256.
  • vllm_enforce_eager: vLLM passthrough parameter, default is False.
  • vllm_limit_mm_per_prompt: vLLM passthrough parameter, default is None.
  • vllm_enable_prefix_caching: vLLM passthrough parameter, default is True.
  • vllm_server_host: The host address of the vLLM server. Default is None. This is used when connecting to an external vLLM server.
  • vllm_server_port: The service port of the vLLM server. Default is 8000.
  • vllm_server_timeout: The connection timeout for the vLLM server. Default is 120 seconds.
  • reward_model: Same as the model, using a reward model as a reward function. At least one of reward_funcs and reward_model needs to be specified.
  • num_iterations: number of iterations per batch. Default is 1.
  • epsilon: epsilon value for clipping. Default is 0.2.
  • epsilon_high: Upper clip coefficient, default is None. When set, it forms a clipping range of [epsilon, epsilon_high] together with epsilon.
  • async_generate: Use async rollout to improve train speed,default false.
  • sleep_level: vllm specific,when both actor and rollout in the same GPU,you can make vllm sleep when model is training.
  • move_model_batches: When moving model parameters to fast inference frameworks such as vLLM/LMDeploy, determines how many batches to divide the layers into. The default is None, which means the entire model is not split. Otherwise, the model is split into move_model_batches + 1 (non-layer parameters) + 1 (multi-modal component parameters) batches.
  • offload_optimizer: Whether to offload optimizer parameters during inference with vLLM/LMDeploy. The default is False.
  • offload_model: Whether to offload the model itself during inference with vLLM/LMDeploy. The default is False.
    • Note: If this parameter is set to True and the grad_norm remains zero during training, please install vllm==0.7.3.
  • gc_collect_after_offload: Whether to perform garbage collection (both Python GC and GPU GC) after offloading. The default is False.
  • multi_turn_func: The multi turn GRPO plugin name. Add your multi-turn implementation in plugin/multi_turn.py
  • mini_batch_size: Used to further split the batch size on each device (per_device_batch) into smaller sub-batches. To ensure the split is valid, per_device_train_batch_size needs be divisible by mini_batch_size
  • dynamic_sample: Exclude data within the group where the reward standard deviation is 0, and additionally sample new data. Default is False.
  • max_resample_times: Under the dynamic_sample setting, limit the number of resampling attempts to a maximum of 3. Default is 3 times.
  • overlong_filter: Skip overlong truncated samples, which will not be included in loss calculation. Default is False. The hyperparameters for the reward function can be found in the Built-in Reward Functions section.

You can use vLLM and LMDeploy as sampling backends to accelerate training.

Multi-GPU vLLM

# async mode
# The requirement is that num_infer_workers (deployment) + NPROC_PER_NODE (training) = device_count.
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
NPROC_PER_NODE=7 \
swift rlhf \
    --rlhf_type grpo \
    --model Qwen/Qwen2.5-7B \
    --reward_funcs accuracy format cosine repetition\
    --use_vllm true \
    --vllm_device auto \
    --vllm_gpu_memory_utilization 0.7 \
    --vllm_max_model_len 8192 \
    --num_infer_workers 1 \
    --train_type full \
    --torch_dtype bfloat16 \
    --dataset 'AI-MO/NuminaMath-TIR#5000' \
    --max_completion_length 2048 \
    --num_train_epochs 1 \
    --per_device_train_batch_size 1 \
    --per_device_eval_batch_size 1 \
    --learning_rate 1e-6 \
    --gradient_accumulation_steps 2 \
    --eval_steps 200 \
    --save_steps 200 \
    --save_total_limit 2 \
    --logging_steps 5 \
    --max_length 4096 \
    --output_dir output \
    --warmup_ratio 0.05 \
    --dataloader_num_workers 4 \
    --dataset_num_proc 4 \
    --num_generations 7 \
    --temperature 0.9 \
    --system 'examples/train/grpo/prompt.txt' \
    --deepspeed zero2 \
    --log_completions true

# colocate mode
# The requirement is that num_infer_workers (deployment) = NPROC_PER_NODE (training) = device_count.
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
NPROC_PER_NODE=8 \
swift rlhf \
    --rlhf_type grpo \
    --model Qwen/Qwen2.5-1.5B \
    --reward_funcs accuracy format \
    --use_vllm true \
    --vllm_device auto \
    --vllm_gpu_memory_utilization 0.7 \
    --vllm_max_model_len 8192 \
    --num_infer_workers 8 \
    --train_type full \
    --torch_dtype bfloat16 \
    --dataset 'AI-MO/NuminaMath-TIR#5000' \
    --max_completion_length 2048 \
    --num_train_epochs 1 \
    --per_device_train_batch_size 1 \
    --per_device_eval_batch_size 1 \
    --learning_rate 1e-6 \
    --gradient_accumulation_steps 2 \
    --eval_steps 200 \
    --save_steps 200 \
    --save_total_limit 2 \
    --logging_steps 5 \
    --max_length 4096 \
    --output_dir output \
    --warmup_ratio 0.05 \
    --dataloader_num_workers 4 \
    --dataset_num_proc 4 \
    --num_generations 8 \
    --temperature 0.9 \
    --system 'examples/train/grpo/prompt.txt' \
    --deepspeed zero2 \
    --log_completions true \
    --sleep_level 1 \
    --offload_model true \
    --offload_optimizer true \
    --gc_collect_after_offload true \
    --log_completions true \

Single-GPU

# PT backend
CUDA_VISIBLE_DEVICES=0 \
swift rlhf \
    --rlhf_type grpo \
    --model Qwen/Qwen2.5-7B \
    --reward_funcs accuracy format cosine repetition\
    --train_type lora \
    --lora_rank 8 \
    --lora_alpha 32 \
    --target_modules all-linear \
    --torch_dtype bfloat16 \
    --dataset 'AI-MO/NuminaMath-TIR#1000' \
    --max_completion_length 1024 \
    --num_train_epochs 1 \
    --per_device_train_batch_size 4 \
    --per_device_eval_batch_size 4 \
    --learning_rate 1e-5 \
    --gradient_accumulation_steps 1 \
    --eval_steps 100 \
    --save_steps 100 \
    --save_total_limit 2 \
    --logging_steps 5 \
    --max_length 2048 \
    --output_dir output \
    --warmup_ratio 0.05 \
    --dataloader_num_workers 4 \
    --dataset_num_proc 4 \
    --num_generations 4 \
    --temperature 0.9 \
    --system 'examples/train/grpo/prompt.txt' \
    --log_completions true

# vLLM backend
CUDA_VISIBLE_DEVICES=0 \
swift rlhf \
    --rlhf_type grpo \
    --model Qwen/Qwen2.5-7B \
    --vllm_gpu_memory_utilization 0.5 \
    --use_vllm true \
    --sleep_level 1 \
    --offload_model true \
    --offload_optimizer true \
    --gc_collect_after_offload true \
    --reward_funcs accuracy format \
    --train_type lora \
    --lora_rank 8 \
    --lora_alpha 32 \
    --target_modules all-linear \
    --torch_dtype bfloat16 \
    --dataset 'AI-MO/NuminaMath-TIR#1000' \
    --max_completion_length 1024 \
    --num_train_epochs 1 \
    --per_device_train_batch_size 4 \
    --per_device_eval_batch_size 4 \
    --learning_rate 1e-5 \
    --gradient_accumulation_steps 1 \
    --eval_steps 100 \
    --save_steps 100 \
    --save_total_limit 2 \
    --logging_steps 5 \
    --max_length 2048 \
    --output_dir output \
    --warmup_ratio 0.05 \
    --dataloader_num_workers 4 \
    --dataset_num_proc 4 \
    --num_generations 4 \
    --temperature 0.9 \
    --system 'examples/train/grpo/prompt.txt' \
    --log_completions true

For multi-node training, refer to here .

Note : In the internal integration mode, the GPU configurations and training parameters must be identical across different nodes.

DAPO

Decoupled Clip and Dynamic Sampling Policy Optimization (DAPO) introduces several tricks based on GRPO, which are:

  • Clip Higher
  • Dynamic Sampling
  • Overlong Filtering
  • Token level Loss
  • Soft Overlong Punishment

Among these, Token level Loss is implemented by default and does not require additional settings. For the other tricks, we can achieve the desired setup based on GRPOTrainer by configuring the following parameters.

Parameter Type Value
--epsilon_high float 0.28
--dynamic_sample bool true
--overlong_filter bool true
--reward_funcs str soft_overlong
--max_resample_times int 3

Reference training script (for 8-card colocate mode):

CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
NPROC_PER_NODE=8 \
WANDB_API_KEY=xxx \
swift rlhf \
    --rlhf_type grpo \
    --model Qwen/Qwen2.5-1.5B \
    --reward_funcs accuracy soft_overlong \
    --max_completion_length 4096 \
    --soft_cache_length 819 \
    --epsilon 0.2 \
    --epsilon_high 0.28 \
    --dynamic_sample true \
    --overlong_filter true \
    --max_resample_times 3 \
    --use_vllm true \
    --vllm_gpu_memory_utilization 0.6 \
    --num_infer_workers 8 \
    --train_type full \
    --torch_dtype bfloat16 \
    --dataset AI-MO/NuminaMath-TIR#5000 \
    --num_train_epochs 1 \
    --per_device_train_batch_size 4 \
    --per_device_eval_batch_size 4 \
    --learning_rate 1e-6 \
    --eval_steps 1000 \
    --save_steps 1000 \
    --save_total_limit 2 \
    --logging_steps 5 \
    --warmup_ratio 0.05 \
    --dataloader_num_workers 4 \
    --dataset_num_proc 4 \
    --num_generations 8 \
    --temperature 1.0 \
    --top_p 1.0 \
    --deepspeed zero2 \
    --log_completions true \
    --num_iterations 1 \
    --report_to tensorboard wandb \
    --beta 0.0 \