From e436e1fa63445e398f76eff9c5ec3f3fb7403171 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 25 Nov 2025 05:07:01 +0000 Subject: [PATCH 1/3] Initial plan From da57239a2c753336f803e42688e4e111480a931c Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 25 Nov 2025 05:13:56 +0000 Subject: [PATCH 2/3] Add ReverseReverse optimization pattern and test Co-authored-by: wsmoses <1260124+wsmoses@users.noreply.github.com> --- src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp | 53 +++++++++++++++- .../jax/TransformOps/TransformOps.td | 4 ++ test/lit_tests/reversereverse.mlir | 62 +++++++++++++++++++ 3 files changed, 118 insertions(+), 1 deletion(-) create mode 100644 test/lit_tests/reversereverse.mlir diff --git a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp index 6bda073b4..7c5d107c6 100644 --- a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp +++ b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp @@ -13571,6 +13571,57 @@ struct NoopReverse final } }; +// reverse(reverse(x)) -> x or reverse(x) with reduced dimensions +// When we have two consecutive reverse operations, dimensions that appear +// in both cancel out, and we only need to reverse the symmetric difference. +struct ReverseReverse final + : CheckedOpRewritePattern { + using CheckedOpRewritePattern::CheckedOpRewritePattern; + + LogicalResult matchAndRewriteImpl(stablehlo::ReverseOp op, + PatternRewriter &rewriter) const { + auto prevReverse = op.getOperand().getDefiningOp(); + if (!prevReverse) + return failure(); + + // Get dimensions from both reverse operations + auto outerDims = op.getDimensions(); + auto innerDims = prevReverse.getDimensions(); + + // Compute the symmetric difference of dimensions: + // - Dimensions in both cancel out (reverse twice = identity) + // - Dimensions in only one remain + llvm::SmallDenseSet innerDimSet(innerDims.begin(), innerDims.end()); + llvm::SmallDenseSet outerDimSet(outerDims.begin(), outerDims.end()); + + SmallVector newDimensions; + + // Add dimensions that are only in inner (not in outer) + for (int64_t dim : innerDims) { + if (!outerDimSet.contains(dim)) + newDimensions.push_back(dim); + } + + // Add dimensions that are only in outer (not in inner) + for (int64_t dim : outerDims) { + if (!innerDimSet.contains(dim)) + newDimensions.push_back(dim); + } + + // Sort the dimensions to maintain canonical form + llvm::sort(newDimensions); + + if (newDimensions.empty()) { + // Both reverses cancel out completely + rewriter.replaceOp(op, prevReverse.getOperand()); + } else { + rewriter.replaceOpWithNewOp( + op, prevReverse.getOperand(), newDimensions); + } + return success(); + } +}; + /// Converts gather ops to slice ops in case we have a single set of constant /// indices. struct GatherOpCanon final @@ -26055,7 +26106,7 @@ struct EnzymeHLOOptPass patterns.add< AddSimplify, SubSimplify, AndSimplify, MaxSimplify, MinSimplify, OrSimplify, XorSimplify, MulSimplify, DivSimplify, RemSimplify, - PowSimplify, NoopSlice, NoopReverse, SliceSlice, + PowSimplify, NoopSlice, NoopReverse, ReverseReverse, SliceSlice, DynamicSliceDynamicSlice, DynamicSliceSlice, SliceDynamicSlice, LogSimplify, ShiftRightLogicalSimplify, NegativePadToSlice, SliceSimplify, ConvertSimplify, TransposeSimplify, DotGeneralSimplify, diff --git a/src/enzyme_ad/jax/TransformOps/TransformOps.td b/src/enzyme_ad/jax/TransformOps/TransformOps.td index dc0d6f760..f3207c8b9 100644 --- a/src/enzyme_ad/jax/TransformOps/TransformOps.td +++ b/src/enzyme_ad/jax/TransformOps/TransformOps.td @@ -93,6 +93,10 @@ def ApplyNoopReversePatterns : EnzymeHLOPatternOp< "noop_reverse"> { let patterns = ["NoopReverse"]; } +def ApplyReverseReversePatterns : EnzymeHLOPatternOp< + "reverse_reverse"> { + let patterns = ["ReverseReverse"]; +} def ApplySliceSlicePatterns : EnzymeHLOPatternOp< "slice_slice"> { let patterns = ["SliceSlice"]; diff --git a/test/lit_tests/reversereverse.mlir b/test/lit_tests/reversereverse.mlir new file mode 100644 index 000000000..2e7fa6751 --- /dev/null +++ b/test/lit_tests/reversereverse.mlir @@ -0,0 +1,62 @@ +// RUN: enzymexlamlir-opt --enzyme-hlo-opt %s | FileCheck %s + +// Test: reverse(reverse(x)) with same dimensions -> x +// CHECK-LABEL: @reverse_reverse_same_dims +// CHECK-NOT: stablehlo.reverse +func.func @reverse_reverse_same_dims(%arg0: tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { + %0 = stablehlo.reverse %arg0, dims = [0, 1] : tensor<8x4x3xf32> + %1 = stablehlo.reverse %0, dims = [0, 1] : tensor<8x4x3xf32> + return %1 : tensor<8x4x3xf32> +} + +// Test: reverse(reverse(x)) with single dimension -> x +// CHECK-LABEL: @reverse_reverse_single_dim +// CHECK-NOT: stablehlo.reverse +func.func @reverse_reverse_single_dim(%arg0: tensor<8x4xf32>) -> tensor<8x4xf32> { + %0 = stablehlo.reverse %arg0, dims = [0] : tensor<8x4xf32> + %1 = stablehlo.reverse %0, dims = [0] : tensor<8x4xf32> + return %1 : tensor<8x4xf32> +} + +// Test: reverse(reverse(x)) with partial overlap -> single reverse +// Inner dims [0, 1], outer dims [1, 2] -> effective dims [0, 2] +// CHECK-LABEL: @reverse_reverse_partial_overlap +// CHECK: %[[R:.*]] = stablehlo.reverse %arg0, dims = [0, 2] : tensor<8x4x3xf32> +// CHECK-NEXT: return %[[R]] +func.func @reverse_reverse_partial_overlap(%arg0: tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { + %0 = stablehlo.reverse %arg0, dims = [0, 1] : tensor<8x4x3xf32> + %1 = stablehlo.reverse %0, dims = [1, 2] : tensor<8x4x3xf32> + return %1 : tensor<8x4x3xf32> +} + +// Test: reverse(reverse(x)) with disjoint dimensions -> combined reverse +// CHECK-LABEL: @reverse_reverse_disjoint_dims +// CHECK: %[[R:.*]] = stablehlo.reverse %arg0, dims = [0, 1, 2] : tensor<8x4x3xf32> +// CHECK-NEXT: return %[[R]] +func.func @reverse_reverse_disjoint_dims(%arg0: tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { + %0 = stablehlo.reverse %arg0, dims = [0, 1] : tensor<8x4x3xf32> + %1 = stablehlo.reverse %0, dims = [2] : tensor<8x4x3xf32> + return %1 : tensor<8x4x3xf32> +} + +// Test: reverse(reverse(x)) with subset dimensions +// Inner has more dims, outer has subset -> remaining dims from inner +// CHECK-LABEL: @reverse_reverse_subset +// CHECK: %[[R:.*]] = stablehlo.reverse %arg0, dims = [2] : tensor<8x4x3xf32> +// CHECK-NEXT: return %[[R]] +func.func @reverse_reverse_subset(%arg0: tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { + %0 = stablehlo.reverse %arg0, dims = [0, 1, 2] : tensor<8x4x3xf32> + %1 = stablehlo.reverse %0, dims = [0, 1] : tensor<8x4x3xf32> + return %1 : tensor<8x4x3xf32> +} + +// Test: 3 reverses - should optimize first two, then optimize result with third +// CHECK-LABEL: @reverse_three_times +// CHECK: %[[R:.*]] = stablehlo.reverse %arg0, dims = [0] : tensor<8x4xf32> +// CHECK-NEXT: return %[[R]] +func.func @reverse_three_times(%arg0: tensor<8x4xf32>) -> tensor<8x4xf32> { + %0 = stablehlo.reverse %arg0, dims = [0] : tensor<8x4xf32> + %1 = stablehlo.reverse %0, dims = [0] : tensor<8x4xf32> + %2 = stablehlo.reverse %1, dims = [0] : tensor<8x4xf32> + return %2 : tensor<8x4xf32> +} From ed05ea618eedc97c0d574debb44dc10949dd8621 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 25 Nov 2025 05:15:29 +0000 Subject: [PATCH 3/3] Optimize ReverseReverse pattern to use single set for symmetric difference Co-authored-by: wsmoses <1260124+wsmoses@users.noreply.github.com> --- src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp | 30 ++++++++++------------- 1 file changed, 13 insertions(+), 17 deletions(-) diff --git a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp index 7c5d107c6..71bb41956 100644 --- a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp +++ b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp @@ -13588,33 +13588,29 @@ struct ReverseReverse final auto outerDims = op.getDimensions(); auto innerDims = prevReverse.getDimensions(); - // Compute the symmetric difference of dimensions: + // Compute the symmetric difference of dimensions using a single set: // - Dimensions in both cancel out (reverse twice = identity) // - Dimensions in only one remain - llvm::SmallDenseSet innerDimSet(innerDims.begin(), innerDims.end()); - llvm::SmallDenseSet outerDimSet(outerDims.begin(), outerDims.end()); - - SmallVector newDimensions; - - // Add dimensions that are only in inner (not in outer) + // XOR-like operation: add if not present, remove if already present + llvm::SmallDenseSet dimSet; for (int64_t dim : innerDims) { - if (!outerDimSet.contains(dim)) - newDimensions.push_back(dim); + dimSet.insert(dim); } - - // Add dimensions that are only in outer (not in inner) for (int64_t dim : outerDims) { - if (!innerDimSet.contains(dim)) - newDimensions.push_back(dim); + auto [it, inserted] = dimSet.insert(dim); + if (!inserted) { + // Dimension was in both - cancel out + dimSet.erase(it); + } } - // Sort the dimensions to maintain canonical form - llvm::sort(newDimensions); - - if (newDimensions.empty()) { + if (dimSet.empty()) { // Both reverses cancel out completely rewriter.replaceOp(op, prevReverse.getOperand()); } else { + // Convert set to sorted vector for canonical form + SmallVector newDimensions(dimSet.begin(), dimSet.end()); + llvm::sort(newDimensions); rewriter.replaceOpWithNewOp( op, prevReverse.getOperand(), newDimensions); }