Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 48 additions & 1 deletion src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13571,6 +13571,53 @@ struct NoopReverse final
}
};

// reverse(reverse(x)) -> x or reverse(x) with reduced dimensions
// When we have two consecutive reverse operations, dimensions that appear
// in both cancel out, and we only need to reverse the symmetric difference.
struct ReverseReverse final
: CheckedOpRewritePattern<stablehlo::ReverseOp, ReverseReverse> {
using CheckedOpRewritePattern::CheckedOpRewritePattern;

LogicalResult matchAndRewriteImpl(stablehlo::ReverseOp op,
PatternRewriter &rewriter) const {
auto prevReverse = op.getOperand().getDefiningOp<stablehlo::ReverseOp>();
if (!prevReverse)
return failure();

// Get dimensions from both reverse operations
auto outerDims = op.getDimensions();
auto innerDims = prevReverse.getDimensions();

// Compute the symmetric difference of dimensions using a single set:
// - Dimensions in both cancel out (reverse twice = identity)
// - Dimensions in only one remain
// XOR-like operation: add if not present, remove if already present
llvm::SmallDenseSet<int64_t> dimSet;
for (int64_t dim : innerDims) {
dimSet.insert(dim);
}
for (int64_t dim : outerDims) {
auto [it, inserted] = dimSet.insert(dim);
if (!inserted) {
// Dimension was in both - cancel out
dimSet.erase(it);
}
}

if (dimSet.empty()) {
// Both reverses cancel out completely
rewriter.replaceOp(op, prevReverse.getOperand());
} else {
// Convert set to sorted vector for canonical form
SmallVector<int64_t> newDimensions(dimSet.begin(), dimSet.end());
llvm::sort(newDimensions);
rewriter.replaceOpWithNewOp<stablehlo::ReverseOp>(
op, prevReverse.getOperand(), newDimensions);
}
return success();
}
};

/// Converts gather ops to slice ops in case we have a single set of constant
/// indices.
struct GatherOpCanon final
Expand Down Expand Up @@ -26055,7 +26102,7 @@ struct EnzymeHLOOptPass
patterns.add<
AddSimplify, SubSimplify, AndSimplify, MaxSimplify, MinSimplify,
OrSimplify, XorSimplify, MulSimplify, DivSimplify, RemSimplify,
PowSimplify, NoopSlice, NoopReverse, SliceSlice,
PowSimplify, NoopSlice, NoopReverse, ReverseReverse, SliceSlice,
DynamicSliceDynamicSlice, DynamicSliceSlice, SliceDynamicSlice,
LogSimplify, ShiftRightLogicalSimplify, NegativePadToSlice,
SliceSimplify, ConvertSimplify, TransposeSimplify, DotGeneralSimplify,
Expand Down
4 changes: 4 additions & 0 deletions src/enzyme_ad/jax/TransformOps/TransformOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,10 @@ def ApplyNoopReversePatterns : EnzymeHLOPatternOp<
"noop_reverse"> {
let patterns = ["NoopReverse"];
}
def ApplyReverseReversePatterns : EnzymeHLOPatternOp<
"reverse_reverse"> {
let patterns = ["ReverseReverse"];
}
def ApplySliceSlicePatterns : EnzymeHLOPatternOp<
"slice_slice"> {
let patterns = ["SliceSlice"];
Expand Down
62 changes: 62 additions & 0 deletions test/lit_tests/reversereverse.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
// RUN: enzymexlamlir-opt --enzyme-hlo-opt %s | FileCheck %s

// Test: reverse(reverse(x)) with same dimensions -> x
// CHECK-LABEL: @reverse_reverse_same_dims
// CHECK-NOT: stablehlo.reverse
func.func @reverse_reverse_same_dims(%arg0: tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
%0 = stablehlo.reverse %arg0, dims = [0, 1] : tensor<8x4x3xf32>
%1 = stablehlo.reverse %0, dims = [0, 1] : tensor<8x4x3xf32>
return %1 : tensor<8x4x3xf32>
}

// Test: reverse(reverse(x)) with single dimension -> x
// CHECK-LABEL: @reverse_reverse_single_dim
// CHECK-NOT: stablehlo.reverse
func.func @reverse_reverse_single_dim(%arg0: tensor<8x4xf32>) -> tensor<8x4xf32> {
%0 = stablehlo.reverse %arg0, dims = [0] : tensor<8x4xf32>
%1 = stablehlo.reverse %0, dims = [0] : tensor<8x4xf32>
return %1 : tensor<8x4xf32>
}

// Test: reverse(reverse(x)) with partial overlap -> single reverse
// Inner dims [0, 1], outer dims [1, 2] -> effective dims [0, 2]
// CHECK-LABEL: @reverse_reverse_partial_overlap
// CHECK: %[[R:.*]] = stablehlo.reverse %arg0, dims = [0, 2] : tensor<8x4x3xf32>
// CHECK-NEXT: return %[[R]]
func.func @reverse_reverse_partial_overlap(%arg0: tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
%0 = stablehlo.reverse %arg0, dims = [0, 1] : tensor<8x4x3xf32>
%1 = stablehlo.reverse %0, dims = [1, 2] : tensor<8x4x3xf32>
return %1 : tensor<8x4x3xf32>
}

// Test: reverse(reverse(x)) with disjoint dimensions -> combined reverse
// CHECK-LABEL: @reverse_reverse_disjoint_dims
// CHECK: %[[R:.*]] = stablehlo.reverse %arg0, dims = [0, 1, 2] : tensor<8x4x3xf32>
// CHECK-NEXT: return %[[R]]
func.func @reverse_reverse_disjoint_dims(%arg0: tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
%0 = stablehlo.reverse %arg0, dims = [0, 1] : tensor<8x4x3xf32>
%1 = stablehlo.reverse %0, dims = [2] : tensor<8x4x3xf32>
return %1 : tensor<8x4x3xf32>
}

// Test: reverse(reverse(x)) with subset dimensions
// Inner has more dims, outer has subset -> remaining dims from inner
// CHECK-LABEL: @reverse_reverse_subset
// CHECK: %[[R:.*]] = stablehlo.reverse %arg0, dims = [2] : tensor<8x4x3xf32>
// CHECK-NEXT: return %[[R]]
func.func @reverse_reverse_subset(%arg0: tensor<8x4x3xf32>) -> tensor<8x4x3xf32> {
%0 = stablehlo.reverse %arg0, dims = [0, 1, 2] : tensor<8x4x3xf32>
%1 = stablehlo.reverse %0, dims = [0, 1] : tensor<8x4x3xf32>
return %1 : tensor<8x4x3xf32>
}

// Test: 3 reverses - should optimize first two, then optimize result with third
// CHECK-LABEL: @reverse_three_times
// CHECK: %[[R:.*]] = stablehlo.reverse %arg0, dims = [0] : tensor<8x4xf32>
// CHECK-NEXT: return %[[R]]
func.func @reverse_three_times(%arg0: tensor<8x4xf32>) -> tensor<8x4xf32> {
%0 = stablehlo.reverse %arg0, dims = [0] : tensor<8x4xf32>
%1 = stablehlo.reverse %0, dims = [0] : tensor<8x4xf32>
%2 = stablehlo.reverse %1, dims = [0] : tensor<8x4xf32>
return %2 : tensor<8x4xf32>
}