Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions trl/trainer/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1118,7 +1118,8 @@ def compute_loss(

# Set aside labels as it will be dropped by super().compute_loss() if a custom `compute_loss_func` is used.
# This can be removed when this issue is fixed.
labels = inputs["labels"]
# When using CP or SP, labels are pre-shifted, we must use shift_labels instead.
labels = inputs["labels"] if "shift_labels" not in inputs else None

# If not set, defaults from model config and may warn since cache isn't compatible with gradient checkpointing
inputs["use_cache"] = False
Expand Down Expand Up @@ -1172,7 +1173,7 @@ def compute_loss(
# Compute accuracy from logits using argmax (traditional method)
with torch.no_grad():
if "shift_labels" in inputs:
# When using CP, labels are pre-shifted. We must use these (and cannot manually shift) because:
# When using CP or SP, labels are pre-shifted. We must use these (and cannot manually shift) because:
# - The first discarded token from inputs["labels"] actually belongs to process n-1
# - The last logits require the label from process n+1
shift_logits = outputs.logits.contiguous()
Expand Down
Loading