forked from data-apis/array-api-compat
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_dask.py
183 lines (146 loc) · 5.34 KB
/
test_dask.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
from contextlib import contextmanager
import array_api_strict
import numpy as np
import pytest
try:
import dask
import dask.array as da
except ImportError:
pytestmark = pytest.skip(allow_module_level=True, reason="dask not found")
from array_api_compat import array_namespace
@pytest.fixture
def xp():
"""Fixture returning the wrapped dask namespace"""
return array_namespace(da.empty(0))
@contextmanager
def assert_no_compute():
"""
Context manager that raises if at any point inside it anything calls compute()
or persist(), e.g. as it can be triggered implicitly by __bool__, __array__, etc.
"""
def get(dsk, *args, **kwargs):
raise AssertionError("Called compute() or persist()")
with dask.config.set(scheduler=get):
yield
def test_assert_no_compute():
"""Test the assert_no_compute context manager"""
a = da.asarray(True)
with pytest.raises(AssertionError, match="Called compute"):
with assert_no_compute():
bool(a)
# Exiting the context manager restores the original scheduler
assert bool(a) is True
# Test no_compute for functions that use generic _aliases with xp=np
def test_unary_ops_no_compute(xp):
with assert_no_compute():
a = xp.asarray([1.5, -1.5])
xp.ceil(a)
xp.floor(a)
xp.trunc(a)
xp.sign(a)
def test_matmul_tensordot_no_compute(xp):
A = da.ones((4, 4), chunks=2)
B = da.zeros((4, 4), chunks=2)
with assert_no_compute():
xp.matmul(A, B)
xp.tensordot(A, B)
# Test no_compute for functions that are fully bespoke for dask
def test_asarray_no_compute(xp):
with assert_no_compute():
a = xp.arange(10)
xp.asarray(a)
xp.asarray(a, dtype=np.int16)
xp.asarray(a, dtype=a.dtype)
xp.asarray(a, copy=True)
xp.asarray(a, copy=True, dtype=np.int16)
xp.asarray(a, copy=True, dtype=a.dtype)
xp.asarray(a, copy=False)
xp.asarray(a, copy=False, dtype=a.dtype)
@pytest.mark.parametrize("copy", [True, False])
def test_astype_no_compute(xp, copy):
with assert_no_compute():
a = xp.arange(10)
xp.astype(a, np.int16, copy=copy)
xp.astype(a, a.dtype, copy=copy)
def test_clip_no_compute(xp):
with assert_no_compute():
a = xp.arange(10)
xp.clip(a)
xp.clip(a, 1)
xp.clip(a, 1, 8)
@pytest.mark.parametrize("chunks", (5, 10))
def test_sort_argsort_nocompute(xp, chunks):
with assert_no_compute():
a = xp.arange(10, chunks=chunks)
xp.sort(a)
xp.argsort(a)
def test_generators_are_lazy(xp):
"""
Test that generator functions are fully lazy, e.g. that
da.ones(n) is not implemented as da.asarray(np.ones(n))
"""
size = 100_000_000_000 # 800 GB
chunks = size // 10 # 10x 80 GB chunks
with assert_no_compute():
xp.zeros(size, chunks=chunks)
xp.ones(size, chunks=chunks)
xp.empty(size, chunks=chunks)
xp.full(size, fill_value=123, chunks=chunks)
a = xp.arange(size, chunks=chunks)
xp.zeros_like(a)
xp.ones_like(a)
xp.empty_like(a)
xp.full_like(a, fill_value=123)
@pytest.mark.parametrize("axis", [0, 1])
@pytest.mark.parametrize("func", ["sort", "argsort"])
def test_sort_argsort_chunks(xp, func, axis):
"""Test that sort and argsort are functionally correct when
the array is chunked along the sort axis, e.g. the sort is
not just local to each chunk.
"""
a = da.random.random((10, 10), chunks=(5, 5))
actual = getattr(xp, func)(a, axis=axis)
expect = getattr(np, func)(a.compute(), axis=axis)
np.testing.assert_array_equal(actual, expect)
@pytest.mark.parametrize(
"shape,chunks",
[
# 3 GiB; 128 MiB per chunk; must rechunk before sorting.
# Sort chunks can be 128 MiB each; no need for final rechunk.
((20_000, 20_000), "auto"),
# 3 GiB; 128 MiB per chunk; must rechunk before sorting.
# Must sort on two 1.5 GiB chunks; benefits from final rechunk.
((2, 2**30 * 3 // 16), "auto"),
# 3 GiB; 1.5 GiB per chunk; no need to rechunk before sorting.
# Surely the user must know what they're doing, so don't
# perform the final rechunk.
((2, 2**30 * 3 // 16), (1, -1)),
],
)
@pytest.mark.parametrize("func", ["sort", "argsort"])
def test_sort_argsort_chunk_size(xp, func, shape, chunks):
"""
Test that sort and argsort produce reasonably-sized chunks
in the output array, even if they had to go through a singular
huge one to perform the operation.
"""
a = da.random.random(shape, chunks=chunks)
b = getattr(xp, func)(a)
max_chunk_size = max(b.chunks[0]) * max(b.chunks[1]) * b.dtype.itemsize
assert (
max_chunk_size <= 128 * 1024 * 1024 # 128 MiB
or b.chunks == a.chunks
)
@pytest.mark.parametrize("func", ["sort", "argsort"])
def test_sort_argsort_meta(xp, func):
"""Test meta-namespace other than numpy"""
typ = type(array_api_strict.asarray(0))
a = da.random.random(10)
b = a.map_blocks(array_api_strict.asarray)
assert isinstance(b._meta, typ)
c = getattr(xp, func)(b)
assert isinstance(c._meta, typ)
d = c.compute()
# Note: np.sort(array_api_strict.asarray(0)) would return a numpy array
assert isinstance(d, typ)
np.testing.assert_array_equal(d, getattr(np, func)(a.compute()))