Skip to content

Commit fc63f65

Browse files
committed
BUG: at should not force overwrite in Dask when copy=None
1 parent 6e4aba6 commit fc63f65

File tree

2 files changed

+133
-56
lines changed

2 files changed

+133
-56
lines changed

src/array_api_extra/_lib/_at.py

+13-17
Original file line numberDiff line numberDiff line change
@@ -275,16 +275,11 @@ def _op(
275275
msg = f"copy must be True, False, or None; got {copy!r}"
276276
raise ValueError(msg)
277277

278-
if copy is None:
279-
writeable = is_writeable_array(x)
280-
copy = not writeable
281-
elif copy:
282-
writeable = None
283-
else:
284-
writeable = is_writeable_array(x)
278+
writeable = None if copy else is_writeable_array(x)
285279

286-
# JAX inside jax.jit and Dask don't support in-place updates with boolean
287-
# mask. However we can handle the common special case of 0-dimensional y
280+
# JAX inside jax.jit doesn't support in-place updates with boolean
281+
# masks; Dask exclusively supports __setitem__ but not iops.
282+
# We can handle the common special case of 0-dimensional y
288283
# with where(idx, y, x) instead.
289284
if (
290285
(is_dask_array(idx) or is_jax_array(idx))
@@ -293,21 +288,22 @@ def _op(
293288
):
294289
y_xp = xp.asarray(y, dtype=x.dtype)
295290
if y_xp.ndim == 0:
296-
if out_of_place_op:
291+
if out_of_place_op: # add(), subtract(), ...
297292
# FIXME: suppress inf warnings on dask with lazywhere
298293
out = xp.where(idx, out_of_place_op(x, y_xp), x)
299294
# Undo int->float promotion on JAX after _AtOp.DIVIDE
300295
out = xp.astype(out, x.dtype, copy=False)
301-
else:
296+
else: # set()
302297
out = xp.where(idx, y_xp, x)
303298

304-
if copy:
305-
return out
306-
x[()] = out
307-
return x
299+
if copy is False:
300+
x[()] = out
301+
return x
302+
return out
303+
308304
# else: this will work on eager JAX and crash on jax.jit and Dask
309305

310-
if copy:
306+
if copy or (copy is None and not writeable):
311307
if is_jax_array(x):
312308
# Use JAX's at[]
313309
func = cast(Callable[[Array], Array], getattr(x.at[idx], at_op.value))
@@ -331,7 +327,7 @@ def _op(
331327
msg = f"Can't update read-only array {x}"
332328
raise ValueError(msg)
333329

334-
if in_place_op:
330+
if in_place_op: # add(), subtract(), ...
335331
x[self._idx] = in_place_op(x[self._idx], y)
336332
else: # set()
337333
x[self._idx] = y

tests/test_at.py

+120-39
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from collections.abc import Callable, Generator
44
from contextlib import contextmanager
55
from types import ModuleType
6-
from typing import Any, cast
6+
from typing import cast
77

88
import numpy as np
99
import pytest
@@ -23,12 +23,13 @@
2323
]
2424

2525

26-
def at_op( # type: ignore[no-any-explicit]
26+
def at_op(
2727
x: Array,
2828
idx: Index,
2929
op: _AtOp,
3030
y: Array | object,
31-
**kwargs: Any, # Test the default copy=None
31+
copy: bool | None = None,
32+
xp: ModuleType | None = None,
3233
) -> Array:
3334
"""
3435
Wrapper around at(x, idx).op(y, copy=copy, xp=xp).
@@ -39,30 +40,33 @@ def at_op( # type: ignore[no-any-explicit]
3940
which is not a common use case.
4041
"""
4142
if isinstance(idx, (slice | tuple)):
42-
return _at_op(x, None, pickle.dumps(idx), op, y, **kwargs)
43-
return _at_op(x, idx, None, op, y, **kwargs)
43+
return _at_op(x, None, pickle.dumps(idx), op, y, copy=copy, xp=xp)
44+
return _at_op(x, idx, None, op, y, copy=copy, xp=xp)
4445

4546

46-
def _at_op( # type: ignore[no-any-explicit]
47+
def _at_op(
4748
x: Array,
4849
idx: Index | None,
4950
idx_pickle: bytes | None,
5051
op: _AtOp,
5152
y: Array | object,
52-
**kwargs: Any,
53+
copy: bool | None,
54+
xp: ModuleType | None = None,
5355
) -> Array:
5456
"""jitted helper of at_op"""
5557
if idx_pickle:
5658
idx = pickle.loads(idx_pickle)
5759
meth = cast(Callable[..., Array], getattr(at(x, idx), op.value)) # type: ignore[no-any-explicit]
58-
return meth(y, **kwargs)
60+
return meth(y, copy=copy, xp=xp)
5961

6062

6163
lazy_xp_function(_at_op, static_argnames=("op", "idx_pickle", "copy", "xp"))
6264

6365

6466
@contextmanager
65-
def assert_copy(array: Array, copy: bool | None) -> Generator[None, None, None]:
67+
def assert_copy(
68+
array: Array, copy: bool | None, expect_copy: bool | None = None
69+
) -> Generator[None, None, None]:
6670
if copy is False and not is_writeable_array(array):
6771
with pytest.raises((TypeError, ValueError)):
6872
yield
@@ -72,24 +76,23 @@ def assert_copy(array: Array, copy: bool | None) -> Generator[None, None, None]:
7276
array_orig = xp.asarray(array, copy=True)
7377
yield
7478

75-
if copy is None:
76-
copy = not is_writeable_array(array)
77-
xp_assert_equal(xp.all(array == array_orig), xp.asarray(copy))
79+
if expect_copy is None:
80+
expect_copy = copy
7881

82+
if expect_copy:
83+
# Original has not been modified
84+
xp_assert_equal(array, array_orig)
85+
elif expect_copy is False:
86+
# Original has been modified
87+
with pytest.raises(AssertionError):
88+
xp_assert_equal(array, array_orig)
89+
# Test nothing for copy=None. Dask changes behaviour depending on
90+
# whether it's a special case of a bool mask with scalar RHS or not.
7991

92+
93+
@pytest.mark.parametrize("copy", [False, True, None])
8094
@pytest.mark.parametrize(
81-
("kwargs", "expect_copy"),
82-
[
83-
pytest.param({"copy": True}, True, id="copy=True"),
84-
pytest.param({"copy": False}, False, id="copy=False"),
85-
# Behavior is backend-specific
86-
pytest.param({"copy": None}, None, id="copy=None"),
87-
# Test that the copy parameter defaults to None
88-
pytest.param({}, None, id="no copy kwarg"),
89-
],
90-
)
91-
@pytest.mark.parametrize(
92-
("op", "y", "expect"),
95+
("op", "y", "expect_list"),
9396
[
9497
(_AtOp.SET, 40.0, [10.0, 40.0, 40.0]),
9598
(_AtOp.ADD, 40.0, [10.0, 60.0, 70.0]),
@@ -102,14 +105,13 @@ def assert_copy(array: Array, copy: bool | None) -> Generator[None, None, None]:
102105
],
103106
)
104107
@pytest.mark.parametrize(
105-
("bool_mask", "shaped_y"),
108+
("bool_mask", "x_ndim", "y_ndim"),
106109
[
107-
(False, False),
108-
(False, True),
109-
(True, False), # Uses xp.where(idx, y, x) on JAX and Dask
110+
(False, 1, 0),
111+
(False, 1, 1),
112+
(True, 1, 0), # Uses xp.where(idx, y, x) on JAX and Dask
110113
pytest.param(
111-
True,
112-
True,
114+
*(True, 1, 1),
113115
marks=(
114116
pytest.mark.skip_xp_backend( # test passes when copy=False
115117
Backend.JAX, reason="bool mask update with shaped rhs"
@@ -119,29 +121,65 @@ def assert_copy(array: Array, copy: bool | None) -> Generator[None, None, None]:
119121
),
120122
),
121123
),
124+
(False, 0, 0),
125+
(True, 0, 0),
122126
],
123127
)
124128
def test_update_ops(
125129
xp: ModuleType,
126-
kwargs: dict[str, bool | None],
127-
expect_copy: bool | None,
130+
copy: bool | None,
128131
op: _AtOp,
129132
y: float,
130-
expect: list[float],
133+
expect_list: list[float],
131134
bool_mask: bool,
132-
shaped_y: bool,
135+
x_ndim: int,
136+
y_ndim: int,
133137
):
134-
x = xp.asarray([10.0, 20.0, 30.0])
135-
idx = xp.asarray([False, True, True]) if bool_mask else slice(1, None)
136-
if shaped_y:
138+
if x_ndim == 1:
139+
x = xp.asarray([10.0, 20.0, 30.0])
140+
idx = xp.asarray([False, True, True]) if bool_mask else slice(1, None)
141+
expect: list[float] | float = expect_list
142+
else:
143+
idx = xp.asarray(True) if bool_mask else ()
144+
# Pick an element that does change with the operation
145+
if op is _AtOp.MIN:
146+
x = xp.asarray(30.0)
147+
expect = expect_list[2]
148+
else:
149+
x = xp.asarray(20.0)
150+
expect = expect_list[1]
151+
152+
if y_ndim == 1:
137153
y = xp.asarray([y, y])
138154

139-
with assert_copy(x, expect_copy):
140-
z = at_op(x, idx, op, y, **kwargs)
155+
with assert_copy(x, copy):
156+
z = at_op(x, idx, op, y, copy=copy)
141157
assert isinstance(z, type(x))
142158
xp_assert_equal(z, xp.asarray(expect))
143159

144160

161+
@pytest.mark.parametrize("op", list(_AtOp))
162+
def test_copy_default(xp: ModuleType, library: Backend, op: _AtOp):
163+
"""
164+
Test that the default copy behaviour is False for writeable arrays
165+
and True for read-only ones.
166+
"""
167+
x = xp.asarray([1.0, 10.0, 20.0])
168+
expect_copy = not is_writeable_array(x)
169+
meth = cast(Callable[..., Array], getattr(at(x)[:2], op.value)) # type: ignore[no-any-explicit]
170+
with assert_copy(x, None, expect_copy):
171+
_ = meth(2.0)
172+
173+
x = xp.asarray([1.0, 10.0, 20.0])
174+
# Dask's default copy value is True for bool masks,
175+
# even if the arrays are writeable.
176+
expect_copy = not is_writeable_array(x) or library is Backend.DASK
177+
idx = xp.asarray([True, True, False])
178+
meth = cast(Callable[..., Array], getattr(at(x, idx), op.value)) # type: ignore[no-any-explicit]
179+
with assert_copy(x, None, expect_copy):
180+
_ = meth(2.0)
181+
182+
145183
def test_copy_invalid():
146184
a = np.asarray([1, 2, 3])
147185
with pytest.raises(ValueError, match="copy"):
@@ -259,3 +297,46 @@ def test_no_inf_warnings(xp: ModuleType, bool_mask: bool):
259297
# inf - inf -> nan with a warning
260298
z = at_op(x, idx, _AtOp.SUBTRACT, math.inf)
261299
xp_assert_equal(z, xp.asarray([math.inf, -math.inf, -math.inf]))
300+
301+
302+
@pytest.mark.parametrize(
303+
"copy",
304+
[
305+
None,
306+
pytest.param(
307+
False,
308+
marks=[
309+
pytest.mark.skip_xp_backend(
310+
Backend.NUMPY, reason="np.generic is read-only"
311+
),
312+
pytest.mark.skip_xp_backend(
313+
Backend.NUMPY_READONLY, reason="read-only backend"
314+
),
315+
pytest.mark.skip_xp_backend(Backend.JAX, reason="read-only backend"),
316+
pytest.mark.skip_xp_backend(Backend.SPARSE, reason="read-only backend"),
317+
],
318+
),
319+
],
320+
)
321+
@pytest.mark.parametrize("bool_mask", [False, True])
322+
def test_gh134(xp: ModuleType, bool_mask: bool, copy: bool | None):
323+
"""
324+
Test that xpx.at doesn't encroach in a bug of dask.array.Array.__setitem__, which
325+
blindly assumes that chunk contents are writeable np.ndarray objects:
326+
327+
https://github.com/dask/dask/issues/11722
328+
329+
In other words: when special-casing bool masks for Dask, unless the user explicitly
330+
asks for copy=False, do not needlessly write back to the input.
331+
"""
332+
x = xp.zeros(1)
333+
334+
# In numpy, we have a writeable np.ndarray in input and a read-only np.generic in
335+
# output. As both are Arrays, this behaviour is Array API compliant.
336+
# In Dask, we have a writeable da.Array on both sides, and if you call __setitem__
337+
# on it all seems fine, but when you compute() your graph is corrupted.
338+
y = x[0]
339+
340+
idx = xp.asarray(True) if bool_mask else ()
341+
z = at_op(y, idx, _AtOp.SET, 1, copy=copy)
342+
xp_assert_equal(z, xp.asarray(1, dtype=x.dtype))

0 commit comments

Comments
 (0)