Skip to content

Commit e585c97

Browse files
authoredFeb 20, 2025··
Merge pull request #145 from crusaderky/test_helpers
TST: refactor test_utils
2 parents 6e4aba6 + 7b90a4e commit e585c97

File tree

2 files changed

+150
-153
lines changed

2 files changed

+150
-153
lines changed
 

‎tests/test_helpers.py

+150
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
from types import ModuleType
2+
3+
import numpy as np
4+
import pytest
5+
6+
from array_api_extra._lib import Backend
7+
from array_api_extra._lib._testing import xp_assert_equal
8+
from array_api_extra._lib._utils._compat import device as get_device
9+
from array_api_extra._lib._utils._helpers import asarrays, in1d
10+
from array_api_extra._lib._utils._typing import Device
11+
from array_api_extra.testing import lazy_xp_function
12+
13+
# mypy: disable-error-code=no-untyped-usage
14+
15+
# FIXME calls xp.unique_values without size
16+
lazy_xp_function(in1d, jax_jit=False, static_argnames=("assume_unique", "invert", "xp"))
17+
18+
19+
class TestIn1D:
20+
@pytest.mark.xfail_xp_backend(
21+
Backend.SPARSE, reason="no unique_inverse, no device kwarg in asarray()"
22+
)
23+
# cover both code paths
24+
@pytest.mark.parametrize(
25+
"n",
26+
[
27+
pytest.param(9, id="fast path"),
28+
pytest.param(
29+
15,
30+
id="slow path",
31+
marks=pytest.mark.xfail_xp_backend(
32+
Backend.DASK, reason="NaN-shaped array"
33+
),
34+
),
35+
],
36+
)
37+
def test_no_invert_assume_unique(self, xp: ModuleType, n: int):
38+
x1 = xp.asarray([3, 8, 20])
39+
x2 = xp.arange(n)
40+
expected = xp.asarray([True, True, False])
41+
actual = in1d(x1, x2)
42+
xp_assert_equal(actual, expected)
43+
44+
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no device kwarg in asarray")
45+
def test_device(self, xp: ModuleType, device: Device):
46+
x1 = xp.asarray([3, 8, 20], device=device)
47+
x2 = xp.asarray([2, 3, 4], device=device)
48+
assert get_device(in1d(x1, x2)) == device
49+
50+
@pytest.mark.skip_xp_backend(Backend.NUMPY_READONLY, reason="xp=xp")
51+
@pytest.mark.xfail_xp_backend(
52+
Backend.SPARSE, reason="no arange, no device kwarg in asarray"
53+
)
54+
def test_xp(self, xp: ModuleType):
55+
x1 = xp.asarray([1, 6])
56+
x2 = xp.arange(5)
57+
expected = xp.asarray([True, False])
58+
actual = in1d(x1, x2, xp=xp)
59+
xp_assert_equal(actual, expected)
60+
61+
62+
class TestAsArrays:
63+
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no isdtype")
64+
@pytest.mark.parametrize(
65+
("dtype", "b", "defined"),
66+
[
67+
# Well-defined cases of dtype promotion from Python scalar to Array
68+
# bool vs. bool
69+
("bool", True, True),
70+
# int vs. xp.*int*, xp.float*, xp.complex*
71+
("int16", 1, True),
72+
("uint8", 1, True),
73+
("float32", 1, True),
74+
("float64", 1, True),
75+
("complex64", 1, True),
76+
("complex128", 1, True),
77+
# float vs. xp.float, xp.complex
78+
("float32", 1.0, True),
79+
("float64", 1.0, True),
80+
("complex64", 1.0, True),
81+
("complex128", 1.0, True),
82+
# complex vs. xp.complex
83+
("complex64", 1.0j, True),
84+
("complex128", 1.0j, True),
85+
# Undefined cases
86+
("bool", 1, False),
87+
("int64", 1.0, False),
88+
("float64", 1.0j, False),
89+
],
90+
)
91+
def test_array_vs_scalar(
92+
self, dtype: str, b: int | float | complex, defined: bool, xp: ModuleType
93+
):
94+
a = xp.asarray(1, dtype=getattr(xp, dtype))
95+
96+
xa, xb = asarrays(a, b, xp)
97+
assert xa.dtype == a.dtype
98+
if defined:
99+
assert xb.dtype == a.dtype
100+
else:
101+
assert xb.dtype == xp.asarray(b).dtype
102+
103+
xbr, xar = asarrays(b, a, xp)
104+
assert xar.dtype == xa.dtype
105+
assert xbr.dtype == xb.dtype
106+
107+
def test_scalar_vs_scalar(self, xp: ModuleType):
108+
a, b = asarrays(1, 2.2, xp=xp)
109+
assert a.dtype == xp.asarray(1).dtype # Default dtype
110+
assert b.dtype == xp.asarray(2.2).dtype # Default dtype; not broadcasted
111+
112+
ALL_TYPES: tuple[str, ...] = (
113+
"int8",
114+
"int16",
115+
"int32",
116+
"int64",
117+
"uint8",
118+
"uint16",
119+
"uint32",
120+
"uint64",
121+
"float32",
122+
"float64",
123+
"complex64",
124+
"complex128",
125+
"bool",
126+
)
127+
128+
@pytest.mark.parametrize("a_type", ALL_TYPES)
129+
@pytest.mark.parametrize("b_type", ALL_TYPES)
130+
def test_array_vs_array(self, a_type: str, b_type: str, xp: ModuleType):
131+
"""
132+
Test that when both inputs of asarray are already Array API objects,
133+
they are returned unchanged.
134+
"""
135+
a = xp.asarray(1, dtype=getattr(xp, a_type))
136+
b = xp.asarray(1, dtype=getattr(xp, b_type))
137+
xa, xb = asarrays(a, b, xp)
138+
assert xa.dtype == a.dtype
139+
assert xb.dtype == b.dtype
140+
141+
@pytest.mark.parametrize("dtype", [np.float64, np.complex128])
142+
def test_numpy_generics(self, dtype: type):
143+
"""
144+
Test special case of np.float64 and np.complex128,
145+
which are subclasses of float and complex.
146+
"""
147+
a = dtype(0)
148+
xa, xb = asarrays(a, 0, xp=np)
149+
assert xa.dtype == dtype
150+
assert xb.dtype == dtype

‎tests/test_utils.py

-153
This file was deleted.

0 commit comments

Comments
 (0)
Please sign in to comment.