3
3
from collections .abc import Callable , Generator
4
4
from contextlib import contextmanager
5
5
from types import ModuleType
6
- from typing import Any , cast
6
+ from typing import cast
7
7
8
8
import numpy as np
9
9
import pytest
23
23
]
24
24
25
25
26
- def at_op ( # type: ignore[no-any-explicit]
26
+ def at_op (
27
27
x : Array ,
28
28
idx : Index ,
29
29
op : _AtOp ,
30
30
y : Array | object ,
31
- ** kwargs : Any , # Test the default copy=None
31
+ copy : bool | None = None ,
32
+ xp : ModuleType | None = None ,
32
33
) -> Array :
33
34
"""
34
35
Wrapper around at(x, idx).op(y, copy=copy, xp=xp).
@@ -39,30 +40,33 @@ def at_op( # type: ignore[no-any-explicit]
39
40
which is not a common use case.
40
41
"""
41
42
if isinstance (idx , (slice | tuple )):
42
- return _at_op (x , None , pickle .dumps (idx ), op , y , ** kwargs )
43
- return _at_op (x , idx , None , op , y , ** kwargs )
43
+ return _at_op (x , None , pickle .dumps (idx ), op , y , copy = copy , xp = xp )
44
+ return _at_op (x , idx , None , op , y , copy = copy , xp = xp )
44
45
45
46
46
- def _at_op ( # type: ignore[no-any-explicit]
47
+ def _at_op (
47
48
x : Array ,
48
49
idx : Index | None ,
49
50
idx_pickle : bytes | None ,
50
51
op : _AtOp ,
51
52
y : Array | object ,
52
- ** kwargs : Any ,
53
+ copy : bool | None ,
54
+ xp : ModuleType | None = None ,
53
55
) -> Array :
54
56
"""jitted helper of at_op"""
55
57
if idx_pickle :
56
58
idx = pickle .loads (idx_pickle )
57
59
meth = cast (Callable [..., Array ], getattr (at (x , idx ), op .value )) # type: ignore[no-any-explicit]
58
- return meth (y , ** kwargs )
60
+ return meth (y , copy = copy , xp = xp )
59
61
60
62
61
63
lazy_xp_function (_at_op , static_argnames = ("op" , "idx_pickle" , "copy" , "xp" ))
62
64
63
65
64
66
@contextmanager
65
- def assert_copy (array : Array , copy : bool | None ) -> Generator [None , None , None ]:
67
+ def assert_copy (
68
+ array : Array , copy : bool | None , expect_copy : bool | None = None
69
+ ) -> Generator [None , None , None ]:
66
70
if copy is False and not is_writeable_array (array ):
67
71
with pytest .raises ((TypeError , ValueError )):
68
72
yield
@@ -72,24 +76,23 @@ def assert_copy(array: Array, copy: bool | None) -> Generator[None, None, None]:
72
76
array_orig = xp .asarray (array , copy = True )
73
77
yield
74
78
75
- if copy is None :
76
- copy = not is_writeable_array (array )
77
- xp_assert_equal (xp .all (array == array_orig ), xp .asarray (copy ))
79
+ if expect_copy is None :
80
+ expect_copy = copy
78
81
82
+ if expect_copy :
83
+ # Original has not been modified
84
+ xp_assert_equal (array , array_orig )
85
+ elif expect_copy is False :
86
+ # Original has been modified
87
+ with pytest .raises (AssertionError ):
88
+ xp_assert_equal (array , array_orig )
89
+ # Test nothing for copy=None. Dask changes behaviour depending on
90
+ # whether it's a special case of a bool mask with scalar RHS or not.
79
91
92
+
93
+ @pytest .mark .parametrize ("copy" , [False , True , None ])
80
94
@pytest .mark .parametrize (
81
- ("kwargs" , "expect_copy" ),
82
- [
83
- pytest .param ({"copy" : True }, True , id = "copy=True" ),
84
- pytest .param ({"copy" : False }, False , id = "copy=False" ),
85
- # Behavior is backend-specific
86
- pytest .param ({"copy" : None }, None , id = "copy=None" ),
87
- # Test that the copy parameter defaults to None
88
- pytest .param ({}, None , id = "no copy kwarg" ),
89
- ],
90
- )
91
- @pytest .mark .parametrize (
92
- ("op" , "y" , "expect" ),
95
+ ("op" , "y" , "expect_list" ),
93
96
[
94
97
(_AtOp .SET , 40.0 , [10.0 , 40.0 , 40.0 ]),
95
98
(_AtOp .ADD , 40.0 , [10.0 , 60.0 , 70.0 ]),
@@ -102,14 +105,13 @@ def assert_copy(array: Array, copy: bool | None) -> Generator[None, None, None]:
102
105
],
103
106
)
104
107
@pytest .mark .parametrize (
105
- ("bool_mask" , "shaped_y " ),
108
+ ("bool_mask" , "x_ndim" , "y_ndim " ),
106
109
[
107
- (False , False ),
108
- (False , True ),
109
- (True , False ), # Uses xp.where(idx, y, x) on JAX and Dask
110
+ (False , 1 , 0 ),
111
+ (False , 1 , 1 ),
112
+ (True , 1 , 0 ), # Uses xp.where(idx, y, x) on JAX and Dask
110
113
pytest .param (
111
- True ,
112
- True ,
114
+ * (True , 1 , 1 ),
113
115
marks = (
114
116
pytest .mark .skip_xp_backend ( # test passes when copy=False
115
117
Backend .JAX , reason = "bool mask update with shaped rhs"
@@ -119,29 +121,65 @@ def assert_copy(array: Array, copy: bool | None) -> Generator[None, None, None]:
119
121
),
120
122
),
121
123
),
124
+ (False , 0 , 0 ),
125
+ (True , 0 , 0 ),
122
126
],
123
127
)
124
128
def test_update_ops (
125
129
xp : ModuleType ,
126
- kwargs : dict [str , bool | None ],
127
- expect_copy : bool | None ,
130
+ copy : bool | None ,
128
131
op : _AtOp ,
129
132
y : float ,
130
- expect : list [float ],
133
+ expect_list : list [float ],
131
134
bool_mask : bool ,
132
- shaped_y : bool ,
135
+ x_ndim : int ,
136
+ y_ndim : int ,
133
137
):
134
- x = xp .asarray ([10.0 , 20.0 , 30.0 ])
135
- idx = xp .asarray ([False , True , True ]) if bool_mask else slice (1 , None )
136
- if shaped_y :
138
+ if x_ndim == 1 :
139
+ x = xp .asarray ([10.0 , 20.0 , 30.0 ])
140
+ idx = xp .asarray ([False , True , True ]) if bool_mask else slice (1 , None )
141
+ expect : list [float ] | float = expect_list
142
+ else :
143
+ idx = xp .asarray (True ) if bool_mask else ()
144
+ # Pick an element that does change with the operation
145
+ if op is _AtOp .MIN :
146
+ x = xp .asarray (30.0 )
147
+ expect = expect_list [2 ]
148
+ else :
149
+ x = xp .asarray (20.0 )
150
+ expect = expect_list [1 ]
151
+
152
+ if y_ndim == 1 :
137
153
y = xp .asarray ([y , y ])
138
154
139
- with assert_copy (x , expect_copy ):
140
- z = at_op (x , idx , op , y , ** kwargs )
155
+ with assert_copy (x , copy ):
156
+ z = at_op (x , idx , op , y , copy = copy )
141
157
assert isinstance (z , type (x ))
142
158
xp_assert_equal (z , xp .asarray (expect ))
143
159
144
160
161
+ @pytest .mark .parametrize ("op" , list (_AtOp ))
162
+ def test_copy_default (xp : ModuleType , library : Backend , op : _AtOp ):
163
+ """
164
+ Test that the default copy behaviour is False for writeable arrays
165
+ and True for read-only ones.
166
+ """
167
+ x = xp .asarray ([1.0 , 10.0 , 20.0 ])
168
+ expect_copy = not is_writeable_array (x )
169
+ meth = cast (Callable [..., Array ], getattr (at (x )[:2 ], op .value )) # type: ignore[no-any-explicit]
170
+ with assert_copy (x , None , expect_copy ):
171
+ _ = meth (2.0 )
172
+
173
+ x = xp .asarray ([1.0 , 10.0 , 20.0 ])
174
+ # Dask's default copy value is True for bool masks,
175
+ # even if the arrays are writeable.
176
+ expect_copy = not is_writeable_array (x ) or library is Backend .DASK
177
+ idx = xp .asarray ([True , True , False ])
178
+ meth = cast (Callable [..., Array ], getattr (at (x , idx ), op .value )) # type: ignore[no-any-explicit]
179
+ with assert_copy (x , None , expect_copy ):
180
+ _ = meth (2.0 )
181
+
182
+
145
183
def test_copy_invalid ():
146
184
a = np .asarray ([1 , 2 , 3 ])
147
185
with pytest .raises (ValueError , match = "copy" ):
@@ -259,3 +297,46 @@ def test_no_inf_warnings(xp: ModuleType, bool_mask: bool):
259
297
# inf - inf -> nan with a warning
260
298
z = at_op (x , idx , _AtOp .SUBTRACT , math .inf )
261
299
xp_assert_equal (z , xp .asarray ([math .inf , - math .inf , - math .inf ]))
300
+
301
+
302
+ @pytest .mark .parametrize (
303
+ "copy" ,
304
+ [
305
+ None ,
306
+ pytest .param (
307
+ False ,
308
+ marks = [
309
+ pytest .mark .skip_xp_backend (
310
+ Backend .NUMPY , reason = "np.generic is read-only"
311
+ ),
312
+ pytest .mark .skip_xp_backend (
313
+ Backend .NUMPY_READONLY , reason = "read-only backend"
314
+ ),
315
+ pytest .mark .skip_xp_backend (Backend .JAX , reason = "read-only backend" ),
316
+ pytest .mark .skip_xp_backend (Backend .SPARSE , reason = "read-only backend" ),
317
+ ],
318
+ ),
319
+ ],
320
+ )
321
+ @pytest .mark .parametrize ("bool_mask" , [False , True ])
322
+ def test_gh134 (xp : ModuleType , bool_mask : bool , copy : bool | None ):
323
+ """
324
+ Test that xpx.at doesn't encroach in a bug of dask.array.Array.__setitem__, which
325
+ blindly assumes that chunk contents are writeable np.ndarray objects:
326
+
327
+ https://github.com/dask/dask/issues/11722
328
+
329
+ In other words: when special-casing bool masks for Dask, unless the user explicitly
330
+ asks for copy=False, do not needlessly write back to the input.
331
+ """
332
+ x = xp .zeros (1 )
333
+
334
+ # In numpy, we have a writeable np.ndarray in input and a read-only np.generic in
335
+ # output. As both are Arrays, this behaviour is Array API compliant.
336
+ # In Dask, we have a writeable da.Array on both sides, and if you call __setitem__
337
+ # on it all seems fine, but when you compute() your graph is corrupted.
338
+ y = x [0 ]
339
+
340
+ idx = xp .asarray (True ) if bool_mask else ()
341
+ z = at_op (y , idx , _AtOp .SET , 1 , copy = copy )
342
+ xp_assert_equal (z , xp .asarray (1 , dtype = x .dtype ))
0 commit comments