From e06f332f616078fdab9037aab93cf56c101dac9e Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Wed, 6 Aug 2025 22:26:39 +0000 Subject: [PATCH 1/2] fix loss masking --- fast_llm/functional/cross_entropy.py | 25 ++++++++++++++----------- tests/layers/test_lm_head.py | 6 +++--- 2 files changed, 17 insertions(+), 14 deletions(-) diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index d56dce98..c19cc946 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -35,12 +35,11 @@ def _torch_cross_entropy_forward_backward( logits_ if logits_scale_factor == 1 else logits_ * logits_scale_factor, target ) else: - loss = ( - torch.nn.functional.cross_entropy( - logits_ if logits_scale_factor == 1 else logits_ * logits_scale_factor, target, reduction="none" - ) - * loss_mask - ).mean() + per_token_loss = torch.nn.functional.cross_entropy( + logits_ if logits_scale_factor == 1 else logits_ * logits_scale_factor, target, reduction="none" + ) + loss = (per_token_loss * loss_mask).sum() / loss_mask.sum() + if grad_output is None: grad = None else: @@ -133,7 +132,9 @@ def _fused_cross_entropy_forward_backward( if logits_scale_factor != 1.0: grad *= logits_scale_factor if loss_mask is not None: - grad *= loss_mask + # Take into account the modified denominator due to loss masking. + loss_masking_grad_factor = logits.size(0) / loss_mask.sum() if loss_mask.sum() > 0 else 1.0 + grad *= loss_mask * loss_masking_grad_factor grad = grad.to(logits.dtype) # loss = mean(log(sum_exp_logits) - sum(probabilities * logits)) @@ -149,9 +150,11 @@ def _fused_cross_entropy_forward_backward( if loss_mask is not None: per_sample_loss = per_sample_loss * loss_mask - loss = per_sample_loss.mean() + loss = per_sample_loss.sum() / ( + loss_mask.sum() if loss_mask is not None else torch.tensor(per_sample_loss.numel()) + ) if target_format != TargetFormat.labels and group is not None: - all_reduce(loss, op=ReduceOp.MEAN, group=group) + all_reduce(loss, op=ReduceOp.AVG, group=group) return loss, grad @@ -274,10 +277,10 @@ def _torch_reverse_kl_forward_backward( loss_per_sample = torch.nn.functional.kl_div( teacher_log_probs, student_log_probs, reduction="none", log_target=True ).sum(dim=-1) - loss = (loss_per_sample * loss_mask).mean() + loss = (loss_per_sample * loss_mask).sum() / loss_mask.sum() if group is not None and target_format != TargetFormat.labels: - all_reduce(loss, op=ReduceOp.MEAN, group=group) + all_reduce(loss, op=ReduceOp.AVG, group=group) if grad_output is not None: loss.backward(torch.full_like(loss, grad_output)) diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index 9a878c49..5c1acaa3 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -23,7 +23,7 @@ def _reverse_kl_loss( ): scaled_target = target / teacher_softmax_temperature - scaled_target = torch.clamp(target, min=-50, max=50) + scaled_target = torch.clamp(scaled_target, min=-50, max=50) teacher_log_probs = torch.log_softmax(scaled_target, dim=-1) with torch.enable_grad(): @@ -42,7 +42,7 @@ def _reverse_kl_loss( loss_per_sample = torch.nn.functional.kl_div( teacher_log_probs, student_log_probs, reduction="none", log_target=True ).sum(dim=-1) - loss = (loss_per_sample * loss_mask.flatten()).mean() + loss = (loss_per_sample * loss_mask.flatten()).sum() / loss_mask.sum() return loss @@ -84,7 +84,7 @@ def _lm_head( ) if loss_mask is not None: loss = loss * loss_mask.flatten() - loss = loss.mean() + loss = loss.sum() / (loss_mask.sum() if loss_mask is not None else torch.tensor(loss.numel())) else: loss = torch.nn.functional.cross_entropy(logits.flatten(0, -2), target.flatten()) loss.backward(torch.full_like(loss, grad_output)) From 18ddbe59cfa9916cf9d84c2e5d119b0c6a4e5c42 Mon Sep 17 00:00:00 2001 From: Toolkit User Date: Thu, 7 Aug 2025 15:36:13 +0000 Subject: [PATCH 2/2] prevent nan with fully-masked micro-sequence --- fast_llm/functional/cross_entropy.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/fast_llm/functional/cross_entropy.py b/fast_llm/functional/cross_entropy.py index c19cc946..f4499970 100644 --- a/fast_llm/functional/cross_entropy.py +++ b/fast_llm/functional/cross_entropy.py @@ -38,7 +38,7 @@ def _torch_cross_entropy_forward_backward( per_token_loss = torch.nn.functional.cross_entropy( logits_ if logits_scale_factor == 1 else logits_ * logits_scale_factor, target, reduction="none" ) - loss = (per_token_loss * loss_mask).sum() / loss_mask.sum() + loss = (per_token_loss * loss_mask).sum() / (loss_mask.sum() if loss_mask.sum() > 0 else 1.0) if grad_output is None: grad = None @@ -150,9 +150,8 @@ def _fused_cross_entropy_forward_backward( if loss_mask is not None: per_sample_loss = per_sample_loss * loss_mask - loss = per_sample_loss.sum() / ( - loss_mask.sum() if loss_mask is not None else torch.tensor(per_sample_loss.numel()) - ) + loss_mask_sum = loss_mask.sum() if loss_mask is not None else torch.tensor(per_sample_loss.numel()) + loss = per_sample_loss.sum() / (loss_mask_sum if loss_mask_sum > 0 else 1.0) if target_format != TargetFormat.labels and group is not None: all_reduce(loss, op=ReduceOp.AVG, group=group) @@ -277,7 +276,7 @@ def _torch_reverse_kl_forward_backward( loss_per_sample = torch.nn.functional.kl_div( teacher_log_probs, student_log_probs, reduction="none", log_target=True ).sum(dim=-1) - loss = (loss_per_sample * loss_mask).sum() / loss_mask.sum() + loss = (loss_per_sample * loss_mask).sum() / (loss_mask.sum() if loss_mask.sum() > 0 else 1.0) if group is not None and target_format != TargetFormat.labels: all_reduce(loss, op=ReduceOp.AVG, group=group)