Skip to content

Commit b52cb84

Browse files
authored
Merge pull request #375 from ev-br/searchsorted_skip_dask
BUG: torch: expand_dims axis is keyword or positional argument
2 parents 11f4d3f + 35b06e5 commit b52cb84

File tree

3 files changed

+6
-1
lines changed

3 files changed

+6
-1
lines changed

array_api_compat/torch/_aliases.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -694,7 +694,7 @@ def triu(x: Array, /, *, k: int = 0) -> Array:
694694
return torch.triu(x, k)
695695

696696
# Functions that aren't in torch https://github.com/pytorch/pytorch/issues/58742
697-
def expand_dims(x: Array, /, *, axis: int | tuple[int, ...]) -> Array:
697+
def expand_dims(x: Array, /, axis: int | tuple[int, ...]) -> Array:
698698
if isinstance(axis, int):
699699
return torch.unsqueeze(x, axis)
700700
else:

dask-xfails.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,10 +135,13 @@ array_api_tests/test_signatures.py::test_func_signature[broadcast_shapes]
135135
array_api_tests/test_data_type_functions.py::TestBroadcastShapes::test_broadcast_shapes
136136
array_api_tests/test_data_type_functions.py::TestBroadcastShapes::test_empty
137137
array_api_tests/test_data_type_functions.py::TestBroadcastShapes::test_error
138+
array_api_tests/test_searching_functions.py::test_searchsorted_with_scalars
138139

139140
array_api_tests/test_linalg.py::test_eig
140141
array_api_tests/test_linalg.py::test_eigvals
141142

143+
array_api_tests/test_manipulation_functions.py::TestExpandDims::test_expand_dims_tuples
144+
142145
# Stubs have a comment: (**note**: libraries may return ``NaN`` to match Python behavior.)
143146
array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity]
144147
array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i < 0) -> -infinity]

numpy-1-26-xfails.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_where_with_sca
5050
array_api_tests/test_creation_functions.py::test_meshgrid
5151
array_api_tests/test_data_type_functions.py::test_broadcast_arrays
5252

53+
# observed with numpy==1.26 only, looks like is fixed on numpy 2.x
54+
array_api_tests/test_set_functions.py::TestIsin::test_isin_scalars
5355

5456
# Stubs have a comment: (**note**: libraries may return ``NaN`` to match Python behavior.); Apparently, NumPy does just that
5557
array_api_tests/test_special_cases.py::test_binary[floor_divide(x1_i is +infinity and isfinite(x2_i) and x2_i > 0) -> +infinity]

0 commit comments

Comments
 (0)