Skip to content

Conversation

@vignesh1507
Copy link
Contributor

Fixes a runtime error in train_rl.py where a 0-d PyTorch tensor (total_tokens) was compared directly to 0 in an if-statement. Comparing a tensor truth value in Python is ambiguous and can raise an exception ("The truth value of a tensor is ambiguous"). The change uses total_tokens.item() for a safe scalar comparison and creates the fallback tensor with a matching dtype/device. Motivation

In some RL training configurations (ex: empty bins / no non-padded tokens), loss_mask.sum() produces a 0-d tensor. The original code used if total_tokens == 0: which tries to evaluate the truth value of a tensor and can crash. This fix makes the check explicit and robust, avoiding failures during forward passes. What I changed

In loss_func in train_rl.py:

Replaced: if total_tokens == 0:
With: if total_tokens.item() == 0: and create the fallback total_tokens with the same device and dtype as loss_mask_flat: total_tokens = torch.tensor(1.0, device=loss_mask_flat.device, dtype=loss_mask_flat.dtype) No functional changes to the loss computation beyond avoiding the exception and ensuring dtype/device consistency for the fallback scalar.

Files modified

train_rl.py
One-line logic change in loss_func to safely compare the 0-d tensor and to create the fallback tensor with matching dtype/device. Behavioral impact

Prevents a runtime exception when total_tokens is zero (empty or fully padded microbatches). Keeps the original behavior of using 1 as the fallback token count to avoid division-by-zero. Backwards compatible with existing training setups.

Notes

Cosmetic typos in comments were observed (ex: "Without out" -> "Without", "determinisic" -> "deterministic"). Those do not affect runtime; I left them untouched in this change but can include a follow-up cleanup commit if preferred.

…em())

Fixes a runtime error in train_rl.py where a 0-d PyTorch tensor (total_tokens) was compared directly to 0 in an if-statement. Comparing a tensor truth value in Python is ambiguous and can raise an exception ("The truth value of a tensor is ambiguous"). The change uses total_tokens.item() for a safe scalar comparison and creates the fallback tensor with a matching dtype/device.
Motivation

In some RL training configurations (ex: empty bins / no non-padded tokens), loss_mask.sum() produces a 0-d tensor. The original code used if total_tokens == 0: which tries to evaluate the truth value of a tensor and can crash. This fix makes the check explicit and robust, avoiding failures during forward passes.
What I changed

In loss_func in train_rl.py:

Replaced: if total_tokens == 0:
With: if total_tokens.item() == 0: and create the fallback total_tokens with the same device and dtype as loss_mask_flat: total_tokens = torch.tensor(1.0, device=loss_mask_flat.device, dtype=loss_mask_flat.dtype)
No functional changes to the loss computation beyond avoiding the exception and ensuring dtype/device consistency for the fallback scalar.

Files modified

train_rl.py
One-line logic change in loss_func to safely compare the 0-d tensor and to create the fallback tensor with matching dtype/device.
Behavioral impact

Prevents a runtime exception when total_tokens is zero (empty or fully padded microbatches).
Keeps the original behavior of using 1 as the fallback token count to avoid division-by-zero.
Backwards compatible with existing training setups.

Notes

Cosmetic typos in comments were observed (ex: "Without out" -> "Without", "determinisic" -> "deterministic"). Those do not affect runtime; I left them untouched in this change but can include a follow-up cleanup commit if preferred.
@vignesh1507 vignesh1507 requested a review from a team as a code owner November 2, 2025 03:40
@copy-pr-bot
Copy link

copy-pr-bot bot commented Nov 2, 2025

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@tdene tdene added the Expert Review Apply this label to indicate that your PR is ready for expert review. label Nov 2, 2025
Copy link
Contributor

@tdene tdene left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the contribution!

# Avoid division by zero for empty bins
if total_tokens == 0:
total_tokens = torch.tensor(1.0, device=loss_mask_flat.device)
# total_tokens is a 0-d tensor; use .item() to compare its scalar value safely in Python
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this comment is unnecessary, could you remove?

total_tokens = torch.tensor(1.0, device=loss_mask_flat.device)
# total_tokens is a 0-d tensor; use .item() to compare its scalar value safely in Python
if total_tokens.item() == 0:
total_tokens = torch.tensor(1.0, device=loss_mask_flat.device, dtype=loss_mask_flat.dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you reformat this line so that it is less than 100 characters in length?

I know the RL files are not good at always following this linting convention, but we should at least try to make any new code follow it. We'll do a comprehensive cleanup pass later, and then add RL files to the auto-linter.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Expert Review Apply this label to indicate that your PR is ready for expert review.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants