diff --git a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp index 6bda073b4..71bb41956 100644 --- a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp +++ b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp @@ -13571,6 +13571,53 @@ struct NoopReverse final } }; +// reverse(reverse(x)) -> x or reverse(x) with reduced dimensions +// When we have two consecutive reverse operations, dimensions that appear +// in both cancel out, and we only need to reverse the symmetric difference. +struct ReverseReverse final + : CheckedOpRewritePattern { + using CheckedOpRewritePattern::CheckedOpRewritePattern; + + LogicalResult matchAndRewriteImpl(stablehlo::ReverseOp op, + PatternRewriter &rewriter) const { + auto prevReverse = op.getOperand().getDefiningOp(); + if (!prevReverse) + return failure(); + + // Get dimensions from both reverse operations + auto outerDims = op.getDimensions(); + auto innerDims = prevReverse.getDimensions(); + + // Compute the symmetric difference of dimensions using a single set: + // - Dimensions in both cancel out (reverse twice = identity) + // - Dimensions in only one remain + // XOR-like operation: add if not present, remove if already present + llvm::SmallDenseSet dimSet; + for (int64_t dim : innerDims) { + dimSet.insert(dim); + } + for (int64_t dim : outerDims) { + auto [it, inserted] = dimSet.insert(dim); + if (!inserted) { + // Dimension was in both - cancel out + dimSet.erase(it); + } + } + + if (dimSet.empty()) { + // Both reverses cancel out completely + rewriter.replaceOp(op, prevReverse.getOperand()); + } else { + // Convert set to sorted vector for canonical form + SmallVector newDimensions(dimSet.begin(), dimSet.end()); + llvm::sort(newDimensions); + rewriter.replaceOpWithNewOp( + op, prevReverse.getOperand(), newDimensions); + } + return success(); + } +}; + /// Converts gather ops to slice ops in case we have a single set of constant /// indices. struct GatherOpCanon final @@ -26055,7 +26102,7 @@ struct EnzymeHLOOptPass patterns.add< AddSimplify, SubSimplify, AndSimplify, MaxSimplify, MinSimplify, OrSimplify, XorSimplify, MulSimplify, DivSimplify, RemSimplify, - PowSimplify, NoopSlice, NoopReverse, SliceSlice, + PowSimplify, NoopSlice, NoopReverse, ReverseReverse, SliceSlice, DynamicSliceDynamicSlice, DynamicSliceSlice, SliceDynamicSlice, LogSimplify, ShiftRightLogicalSimplify, NegativePadToSlice, SliceSimplify, ConvertSimplify, TransposeSimplify, DotGeneralSimplify, diff --git a/src/enzyme_ad/jax/TransformOps/TransformOps.td b/src/enzyme_ad/jax/TransformOps/TransformOps.td index dc0d6f760..f3207c8b9 100644 --- a/src/enzyme_ad/jax/TransformOps/TransformOps.td +++ b/src/enzyme_ad/jax/TransformOps/TransformOps.td @@ -93,6 +93,10 @@ def ApplyNoopReversePatterns : EnzymeHLOPatternOp< "noop_reverse"> { let patterns = ["NoopReverse"]; } +def ApplyReverseReversePatterns : EnzymeHLOPatternOp< + "reverse_reverse"> { + let patterns = ["ReverseReverse"]; +} def ApplySliceSlicePatterns : EnzymeHLOPatternOp< "slice_slice"> { let patterns = ["SliceSlice"]; diff --git a/test/lit_tests/reversereverse.mlir b/test/lit_tests/reversereverse.mlir new file mode 100644 index 000000000..2e7fa6751 --- /dev/null +++ b/test/lit_tests/reversereverse.mlir @@ -0,0 +1,62 @@ +// RUN: enzymexlamlir-opt --enzyme-hlo-opt %s | FileCheck %s + +// Test: reverse(reverse(x)) with same dimensions -> x +// CHECK-LABEL: @reverse_reverse_same_dims +// CHECK-NOT: stablehlo.reverse +func.func @reverse_reverse_same_dims(%arg0: tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { + %0 = stablehlo.reverse %arg0, dims = [0, 1] : tensor<8x4x3xf32> + %1 = stablehlo.reverse %0, dims = [0, 1] : tensor<8x4x3xf32> + return %1 : tensor<8x4x3xf32> +} + +// Test: reverse(reverse(x)) with single dimension -> x +// CHECK-LABEL: @reverse_reverse_single_dim +// CHECK-NOT: stablehlo.reverse +func.func @reverse_reverse_single_dim(%arg0: tensor<8x4xf32>) -> tensor<8x4xf32> { + %0 = stablehlo.reverse %arg0, dims = [0] : tensor<8x4xf32> + %1 = stablehlo.reverse %0, dims = [0] : tensor<8x4xf32> + return %1 : tensor<8x4xf32> +} + +// Test: reverse(reverse(x)) with partial overlap -> single reverse +// Inner dims [0, 1], outer dims [1, 2] -> effective dims [0, 2] +// CHECK-LABEL: @reverse_reverse_partial_overlap +// CHECK: %[[R:.*]] = stablehlo.reverse %arg0, dims = [0, 2] : tensor<8x4x3xf32> +// CHECK-NEXT: return %[[R]] +func.func @reverse_reverse_partial_overlap(%arg0: tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { + %0 = stablehlo.reverse %arg0, dims = [0, 1] : tensor<8x4x3xf32> + %1 = stablehlo.reverse %0, dims = [1, 2] : tensor<8x4x3xf32> + return %1 : tensor<8x4x3xf32> +} + +// Test: reverse(reverse(x)) with disjoint dimensions -> combined reverse +// CHECK-LABEL: @reverse_reverse_disjoint_dims +// CHECK: %[[R:.*]] = stablehlo.reverse %arg0, dims = [0, 1, 2] : tensor<8x4x3xf32> +// CHECK-NEXT: return %[[R]] +func.func @reverse_reverse_disjoint_dims(%arg0: tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { + %0 = stablehlo.reverse %arg0, dims = [0, 1] : tensor<8x4x3xf32> + %1 = stablehlo.reverse %0, dims = [2] : tensor<8x4x3xf32> + return %1 : tensor<8x4x3xf32> +} + +// Test: reverse(reverse(x)) with subset dimensions +// Inner has more dims, outer has subset -> remaining dims from inner +// CHECK-LABEL: @reverse_reverse_subset +// CHECK: %[[R:.*]] = stablehlo.reverse %arg0, dims = [2] : tensor<8x4x3xf32> +// CHECK-NEXT: return %[[R]] +func.func @reverse_reverse_subset(%arg0: tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { + %0 = stablehlo.reverse %arg0, dims = [0, 1, 2] : tensor<8x4x3xf32> + %1 = stablehlo.reverse %0, dims = [0, 1] : tensor<8x4x3xf32> + return %1 : tensor<8x4x3xf32> +} + +// Test: 3 reverses - should optimize first two, then optimize result with third +// CHECK-LABEL: @reverse_three_times +// CHECK: %[[R:.*]] = stablehlo.reverse %arg0, dims = [0] : tensor<8x4xf32> +// CHECK-NEXT: return %[[R]] +func.func @reverse_three_times(%arg0: tensor<8x4xf32>) -> tensor<8x4xf32> { + %0 = stablehlo.reverse %arg0, dims = [0] : tensor<8x4xf32> + %1 = stablehlo.reverse %0, dims = [0] : tensor<8x4xf32> + %2 = stablehlo.reverse %1, dims = [0] : tensor<8x4xf32> + return %2 : tensor<8x4xf32> +}