Skip to content

Commit 7f39b03

Browse files
committed
[WIP] ENH: dask+cupy, dask+sparse etc. namespaces
1 parent e14754b commit 7f39b03

File tree

4 files changed

+70
-5
lines changed

4 files changed

+70
-5
lines changed

array_api_compat/common/_helpers.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -368,7 +368,11 @@ def is_dask_namespace(xp) -> bool:
368368
is_pydata_sparse_namespace
369369
is_array_api_strict_namespace
370370
"""
371-
return xp.__name__ in {'dask.array', _compat_module_name() + '.dask.array'}
371+
names = {'dask.array', _compat_module_name() + '.dask.array'}
372+
return (
373+
xp.__name__ in names
374+
or any(xp.__name__.startswith(name + '.') for name in names)
375+
)
372376

373377

374378
def is_jax_namespace(xp) -> bool:
@@ -541,8 +545,10 @@ def your_function(x, y):
541545
elif is_dask_array(x):
542546
if _use_compat:
543547
_check_api_version(api_version)
544-
from ..dask import array as dask_namespace
545-
namespaces.add(dask_namespace)
548+
from ..dask.array import wrap_namespace
549+
mxp = array_namespace(x._meta, use_compat=False)
550+
xp = wrap_namespace(mxp)
551+
namespaces.add(xp)
546552
else:
547553
import dask.array as da
548554
namespaces.add(da)

array_api_compat/dask/array/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
# These imports may overwrite names from the import * above.
44
from ._aliases import * # noqa: F403
5+
from ._meta import wrap_namespace # noqa: F401
56

67
__array_api_version__ = '2024.12'
78

array_api_compat/dask/array/_aliases.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ def asarray(
157157
dtype: Optional[Dtype] = None,
158158
device: Optional[Device] = None,
159159
copy: Optional[Union[bool, np._CopyMode]] = None,
160+
like: Optional[Array] = None,
160161
**kwargs,
161162
) -> Array:
162163
"""
@@ -172,7 +173,11 @@ def asarray(
172173
if copy is False:
173174
raise ValueError("Unable to avoid copy when changing dtype")
174175
obj = obj.astype(dtype)
175-
return obj.copy() if copy else obj
176+
if copy:
177+
obj = obj.copy()
178+
if like is not None:
179+
obj = da.asarray(obj, like=like)
180+
return obj
176181

177182
if copy is False:
178183
raise NotImplementedError(
@@ -181,7 +186,11 @@ def asarray(
181186

182187
# copy=None to be uniform across dask < 2024.12 and >= 2024.12
183188
# see https://github.com/dask/dask/pull/11524/
184-
obj = np.array(obj, dtype=dtype, copy=True)
189+
if like is not None:
190+
mxp = array_namespace(like)
191+
obj = mxp.asarray(obj, dtype=dtype, copy=True)
192+
else:
193+
obj = np.array(obj, dtype=dtype, copy=True)
185194
return da.from_array(obj)
186195

187196

array_api_compat/dask/array/_meta.py

+49
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import functools
2+
import sys
3+
import types
4+
5+
from ...common import is_numpy_namespace
6+
7+
__all__ = ['wrap_namespace']
8+
_all_ignore = ['functools', 'sys', 'types', 'is_numpy_namespace']
9+
10+
11+
def wrap_namespace(xp):
12+
"""Create a bespoke Dask namespace that wraps around another namespace.
13+
14+
Parameters
15+
----------
16+
xp : namespace
17+
Namespace to be wrapped by Dask
18+
19+
Returns
20+
-------
21+
namespace :
22+
A module object that duplicates array_api_compat.dask.array, with the
23+
difference that all creation functions will create an array with the same
24+
meta namespace as the input.
25+
"""
26+
from .. import array as da_compat
27+
28+
if is_numpy_namespace(xp):
29+
return da_compat
30+
31+
mod_name = f'{da_compat.__name__}.{xp.__name__}'
32+
try:
33+
return sys.modules[mod_name]
34+
except KeyError:
35+
pass
36+
37+
mod = types.ModuleType(mod_name)
38+
sys.modules[mod_name] = mod
39+
40+
meta = xp.empty(())
41+
for name, v in da_compat.__dict__.items():
42+
if name.startswith('_'):
43+
continue
44+
if name in {'arange', 'asarray', 'empty', 'eye', 'from_dlpack',
45+
'full', 'linspace', 'ones', 'zeros'}:
46+
v = functools.wraps(v)(functools.partial(v, like=meta))
47+
setattr(mod, name, v)
48+
49+
return mod

0 commit comments

Comments
 (0)