@@ -759,6 +759,32 @@ where D: Dimension
759
759
}
760
760
}
761
761
762
+ /// Attempt to merge axes if possible, starting from the back
763
+ ///
764
+ /// Given axes [Axis(0), Axis(1), Axis(2), Axis(3)] this attempts
765
+ /// to merge all axes one by one into Axis(3); when/if this fails,
766
+ /// it attempts to merge the rest of the axes together into the next
767
+ /// axis in line, for example a result could be:
768
+ ///
769
+ /// [1, Axis(0) + Axis(1), 1, Axis(2) + Axis(3)] where `+` would
770
+ /// mean axes were merged.
771
+ pub ( crate ) fn merge_axes_from_the_back < D > ( dim : & mut D , strides : & mut D )
772
+ where D : Dimension
773
+ {
774
+ debug_assert_eq ! ( dim. ndim( ) , strides. ndim( ) ) ;
775
+ match dim. ndim ( ) {
776
+ 0 | 1 => { }
777
+ n => {
778
+ let mut last = n - 1 ;
779
+ for i in ( 0 ..last) . rev ( ) {
780
+ if !merge_axes ( dim, strides, Axis ( i) , Axis ( last) ) {
781
+ last = i;
782
+ }
783
+ }
784
+ }
785
+ }
786
+ }
787
+
762
788
/// Move the axis which has the smallest absolute stride and a length
763
789
/// greater than one to be the last axis.
764
790
pub fn move_min_stride_axis_to_last < D > ( dim : & mut D , strides : & mut D )
@@ -822,6 +848,30 @@ where D: Dimension
822
848
* strides = new_strides;
823
849
}
824
850
851
+ /// Sort axes to standard/row major order, i.e Axis(0) has biggest stride and Axis(n - 1) least
852
+ /// stride
853
+ ///
854
+ /// The axes are sorted according to the .abs() of their stride.
855
+ pub ( crate ) fn sort_axes_to_standard < D > ( dim : & mut D , strides : & mut D )
856
+ where D : Dimension
857
+ {
858
+ debug_assert ! ( dim. ndim( ) > 1 ) ;
859
+ debug_assert_eq ! ( dim. ndim( ) , strides. ndim( ) ) ;
860
+ // bubble sort axes
861
+ let mut changed = true ;
862
+ while changed {
863
+ changed = false ;
864
+ for i in 0 ..dim. ndim ( ) - 1 {
865
+ // make sure higher stride axes sort before.
866
+ if strides. get_stride ( Axis ( i) ) . abs ( ) < strides. get_stride ( Axis ( i + 1 ) ) . abs ( ) {
867
+ changed = true ;
868
+ dim. slice_mut ( ) . swap ( i, i + 1 ) ;
869
+ strides. slice_mut ( ) . swap ( i, i + 1 ) ;
870
+ }
871
+ }
872
+ }
873
+ }
874
+
825
875
#[ cfg( test) ]
826
876
mod test
827
877
{
@@ -831,6 +881,7 @@ mod test
831
881
can_index_slice_not_custom,
832
882
extended_gcd,
833
883
max_abs_offset_check_overflow,
884
+ merge_axes_from_the_back,
834
885
slice_min_max,
835
886
slices_intersect,
836
887
solve_linear_diophantine_eq,
@@ -1215,4 +1266,27 @@ mod test
1215
1266
assert_eq ! ( d, dans) ;
1216
1267
assert_eq ! ( s, sans) ;
1217
1268
}
1269
+
1270
+ #[ test]
1271
+ fn test_merge_axes_from_the_back ( )
1272
+ {
1273
+ let dyndim = Dim :: < & [ usize ] > ;
1274
+
1275
+ let mut d = Dim ( [ 3 , 4 , 5 ] ) ;
1276
+ let mut s = Dim ( [ 20 , 5 , 1 ] ) ;
1277
+ merge_axes_from_the_back ( & mut d, & mut s) ;
1278
+ assert_eq ! ( d, Dim ( [ 1 , 1 , 60 ] ) ) ;
1279
+ assert_eq ! ( s, Dim ( [ 20 , 5 , 1 ] ) ) ;
1280
+
1281
+ let mut d = Dim ( [ 3 , 4 , 5 , 2 ] ) ;
1282
+ let mut s = Dim ( [ 80 , 20 , 2 , 1 ] ) ;
1283
+ merge_axes_from_the_back ( & mut d, & mut s) ;
1284
+ assert_eq ! ( d, Dim ( [ 1 , 12 , 1 , 10 ] ) ) ;
1285
+ assert_eq ! ( s, Dim ( [ 80 , 20 , 2 , 1 ] ) ) ;
1286
+ let mut d = d. into_dyn ( ) ;
1287
+ let mut s = s. into_dyn ( ) ;
1288
+ squeeze ( & mut d, & mut s) ;
1289
+ assert_eq ! ( d, dyndim( & [ 12 , 10 ] ) ) ;
1290
+ assert_eq ! ( s, dyndim( & [ 20 , 1 ] ) ) ;
1291
+ }
1218
1292
}
0 commit comments