diff --git a/tests/test_grpo_trainer.py b/tests/test_grpo_trainer.py index baaea524b2..80a9bf5956 100644 --- a/tests/test_grpo_trainer.py +++ b/tests/test_grpo_trainer.py @@ -1527,7 +1527,8 @@ def reward_func(completions, **kwargs): ) @require_vision @require_liger_kernel - def test_training_vlm_and_liger(self, model_id): + @pytest.mark.parametrize("loss_type", ["grpo", "bnpo", "dr_grpo", "dapo"]) + def test_training_vlm_and_liger(self, model_id, loss_type): dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train") def reward_func(completions, **kwargs): @@ -1541,6 +1542,7 @@ def reward_func(completions, **kwargs): num_generations=3, # reduce the number of generations to reduce memory usage max_completion_length=8, # reduce the completion length to reduce memory usage use_liger_kernel=True, # enable Liger kernel + loss_type=loss_type, report_to="none", ) trainer = GRPOTrainer( diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 995bfdc17a..bcd00d3c07 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -62,7 +62,6 @@ from ..extras.vllm_client import VLLMClient from ..import_utils import is_liger_kernel_available, is_vllm_available from ..models import prepare_deepspeed, prepare_fsdp, prepare_peft_model, unwrap_model_for_generation -from ..models.utils import _ForwardRedirection from .base_trainer import BaseTrainer from .callbacks import SyncRefModelCallback from .grpo_config import GRPOConfig @@ -90,7 +89,7 @@ from peft import PeftConfig, PeftModel if is_liger_kernel_available(): - from liger_kernel.chunked_loss import LigerFusedLinearGRPOLoss + from liger_kernel.transformers.grpo_loss import triton_grpo_loss if is_vllm_available(): from vllm import LLM, SamplingParams @@ -515,19 +514,6 @@ def cast_outputs_to_original_dtype(module, args, output): raise ImportError( "Liger is required to use `use_liger_kernel` as the GRPO loss. Run `pip install liger-kernel`." ) - # redirect the model.module forward to the model forward to ensure pre-forward hooks are called - self._forward_redirection = _ForwardRedirection() - - self.liger_grpo_loss = LigerFusedLinearGRPOLoss( - beta=self.beta, - epsilon_low=self.epsilon_low, - epsilon_high=self.epsilon_high, - temperature=self.temperature, - use_ref_model=self.beta != 0.0, - loss_type=self.loss_type, - max_completion_length=self.max_completion_length, - ) - # Initialize the metrics self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)} self._total_train_tokens = 0 @@ -776,51 +762,6 @@ def _get_eval_sampler(self, eval_dataset) -> Sampler: seed=self.args.seed, ) - @profiling_decorator - def _get_last_hidden_state( - self, - unwrapped_model, - input_ids, - attention_mask, - logits_to_keep, - pixel_values=None, - image_grid_thw=None, - pixel_attention_mask=None, - image_sizes=None, - ): - if is_peft_model(unwrapped_model): - unwrapped_model = unwrapped_model.base_model.model - - # Build model inputs - check if the model supports logits_to_keep (some models and VLMs don't) - model_inputs = {"input_ids": input_ids, "attention_mask": attention_mask} - - # For Qwen models: - if image_grid_thw is not None and pixel_values is not None: - model_inputs["image_grid_thw"] = image_grid_thw - # For Gemma, SmolVLM2, LLaVa-Next etc.: - if pixel_values is not None: - model_inputs["pixel_values"] = pixel_values - # For SmolVLM2 - if pixel_attention_mask is not None: - model_inputs["pixel_attention_mask"] = pixel_attention_mask - # For LLaVa-Next - if image_sizes is not None: - model_inputs["image_sizes"] = image_sizes - - # Only add logits_to_keep if the model supports it - if "logits_to_keep" in self.model_kwarg_keys: - # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded - model_inputs["logits_to_keep"] = logits_to_keep + 1 - - model_inputs["use_cache"] = False # only used in generation; set False to suppress warnings - - last_hidden_state = unwrapped_model.model(**model_inputs).last_hidden_state - # Exclude the last value: it corresponds to the next token pred - last_hidden_state = last_hidden_state[:, :-1, :] # (B, L-1, H) - # Only keep the last logits_to_keep. For model that support logits_to_keep, this is a no-op. - last_hidden_state = last_hidden_state[:, -logits_to_keep:, :] # (B, logits_to_keep, H) - return last_hidden_state - def get_high_entropy_mask(self, entropies: torch.Tensor, mask: torch.Tensor, threshold: float) -> torch.Tensor: """ Returns a binary mask identifying tokens whose entropy exceeds a given quantile threshold. @@ -1778,45 +1719,59 @@ def _generate_and_score_completions( output["num_images"] = num_images return output - def compute_liger_loss(self, unwrapped_model, inputs): - # Compute the per-token log probabilities for the model + def compute_liger_loss(self, model, inputs): + # Compute logits directly and run Triton GRPO loss prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"] - completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"] + completion_ids = inputs["completion_ids"].contiguous() + completion_mask = inputs["completion_mask"].contiguous() input_ids = torch.cat([prompt_ids, completion_ids], dim=1) attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) - logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens + logits_to_keep = completion_ids.size(1) - # Get the last hidden state of the model - last_hidden_state = self._get_last_hidden_state( - unwrapped_model, - input_ids, - attention_mask, - logits_to_keep, - inputs.get("pixel_values"), - inputs.get("image_grid_thw"), - inputs.get("pixel_attention_mask"), - inputs.get("image_sizes"), - ) + model_inputs = {"input_ids": input_ids, "attention_mask": attention_mask} + optional_keys = [ + "pixel_values", + "image_grid_thw", + "pixel_attention_mask", + "image_sizes", + "token_type_ids", + ] + for key in optional_keys: + if key in inputs: + model_inputs[key] = inputs[key] + model_inputs["use_cache"] = False + if "logits_to_keep" in self.model_kwarg_keys: + # Add 1 because the last logits of the sequence is later excluded + model_inputs["logits_to_keep"] = logits_to_keep + 1 + + logits = model(**model_inputs).logits + logits = logits[:, -(logits_to_keep + 1) :, :].contiguous() - # compute loss and metrics using liger grpo loss - loss, metrics = self.liger_grpo_loss( - _input=last_hidden_state, - lin_weight=unwrapped_model.lm_head.weight, - selected_token_ids=completion_ids, - attention_mask=completion_mask, + loss, metrics = triton_grpo_loss( + logits=logits, + old_logp=inputs.get("old_per_token_logps"), + ref_logp=inputs.get("ref_per_token_logps"), + completion_ids=completion_ids, advantages=inputs["advantages"], - bias=unwrapped_model.lm_head.bias, - old_per_token_logps=inputs.get("old_per_token_logps"), - ref_per_token_logps=inputs.get("ref_per_token_logps"), + completion_mask=completion_mask, + temperature=self.temperature, + beta=self.beta, + eps_low=self.epsilon_low, + eps_high=self.epsilon_high, + inplace=True, + loss_type=self.loss_type, + max_completion_length=self.max_completion_length, + importance_sampling_level=self.importance_sampling_level, + reduce=True, ) - # Extract metrics from the liger_grpo_loss output - # KL divergence is the first metric when beta is non-zero - mean_kl = metrics[0] if self.beta != 0.0 else None - clip_ratio = metrics[-1] mode = "train" if self.model.training else "eval" + metric_offset = 0 if self.beta != 0.0: - self._metrics[mode]["kl"].append(self.accelerator.gather(mean_kl).mean().item()) + kl_metric = metrics[0] + self._metrics[mode]["kl"].append(self.accelerator.gather(kl_metric).mean().item()) + metric_offset = 1 + clip_ratio = metrics[metric_offset] self._metrics[mode]["clip_ratio"].append(self.accelerator.gather(clip_ratio).mean().item()) return loss / self.current_gradient_accumulation_steps @@ -1825,9 +1780,7 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N if return_outputs: raise ValueError("The GRPOTrainer does not support returning outputs") if self.use_liger_kernel: - # Compute the loss using the liger grpo loss - unwrapped_model = self.accelerator.unwrap_model(model) - return self._forward_redirection(model, unwrapped_model, self.compute_liger_loss, unwrapped_model, inputs) + return self.compute_liger_loss(model, inputs) else: return self._compute_loss(model, inputs)