diff --git a/docs/source/grpo_trainer.md b/docs/source/grpo_trainer.md index 9f13263f43..7a89d79a75 100644 --- a/docs/source/grpo_trainer.md +++ b/docs/source/grpo_trainer.md @@ -144,6 +144,16 @@ $$ This constant is recommended to be the maximum completion length. To use this formulation, set `loss_type="dr_grpo"` in the [`GRPOConfig`]. +### CISPO: Truncated importance-sampling REINFORCE + +The ScaleRL paper[^scalerl] introduces CISPO, a variant of truncated importance-sampling REINFORCE that keeps the prompt-level normalization from DAPO while replacing the PPO-style min operator with a stop-gradient truncation of the importance ratios: + +$$ +\mathcal{L}_{\text{CISPO}}(\theta) = - \frac{1}{T_G} \sum_{i=1}^{G} \sum_{t=1}^{|o_i|} \operatorname{sg}\!\left(\min(\rho_{i,t}, \epsilon_{\max})\right) \, \hat{A}_i \log \pi_\theta(o_{i,t} \mid q, o_{i, < t}) \,, +$$ + +where \( \rho_{i,t} = \tfrac{\pi_\theta(o_{i,t} \mid q, o_{i, 1 + \epsilon_\mathrm{high}\\) - `clip_ratio/high_max`: The maximum ratio of token (or sequence, if `importance_sampling_level="sequence"`) probabilities that were clipped on the upper bound of the trust region: \\(r_{i,t}(\theta) > 1 + \epsilon_\mathrm{high}\\). +- `cispo/importance_ratio/mean`: (Only when `loss_type="cispo"`.) Average importance ratio \( \rho_{i,t} \) before truncation. +- `cispo/importance_ratio/truncated_mean`: (Only when `loss_type="cispo"`.) Average truncated ratio \( \min(\rho_{i,t}, \epsilon_{\max}) \). +- `cispo/importance_ratio/max`: (Only when `loss_type="cispo"`.) Maximum observed importance ratio \( \rho_{i,t} \) in the batch. +- `cispo/importance_ratio/max_truncated`: (Only when `loss_type="cispo"`.) Maximum truncated ratio after applying \( \epsilon_{\max} \). +- `cispo/clip_fraction`: (Only when `loss_type="cispo"`.) Fraction of tokens whose importance ratio exceeded \( \epsilon_{\max} \). ## Customization @@ -185,6 +200,8 @@ Generation is often the main bottleneck when training with online methods. To ac pip install trl[vllm] ``` +[^scalerl]: Yao et al., *ScaleRL: Scaling RL Compute Effectively and Predictably*, 2025. + We support two ways of using vLLM during training: **server mode** and **colocate mode**. > [!TIP] diff --git a/trl/trainer/grpo_config.py b/trl/trainer/grpo_config.py index f6f3c6e346..cc1892f22c 100644 --- a/trl/trainer/grpo_config.py +++ b/trl/trainer/grpo_config.py @@ -538,22 +538,18 @@ class GRPOConfig(TrainingArguments): loss_type: str = field( default="dapo", metadata={ - "help": "Specifies the loss formulation to use. Supported values are 'grpo', 'dapo', 'bnpo', and " - "'dr_grpo'. " - "'grpo': Aggregates token-level losses by normalizing over sequence length. Not recommended due to length " - "bias—this approach tends to prefer shorter completions with positive advantages and longer ones with " - "negative advantages. " - "'dapo' (default): Aggregates token-level losses by normalizing with the number of active token in the " - "global accumulated batch. This method was introduced in the DAPO paper to eliminate length bias. " - "'dr_grpo': Aggregates token-level losses by normalizing with a global constant. This method was " - "introduced in the Dr. GRPO paper to eliminate length bias. The value of the constant corresponds to " - "`max_completion_length`. " - "'bnpo': Aggregates token-level losses by normalizing with the number of active token in the local batch. " - "Note that normalization is performed over the local batch only, so results may slightly vary depending " - "on the local batch size, despite a constant effective batch size. When using " - "`per_device_train_batch_size==1`, the loss is equivalent to the GRPO loss." + "help": "Specifies the loss formulation to use. Supported values are 'grpo', 'dapo', 'bnpo', 'dr_grpo', and 'cispo'. 'grpo': Aggregates token-level losses by normalizing over sequence length. Not recommended due to length bias-this approach tends to prefer shorter completions with positive advantages and longer ones with negative advantages. 'dapo' (default): Aggregates token-level losses by normalizing with the number of active token in the global accumulated batch. This method was introduced in the DAPO paper to eliminate length bias. 'dr_grpo': Aggregates token-level losses by normalizing with a global constant. This method was introduced in the Dr. GRPO paper to eliminate length bias. The value of the constant corresponds to `max_completion_length`. 'bnpo': Aggregates token-level losses by normalizing with the number of active token in the local batch. Note that normalization is performed over the local batch only, so results may slightly vary depending on the local batch size, despite a constant effective batch size. When using `per_device_train_batch_size==1`, the loss is equivalent to the GRPO loss. 'cispo': Uses the truncated importance-sampling REINFORCE loss introduced in the ScaleRL paper (Eq. 4), truncating importance ratios at `cispo_clip_max` with gradients stopped through the truncation.", + "choices": ["grpo", "dapo", "bnpo", "dr_grpo", "cispo"], }, ) + + cispo_clip_max: float = field( + default=5.0, + metadata={ + "help": "Upper truncation epsilon_max applied to the importance sampling ratio for the CISPO loss. Weights are set to min(rho, epsilon_max) with gradients stopped through the truncation, following ScaleRL Eq. 4.", + }, + ) + mask_truncated_completions: bool = field( default=False, metadata={ diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 352a0144ef..a27fc5e97e 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -391,6 +391,9 @@ def __init__( self.num_iterations = args.num_iterations # = 𝜇 in the GRPO paper self.epsilon_low = args.epsilon self.epsilon_high = args.epsilon_high if args.epsilon_high is not None else args.epsilon + self.cispo_clip_max = args.cispo_clip_max + if self.loss_type == "cispo" and self.cispo_clip_max <= 0: + raise ValueError("`cispo_clip_max` must be a positive float when using the CISPO loss.") # Tracks the number of iterations (forward + backward passes), including those within a grad accum cycle self._step = 0 # Buffer the batch to reuse generated outputs across multiple updates. For more details, see @@ -445,6 +448,8 @@ def __init__( # Liger loss if self.use_liger_loss: + if self.loss_type == "cispo": + raise NotImplementedError("Liger kernels do not currently support the CISPO loss.") if not is_liger_kernel_available(): raise ImportError( "Liger is required to use `liger_loss` as the GRPO loss. Run `pip install liger-kernel`." @@ -1705,19 +1710,28 @@ def _compute_loss(self, model, inputs): f"Unknown importance sampling level: {self.importance_sampling_level}. Possible values are 'token' " "and 'sequence'." ) - # From here, log_importance_weights (and all subsequent tensors, coef_1, coef_2, etc.) shape depends on - # importance_sampling_level: "token" level: (B, T); "sequence" level: (B, 1) - + # From here, log_importance_weights (and all subsequent tensors) shape depends on the importance sampling + # level: "token" level -> (B, T); "sequence" level -> (B, 1) coef_1 = torch.exp(log_importance_weights) - coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high) + cispo_truncated_weights = None + cispo_clipped_mask = None + + if self.loss_type == "cispo": + cispo_cap = torch.full_like(coef_1, self.cispo_clip_max) + cispo_truncated_weights = torch.minimum(coef_1, cispo_cap) + cispo_clipped_mask = coef_1 > cispo_cap + cispo_weights = cispo_truncated_weights.detach() + per_token_loss = -cispo_weights * advantages.unsqueeze(1) * per_token_logps + else: + coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high) - # Two-sided clipping - if self.args.delta is not None: - coef_1 = torch.clamp(coef_1, max=self.args.delta) + # Two-sided clipping + if self.args.delta is not None: + coef_1 = torch.clamp(coef_1, max=self.args.delta) - per_token_loss1 = coef_1 * advantages.unsqueeze(1) - per_token_loss2 = coef_2 * advantages.unsqueeze(1) - per_token_loss = -torch.min(per_token_loss1, per_token_loss2) + per_token_loss1 = coef_1 * advantages.unsqueeze(1) + per_token_loss2 = coef_2 * advantages.unsqueeze(1) + per_token_loss = -torch.min(per_token_loss1, per_token_loss2) if entropy_mask is not None: per_token_loss = per_token_loss * entropy_mask @@ -1739,6 +1753,9 @@ def _compute_loss(self, model, inputs): elif self.loss_type == "dapo": normalizer = inputs["num_items_in_batch"] / self.accelerator.num_processes loss = (per_token_loss * completion_mask).sum() / normalizer + elif self.loss_type == "cispo": + normalizer = inputs["num_items_in_batch"] / self.accelerator.num_processes + loss = (per_token_loss * completion_mask).sum() / normalizer else: raise ValueError(f"Unknown loss type: {self.loss_type}") @@ -1760,6 +1777,41 @@ def masked_batch_mean(x): mean_entropy = masked_batch_mean(entropies) self._metrics[mode]["entropy"].append(self.accelerator.gather(mean_entropy).nanmean().item()) + if self.loss_type == "cispo": + truncated_mean = masked_batch_mean(cispo_truncated_weights) + ratio_mean = masked_batch_mean(coef_1) + clip_fraction = masked_batch_mean(cispo_clipped_mask.float()) + + gathered_truncated_mean = self.accelerator.gather(truncated_mean) + gathered_ratio_mean = self.accelerator.gather(ratio_mean) + gathered_clip_fraction = self.accelerator.gather(clip_fraction) + + self._metrics[mode]["cispo/importance_ratio/truncated_mean"].append( + gathered_truncated_mean.nanmean().item() + ) + self._metrics[mode]["cispo/importance_ratio/mean"].append(gathered_ratio_mean.nanmean().item()) + self._metrics[mode]["cispo/clip_fraction"].append(gathered_clip_fraction.nanmean().item()) + + if cispo_truncated_weights.shape[1] == 1: + flat_original = coef_1.squeeze(1) + flat_truncated = cispo_truncated_weights.squeeze(1) + else: + mask = completion_mask.bool() + flat_original = coef_1.masked_select(mask) + flat_truncated = cispo_truncated_weights.masked_select(mask) + + max_ratio = flat_original.max() if flat_original.numel() > 0 else torch.tensor(0.0, device=coef_1.device) + max_truncated = ( + flat_truncated.max() if flat_truncated.numel() > 0 else torch.tensor(0.0, device=coef_1.device) + ) + self._metrics[mode]["cispo/importance_ratio/max"].append( + nanmax(self.accelerator.gather(max_ratio)).item() + ) + self._metrics[mode]["cispo/importance_ratio/max_truncated"].append( + nanmax(self.accelerator.gather(max_truncated)).item() + ) + return loss + # Compute the clipped probability ratios is_low_clipped = (coef_1 < 1 - self.epsilon_low) & (advantages.unsqueeze(1) < 0) is_high_clipped = (coef_1 > 1 + self.epsilon_high) & (advantages.unsqueeze(1) > 0)