1+ use super :: Error ;
12use crate :: geometry:: Mbr ;
23use crate :: geometry:: PointND ;
34use async_lock:: Mutex ;
@@ -162,7 +163,7 @@ where
162163 . map ( |item| item. point [ coord] )
163164 . minmax ( )
164165 . into_option ( )
165- . unwrap ( ) ;
166+ . unwrap ( ) ; // Won't panic because items has at least two elements.
166167
167168 mem:: drop ( enter) ;
168169
@@ -501,7 +502,12 @@ fn rcb_thread<const D: usize, W>(
501502 futures_lite:: future:: block_on ( task) ;
502503}
503504
504- fn rcb < const D : usize , P , W > ( partition : & mut [ usize ] , points : P , weights : W , iter_count : usize )
505+ fn rcb < const D : usize , P , W > (
506+ partition : & mut [ usize ] ,
507+ points : P ,
508+ weights : W ,
509+ iter_count : usize ,
510+ ) -> Result < ( ) , Error >
505511where
506512 P : rayon:: iter:: IntoParallelIterator < Item = PointND < D > > ,
507513 P :: Iter : rayon:: iter:: IndexedParallelIterator ,
@@ -514,8 +520,18 @@ where
514520 let points = points. into_par_iter ( ) ;
515521 let weights = weights. into_par_iter ( ) ;
516522
517- assert_eq ! ( points. len( ) , weights. len( ) ) ;
518- assert_eq ! ( points. len( ) , partition. len( ) ) ;
523+ if weights. len ( ) != partition. len ( ) {
524+ return Err ( Error :: InputLenMismatch {
525+ expected : partition. len ( ) ,
526+ actual : weights. len ( ) ,
527+ } ) ;
528+ }
529+ if points. len ( ) != partition. len ( ) {
530+ return Err ( Error :: InputLenMismatch {
531+ expected : partition. len ( ) ,
532+ actual : points. len ( ) ,
533+ } ) ;
534+ }
519535
520536 let init_span = tracing:: info_span!( "convert input and make initial data structures" ) ;
521537 let enter = init_span. enter ( ) ;
@@ -546,6 +562,8 @@ where
546562 s. spawn ( move |_| rcb_thread ( iteration_ctxs, chunk, iter_count, 0.05 ) ) ;
547563 }
548564 } ) ;
565+
566+ Ok ( ( ) )
549567}
550568
551569/// # Recursive Coordinate Bisection algorithm
@@ -602,15 +620,18 @@ where
602620 W :: Iter : rayon:: iter:: IndexedParallelIterator ,
603621{
604622 type Metadata = ( ) ;
605- type Error = std :: convert :: Infallible ;
623+ type Error = Error ;
606624
607625 fn partition (
608626 & mut self ,
609627 part_ids : & mut [ usize ] ,
610628 ( points, weights) : ( P , W ) ,
611629 ) -> Result < Self :: Metadata , Self :: Error > {
612- rcb ( part_ids, points, weights, self . iter_count ) ;
613- Ok ( ( ) )
630+ if part_ids. len ( ) < 2 {
631+ // Would make Itertools::minmax().into_option() return None.
632+ return Ok ( ( ) ) ;
633+ }
634+ rcb ( part_ids, points, weights, self . iter_count )
614635 }
615636}
616637
@@ -647,7 +668,12 @@ pub fn axis_sort<const D: usize>(
647668/// The global shape of the data is first considered and the separator is computed to
648669/// be parallel to the inertia axis of the global shape, which aims to lead to better shaped
649670/// partitions.
650- fn rib < const D : usize , W > ( partition : & mut [ usize ] , points : & [ PointND < D > ] , weights : W , n_iter : usize )
671+ fn rib < const D : usize , W > (
672+ partition : & mut [ usize ] ,
673+ points : & [ PointND < D > ] ,
674+ weights : W ,
675+ n_iter : usize ,
676+ ) -> Result < ( ) , Error >
651677where
652678 Const < D > : DimSub < Const < 1 > > ,
653679 DefaultAllocator : Allocator < f64 , Const < D > , Const < D > , Buffer = ArrayStorage < f64 , D , D > >
@@ -658,15 +684,8 @@ where
658684 W :: Item : num:: ToPrimitive ,
659685 W :: Iter : rayon:: iter:: IndexedParallelIterator ,
660686{
661- let weights = weights. into_par_iter ( ) ;
662-
663- assert_eq ! ( points. len( ) , weights. len( ) ) ;
664- assert_eq ! ( points. len( ) , partition. len( ) ) ;
665-
666687 let mbr = Mbr :: from_points ( points) ;
667-
668688 let points = points. par_iter ( ) . map ( |p| mbr. mbr_to_aabb ( p) ) ;
669-
670689 // When the rotation is done, we just apply RCB
671690 rcb ( partition, points, weights, n_iter)
672691}
@@ -727,15 +746,14 @@ where
727746 W :: Iter : rayon:: iter:: IndexedParallelIterator ,
728747{
729748 type Metadata = ( ) ;
730- type Error = std :: convert :: Infallible ;
749+ type Error = Error ;
731750
732751 fn partition (
733752 & mut self ,
734753 part_ids : & mut [ usize ] ,
735754 ( points, weights) : ( & ' a [ PointND < D > ] , W ) ,
736755 ) -> Result < Self :: Metadata , Self :: Error > {
737- rib ( part_ids, points, weights, self . iter_count ) ;
738- Ok ( ( ) )
756+ rib ( part_ids, points, weights, self . iter_count )
739757 }
740758}
741759
@@ -795,7 +813,8 @@ mod tests {
795813 . num_threads ( 1 ) // make the test deterministic
796814 . build ( )
797815 . unwrap ( )
798- . install ( || rcb ( & mut partition, points, weights, 2 ) ) ;
816+ . install ( || rcb ( & mut partition, points, weights, 2 ) )
817+ . unwrap ( ) ;
799818
800819 assert_eq ! ( partition[ 0 ] , partition[ 6 ] ) ;
801820 assert_eq ! ( partition[ 1 ] , partition[ 7 ] ) ;
@@ -825,7 +844,7 @@ mod tests {
825844 let weights: Vec < f64 > = ( 0 ..points. len ( ) ) . map ( |_| rand:: random ( ) ) . collect ( ) ;
826845
827846 let mut partition = vec ! [ 0 ; points. len( ) ] ;
828- rcb ( & mut partition, points, weights. par_iter ( ) . cloned ( ) , 3 ) ;
847+ rcb ( & mut partition, points, weights. par_iter ( ) . cloned ( ) , 3 ) . unwrap ( ) ;
829848
830849 let mut loads: HashMap < usize , f64 > = HashMap :: new ( ) ;
831850 let mut sizes: HashMap < usize , usize > = HashMap :: new ( ) ;
0 commit comments