Skip to content

Commit 03f0b3e

Browse files
ENH: new function broadcast_shapes (#133)
* ENH: `broadcast_shapes` * xref broadcast_arrays * mixed nan-none test * add comment * Update src/array_api_extra/_lib/_funcs.py --------- Co-authored-by: Lucas Colley <[email protected]>
1 parent a71bd2e commit 03f0b3e

File tree

5 files changed

+127
-0
lines changed

5 files changed

+127
-0
lines changed

docs/api-reference.md

+1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
99
at
1010
atleast_nd
11+
broadcast_shapes
1112
cov
1213
create_diagonal
1314
expand_dims

docs/conf.py

+1
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
intersphinx_mapping = {
5555
"python": ("https://docs.python.org/3", None),
5656
"array-api": ("https://data-apis.org/array-api/draft", None),
57+
"numpy": ("https://numpy.org/doc/stable", None),
5758
"jax": ("https://jax.readthedocs.io/en/latest", None),
5859
}
5960

src/array_api_extra/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from ._lib._at import at
55
from ._lib._funcs import (
66
atleast_nd,
7+
broadcast_shapes,
78
cov,
89
create_diagonal,
910
expand_dims,
@@ -20,6 +21,7 @@
2021
"__version__",
2122
"at",
2223
"atleast_nd",
24+
"broadcast_shapes",
2325
"cov",
2426
"create_diagonal",
2527
"expand_dims",

src/array_api_extra/_lib/_funcs.py

+64
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
__all__ = [
1919
"atleast_nd",
20+
"broadcast_shapes",
2021
"cov",
2122
"create_diagonal",
2223
"expand_dims",
@@ -71,6 +72,69 @@ def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType | None = None) -> Array
7172
return x
7273

7374

75+
# `float` in signature to accept `math.nan` for Dask.
76+
# `int`s are still accepted as `float` is a superclass of `int` in typing
77+
def broadcast_shapes(*shapes: tuple[float | None, ...]) -> tuple[int | None, ...]:
78+
"""
79+
Compute the shape of the broadcasted arrays.
80+
81+
Duplicates :func:`numpy.broadcast_shapes`, with additional support for
82+
None and NaN sizes.
83+
84+
This is equivalent to ``xp.broadcast_arrays(arr1, arr2, ...)[0].shape``
85+
without needing to worry about the backend potentially deep copying
86+
the arrays.
87+
88+
Parameters
89+
----------
90+
*shapes : tuple[int | None, ...]
91+
Shapes of the arrays to broadcast.
92+
93+
Returns
94+
-------
95+
tuple[int | None, ...]
96+
The shape of the broadcasted arrays.
97+
98+
See Also
99+
--------
100+
numpy.broadcast_shapes : Equivalent NumPy function.
101+
array_api.broadcast_arrays : Function to broadcast actual arrays.
102+
103+
Notes
104+
-----
105+
This function accepts the Array API's ``None`` for unknown sizes,
106+
as well as Dask's non-standard ``math.nan``.
107+
Regardless of input, the output always contains ``None`` for unknown sizes.
108+
109+
Examples
110+
--------
111+
>>> import array_api_extra as xpx
112+
>>> xpx.broadcast_shapes((2, 3), (2, 1))
113+
(2, 3)
114+
>>> xpx.broadcast_shapes((4, 2, 3), (2, 1), (1, 3))
115+
(4, 2, 3)
116+
"""
117+
if not shapes:
118+
return () # Match numpy output
119+
120+
ndim = max(len(shape) for shape in shapes)
121+
out: list[int | None] = []
122+
for axis in range(-ndim, 0):
123+
sizes = {shape[axis] for shape in shapes if axis >= -len(shape)}
124+
# Dask uses NaN for unknown shape, which predates the Array API spec for None
125+
none_size = None in sizes or math.nan in sizes
126+
sizes -= {1, None, math.nan}
127+
if len(sizes) > 1:
128+
msg = (
129+
"shape mismatch: objects cannot be broadcast to a single shape: "
130+
f"{shapes}."
131+
)
132+
raise ValueError(msg)
133+
out.append(None if none_size else cast(int, sizes.pop()) if sizes else 1)
134+
135+
return tuple(out)
136+
137+
74138
def cov(m: Array, /, *, xp: ModuleType | None = None) -> Array:
75139
"""
76140
Estimate a covariance matrix.

tests/test_funcs.py

+59
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import contextlib
2+
import math
23
import warnings
34
from types import ModuleType
45

@@ -8,6 +9,7 @@
89
from array_api_extra import (
910
at,
1011
atleast_nd,
12+
broadcast_shapes,
1113
cov,
1214
create_diagonal,
1315
expand_dims,
@@ -113,6 +115,63 @@ def test_xp(self, xp: ModuleType):
113115
xp_assert_equal(y, xp.ones((1,)))
114116

115117

118+
class TestBroadcastShapes:
119+
@pytest.mark.parametrize(
120+
"args",
121+
[
122+
(),
123+
((),),
124+
((), ()),
125+
((1,),),
126+
((1,), (1,)),
127+
((2,), (1,)),
128+
((3, 1, 4), (2, 1)),
129+
((1, 1, 4), (2, 1)),
130+
((1,), ()),
131+
((), (2,), ()),
132+
((0,),),
133+
((0,), (1,)),
134+
((2, 0), (1, 1)),
135+
((2, 0, 3), (2, 1, 1)),
136+
],
137+
)
138+
def test_simple(self, args: tuple[tuple[int, ...], ...]):
139+
expect = np.broadcast_shapes(*args)
140+
actual = broadcast_shapes(*args)
141+
assert actual == expect
142+
143+
@pytest.mark.parametrize(
144+
"args",
145+
[
146+
((2,), (3,)),
147+
((2, 3), (1, 2)),
148+
((2,), (0,)),
149+
((2, 0, 2), (1, 3, 1)),
150+
],
151+
)
152+
def test_fail(self, args: tuple[tuple[int, ...], ...]):
153+
match = "cannot be broadcast to a single shape"
154+
with pytest.raises(ValueError, match=match):
155+
_ = np.broadcast_shapes(*args)
156+
with pytest.raises(ValueError, match=match):
157+
_ = broadcast_shapes(*args)
158+
159+
@pytest.mark.parametrize(
160+
"args",
161+
[
162+
((None,), (None,)),
163+
((math.nan,), (None,)),
164+
((1, None, 2, 4), (2, 3, None, 1), (2, None, None, 4)),
165+
((1, math.nan, 2), (4, 2, 3, math.nan), (4, 2, None, None)),
166+
((math.nan, 1), (None, 2), (None, 2)),
167+
],
168+
)
169+
def test_none(self, args: tuple[tuple[float | None, ...], ...]):
170+
expect = args[-1]
171+
actual = broadcast_shapes(*args[:-1])
172+
assert actual == expect
173+
174+
116175
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no isdtype")
117176
class TestCov:
118177
def test_basic(self, xp: ModuleType):

0 commit comments

Comments
 (0)