-
Notifications
You must be signed in to change notification settings - Fork 3.3k
Fix ambiguous tensor truth-value check in train_rl.loss_func (use .it… #2085
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
…em())
Fixes a runtime error in train_rl.py where a 0-d PyTorch tensor (total_tokens) was compared directly to 0 in an if-statement. Comparing a tensor truth value in Python is ambiguous and can raise an exception ("The truth value of a tensor is ambiguous"). The change uses total_tokens.item() for a safe scalar comparison and creates the fallback tensor with a matching dtype/device.
Motivation
In some RL training configurations (ex: empty bins / no non-padded tokens), loss_mask.sum() produces a 0-d tensor. The original code used if total_tokens == 0: which tries to evaluate the truth value of a tensor and can crash. This fix makes the check explicit and robust, avoiding failures during forward passes.
What I changed
In loss_func in train_rl.py:
Replaced: if total_tokens == 0:
With: if total_tokens.item() == 0: and create the fallback total_tokens with the same device and dtype as loss_mask_flat: total_tokens = torch.tensor(1.0, device=loss_mask_flat.device, dtype=loss_mask_flat.dtype)
No functional changes to the loss computation beyond avoiding the exception and ensuring dtype/device consistency for the fallback scalar.
Files modified
train_rl.py
One-line logic change in loss_func to safely compare the 0-d tensor and to create the fallback tensor with matching dtype/device.
Behavioral impact
Prevents a runtime exception when total_tokens is zero (empty or fully padded microbatches).
Keeps the original behavior of using 1 as the fallback token count to avoid division-by-zero.
Backwards compatible with existing training setups.
Notes
Cosmetic typos in comments were observed (ex: "Without out" -> "Without", "determinisic" -> "deterministic"). Those do not affect runtime; I left them untouched in this change but can include a follow-up cleanup commit if preferred.
tdene
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the contribution!
| # Avoid division by zero for empty bins | ||
| if total_tokens == 0: | ||
| total_tokens = torch.tensor(1.0, device=loss_mask_flat.device) | ||
| # total_tokens is a 0-d tensor; use .item() to compare its scalar value safely in Python |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this comment is unnecessary, could you remove?
| total_tokens = torch.tensor(1.0, device=loss_mask_flat.device) | ||
| # total_tokens is a 0-d tensor; use .item() to compare its scalar value safely in Python | ||
| if total_tokens.item() == 0: | ||
| total_tokens = torch.tensor(1.0, device=loss_mask_flat.device, dtype=loss_mask_flat.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you reformat this line so that it is less than 100 characters in length?
I know the RL files are not good at always following this linting convention, but we should at least try to make any new code follow it. We'll do a comprehensive cleanup pass later, and then add RL files to the auto-linter.
Fixes a runtime error in train_rl.py where a 0-d PyTorch tensor (total_tokens) was compared directly to 0 in an if-statement. Comparing a tensor truth value in Python is ambiguous and can raise an exception ("The truth value of a tensor is ambiguous"). The change uses total_tokens.item() for a safe scalar comparison and creates the fallback tensor with a matching dtype/device. Motivation
In some RL training configurations (ex: empty bins / no non-padded tokens), loss_mask.sum() produces a 0-d tensor. The original code used if total_tokens == 0: which tries to evaluate the truth value of a tensor and can crash. This fix makes the check explicit and robust, avoiding failures during forward passes. What I changed
In loss_func in train_rl.py:
Replaced: if total_tokens == 0:
With: if total_tokens.item() == 0: and create the fallback total_tokens with the same device and dtype as loss_mask_flat: total_tokens = torch.tensor(1.0, device=loss_mask_flat.device, dtype=loss_mask_flat.dtype) No functional changes to the loss computation beyond avoiding the exception and ensuring dtype/device consistency for the fallback scalar.
Files modified
train_rl.py
One-line logic change in loss_func to safely compare the 0-d tensor and to create the fallback tensor with matching dtype/device. Behavioral impact
Prevents a runtime exception when total_tokens is zero (empty or fully padded microbatches). Keeps the original behavior of using 1 as the fallback token count to avoid division-by-zero. Backwards compatible with existing training setups.
Notes
Cosmetic typos in comments were observed (ex: "Without out" -> "Without", "determinisic" -> "deterministic"). Those do not affect runtime; I left them untouched in this change but can include a follow-up cleanup commit if preferred.