Skip to content

Commit 90727de

Browse files
committed
rcb: replace panics with errors
1 parent 04b6a2f commit 90727de

File tree

1 file changed

+39
-20
lines changed

1 file changed

+39
-20
lines changed

src/algorithms/recursive_bisection.rs

Lines changed: 39 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use super::Error;
12
use crate::geometry::Mbr;
23
use crate::geometry::PointND;
34
use async_lock::Mutex;
@@ -162,7 +163,7 @@ where
162163
.map(|item| item.point[coord])
163164
.minmax()
164165
.into_option()
165-
.unwrap();
166+
.unwrap(); // Won't panic because items has at least two elements.
166167

167168
mem::drop(enter);
168169

@@ -501,7 +502,12 @@ fn rcb_thread<const D: usize, W>(
501502
futures_lite::future::block_on(task);
502503
}
503504

504-
fn rcb<const D: usize, P, W>(partition: &mut [usize], points: P, weights: W, iter_count: usize)
505+
fn rcb<const D: usize, P, W>(
506+
partition: &mut [usize],
507+
points: P,
508+
weights: W,
509+
iter_count: usize,
510+
) -> Result<(), Error>
505511
where
506512
P: rayon::iter::IntoParallelIterator<Item = PointND<D>>,
507513
P::Iter: rayon::iter::IndexedParallelIterator,
@@ -514,8 +520,18 @@ where
514520
let points = points.into_par_iter();
515521
let weights = weights.into_par_iter();
516522

517-
assert_eq!(points.len(), weights.len());
518-
assert_eq!(points.len(), partition.len());
523+
if weights.len() != partition.len() {
524+
return Err(Error::InputLenMismatch {
525+
expected: partition.len(),
526+
actual: weights.len(),
527+
});
528+
}
529+
if points.len() != partition.len() {
530+
return Err(Error::InputLenMismatch {
531+
expected: partition.len(),
532+
actual: points.len(),
533+
});
534+
}
519535

520536
let init_span = tracing::info_span!("convert input and make initial data structures");
521537
let enter = init_span.enter();
@@ -546,6 +562,8 @@ where
546562
s.spawn(move |_| rcb_thread(iteration_ctxs, chunk, iter_count, 0.05));
547563
}
548564
});
565+
566+
Ok(())
549567
}
550568

551569
/// # Recursive Coordinate Bisection algorithm
@@ -602,15 +620,18 @@ where
602620
W::Iter: rayon::iter::IndexedParallelIterator,
603621
{
604622
type Metadata = ();
605-
type Error = std::convert::Infallible;
623+
type Error = Error;
606624

607625
fn partition(
608626
&mut self,
609627
part_ids: &mut [usize],
610628
(points, weights): (P, W),
611629
) -> Result<Self::Metadata, Self::Error> {
612-
rcb(part_ids, points, weights, self.iter_count);
613-
Ok(())
630+
if part_ids.len() < 2 {
631+
// Would make Itertools::minmax().into_option() return None.
632+
return Ok(());
633+
}
634+
rcb(part_ids, points, weights, self.iter_count)
614635
}
615636
}
616637

@@ -647,7 +668,12 @@ pub fn axis_sort<const D: usize>(
647668
/// The global shape of the data is first considered and the separator is computed to
648669
/// be parallel to the inertia axis of the global shape, which aims to lead to better shaped
649670
/// partitions.
650-
fn rib<const D: usize, W>(partition: &mut [usize], points: &[PointND<D>], weights: W, n_iter: usize)
671+
fn rib<const D: usize, W>(
672+
partition: &mut [usize],
673+
points: &[PointND<D>],
674+
weights: W,
675+
n_iter: usize,
676+
) -> Result<(), Error>
651677
where
652678
Const<D>: DimSub<Const<1>>,
653679
DefaultAllocator: Allocator<f64, Const<D>, Const<D>, Buffer = ArrayStorage<f64, D, D>>
@@ -658,15 +684,8 @@ where
658684
W::Item: num::ToPrimitive,
659685
W::Iter: rayon::iter::IndexedParallelIterator,
660686
{
661-
let weights = weights.into_par_iter();
662-
663-
assert_eq!(points.len(), weights.len());
664-
assert_eq!(points.len(), partition.len());
665-
666687
let mbr = Mbr::from_points(points);
667-
668688
let points = points.par_iter().map(|p| mbr.mbr_to_aabb(p));
669-
670689
// When the rotation is done, we just apply RCB
671690
rcb(partition, points, weights, n_iter)
672691
}
@@ -727,15 +746,14 @@ where
727746
W::Iter: rayon::iter::IndexedParallelIterator,
728747
{
729748
type Metadata = ();
730-
type Error = std::convert::Infallible;
749+
type Error = Error;
731750

732751
fn partition(
733752
&mut self,
734753
part_ids: &mut [usize],
735754
(points, weights): (&'a [PointND<D>], W),
736755
) -> Result<Self::Metadata, Self::Error> {
737-
rib(part_ids, points, weights, self.iter_count);
738-
Ok(())
756+
rib(part_ids, points, weights, self.iter_count)
739757
}
740758
}
741759

@@ -795,7 +813,8 @@ mod tests {
795813
.num_threads(1) // make the test deterministic
796814
.build()
797815
.unwrap()
798-
.install(|| rcb(&mut partition, points, weights, 2));
816+
.install(|| rcb(&mut partition, points, weights, 2))
817+
.unwrap();
799818

800819
assert_eq!(partition[0], partition[6]);
801820
assert_eq!(partition[1], partition[7]);
@@ -825,7 +844,7 @@ mod tests {
825844
let weights: Vec<f64> = (0..points.len()).map(|_| rand::random()).collect();
826845

827846
let mut partition = vec![0; points.len()];
828-
rcb(&mut partition, points, weights.par_iter().cloned(), 3);
847+
rcb(&mut partition, points, weights.par_iter().cloned(), 3).unwrap();
829848

830849
let mut loads: HashMap<usize, f64> = HashMap::new();
831850
let mut sizes: HashMap<usize, usize> = HashMap::new();

0 commit comments

Comments
 (0)