diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py index 8e0ae98a48..dc70b50a67 100644 --- a/trl/trainer/sft_trainer.py +++ b/trl/trainer/sft_trainer.py @@ -1118,7 +1118,8 @@ def compute_loss( # Set aside labels as it will be dropped by super().compute_loss() if a custom `compute_loss_func` is used. # This can be removed when this issue is fixed. - labels = inputs["labels"] + # When using CP or SP, labels are pre-shifted, we must use shift_labels instead. + labels = inputs["labels"] if "shift_labels" not in inputs else None # If not set, defaults from model config and may warn since cache isn't compatible with gradient checkpointing inputs["use_cache"] = False @@ -1172,7 +1173,7 @@ def compute_loss( # Compute accuracy from logits using argmax (traditional method) with torch.no_grad(): if "shift_labels" in inputs: - # When using CP, labels are pre-shifted. We must use these (and cannot manually shift) because: + # When using CP or SP, labels are pre-shifted. We must use these (and cannot manually shift) because: # - The first discarded token from inputs["labels"] actually belongs to process n-1 # - The last logits require the label from process n+1 shift_logits = outputs.logits.contiguous()