Skip to content

Commit a5cb116

Browse files
ENH: at support for bool mask in Dask and JAX (#121)
Co-authored-by: Lucas Colley <[email protected]>
1 parent 37b116a commit a5cb116

File tree

5 files changed

+153
-32
lines changed

5 files changed

+153
-32
lines changed

src/array_api_extra/_lib/_at.py

+80-12
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,12 @@
99
from types import ModuleType
1010
from typing import ClassVar, cast
1111

12-
from ._utils._compat import array_namespace, is_jax_array, is_writeable_array
12+
from ._utils._compat import (
13+
array_namespace,
14+
is_dask_array,
15+
is_jax_array,
16+
is_writeable_array,
17+
)
1318
from ._utils._typing import Array, Index
1419

1520

@@ -141,6 +146,25 @@ class at: # pylint: disable=invalid-name # numpydoc ignore=PR02
141146
not explicitly covered by ``array-api-compat``, are not supported by update
142147
methods.
143148
149+
Boolean masks are supported on Dask and jitted JAX arrays exclusively
150+
when `idx` has the same shape as `x` and `y` is 0-dimensional.
151+
Note that this support is not available in JAX's native
152+
``x.at[mask].set(y)``.
153+
154+
This pattern::
155+
156+
>>> mask = m(x)
157+
>>> x[mask] = f(x[mask])
158+
159+
Can't be replaced by `at`, as it won't work on Dask and JAX inside jax.jit::
160+
161+
>>> mask = m(x)
162+
>>> x = xpx.at(x, mask).set(f(x[mask]) # Crash on Dask and jax.jit
163+
164+
You should instead use::
165+
166+
>>> x = xp.where(m(x), f(x), x)
167+
144168
Examples
145169
--------
146170
Given either of these equivalent expressions::
@@ -189,6 +213,7 @@ def _op(
189213
self,
190214
at_op: _AtOp,
191215
in_place_op: Callable[[Array, Array | object], Array] | None,
216+
out_of_place_op: Callable[[Array, Array], Array] | None,
192217
y: Array | object,
193218
/,
194219
copy: bool | None,
@@ -210,6 +235,16 @@ def _op(
210235
211236
x[idx] = y
212237
238+
out_of_place_op : Callable[[Array, Array], Array] | None
239+
Out-of-place operation to apply when idx is a boolean mask and the backend
240+
doesn't support in-place updates::
241+
242+
x = xp.where(idx, out_of_place_op(x, y), x)
243+
244+
If None::
245+
246+
x = xp.where(idx, y, x)
247+
213248
y : array or object
214249
Right-hand side of the operation.
215250
copy : bool or None
@@ -223,6 +258,7 @@ def _op(
223258
Updated `x`.
224259
"""
225260
x, idx = self._x, self._idx
261+
xp = array_namespace(x, y) if xp is None else xp
226262

227263
if idx is _undef:
228264
msg = (
@@ -247,15 +283,41 @@ def _op(
247283
else:
248284
writeable = is_writeable_array(x)
249285

286+
# JAX inside jax.jit and Dask don't support in-place updates with boolean
287+
# mask. However we can handle the common special case of 0-dimensional y
288+
# with where(idx, y, x) instead.
289+
if (
290+
(is_dask_array(idx) or is_jax_array(idx))
291+
and idx.dtype == xp.bool
292+
and idx.shape == x.shape
293+
):
294+
y_xp = xp.asarray(y, dtype=x.dtype)
295+
if y_xp.ndim == 0:
296+
if out_of_place_op:
297+
# FIXME: suppress inf warnings on dask with lazywhere
298+
out = xp.where(idx, out_of_place_op(x, y_xp), x)
299+
# Undo int->float promotion on JAX after _AtOp.DIVIDE
300+
out = xp.astype(out, x.dtype, copy=False)
301+
else:
302+
out = xp.where(idx, y_xp, x)
303+
304+
if copy:
305+
return out
306+
x[()] = out
307+
return x
308+
# else: this will work on eager JAX and crash on jax.jit and Dask
309+
250310
if copy:
251311
if is_jax_array(x):
252312
# Use JAX's at[]
253313
func = cast(Callable[[Array], Array], getattr(x.at[idx], at_op.value))
254-
return func(y)
314+
out = func(y)
315+
# Undo int->float promotion on JAX after _AtOp.DIVIDE
316+
return xp.astype(out, x.dtype, copy=False)
317+
255318
# Emulate at[] behaviour for non-JAX arrays
256319
# with a copy followed by an update
257-
if xp is None:
258-
xp = array_namespace(x)
320+
259321
x = xp.asarray(x, copy=True)
260322
if writeable is False:
261323
# A copy of a read-only numpy array is writeable
@@ -283,7 +345,7 @@ def set(
283345
xp: ModuleType | None = None,
284346
) -> Array: # numpydoc ignore=PR01,RT01
285347
"""Apply ``x[idx] = y`` and return the update array."""
286-
return self._op(_AtOp.SET, None, y, copy=copy, xp=xp)
348+
return self._op(_AtOp.SET, None, None, y, copy=copy, xp=xp)
287349

288350
def add(
289351
self,
@@ -297,7 +359,7 @@ def add(
297359
# Note for this and all other methods based on _iop:
298360
# operator.iadd and operator.add subtly differ in behaviour, as
299361
# only iadd will trigger exceptions when y has an incompatible dtype.
300-
return self._op(_AtOp.ADD, operator.iadd, y, copy=copy, xp=xp)
362+
return self._op(_AtOp.ADD, operator.iadd, operator.add, y, copy=copy, xp=xp)
301363

302364
def subtract(
303365
self,
@@ -307,7 +369,9 @@ def subtract(
307369
xp: ModuleType | None = None,
308370
) -> Array: # numpydoc ignore=PR01,RT01
309371
"""Apply ``x[idx] -= y`` and return the updated array."""
310-
return self._op(_AtOp.SUBTRACT, operator.isub, y, copy=copy, xp=xp)
372+
return self._op(
373+
_AtOp.SUBTRACT, operator.isub, operator.sub, y, copy=copy, xp=xp
374+
)
311375

312376
def multiply(
313377
self,
@@ -317,7 +381,9 @@ def multiply(
317381
xp: ModuleType | None = None,
318382
) -> Array: # numpydoc ignore=PR01,RT01
319383
"""Apply ``x[idx] *= y`` and return the updated array."""
320-
return self._op(_AtOp.MULTIPLY, operator.imul, y, copy=copy, xp=xp)
384+
return self._op(
385+
_AtOp.MULTIPLY, operator.imul, operator.mul, y, copy=copy, xp=xp
386+
)
321387

322388
def divide(
323389
self,
@@ -327,7 +393,9 @@ def divide(
327393
xp: ModuleType | None = None,
328394
) -> Array: # numpydoc ignore=PR01,RT01
329395
"""Apply ``x[idx] /= y`` and return the updated array."""
330-
return self._op(_AtOp.DIVIDE, operator.itruediv, y, copy=copy, xp=xp)
396+
return self._op(
397+
_AtOp.DIVIDE, operator.itruediv, operator.truediv, y, copy=copy, xp=xp
398+
)
331399

332400
def power(
333401
self,
@@ -337,7 +405,7 @@ def power(
337405
xp: ModuleType | None = None,
338406
) -> Array: # numpydoc ignore=PR01,RT01
339407
"""Apply ``x[idx] **= y`` and return the updated array."""
340-
return self._op(_AtOp.POWER, operator.ipow, y, copy=copy, xp=xp)
408+
return self._op(_AtOp.POWER, operator.ipow, operator.pow, y, copy=copy, xp=xp)
341409

342410
def min(
343411
self,
@@ -349,7 +417,7 @@ def min(
349417
"""Apply ``x[idx] = minimum(x[idx], y)`` and return the updated array."""
350418
xp = array_namespace(self._x) if xp is None else xp
351419
y = xp.asarray(y)
352-
return self._op(_AtOp.MIN, xp.minimum, y, copy=copy, xp=xp)
420+
return self._op(_AtOp.MIN, xp.minimum, xp.minimum, y, copy=copy, xp=xp)
353421

354422
def max(
355423
self,
@@ -361,4 +429,4 @@ def max(
361429
"""Apply ``x[idx] = maximum(x[idx], y)`` and return the updated array."""
362430
xp = array_namespace(self._x) if xp is None else xp
363431
y = xp.asarray(y)
364-
return self._op(_AtOp.MAX, xp.maximum, y, copy=copy, xp=xp)
432+
return self._op(_AtOp.MAX, xp.maximum, xp.maximum, y, copy=copy, xp=xp)

src/array_api_extra/_lib/_utils/_compat.py

+3
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
device,
99
is_array_api_strict_namespace,
1010
is_cupy_namespace,
11+
is_dask_array,
1112
is_dask_namespace,
1213
is_jax_array,
1314
is_jax_namespace,
@@ -23,6 +24,7 @@
2324
device,
2425
is_array_api_strict_namespace,
2526
is_cupy_namespace,
27+
is_dask_array,
2628
is_dask_namespace,
2729
is_jax_array,
2830
is_jax_namespace,
@@ -38,6 +40,7 @@
3840
"device",
3941
"is_array_api_strict_namespace",
4042
"is_cupy_namespace",
43+
"is_dask_array",
4144
"is_dask_namespace",
4245
"is_jax_array",
4346
"is_jax_namespace",

src/array_api_extra/_lib/_utils/_compat.pyi

+1
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def is_jax_namespace(xp: ModuleType, /) -> bool: ...
2525
def is_numpy_namespace(xp: ModuleType, /) -> bool: ...
2626
def is_pydata_sparse_namespace(xp: ModuleType, /) -> bool: ...
2727
def is_torch_namespace(xp: ModuleType, /) -> bool: ...
28+
def is_dask_array(x: object, /) -> bool: ...
2829
def is_jax_array(x: object, /) -> bool: ...
2930
def is_writeable_array(x: object, /) -> bool: ...
3031
def size(x: Array, /) -> int | None: ...

tests/test_at.py

+67-20
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import math
12
import pickle
23
from collections.abc import Callable, Generator
34
from contextlib import contextmanager
@@ -15,6 +16,12 @@
1516
from array_api_extra._lib._utils._typing import Array, Index
1617
from array_api_extra.testing import lazy_xp_function
1718

19+
pytestmark = [
20+
pytest.mark.skip_xp_backend(
21+
Backend.SPARSE, reason="read-only backend without .at support"
22+
)
23+
]
24+
1825

1926
def at_op( # type: ignore[no-any-explicit]
2027
x: Array,
@@ -70,9 +77,6 @@ def assert_copy(array: Array, copy: bool | None) -> Generator[None, None, None]:
7077
xp_assert_equal(xp.all(array == array_orig), xp.asarray(copy))
7178

7279

73-
@pytest.mark.skip_xp_backend(
74-
Backend.SPARSE, reason="read-only backend without .at support"
75-
)
7680
@pytest.mark.parametrize(
7781
("kwargs", "expect_copy"),
7882
[
@@ -100,14 +104,7 @@ def assert_copy(array: Array, copy: bool | None) -> Generator[None, None, None]:
100104
[
101105
(False, False),
102106
(False, True),
103-
pytest.param(
104-
True,
105-
False,
106-
marks=(
107-
pytest.mark.skip_xp_backend(Backend.JAX, reason="TODO special case"),
108-
pytest.mark.skip_xp_backend(Backend.DASK, reason="TODO special case"),
109-
),
110-
),
107+
(True, False), # Uses xp.where(idx, y, x) on JAX and Dask
111108
pytest.param(
112109
True,
113110
True,
@@ -176,22 +173,72 @@ def test_alternate_index_syntax():
176173
at(a, 0)[0].set(4)
177174

178175

179-
@pytest.mark.parametrize("copy", [True, False])
180-
@pytest.mark.parametrize(
181-
"op", [_AtOp.ADD, _AtOp.SUBTRACT, _AtOp.MULTIPLY, _AtOp.DIVIDE, _AtOp.POWER]
182-
)
183-
def test_iops_incompatible_dtype(op: _AtOp, copy: bool):
176+
@pytest.mark.parametrize("copy", [True, None])
177+
@pytest.mark.parametrize("bool_mask", [False, True])
178+
@pytest.mark.parametrize("op", list(_AtOp))
179+
def test_incompatible_dtype(
180+
xp: ModuleType, library: Backend, op: _AtOp, copy: bool | None, bool_mask: bool
181+
):
184182
"""Test that at() replicates the backend's behaviour for
185183
in-place operations with incompatible dtypes.
186184
187-
Note:
185+
Behavior is backend-specific, but only two behaviors are allowed:
186+
1. raise an exception, or
187+
2. return the same dtype as x, disregarding y.dtype (no broadcasting).
188+
189+
Note that __i<op>__ and __<op>__ behave differently, and we want to
190+
replicate the behavior of __i<op>__:
191+
188192
>>> a = np.asarray([1, 2, 3])
189193
>>> a / 1.5
190194
array([0. , 0.66666667, 1.33333333])
191195
>>> a /= 1.5
192196
UFuncTypeError: Cannot cast ufunc 'divide' output from dtype('float64')
193197
to dtype('int64') with casting rule 'same_kind'
194198
"""
195-
x = np.asarray([2, 4])
196-
with pytest.raises(TypeError, match="Cannot cast ufunc"):
197-
at_op(x, slice(None), op, 1.1, copy=copy)
199+
x = xp.asarray([2, 4])
200+
idx = xp.asarray([True, False]) if bool_mask else slice(None)
201+
z = None
202+
203+
if library is Backend.JAX:
204+
if bool_mask:
205+
z = at_op(x, idx, op, 1.1, copy=copy)
206+
else:
207+
with pytest.warns(FutureWarning, match="cannot safely cast"):
208+
z = at_op(x, idx, op, 1.1, copy=copy)
209+
210+
elif library is Backend.DASK:
211+
if op in (_AtOp.MIN, _AtOp.MAX):
212+
pytest.xfail(reason="need array-api-compat 1.11")
213+
z = at_op(x, idx, op, 1.1, copy=copy)
214+
215+
elif library is Backend.ARRAY_API_STRICT and op is not _AtOp.SET:
216+
with pytest.raises(Exception, match=r"cast|promote|dtype"):
217+
at_op(x, idx, op, 1.1, copy=copy)
218+
219+
elif op in (_AtOp.SET, _AtOp.MIN, _AtOp.MAX):
220+
# There is no __i<op>__ version of these operations
221+
z = at_op(x, idx, op, 1.1, copy=copy)
222+
223+
else:
224+
with pytest.raises(Exception, match=r"cast|promote|dtype"):
225+
at_op(x, idx, op, 1.1, copy=copy)
226+
227+
assert z is None or z.dtype == x.dtype
228+
229+
230+
def test_bool_mask_nd(xp: ModuleType):
231+
x = xp.asarray([[1, 2, 3], [4, 5, 6]])
232+
idx = xp.asarray([[True, False, False], [False, True, True]])
233+
z = at_op(x, idx, _AtOp.SET, 0)
234+
xp_assert_equal(z, xp.asarray([[0, 2, 3], [4, 0, 0]]))
235+
236+
237+
@pytest.mark.skip_xp_backend(Backend.DASK, reason="FIXME need scipy's lazywhere")
238+
@pytest.mark.parametrize("bool_mask", [False, True])
239+
def test_no_inf_warnings(xp: ModuleType, bool_mask: bool):
240+
x = xp.asarray([math.inf, 1.0, 2.0])
241+
idx = ~xp.isinf(x) if bool_mask else slice(1, None)
242+
# inf - inf -> nan with a warning
243+
z = at_op(x, idx, _AtOp.SUBTRACT, math.inf)
244+
xp_assert_equal(z, xp.asarray([math.inf, -math.inf, -math.inf]))

vendor_tests/test_vendor.py

+2
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ def test_vendor_compat():
77
array_namespace,
88
device,
99
is_cupy_namespace,
10+
is_dask_array,
1011
is_dask_namespace,
1112
is_jax_array,
1213
is_jax_namespace,
@@ -20,6 +21,7 @@ def test_vendor_compat():
2021
assert array_namespace(x) is xp
2122
device(x)
2223
assert not is_cupy_namespace(xp)
24+
assert not is_dask_array(x)
2325
assert not is_dask_namespace(xp)
2426
assert not is_jax_array(x)
2527
assert not is_jax_namespace(xp)

0 commit comments

Comments
 (0)