Skip to content

Commit 6e4aba6

Browse files
authored
Merge pull request #146 from crusaderky/backports
MAINT: various backports and tweaks
2 parents 573ed3c + 45c9fbb commit 6e4aba6

File tree

8 files changed

+3148
-2227
lines changed

8 files changed

+3148
-2227
lines changed

docs/conf.py

+1
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
intersphinx_mapping = {
5555
"python": ("https://docs.python.org/3", None),
5656
"array-api": ("https://data-apis.org/array-api/draft", None),
57+
"dask": ("https://docs.dask.org/en/stable", None),
5758
"numpy": ("https://numpy.org/doc/stable", None),
5859
"jax": ("https://jax.readthedocs.io/en/latest", None),
5960
}

pixi.lock

+3,118-2,204
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

+1-10
Original file line numberDiff line numberDiff line change
@@ -191,10 +191,6 @@ markers = [
191191

192192
[tool.coverage]
193193
run.source = ["array_api_extra"]
194-
report.exclude_also = [
195-
'\.\.\.',
196-
'if TYPE_CHECKING:',
197-
]
198194

199195
# mypy
200196

@@ -314,10 +310,5 @@ checks = [
314310
"ES01", # most docstrings do not need an extended summary
315311
]
316312
exclude = [ # don't report on objects that match any of these regex
317-
'.*test_at.*',
318-
'.*test_funcs.*',
319-
'.*test_testing.*',
320-
'.*test_utils.*',
321-
'.*test_version.*',
322-
'.*test_vendor.*',
313+
'.*test_*',
323314
]

src/array_api_extra/_lib/_at.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -249,8 +249,8 @@ def _op(
249249
Right-hand side of the operation.
250250
copy : bool or None
251251
Whether to copy the input array. See the class docstring for details.
252-
xp : array_namespace or None
253-
The array namespace for the input array.
252+
xp : array_namespace, optional
253+
The array namespace for the input array. Default: infer.
254254
255255
Returns
256256
-------

src/array_api_extra/_lib/_utils/_helpers.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@
77
from typing import cast
88

99
from . import _compat
10-
from ._compat import is_array_api_obj, is_numpy_array
10+
from ._compat import array_namespace, is_array_api_obj, is_numpy_array
1111
from ._typing import Array
1212

13-
__all__ = ["in1d", "mean"]
13+
__all__ = ["asarrays", "in1d", "is_python_scalar", "mean"]
1414

1515

1616
def in1d(
@@ -33,7 +33,7 @@ def in1d(
3333
https://github.com/numpy/numpy/blob/v1.26.0/numpy/lib/arraysetops.py#L524-L758
3434
"""
3535
if xp is None:
36-
xp = _compat.array_namespace(x1, x2)
36+
xp = array_namespace(x1, x2)
3737

3838
# This code is run to make the code significantly faster
3939
if x2.shape[0] < 10 * x1.shape[0] ** 0.145:
@@ -84,7 +84,7 @@ def mean(
8484
Complex mean, https://github.com/data-apis/array-api/issues/846.
8585
"""
8686
if xp is None:
87-
xp = _compat.array_namespace(x)
87+
xp = array_namespace(x)
8888

8989
if xp.isdtype(x.dtype, "complex floating"):
9090
x_real = xp.real(x)
@@ -124,8 +124,8 @@ def asarrays(
124124
----------
125125
a, b : Array | int | float | complex | bool
126126
Input arrays or scalars. At least one must be an array.
127-
xp : ModuleType
128-
The standard-compatible namespace for the returned arrays.
127+
xp : array_namespace, optional
128+
The standard-compatible namespace for `x`. Default: infer.
129129
130130
Returns
131131
-------

src/array_api_extra/_lib/_utils/_typing.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# To be changed to a Protocol later (see data-apis/array-api#589)
66
Array = Any # type: ignore[no-any-explicit]
77
Device = Any # type: ignore[no-any-explicit]
8+
DType = Any # type: ignore[no-any-explicit]
89
Index = Any # type: ignore[no-any-explicit]
910

10-
__all__ = ["Array", "Device", "Index"]
11+
__all__ = ["Array", "DType", "Device", "Index"]

src/array_api_extra/testing.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
__all__ = ["lazy_xp_function", "patch_lazy_xp_functions"]
2020

21-
if TYPE_CHECKING:
21+
if TYPE_CHECKING: # pragma: no cover
2222
# TODO move ParamSpec outside TYPE_CHECKING
2323
# depends on scikit-learn abandoning Python 3.9
2424
# https://github.com/scikit-learn/scikit-learn/pull/27910#issuecomment-2568023972
@@ -169,7 +169,7 @@ def xp(request, monkeypatch):
169169
Pytest fixture, as acquired by the test itself or by one of its fixtures.
170170
monkeypatch : pytest.MonkeyPatch
171171
Pytest fixture, as acquired by the test itself or by one of its fixtures.
172-
xp : module
172+
xp : array_namespace
173173
Array namespace to be tested.
174174
175175
See Also

tests/conftest.py

+16-2
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,11 @@ def xp(
115115
if library == Backend.NUMPY_READONLY:
116116
return NumPyReadOnly() # type: ignore[return-value] # pyright: ignore[reportReturnType]
117117
xp = pytest.importorskip(library.value)
118+
# Possibly wrap module with array_api_compat
119+
xp = array_namespace(xp.empty(0))
118120

121+
# On Dask and JAX, monkey-patch all functions tagged by `lazy_xp_function`
122+
# in the global scope of the module containing the test function.
119123
patch_lazy_xp_functions(request, monkeypatch, xp=xp)
120124

121125
if library == Backend.JAX:
@@ -124,8 +128,18 @@ def xp(
124128
# suppress unused-ignore to run mypy in -e lint as well as -e dev
125129
jax.config.update("jax_enable_x64", True) # type: ignore[no-untyped-call,unused-ignore]
126130

127-
# Possibly wrap module with array_api_compat
128-
return array_namespace(xp.empty(0))
131+
return xp
132+
133+
134+
@pytest.fixture(params=[Backend.DASK]) # Can select the test with `pytest -k dask`
135+
def da(
136+
request: pytest.FixtureRequest, monkeypatch: pytest.MonkeyPatch
137+
) -> ModuleType: # numpydoc ignore=PR01,RT01
138+
"""Variant of the `xp` fixture that only yields dask.array."""
139+
xp = pytest.importorskip("dask.array")
140+
xp = array_namespace(xp.empty(0))
141+
patch_lazy_xp_functions(request, monkeypatch, xp=xp)
142+
return xp
129143

130144

131145
@pytest.fixture

0 commit comments

Comments
 (0)