Skip to content

Commit 330c006

Browse files
committed
[WIP] ENH: dask+cupy, dask+sparse etc. namespaces
1 parent b6900df commit 330c006

File tree

4 files changed

+80
-6
lines changed

4 files changed

+80
-6
lines changed

array_api_compat/common/_helpers.py

+18-4
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,9 @@ def is_dask_namespace(xp: Namespace) -> bool:
352352
"""
353353
Returns True if `xp` is a Dask namespace.
354354
355-
This includes both ``dask.array`` itself and the version wrapped by array-api-compat.
355+
This includes ``dask.array`` itself, the version wrapped by array-api-compat,
356+
and the bespoke namespaces generated by
357+
``array_api_compat.dask.array.wrap_namespace``.
356358
357359
See Also
358360
--------
@@ -366,7 +368,11 @@ def is_dask_namespace(xp: Namespace) -> bool:
366368
is_pydata_sparse_namespace
367369
is_array_api_strict_namespace
368370
"""
369-
return xp.__name__ in {'dask.array', _compat_module_name() + '.dask.array'}
371+
da_compat_name = _compat_module_name() + '.dask.array'
372+
return (
373+
xp.__name__ in {'dask.array', da_compat_name}
374+
or xp.__name__.startswith(da_compat_name + '.')
375+
)
370376

371377

372378
def is_jax_namespace(xp: Namespace) -> bool:
@@ -543,8 +549,16 @@ def your_function(x, y):
543549
elif is_dask_array(x):
544550
if _use_compat:
545551
_check_api_version(api_version)
546-
from ..dask import array as dask_namespace
547-
namespaces.add(dask_namespace)
552+
from ..dask.array import wrap_namespace
553+
554+
# The meta-namespace is only used to generate the meta-array, so it
555+
# would be useless to create a namespace such as e.g.
556+
# array_api_compat.dask.array.array_api_compat.cupy.
557+
# It would get worse once you vendor array-api-compat!
558+
# So keep it clean with array_api_compat.dask.array.cupy.
559+
mxp = array_namespace(x._meta, use_compat=False)
560+
xp = wrap_namespace(mxp)
561+
namespaces.add(xp)
548562
else:
549563
import dask.array as da
550564
namespaces.add(da)

array_api_compat/dask/array/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
# These imports may overwrite names from the import * above.
44
from ._aliases import * # noqa: F403
5+
from ._meta import wrap_namespace # noqa: F401
56

67
__array_api_version__ = '2024.12'
78

array_api_compat/dask/array/_aliases.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ def asarray(
146146
dtype: Optional[DType] = None,
147147
device: Optional[Device] = None,
148148
copy: Optional[Union[bool, np._CopyMode]] = None,
149+
like: Optional[Array] = None,
149150
**kwargs,
150151
) -> Array:
151152
"""
@@ -161,7 +162,11 @@ def asarray(
161162
if copy is False:
162163
raise ValueError("Unable to avoid copy when changing dtype")
163164
obj = obj.astype(dtype)
164-
return obj.copy() if copy else obj
165+
if copy:
166+
obj = obj.copy()
167+
if like is not None:
168+
obj = da.asarray(obj, like=like)
169+
return obj
165170

166171
if copy is False:
167172
raise NotImplementedError(
@@ -170,7 +175,11 @@ def asarray(
170175

171176
# copy=None to be uniform across dask < 2024.12 and >= 2024.12
172177
# see https://github.com/dask/dask/pull/11524/
173-
obj = np.array(obj, dtype=dtype, copy=True)
178+
if like is not None:
179+
mxp = array_namespace(like)
180+
obj = mxp.asarray(obj, dtype=dtype, copy=True)
181+
else:
182+
obj = np.array(obj, dtype=dtype, copy=True)
174183
return da.from_array(obj)
175184

176185

array_api_compat/dask/array/_meta.py

+50
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import functools
2+
import sys
3+
import types
4+
5+
from ...common._helpers import is_numpy_namespace
6+
from ...common._typing import Namespace
7+
8+
__all__ = ['wrap_namespace']
9+
_all_ignore = ['functools', 'sys', 'types', 'is_numpy_namespace']
10+
11+
12+
def wrap_namespace(xp: Namespace) -> Namespace:
13+
"""Create a bespoke Dask namespace that wraps around another namespace.
14+
15+
Parameters
16+
----------
17+
xp : namespace
18+
Namespace to be wrapped by Dask
19+
20+
Returns
21+
-------
22+
namespace :
23+
A module object that duplicates array_api_compat.dask.array, with the
24+
difference that all creation functions will create an array with the same
25+
meta namespace as the input.
26+
"""
27+
from .. import array as da_compat
28+
29+
if is_numpy_namespace(xp):
30+
return da_compat
31+
32+
mod_name = f'{da_compat.__name__}.{xp.__name__}'
33+
try:
34+
return sys.modules[mod_name]
35+
except KeyError:
36+
pass
37+
38+
mod = types.ModuleType(mod_name)
39+
sys.modules[mod_name] = mod
40+
41+
meta = xp.empty(())
42+
for name, v in da_compat.__dict__.items():
43+
if name.startswith('_'):
44+
continue
45+
if name in {'arange', 'asarray', 'empty', 'eye', 'from_dlpack',
46+
'full', 'linspace', 'ones', 'zeros'}:
47+
v = functools.wraps(v)(functools.partial(v, like=meta))
48+
setattr(mod, name, v)
49+
50+
return mod

0 commit comments

Comments
 (0)