Skip to content

Commit ae90079

Browse files
committed
extended tests
1 parent e6df613 commit ae90079

File tree

2 files changed

+142
-0
lines changed

2 files changed

+142
-0
lines changed

src/lib.rs

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,80 @@ fn first_true_1d_f(py: Python, array: PyReadonlyArray1<bool>) -> isize {
180180
}
181181

182182

183+
#[pyfunction]
184+
#[pyo3(signature = (array, forward=true))]
185+
fn first_true_1d(py: Python, array: PyReadonlyArray1<bool>, forward: bool) -> isize {
186+
if let Ok(slice) = array.as_slice() {
187+
py.allow_threads(|| {
188+
let len = slice.len();
189+
let ptr = slice.as_ptr() as *const u8;
190+
let ones = u8x32::splat(1);
191+
192+
if forward {
193+
let mut i = 0;
194+
unsafe {
195+
// Process 32 bytes at a time with SIMD
196+
while i + 32 <= len {
197+
let bytes = &*(ptr.add(i) as *const [u8; 32]);
198+
let chunk = u8x32::from(*bytes);
199+
let equal_one = chunk.cmp_eq(ones);
200+
if equal_one.any() {
201+
break;
202+
}
203+
i += 32;
204+
}
205+
// Handle final remainder
206+
while i < len.min(i + 32) {
207+
if *ptr.add(i) != 0 {
208+
return i as isize;
209+
}
210+
i += 1;
211+
}
212+
}
213+
} else {
214+
// Backward search
215+
let mut i = len;
216+
unsafe {
217+
// Process 32 bytes at a time with SIMD (backwards)
218+
while i >= 32 {
219+
i -= 32;
220+
let bytes = &*(ptr.add(i) as *const [u8; 32]);
221+
let chunk = u8x32::from(*bytes);
222+
let equal_one = chunk.cmp_eq(ones);
223+
if equal_one.any() {
224+
// Found a true in this chunk, search backwards within it
225+
for j in (0..32).rev() {
226+
if i + j < len && *ptr.add(i + j) != 0 {
227+
return (i + j) as isize;
228+
}
229+
}
230+
}
231+
}
232+
// Handle remaining bytes at the beginning
233+
if i > 0 {
234+
for j in (0..i).rev() {
235+
if *ptr.add(j) != 0 {
236+
return j as isize;
237+
}
238+
}
239+
}
240+
}
241+
}
242+
-1
243+
})
244+
} else {
245+
let array_view = array.as_array();
246+
py.allow_threads(|| {
247+
if forward {
248+
array_view.iter().position(|&v| v).map(|i| i as isize).unwrap_or(-1)
249+
} else {
250+
array_view.iter().rposition(|&v| v).map(|i| i as isize).unwrap_or(-1)
251+
}
252+
})
253+
}
254+
}
255+
256+
183257
#[pymodule]
184258
fn arrayredox(m: &Bound<'_, PyModule>) -> PyResult<()> {
185259
m.add_function(wrap_pyfunction!(first_true_1d_a, m)?)?;
@@ -188,5 +262,6 @@ fn arrayredox(m: &Bound<'_, PyModule>) -> PyResult<()> {
188262
m.add_function(wrap_pyfunction!(first_true_1d_d, m)?)?;
189263
m.add_function(wrap_pyfunction!(first_true_1d_e, m)?)?;
190264
m.add_function(wrap_pyfunction!(first_true_1d_f, m)?)?;
265+
m.add_function(wrap_pyfunction!(first_true_1d, m)?)?;
191266
Ok(())
192267
}

tests/test_basic.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,71 @@ def test_first_true_1d():
1010
assert first_true_1d(b) == -1
1111

1212

13+
#--------------------------------------------------------------------------
14+
def test_first_true_1d_a() -> None:
15+
a1 = np.arange(100) == 50
16+
post = first_true_1d(a1, forward=True)
17+
assert post == 50
1318

19+
# def test_first_true_1d_b() -> None:
20+
# with self.assertRaises(TypeError):
21+
# a1 = [2, 4, 5,]
22+
# first_true_1d(a1, forward=True)
23+
24+
# def test_first_true_1d_c() -> None:
25+
# with self.assertRaises(ValueError):
26+
# a1 = np.arange(100) == 50
27+
# first_true_1d(a1, forward=a1)
28+
29+
def test_first_true_1d_d() -> None:
30+
a1 = np.arange(100) < 0
31+
post = first_true_1d(a1, forward=True)
32+
assert post == -1
33+
34+
# def test_first_true_1d_e() -> None:
35+
# a1 = np.arange(100)
36+
# # only a Boolean array
37+
# with self.assertRaises(ValueError):
38+
# post = first_true_1d(a1, forward=True)
39+
40+
# def test_first_true_1d_f() -> None:
41+
# a1 = (np.arange(100) == 0)[:50:2]
42+
# # only a contiguous array
43+
# with self.assertRaises(ValueError):
44+
# post = first_true_1d(a1, forward=True)
45+
46+
# def test_first_true_1d_g() -> None:
47+
# a1 = (np.arange(100) == 0).reshape(10, 10)
48+
# # only a contiguous array
49+
# with self.assertRaises(ValueError):
50+
# post = first_true_1d(a1, forward=True)
51+
52+
def test_first_true_1d_reverse_a() -> None:
53+
a1 = np.arange(100) == 50
54+
post = first_true_1d(a1, forward=False)
55+
assert post == 50
56+
57+
def test_first_true_1d_reverse_b() -> None:
58+
a1 = np.arange(100) == 0
59+
post = first_true_1d(a1, forward=False)
60+
assert post == 0
61+
62+
def test_first_true_1d_reverse_c() -> None:
63+
a1 = np.arange(100) == -1
64+
post = first_true_1d(a1, forward=False)
65+
assert post == -1
66+
67+
def test_first_true_1d_reverse_d() -> None:
68+
a1 = np.arange(100) == 99
69+
post = first_true_1d(a1, forward=False)
70+
assert post == 99
71+
72+
def test_first_true_1d_multi_a() -> None:
73+
a1 = np.isin(np.arange(100), (50, 70, 90))
74+
assert first_true_1d(a1, forward=True) == 50
75+
assert first_true_1d(a1, forward=False) == 90
76+
77+
def test_first_true_1d_multi_b() -> None:
78+
a1 = np.isin(np.arange(100), (10, 30, 50))
79+
assert first_true_1d(a1, forward=True) == 10
80+
assert first_true_1d(a1, forward=False) == 50

0 commit comments

Comments
 (0)