Skip to content

Commit c23ac01

Browse files
authored
BUG: isclose integer overflow (#130)
1 parent e3e9a83 commit c23ac01

File tree

2 files changed

+14
-1
lines changed

2 files changed

+14
-1
lines changed

src/array_api_extra/_lib/_funcs.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -335,7 +335,13 @@ def isclose(
335335
atol = int(atol)
336336
if rtol == 0:
337337
return xp.abs(a - b) <= atol
338-
nrtol = int(1.0 / rtol)
338+
339+
try:
340+
nrtol = xp.asarray(int(1.0 / rtol), dtype=b.dtype)
341+
except OverflowError:
342+
# rtol * max_int(dtype) < 1, so it's inconsequential
343+
return xp.abs(a - b) <= atol
344+
339345
return xp.abs(a - b) <= (atol + xp.abs(b) // nrtol)
340346

341347

tests/test_funcs.py

+7
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,13 @@ def test_tolerance(self, dtype: str, xp: ModuleType):
354354
xp_assert_equal(isclose(a, b, rtol=0), xp.asarray([False, False]))
355355
xp_assert_equal(isclose(a, b, atol=1, rtol=0), xp.asarray([True, False]))
356356

357+
@pytest.mark.parametrize("dtype", ["int8", "uint8"])
358+
def test_tolerance_integer_overflow(self, dtype: str, xp: ModuleType):
359+
"""1/rtol is too large for dtype"""
360+
a = xp.asarray([100, 100], dtype=getattr(xp, dtype))
361+
b = xp.asarray([100, 101], dtype=getattr(xp, dtype))
362+
xp_assert_equal(isclose(a, b), xp.asarray([True, False]))
363+
357364
def test_very_small_numbers(self, xp: ModuleType):
358365
a = xp.asarray([1e-9, 1e-9])
359366
b = xp.asarray([1.0001e-9, 1.00001e-9])

0 commit comments

Comments
 (0)