|
1 | 1 | import contextlib
|
| 2 | +import math |
2 | 3 | import warnings
|
3 | 4 | from types import ModuleType
|
4 | 5 |
|
|
8 | 9 | from array_api_extra import (
|
9 | 10 | at,
|
10 | 11 | atleast_nd,
|
| 12 | + broadcast_shapes, |
11 | 13 | cov,
|
12 | 14 | create_diagonal,
|
13 | 15 | expand_dims,
|
@@ -113,6 +115,63 @@ def test_xp(self, xp: ModuleType):
|
113 | 115 | xp_assert_equal(y, xp.ones((1,)))
|
114 | 116 |
|
115 | 117 |
|
| 118 | +class TestBroadcastShapes: |
| 119 | + @pytest.mark.parametrize( |
| 120 | + "args", |
| 121 | + [ |
| 122 | + (), |
| 123 | + ((),), |
| 124 | + ((), ()), |
| 125 | + ((1,),), |
| 126 | + ((1,), (1,)), |
| 127 | + ((2,), (1,)), |
| 128 | + ((3, 1, 4), (2, 1)), |
| 129 | + ((1, 1, 4), (2, 1)), |
| 130 | + ((1,), ()), |
| 131 | + ((), (2,), ()), |
| 132 | + ((0,),), |
| 133 | + ((0,), (1,)), |
| 134 | + ((2, 0), (1, 1)), |
| 135 | + ((2, 0, 3), (2, 1, 1)), |
| 136 | + ], |
| 137 | + ) |
| 138 | + def test_simple(self, args: tuple[tuple[int, ...], ...]): |
| 139 | + expect = np.broadcast_shapes(*args) |
| 140 | + actual = broadcast_shapes(*args) |
| 141 | + assert actual == expect |
| 142 | + |
| 143 | + @pytest.mark.parametrize( |
| 144 | + "args", |
| 145 | + [ |
| 146 | + ((2,), (3,)), |
| 147 | + ((2, 3), (1, 2)), |
| 148 | + ((2,), (0,)), |
| 149 | + ((2, 0, 2), (1, 3, 1)), |
| 150 | + ], |
| 151 | + ) |
| 152 | + def test_fail(self, args: tuple[tuple[int, ...], ...]): |
| 153 | + match = "cannot be broadcast to a single shape" |
| 154 | + with pytest.raises(ValueError, match=match): |
| 155 | + _ = np.broadcast_shapes(*args) |
| 156 | + with pytest.raises(ValueError, match=match): |
| 157 | + _ = broadcast_shapes(*args) |
| 158 | + |
| 159 | + @pytest.mark.parametrize( |
| 160 | + "args", |
| 161 | + [ |
| 162 | + ((None,), (None,)), |
| 163 | + ((math.nan,), (None,)), |
| 164 | + ((1, None, 2, 4), (2, 3, None, 1), (2, None, None, 4)), |
| 165 | + ((1, math.nan, 2), (4, 2, 3, math.nan), (4, 2, None, None)), |
| 166 | + ((math.nan, 1), (None, 2), (None, 2)), |
| 167 | + ], |
| 168 | + ) |
| 169 | + def test_none(self, args: tuple[tuple[float | None, ...], ...]): |
| 170 | + expect = args[-1] |
| 171 | + actual = broadcast_shapes(*args[:-1]) |
| 172 | + assert actual == expect |
| 173 | + |
| 174 | + |
116 | 175 | @pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no isdtype")
|
117 | 176 | class TestCov:
|
118 | 177 | def test_basic(self, xp: ModuleType):
|
|
0 commit comments