Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 5ddba5b

Browse files
committedDec 1, 2024·
ENH: allow python scalars in binary elementwise functions
Allow func(array, scalar) and func(scalar, array), raise on func(scalar, scalar) if API_VERSION>=2024.12 cross-ref data-apis/array-api#807 To make sure it is all uniform, 1. Generate all binary "ufuncs" in a uniform way, with a decorator 2. Make binary "ufuncs" follow the same logic of the binary operators 3. Reuse the test loop of Array.__binop__ for binary "ufuncs" 4. (minor) in tests, reuse canonical names for dtype categories ("integer or boolean" vs "integer_or_boolean")
1 parent d086c61 commit 5ddba5b

File tree

5 files changed

+266
-512
lines changed

5 files changed

+266
-512
lines changed
 

‎array_api_strict/_array_object.py

+2
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,8 @@ def _check_device(self, other):
234234
elif isinstance(other, Array):
235235
if self.device != other.device:
236236
raise ValueError(f"Arrays from two different devices ({self.device} and {other.device}) can not be combined.")
237+
else:
238+
raise TypeError(f"Cannot combine an Array with {type(other)}.")
237239

238240
# Helper function to match the type promotion rules in the spec
239241
def _promote_scalar(self, scalar):

‎array_api_strict/_elementwise_functions.py

+117-467
Large diffs are not rendered by default.

‎array_api_strict/_helpers.py

+37
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
"""Private helper routines.
2+
"""
3+
4+
from ._flags import get_array_api_strict_flags
5+
from ._dtypes import _dtype_categories
6+
7+
_py_scalars = (bool, int, float, complex)
8+
9+
10+
def _maybe_normalize_py_scalars(x1, x2, dtype_category, func_name):
11+
12+
flags = get_array_api_strict_flags()
13+
if flags["api_version"] < "2024.12":
14+
# scalars will fail at the call site
15+
return x1, x2
16+
17+
_allowed_dtypes = _dtype_categories[dtype_category]
18+
19+
if isinstance(x1, _py_scalars):
20+
if isinstance(x2, _py_scalars):
21+
raise TypeError(f"Two scalars not allowed, got {type(x1) = } and {type(x2) =}")
22+
# x2 must be an array
23+
if x2.dtype not in _allowed_dtypes:
24+
raise TypeError(f"Only {dtype_category} dtypes are allowed {func_name}. Got {x2.dtype}.")
25+
x1 = x2._promote_scalar(x1)
26+
27+
elif isinstance(x2, _py_scalars):
28+
# x1 must be an array
29+
if x1.dtype not in _allowed_dtypes:
30+
raise TypeError(f"Only {dtype_category} dtypes are allowed {func_name}. Got {x1.dtype}.")
31+
x2 = x1._promote_scalar(x2)
32+
else:
33+
if x1.dtype not in _allowed_dtypes or x2.dtype not in _allowed_dtypes:
34+
raise TypeError(f"Only {dtype_category} dtypes are allowed {func_name}. "
35+
f"Got {x1.dtype} and {x2.dtype}.")
36+
return x1, x2
37+

‎array_api_strict/tests/test_array_object.py

+57-44
Original file line numberDiff line numberDiff line change
@@ -96,12 +96,60 @@ def test_promoted_scalar_inherits_device():
9696

9797
assert y.device == device1
9898

