From 4533d80cb58e6eaf6392d265eb20213e65bf944c Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 13 Nov 2025 11:35:40 +0100 Subject: [PATCH 1/5] use the liger triton_grpo_loss --- tests/test_grpo_trainer.py | 5 +- trl/trainer/grpo_trainer.py | 136 ++++++++++++------------------------ 2 files changed, 47 insertions(+), 94 deletions(-) diff --git a/tests/test_grpo_trainer.py b/tests/test_grpo_trainer.py index b3844a399c..8c46399711 100644 --- a/tests/test_grpo_trainer.py +++ b/tests/test_grpo_trainer.py @@ -1504,7 +1504,8 @@ def reward_func(completions, **kwargs): ) @require_vision @require_liger_kernel - def test_training_vlm_and_liger(self, model_id): + @pytest.mark.parametrize("loss_type", ["grpo", "bnpo", "dr_grpo", "dapo"]) + def test_training_vlm_and_liger(self, model_id, loss_type): dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train") def reward_func(completions, **kwargs): @@ -1518,7 +1519,7 @@ def reward_func(completions, **kwargs): num_generations=3, # reduce the number of generations to reduce memory usage max_completion_length=8, # reduce the completion length to reduce memory usage use_liger_kernel=True, # enable Liger kernel - loss_type="bnpo", # default dapo is not supported yet + loss_type=loss_type, report_to="none", ) trainer = GRPOTrainer( diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 6b245b3511..c7311cbcca 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -60,7 +60,6 @@ from ..extras.vllm_client import VLLMClient from ..import_utils import is_liger_kernel_available, is_vllm_available from ..models import prepare_deepspeed, prepare_fsdp, prepare_peft_model, unwrap_model_for_generation -from ..models.utils import _ForwardRedirection from .base_trainer import BaseTrainer from .callbacks import SyncRefModelCallback from .grpo_config import GRPOConfig @@ -88,7 +87,7 @@ from peft import PeftConfig, PeftModel if is_liger_kernel_available(): - from liger_kernel.chunked_loss import LigerFusedLinearGRPOLoss + from liger_kernel.transformers.grpo_loss import triton_grpo_loss if is_vllm_available(): from vllm import LLM, SamplingParams @@ -515,19 +514,6 @@ def cast_outputs_to_original_dtype(module, args, output): raise ImportError( "Liger is required to use `use_liger_kernel` as the GRPO loss. Run `pip install liger-kernel`." ) - # redirect the model.module forward to the model forward to ensure pre-forward hooks are called - self._forward_redirection = _ForwardRedirection() - - self.liger_grpo_loss = LigerFusedLinearGRPOLoss( - beta=self.beta, - epsilon_low=self.epsilon_low, - epsilon_high=self.epsilon_high, - temperature=self.temperature, - use_ref_model=self.beta != 0.0, - loss_type=self.loss_type, - max_completion_length=self.max_completion_length, - ) - # Initialize the metrics self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)} self._total_train_tokens = 0 @@ -765,51 +751,6 @@ def _get_eval_sampler(self, eval_dataset) -> Sampler: seed=self.args.seed, ) - @profiling_decorator - def _get_last_hidden_state( - self, - unwrapped_model, - input_ids, - attention_mask, - logits_to_keep, - pixel_values=None, - image_grid_thw=None, - pixel_attention_mask=None, - image_sizes=None, - ): - if is_peft_model(unwrapped_model): - unwrapped_model = unwrapped_model.base_model.model - - # Build model inputs - check if the model supports logits_to_keep (some models and VLMs don't) - model_inputs = {"input_ids": input_ids, "attention_mask": attention_mask} - - # For Qwen models: - if image_grid_thw is not None and pixel_values is not None: - model_inputs["image_grid_thw"] = image_grid_thw - # For Gemma, SmolVLM2, LLaVa-Next etc.: - if pixel_values is not None: - model_inputs["pixel_values"] = pixel_values - # For SmolVLM2 - if pixel_attention_mask is not None: - model_inputs["pixel_attention_mask"] = pixel_attention_mask - # For LLaVa-Next - if image_sizes is not None: - model_inputs["image_sizes"] = image_sizes - - # Only add logits_to_keep if the model supports it - if "logits_to_keep" in self.model_kwarg_keys: - # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded - model_inputs["logits_to_keep"] = logits_to_keep + 1 - - model_inputs["use_cache"] = False # only used in generation; set False to suppress warnings - - last_hidden_state = unwrapped_model.model(**model_inputs).last_hidden_state - # Exclude the last value: it corresponds to the next token pred - last_hidden_state = last_hidden_state[:, :-1, :] # (B, L-1, H) - # Only keep the last logits_to_keep. For model that support logits_to_keep, this is a no-op. - last_hidden_state = last_hidden_state[:, -logits_to_keep:, :] # (B, logits_to_keep, H) - return last_hidden_state - def get_high_entropy_mask(self, entropies: torch.Tensor, mask: torch.Tensor, threshold: float) -> torch.Tensor: """ Returns a binary mask identifying tokens whose entropy exceeds a given quantile threshold. @@ -1710,45 +1651,58 @@ def _generate_and_score_completions( output["num_images"] = num_images return output - def compute_liger_loss(self, unwrapped_model, inputs): - # Compute the per-token log probabilities for the model + def compute_liger_loss(self, model, inputs): + # Compute logits directly and run Triton GRPO loss prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"] - completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"] + completion_ids = inputs["completion_ids"].contiguous() + completion_mask = inputs["completion_mask"].contiguous() input_ids = torch.cat([prompt_ids, completion_ids], dim=1) attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) - logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens + logits_to_keep = completion_ids.size(1) - # Get the last hidden state of the model - last_hidden_state = self._get_last_hidden_state( - unwrapped_model, - input_ids, - attention_mask, - logits_to_keep, - inputs.get("pixel_values"), - inputs.get("image_grid_thw"), - inputs.get("pixel_attention_mask"), - inputs.get("image_sizes"), - ) + model_inputs = {"input_ids": input_ids, "attention_mask": attention_mask} + optional_keys = [ + "pixel_values", + "image_grid_thw", + "pixel_attention_mask", + "image_sizes", + "token_type_ids", + ] + for key in optional_keys: + if key in inputs: + model_inputs[key] = inputs[key] + model_inputs["use_cache"] = False + if "logits_to_keep" in self.model_kwarg_keys: + model_inputs["logits_to_keep"] = logits_to_keep + 1 + + logits = model(**model_inputs).logits + logits = logits[:, -(logits_to_keep + 1) :, :].contiguous() - # compute loss and metrics using liger grpo loss - loss, metrics = self.liger_grpo_loss( - _input=last_hidden_state, - lin_weight=unwrapped_model.lm_head.weight, - selected_token_ids=completion_ids, - attention_mask=completion_mask, + loss, metrics = triton_grpo_loss( + logits=logits, + old_logp=inputs.get("old_per_token_logps"), + ref_logp=inputs.get("ref_per_token_logps"), + completion_ids=completion_ids, advantages=inputs["advantages"], - bias=unwrapped_model.lm_head.bias, - old_per_token_logps=inputs.get("old_per_token_logps"), - ref_per_token_logps=inputs.get("ref_per_token_logps"), + completion_mask=completion_mask, + temperature=self.temperature, + beta=self.beta, + eps_low=self.epsilon_low, + eps_high=self.epsilon_high, + inplace=True, + loss_type=self.loss_type, + max_completion_length=self.max_completion_length, + importance_sampling_level=self.importance_sampling_level, + reduce=True, ) - # Extract metrics from the liger_grpo_loss output - # KL divergence is the first metric when beta is non-zero - mean_kl = metrics[0] if self.beta != 0.0 else None - clip_ratio = metrics[-1] mode = "train" if self.model.training else "eval" + metric_offset = 0 if self.beta != 0.0: - self._metrics[mode]["kl"].append(self.accelerator.gather(mean_kl).mean().item()) + kl_metric = metrics[0] + self._metrics[mode]["kl"].append(self.accelerator.gather(kl_metric).mean().item()) + metric_offset = 1 + clip_ratio = metrics[metric_offset] self._metrics[mode]["clip_ratio"].append(self.accelerator.gather(clip_ratio).mean().item()) return loss / self.current_gradient_accumulation_steps @@ -1757,9 +1711,7 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N if return_outputs: raise ValueError("The GRPOTrainer does not support returning outputs") if self.use_liger_kernel: - # Compute the loss using the liger grpo loss - unwrapped_model = self.accelerator.unwrap_model(model) - return self._forward_redirection(model, unwrapped_model, self.compute_liger_loss, unwrapped_model, inputs) + return self.compute_liger_loss(model, inputs) else: return self._compute_loss(model, inputs) From b33aff8dff56ea627d0063709fccb9a7a3d17bee Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sat, 22 Nov 2025 14:18:17 +0100 Subject: [PATCH 2/5] needs 0.6.4 --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 1b84ff50d7..61e8c0f472 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,7 +56,7 @@ kernels = [ "kernels" ] liger = [ - "liger-kernel>=0.6.2" + "liger-kernel>=0.6.4" ] peft = [ "peft>=0.8.0" @@ -104,7 +104,7 @@ dev = [ # kernels "kernels", # liger - "liger-kernel>=0.6.2", + "liger-kernel>=0.6.4", # peft "peft>=0.8.0", # quality From c779a23a018ac91b9853e56786f02723aee8ccc4 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 1 Dec 2025 15:22:47 +0100 Subject: [PATCH 3/5] Update trl/trainer/grpo_trainer.py Co-authored-by: lewtun --- trl/trainer/grpo_trainer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 8c4cb24eb5..571ae74ff7 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1713,6 +1713,7 @@ def compute_liger_loss(self, model, inputs): model_inputs[key] = inputs[key] model_inputs["use_cache"] = False if "logits_to_keep" in self.model_kwarg_keys: + # Add 1 because the last logits of the sequence is later excluded model_inputs["logits_to_keep"] = logits_to_keep + 1 logits = model(**model_inputs).logits From 39d7c8da849d22a3691deb461736f1eb94a68e15 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 1 Dec 2025 18:53:24 +0000 Subject: [PATCH 4/5] two-step slicing logic for the logits --- trl/trainer/grpo_trainer.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 571ae74ff7..9a7fb76385 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1717,7 +1717,10 @@ def compute_liger_loss(self, model, inputs): model_inputs["logits_to_keep"] = logits_to_keep + 1 logits = model(**model_inputs).logits - logits = logits[:, -(logits_to_keep + 1) :, :].contiguous() + # Exclude the last value: it corresponds to the next token pred + logits = logits[:, :-1, :] + # Only keep the last logits_to_keep. For model that support logits_to_keep, this is a no-op. + logits = logits[:, -logits_to_keep:, :].contiguous() loss, metrics = triton_grpo_loss( logits=logits, From c57d4ad47fa71321381b6b937a889533e3181196 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 1 Dec 2025 19:29:58 +0000 Subject: [PATCH 5/5] liger triton expects logits_to_keep + 1 --- trl/trainer/grpo_trainer.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 6b8684235f..bcd00d3c07 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1745,10 +1745,7 @@ def compute_liger_loss(self, model, inputs): model_inputs["logits_to_keep"] = logits_to_keep + 1 logits = model(**model_inputs).logits - # Exclude the last value: it corresponds to the next token pred - logits = logits[:, :-1, :] - # Only keep the last logits_to_keep. For model that support logits_to_keep, this is a no-op. - logits = logits[:, -logits_to_keep:, :].contiguous() + logits = logits[:, -(logits_to_keep + 1) :, :].contiguous() loss, metrics = triton_grpo_loss( logits=logits,