Skip to content

Commit 21df00f

Browse files
committedMar 31, 2024
FEAT: Add dimension merge function to merge contiguous axes
1 parent 9438d93 commit 21df00f

File tree

1 file changed

+74
-0
lines changed

1 file changed

+74
-0
lines changed
 

‎src/dimension/mod.rs

+74
Original file line numberDiff line numberDiff line change
@@ -759,6 +759,32 @@ where D: Dimension
759759
}
760760
}
761761

762+
/// Attempt to merge axes if possible, starting from the back
763+
///
764+
/// Given axes [Axis(0), Axis(1), Axis(2), Axis(3)] this attempts
765+
/// to merge all axes one by one into Axis(3); when/if this fails,
766+
/// it attempts to merge the rest of the axes together into the next
767+
/// axis in line, for example a result could be:
768+
///
769+
/// [1, Axis(0) + Axis(1), 1, Axis(2) + Axis(3)] where `+` would
770+
/// mean axes were merged.
771+
pub(crate) fn merge_axes_from_the_back<D>(dim: &mut D, strides: &mut D)
772+
where D: Dimension
773+
{
774+
debug_assert_eq!(dim.ndim(), strides.ndim());
775+
match dim.ndim() {
776+
0 | 1 => {}
777+
n => {
778+
let mut last = n - 1;
779+
for i in (0..last).rev() {
780+
if !merge_axes(dim, strides, Axis(i), Axis(last)) {
781+
last = i;
782+
}
783+
}
784+
}
785+
}
786+
}
787+
762788
/// Move the axis which has the smallest absolute stride and a length
763789
/// greater than one to be the last axis.
764790
pub fn move_min_stride_axis_to_last<D>(dim: &mut D, strides: &mut D)
@@ -822,6 +848,30 @@ where D: Dimension
822848
*strides = new_strides;
823849
}
824850

851+
/// Sort axes to standard/row major order, i.e Axis(0) has biggest stride and Axis(n - 1) least
852+
/// stride
853+
///
854+
/// The axes are sorted according to the .abs() of their stride.
855+
pub(crate) fn sort_axes_to_standard<D>(dim: &mut D, strides: &mut D)
856+
where D: Dimension
857+
{
858+
debug_assert!(dim.ndim() > 1);
859+
debug_assert_eq!(dim.ndim(), strides.ndim());
860+
// bubble sort axes
861+
let mut changed = true;
862+
while changed {
863+
changed = false;
864+
for i in 0..dim.ndim() - 1 {
865+
// make sure higher stride axes sort before.
866+
if strides.get_stride(Axis(i)).abs() < strides.get_stride(Axis(i + 1)).abs() {
867+
changed = true;
868+
dim.slice_mut().swap(i, i + 1);
869+
strides.slice_mut().swap(i, i + 1);
870+
}
871+
}
872+
}
873+
}
874+
825875
#[cfg(test)]
826876
mod test
827877
{
@@ -831,6 +881,7 @@ mod test
831881
can_index_slice_not_custom,
832882
extended_gcd,
833883
max_abs_offset_check_overflow,
884+
merge_axes_from_the_back,
834885
slice_min_max,
835886
slices_intersect,
836887
solve_linear_diophantine_eq,
@@ -1215,4 +1266,27 @@ mod test
12151266
assert_eq!(d, dans);
12161267
assert_eq!(s, sans);
12171268
}
1269+
1270+
#[test]
1271+
fn test_merge_axes_from_the_back()
1272+
{
1273+
let dyndim = Dim::<&[usize]>;
1274+
1275+
let mut d = Dim([3, 4, 5]);
1276+
let mut s = Dim([20, 5, 1]);
1277+
merge_axes_from_the_back(&mut d, &mut s);
1278+
assert_eq!(d, Dim([1, 1, 60]));
1279+
assert_eq!(s, Dim([20, 5, 1]));
1280+
1281+
let mut d = Dim([3, 4, 5, 2]);
1282+
let mut s = Dim([80, 20, 2, 1]);
1283+
merge_axes_from_the_back(&mut d, &mut s);
1284+
assert_eq!(d, Dim([1, 12, 1, 10]));
1285+
assert_eq!(s, Dim([80, 20, 2, 1]));
1286+
let mut d = d.into_dyn();
1287+
let mut s = s.into_dyn();
1288+
squeeze(&mut d, &mut s);
1289+
assert_eq!(d, dyndim(&[12, 10]));
1290+
assert_eq!(s, dyndim(&[20, 1]));
1291+
}
12181292
}

0 commit comments

Comments
 (0)