99+
100+
BIG_INT = int(1e30)
101+
102+
def _check_op_array_scalar(dtypes, a, s, func, func_name, BIG_INT=BIG_INT):
103+
# Test array op scalar. From the spec, the following combinations
104+
# are supported:
105+
106+
# - Python bool for a bool array dtype,
107+
# - a Python int within the bounds of the given dtype for integer array dtypes,
108+
# - a Python int or float for real floating-point array dtypes
109+
# - a Python int, float, or complex for complex floating-point array dtypes
110+
111+
if ((dtypes == "all"
112+
or dtypes == "numeric" and a.dtype in _numeric_dtypes
113+
or dtypes == "real numeric" and a.dtype in _real_numeric_dtypes
114+
or dtypes == "integer" and a.dtype in _integer_dtypes
115+
or dtypes == "integer or boolean" and a.dtype in _integer_or_boolean_dtypes
116+
or dtypes == "boolean" and a.dtype in _boolean_dtypes
117+
or dtypes == "floating-point" and a.dtype in _floating_dtypes
118+
or dtypes == "real floating-point" and a.dtype in _real_floating_dtypes
119+
)
120+
# bool is a subtype of int, which is why we avoid
121+
# isinstance here.
122+
and (a.dtype in _boolean_dtypes and type(s) == bool
123+
or a.dtype in _integer_dtypes and type(s) == int
124+
or a.dtype in _real_floating_dtypes and type(s) in [float, int]
125+
or a.dtype in _complex_floating_dtypes and type(s) in [complex, float, int]
126+
)):
127+
if a.dtype in _integer_dtypes and s == BIG_INT:
128+
with assert_raises(OverflowError):
129+
func(s)
130+
return False
131+
132+
else:
133+
# Only test for no error
134+
with suppress_warnings() as sup:
135+
# ignore warnings from pow(BIG_INT)
136+
sup.filter(RuntimeWarning,
137+
"invalid value encountered in power")
138+
func(s)
139+
return True
140+
141+
else:
142+
with assert_raises(TypeError):
143+
func(s)
144+
return False
145+
146+
99147
def test_operators():
100148
# For every operator, we test that it works for the required type
101149
# combinations and raises TypeError otherwise
102150
binary_op_dtypes = {
103151
"__add__": "numeric",
104-
"__and__": "integer_or_boolean",
152+
"__and__": "integer or boolean",
105153
"__eq__": "all",
106154
"__floordiv__": "real numeric",
107155
"__ge__": "real numeric",
@@ -112,12 +160,12 @@ def test_operators():
112160
"__mod__": "real numeric",
113161
"__mul__": "numeric",
114162
"__ne__": "all",
115-
"__or__": "integer_or_boolean",
163+
"__or__": "integer or boolean",
116164
"__pow__": "numeric",
117165
"__rshift__": "integer",
118166
"__sub__": "numeric",
119-
"__truediv__": "floating",
120-
"__xor__": "integer_or_boolean",
167+
"__truediv__": "floating-point",
168+
"__xor__": "integer or boolean",
121169
}
122170
# Recompute each time because of in-place ops
123171
def _array_vals():
@@ -128,8 +176,6 @@ def _array_vals():
128176
for d in _floating_dtypes:
129177
yield asarray(1.0, dtype=d)
130178

131-
132-
BIG_INT = int(1e30)
133179
for op, dtypes in binary_op_dtypes.items():
134180
ops = [op]
135181
if op not in ["__eq__", "__ne__", "__le__", "__ge__", "__lt__", "__gt__"]:
@@ -139,40 +185,7 @@ def _array_vals():
139185
for s in [1, 1.0, 1j, BIG_INT, False]:
140186
for _op in ops:
141187
for a in _array_vals():
142-
# Test array op scalar. From the spec, the following combinations
143-
# are supported:
144-
145-
# - Python bool for a bool array dtype,
146-
# - a Python int within the bounds of the given dtype for integer array dtypes,
147-
# - a Python int or float for real floating-point array dtypes
148-
# - a Python int, float, or complex for complex floating-point array dtypes
149-
150-
if ((dtypes == "all"
151-
or dtypes == "numeric" and a.dtype in _numeric_dtypes
152-
or dtypes == "real numeric" and a.dtype in _real_numeric_dtypes
153-
or dtypes == "integer" and a.dtype in _integer_dtypes
154-
or dtypes == "integer_or_boolean" and a.dtype in _integer_or_boolean_dtypes
155-
or dtypes == "boolean" and a.dtype in _boolean_dtypes
156-
or dtypes == "floating" and a.dtype in _floating_dtypes
157-
)
158-
# bool is a subtype of int, which is why we avoid
159-
# isinstance here.
160-
and (a.dtype in _boolean_dtypes and type(s) == bool
161-
or a.dtype in _integer_dtypes and type(s) == int
162-
or a.dtype in _real_floating_dtypes and type(s) in [float, int]
163-
or a.dtype in _complex_floating_dtypes and type(s) in [complex, float, int]
164-
)):
165-
if a.dtype in _integer_dtypes and s == BIG_INT:
166-
assert_raises(OverflowError, lambda: getattr(a, _op)(s))
167-
else:
168-
# Only test for no error
169-
with suppress_warnings() as sup:
170-
# ignore warnings from pow(BIG_INT)
171-
sup.filter(RuntimeWarning,
172-
"invalid value encountered in power")
173-
getattr(a, _op)(s)
174-
else:
175-
assert_raises(TypeError, lambda: getattr(a, _op)(s))
188+
_check_op_array_scalar(dtypes, a, s, getattr(a, _op), _op)
176189

