Skip to content

Commit 019d935

Browse files
authored
Merge pull request #66 from asmeurer/repeat-fix
Fix issue with repeat()
2 parents 05c8b0f + 6d780a8 commit 019d935

File tree

1 file changed

+9
-3
lines changed

1 file changed

+9
-3
lines changed

array_api_strict/_manipulation_functions.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22

33
from ._array_object import Array
44
from ._creation_functions import asarray
5-
from ._data_type_functions import result_type
6-
from ._dtypes import _integer_dtypes
5+
from ._data_type_functions import astype, result_type
6+
from ._dtypes import _integer_dtypes, int64, uint64
77
from ._flags import requires_api_version, get_array_api_strict_flags
88

99
from typing import TYPE_CHECKING
@@ -94,7 +94,13 @@ def repeat(
9494
else:
9595
raise TypeError("repeats must be an int or array")
9696

97-
return Array._new(np.repeat(x._array, repeats, axis=axis))
97+
if repeats.dtype == uint64:
98+
# NumPy does not allow uint64 because can't be cast down to x.dtype
99+
# with 'safe' casting. However, repeats values larger than 2**63 are
100+
# infeasable, and even if they are present by mistake, this will
101+
# lead to underflow and an error.
102+
repeats = astype(repeats, int64)
103+
return Array._new(np.repeat(x._array, repeats._array, axis=axis))
98104

99105
# Note: the optional argument is called 'shape', not 'newshape'
100106
def reshape(x: Array,

0 commit comments

Comments
 (0)