Skip to content
Open
4 changes: 3 additions & 1 deletion tests/test_grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1528,7 +1528,8 @@ def reward_func(completions, **kwargs):
)
@require_vision
@require_liger_kernel
def test_training_vlm_and_liger(self, model_id):
@pytest.mark.parametrize("loss_type", ["grpo", "bnpo", "dr_grpo", "dapo"])
def test_training_vlm_and_liger(self, model_id, loss_type):
dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train")

def reward_func(completions, **kwargs):
Expand All @@ -1542,6 +1543,7 @@ def reward_func(completions, **kwargs):
num_generations=3, # reduce the number of generations to reduce memory usage
max_completion_length=8, # reduce the completion length to reduce memory usage
use_liger_kernel=True, # enable Liger kernel
loss_type=loss_type,
report_to="none",
)
trainer = GRPOTrainer(
Expand Down
136 changes: 44 additions & 92 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@
from ..extras.vllm_client import VLLMClient
from ..import_utils import is_liger_kernel_available, is_vllm_available
from ..models import prepare_deepspeed, prepare_fsdp, prepare_peft_model, unwrap_model_for_generation
from ..models.utils import _ForwardRedirection
from .base_trainer import BaseTrainer
from .callbacks import SyncRefModelCallback
from .grpo_config import GRPOConfig
Expand Down Expand Up @@ -90,7 +89,7 @@
from peft import PeftConfig, PeftModel

if is_liger_kernel_available():
from liger_kernel.chunked_loss import LigerFusedLinearGRPOLoss
from liger_kernel.transformers.grpo_loss import triton_grpo_loss

if is_vllm_available():
from vllm import LLM, SamplingParams
Expand Down Expand Up @@ -514,19 +513,6 @@ def cast_outputs_to_original_dtype(module, args, output):
raise ImportError(
"Liger is required to use `use_liger_kernel` as the GRPO loss. Run `pip install liger-kernel`."
)
# redirect the model.module forward to the model forward to ensure pre-forward hooks are called
self._forward_redirection = _ForwardRedirection()

self.liger_grpo_loss = LigerFusedLinearGRPOLoss(
beta=self.beta,
epsilon_low=self.epsilon_low,
epsilon_high=self.epsilon_high,
temperature=self.temperature,
use_ref_model=self.beta != 0.0,
loss_type=self.loss_type,
max_completion_length=self.max_completion_length,
)

# Initialize the metrics
self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)}
self._total_train_tokens = 0
Expand Down Expand Up @@ -775,51 +761,6 @@ def _get_eval_sampler(self, eval_dataset) -> Sampler:
seed=self.args.seed,
)

@profiling_decorator
def _get_last_hidden_state(
self,
unwrapped_model,
input_ids,
attention_mask,
logits_to_keep,
pixel_values=None,
image_grid_thw=None,
pixel_attention_mask=None,
image_sizes=None,
):
if is_peft_model(unwrapped_model):
unwrapped_model = unwrapped_model.base_model.model

# Build model inputs - check if the model supports logits_to_keep (some models and VLMs don't)
model_inputs = {"input_ids": input_ids, "attention_mask": attention_mask}

# For Qwen models:
if image_grid_thw is not None and pixel_values is not None:
model_inputs["image_grid_thw"] = image_grid_thw
# For Gemma, SmolVLM2, LLaVa-Next etc.:
if pixel_values is not None:
model_inputs["pixel_values"] = pixel_values
# For SmolVLM2
if pixel_attention_mask is not None:
model_inputs["pixel_attention_mask"] = pixel_attention_mask
# For LLaVa-Next
if image_sizes is not None:
model_inputs["image_sizes"] = image_sizes

# Only add logits_to_keep if the model supports it
if "logits_to_keep" in self.model_kwarg_keys:
# We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded
model_inputs["logits_to_keep"] = logits_to_keep + 1

model_inputs["use_cache"] = False # only used in generation; set False to suppress warnings

last_hidden_state = unwrapped_model.model(**model_inputs).last_hidden_state
# Exclude the last value: it corresponds to the next token pred
last_hidden_state = last_hidden_state[:, :-1, :] # (B, L-1, H)
# Only keep the last logits_to_keep. For model that support logits_to_keep, this is a no-op.
last_hidden_state = last_hidden_state[:, -logits_to_keep:, :] # (B, logits_to_keep, H)
return last_hidden_state