177190
# Test array op array.
178191
for _op in ops:
@@ -203,18 +216,18 @@ def _array_vals():
203216
or (dtypes == "real numeric" and x.dtype in _real_numeric_dtypes and y.dtype in _real_numeric_dtypes)
204217
or (dtypes == "numeric" and x.dtype in _numeric_dtypes and y.dtype in _numeric_dtypes)
205218
or dtypes == "integer" and x.dtype in _integer_dtypes and y.dtype in _integer_dtypes
206-
or dtypes == "integer_or_boolean" and (x.dtype in _integer_dtypes and y.dtype in _integer_dtypes
219+
or dtypes == "integer or boolean" and (x.dtype in _integer_dtypes and y.dtype in _integer_dtypes
207220
or x.dtype in _boolean_dtypes and y.dtype in _boolean_dtypes)
208221
or dtypes == "boolean" and x.dtype in _boolean_dtypes and y.dtype in _boolean_dtypes
209-
or dtypes == "floating" and x.dtype in _floating_dtypes and y.dtype in _floating_dtypes
222+
or dtypes == "floating-point" and x.dtype in _floating_dtypes and y.dtype in _floating_dtypes
210223
):
211224
getattr(x, _op)(y)
212225
else:
213226
assert_raises(TypeError, lambda: getattr(x, _op)(y))
214227

215228
unary_op_dtypes = {
216229
"__abs__": "numeric",
217-
"__invert__": "integer_or_boolean",
230+
"__invert__": "integer or boolean",
218231
"__neg__": "numeric",
219232
"__pos__": "numeric",
220233
}
@@ -223,7 +236,7 @@ def _array_vals():
223236
if (
224237
dtypes == "numeric"
225238
and a.dtype in _numeric_dtypes
226-
or dtypes == "integer_or_boolean"
239+
or dtypes == "integer or boolean"
227240
and a.dtype in _integer_or_boolean_dtypes
228241
):
229242
# Only test for no error

‎array_api_strict/tests/test_elementwise_functions.py

+53-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from inspect import signature, getmodule
22

3-
from numpy.testing import assert_raises
3+
from pytest import raises as assert_raises
4+
from numpy.testing import suppress_warnings
45

56
import pytest
67

@@ -19,6 +20,8 @@
1920
)
2021
from .._flags import set_array_api_strict_flags
2122

23+
from .test_array_object import _check_op_array_scalar, BIG_INT
24+
2225
import array_api_strict
2326

2427

@@ -120,6 +123,7 @@ def test_missing_functions():
120123
# Ensure the above dictionary is complete.
121124
import array_api_strict._elementwise_functions as mod
122125
mod_funcs = [n for n in dir(mod) if getmodule(getattr(mod, n)) is mod]
126+
mod_funcs = [n for n in mod_funcs if not n.startswith("_")]
123127
assert set(mod_funcs) == set(elementwise_function_input_types)
124128

125129

@@ -202,3 +206,51 @@ def test_bitwise_shift_error():
202206
assert_raises(
203207
ValueError, lambda: bitwise_right_shift(asarray([1, 1]), asarray([1, -1]))
204208
)
209+
210+
211+
212+
def test_scalars():
213+
# mirror test_array_object.py::test_operators()
214+
#
215+
# Also check that binary functions accept (array, scalar) and (scalar, array)
216+
# arguments, and reject (scalar, scalar) arguments.
217+
218+
# Use the latest version of the standard so that scalars are actually allowed
219+
with pytest.warns(UserWarning):
220+
set_array_api_strict_flags(api_version="2024.12")
221+
222+
def _array_vals():
223+
for d in _integer_dtypes:
224+
yield asarray(1, dtype=d)
225+
for d in _boolean_dtypes:
226+
yield asarray(False, dtype=d)
227+
for d in _floating_dtypes:
228+
yield asarray(1.0, dtype=d)
229+
230+
231+
for func_name, dtypes in elementwise_function_input_types.items():
232+
func = getattr(_elementwise_functions, func_name)
233+
if nargs(func) != 2:
234+
continue
235+
236+
for s in [1, 1.0, 1j, BIG_INT, False]:
237+
for a in _array_vals():
238+
for func1 in [lambda s: func(a, s), lambda s: func(s, a)]:
239+
allowed = _check_op_array_scalar(dtypes, a, s, func1, func_name)
240+
241+
# only check `func(array, scalar) == `func(array, array)` if
242+
# the former is legal under the promotion rules
243+
if allowed:
244+
conv_scalar = asarray(s, dtype=a.dtype)
245+
246+
with suppress_warnings() as sup:
247+
# ignore warnings from pow(BIG_INT)
248+
sup.filter(RuntimeWarning,
249+
"invalid value encountered in power")
250+
assert func(s, a) == func(conv_scalar, a)
251+
assert func(a, s) == func(a, conv_scalar)
252+
253+
with assert_raises(TypeError):
254+
func(s, s)
255+
256+

0 commit comments

Comments
 (0)
Please sign in to comment.