Skip to content

Commit cd85158

Browse files
author
Rafael Valle
committed
loss_scaler.py: patching loss scaler for compatibility with current pytorch
1 parent 2da7a2e commit cd85158

File tree

1 file changed

+3
-4
lines changed

1 file changed

+3
-4
lines changed

loss_scaler.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,10 @@ def has_overflow(self, params):
5151

5252
# `x` is a torch.Tensor
5353
def _has_inf_or_nan(x):
54-
inf_count = torch.sum(x.abs() == float('inf'))
55-
if inf_count > 0:
54+
cpu_sum = float(x.float().sum())
55+
if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum:
5656
return True
57-
nan_count = torch.sum(x != x)
58-
return nan_count > 0
57+
return False
5958

6059
# `overflow` is boolean indicating whether we overflowed in gradient
6160
def update_scale(self, overflow):

0 commit comments

Comments
 (0)