We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 2da7a2e commit cd85158Copy full SHA for cd85158
loss_scaler.py
@@ -51,11 +51,10 @@ def has_overflow(self, params):
51
52
# `x` is a torch.Tensor
53
def _has_inf_or_nan(x):
54
- inf_count = torch.sum(x.abs() == float('inf'))
55
- if inf_count > 0:
+ cpu_sum = float(x.float().sum())
+ if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum:
56
return True
57
- nan_count = torch.sum(x != x)
58
- return nan_count > 0
+ return False
59
60
# `overflow` is boolean indicating whether we overflowed in gradient
61
def update_scale(self, overflow):
0 commit comments