def get_high_entropy_mask(self, entropies: torch.Tensor, mask: torch.Tensor, threshold: float) -> torch.Tensor:
"""
Returns a binary mask identifying tokens whose entropy exceeds a given quantile threshold.
Expand Down Expand Up @@ -1750,45 +1691,58 @@ def _generate_and_score_completions(
output["num_images"] = num_images
return output

def compute_liger_loss(self, unwrapped_model, inputs):
# Compute the per-token log probabilities for the model
def compute_liger_loss(self, model, inputs):
# Compute logits directly and run Triton GRPO loss
prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"]
completion_ids = inputs["completion_ids"].contiguous()
completion_mask = inputs["completion_mask"].contiguous()
input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
logits_to_keep = completion_ids.size(1)

# Get the last hidden state of the model
last_hidden_state = self._get_last_hidden_state(
unwrapped_model,
input_ids,
attention_mask,
logits_to_keep,
inputs.get("pixel_values"),
inputs.get("image_grid_thw"),
inputs.get("pixel_attention_mask"),
inputs.get("image_sizes"),
)
model_inputs = {"input_ids": input_ids, "attention_mask": attention_mask}
optional_keys = [
"pixel_values",
"image_grid_thw",
"pixel_attention_mask",
"image_sizes",
"token_type_ids",
]
for key in optional_keys:
if key in inputs:
model_inputs[key] = inputs[key]
model_inputs["use_cache"] = False
if "logits_to_keep" in self.model_kwarg_keys:
model_inputs["logits_to_keep"] = logits_to_keep + 1

logits = model(**model_inputs).logits
logits = logits[:, -(logits_to_keep + 1) :, :].contiguous()

# compute loss and metrics using liger grpo loss
loss, metrics = self.liger_grpo_loss(
_input=last_hidden_state,
lin_weight=unwrapped_model.lm_head.weight,
selected_token_ids=completion_ids,
attention_mask=completion_mask,
loss, metrics = triton_grpo_loss(
logits=logits,
old_logp=inputs.get("old_per_token_logps"),
ref_logp=inputs.get("ref_per_token_logps"),
completion_ids=completion_ids,
advantages=inputs["advantages"],
bias=unwrapped_model.lm_head.bias,
old_per_token_logps=inputs.get("old_per_token_logps"),
ref_per_token_logps=inputs.get("ref_per_token_logps"),
completion_mask=completion_mask,
temperature=self.temperature,
beta=self.beta,
eps_low=self.epsilon_low,
eps_high=self.epsilon_high,
inplace=True,
loss_type=self.loss_type,
max_completion_length=self.max_completion_length,
importance_sampling_level=self.importance_sampling_level,
reduce=True,
)
# Extract metrics from the liger_grpo_loss output
# KL divergence is the first metric when beta is non-zero
mean_kl = metrics[0] if self.beta != 0.0 else None
clip_ratio = metrics[-1]

mode = "train" if self.model.training else "eval"
metric_offset = 0
if self.beta != 0.0:
self._metrics[mode]["kl"].append(self.accelerator.gather(mean_kl).mean().item())
kl_metric = metrics[0]
self._metrics[mode]["kl"].append(self.accelerator.gather(kl_metric).mean().item())
metric_offset = 1
clip_ratio = metrics[metric_offset]
self._metrics[mode]["clip_ratio"].append(self.accelerator.gather(clip_ratio).mean().item())
return loss / self.current_gradient_accumulation_steps

Expand All @@ -1797,9 +1751,7 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
if return_outputs:
raise ValueError("The GRPOTrainer does not support returning outputs")
if self.use_liger_kernel:
# Compute the loss using the liger grpo loss
unwrapped_model = self.accelerator.unwrap_model(model)
return self._forward_redirection(model, unwrapped_model, self.compute_liger_loss, unwrapped_model, inputs)
return self.compute_liger_loss(model, inputs)
else:
return self._compute_loss(model, inputs)

Expand Down
Loading