Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 49 additions & 1 deletion src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28764,6 +28764,53 @@ struct FuseReshapeCollapseOrExpandDimsIntoReduce final
}
};

struct GatherOfScatterSimplify final
: CheckedOpRewritePattern<stablehlo::GatherOp, GatherOfScatterSimplify> {
using CheckedOpRewritePattern::CheckedOpRewritePattern;

LogicalResult matchAndRewriteImpl(stablehlo::GatherOp gatherOp,
PatternRewriter &rewriter) {
auto input = gatherOp.getOperand();
auto scatterOp = input.getDefiningOp<stablehlo::ScatterOp>();

if (!scatterOp ||
scatterOp.getScatterIndices() != gatherOp.getStartIndices() ||
computeGatherSliceSizes(scatterOp) != gatherOp.getSliceSizes() ||
getGatherDims(scatterOp->getContext(),
scatterOp.getScatterDimensionNumbersAttr()) !=
gatherOp.getDimensionNumbersAttr()) {
return failure();
}

auto opResult = cast<OpResult>(input);
auto opNum = opResult.getResultNumber();

SplatElementsAttr constSetIndexValue;
if (!detectConstantSetindexScatterOp(
scatterOp, true, [](auto input) { return true; },
constSetIndexValue)
.ok()) {
return failure();
}

if (constSetIndexValue) {
auto constResult = stablehlo::ConstantOp::create(
rewriter, gatherOp.getLoc(),
constSetIndexValue.resizeSplat(cast<ShapedType>(gatherOp.getType())));
rewriter.replaceOp(gatherOp, constResult);
return success();
}

if (!scatterOp.getUniqueIndices()) {
return failure();
}

auto newResult = scatterOp.getUpdates()[opNum];
rewriter.replaceOp(gatherOp, newResult);
return success();
}
};

/////////////// End Imported from stablehlo

// clang-format off
Expand Down Expand Up @@ -29477,7 +29524,8 @@ struct EnzymeHLOOptPass
DeleteDimsReduce,
ReduceDeleteDims,
DotGeneralInsertDimContractionSimplification,
FuseReshapeCollapseOrExpandDimsIntoReduce
FuseReshapeCollapseOrExpandDimsIntoReduce,
GatherOfScatterSimplify
>(context);

patterns.add<ReshapeElementwise>(true, true, context);
Expand Down
5 changes: 5 additions & 0 deletions src/enzyme_ad/jax/TransformOps/TransformOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2721,3 +2721,8 @@ def ApplyWhileElementwiseReductionToReducePatterns : EnzymeHLOPatternOp<
"while_elementwise_reduction_to_reduce"> {
let patterns = ["WhileElementwiseReductionToReduce"];
}

def ApplyGatherOfScatterSimplifyPatterns : EnzymeHLOPatternOp<
"gather_of_scatter_simplify"> {
let patterns = ["GatherOfScatterSimplify"];
}
1 change: 1 addition & 0 deletions src/enzyme_ad/jax/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,7 @@ def optimization_passes(
f"scatter_const_fold({max_constant_threshold})",
"cse_gather",
"cse_scatter",
"gather_of_scatter_simplify",
]

if enable_pad_optimization_passes:
Expand Down
78 changes: 78 additions & 0 deletions test/lit_tests/gather_scatter.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
// RUN: enzymexlamlir-opt %s --enzyme-hlo-opt | FileCheck %s

module {
func.func @main(%arg0: tensor<7x6xf64>) -> (tensor<4x3xf64>, tensor<7x6xf64>) {
%c = stablehlo.constant dense<1> : tensor<3x1xi64>
%cst = stablehlo.constant dense<1.000000e+00> : tensor<3x4xf64>
%cst_1 = stablehlo.constant dense<1.000000e+00> : tensor<f64>
%c_0 = stablehlo.constant dense<[[1], [3], [2]]> : tensor<3x1xi64>
%0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<7x6xf64>) -> tensor<6x7xf64>
%1 = stablehlo.subtract %c_0, %c : tensor<3x1xi64>
%2 = "stablehlo.scatter"(%0, %1, %cst) <{scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [1], inserted_window_dims = [0], scatter_dims_to_operand_dims = [0], index_vector_dim = 1>}> ({
^bb0(%arg1: tensor<f64>, %arg2: tensor<f64>):
stablehlo.return %cst_1 : tensor<f64>
}) : (tensor<6x7xf64>, tensor<3x1xi64>, tensor<3x4xf64>) -> tensor<6x7xf64>
%3 = "stablehlo.gather"(%2, %1) <{dimension_numbers = #stablehlo.gather<offset_dims = [1], collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 1>, indices_are_sorted = false, slice_sizes = array<i64: 1, 4>}> : (tensor<6x7xf64>, tensor<3x1xi64>) -> tensor<3x4xf64>
%4 = stablehlo.transpose %3, dims = [1, 0] : (tensor<3x4xf64>) -> tensor<4x3xf64>
%5 = stablehlo.transpose %2, dims = [1, 0] : (tensor<6x7xf64>) -> tensor<7x6xf64>
return %4, %5 : tensor<4x3xf64>, tensor<7x6xf64>
}
}

// CHECK: func.func @main
// CHECK: %[[SCATTER:.*]] = "stablehlo.scatter"(%0, %c, %cst_0)
// CHECK-NOT: "stablehlo.gather"
// CHECK: %[[TS:.*]] = stablehlo.transpose %[[SCATTER]]
// CHECK: return %cst, %[[TS]] : tensor<4x3xf64>, tensor<7x6xf64>

module {
func.func @main(%arg0: tensor<7x6xf64>, %arg1: tensor<4x3xf64>) -> (tensor<4x3xf64>, tensor<7x6xf64>) {
%c = stablehlo.constant dense<1> : tensor<3x1xi64>
%c_0 = stablehlo.constant dense<[[1], [3], [2]]> : tensor<3x1xi64>
%0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<7x6xf64>) -> tensor<6x7xf64>
%1 = stablehlo.transpose %arg1, dims = [1, 0] : (tensor<4x3xf64>) -> tensor<3x4xf64>
%2 = stablehlo.subtract %c_0, %c : tensor<3x1xi64>
%3 = "stablehlo.scatter"(%0, %2, %1) <{scatter_dimension_numbers = #stablehlo.scatter<update_window_dims = [1], inserted_window_dims = [0], scatter_dims_to_operand_dims = [0], index_vector_dim = 1>, unique_indices = true}> ({
^bb0(%arg2: tensor<f64>, %arg3: tensor<f64>):
stablehlo.return %arg3 : tensor<f64>
}) : (tensor<6x7xf64>, tensor<3x1xi64>, tensor<3x4xf64>) -> tensor<6x7xf64>
%4 = "stablehlo.gather"(%3, %2) <{dimension_numbers = #stablehlo.gather<offset_dims = [1], collapsed_slice_dims = [0], start_index_map = [0], index_vector_dim = 1>, indices_are_sorted = false, slice_sizes = array<i64: 1, 4>}> : (tensor<6x7xf64>, tensor<3x1xi64>) -> tensor<3x4xf64>
%5 = stablehlo.transpose %4, dims = [1, 0] : (tensor<3x4xf64>) -> tensor<4x3xf64>
%6 = stablehlo.transpose %3, dims = [1, 0] : (tensor<6x7xf64>) -> tensor<7x6xf64>
return %5, %6 : tensor<4x3xf64>, tensor<7x6xf64>
}
}

// CHECK: func.func @main
// CHECK: %[[SCATTER:.*]] = "stablehlo.scatter"(%0, %c, %1)
// CHECK-NOT: "stablehlo.gather"
// CHECK: %[[TS:.*]] = stablehlo.transpose %[[SCATTER]]
// CHECK: return %arg1, %[[TS]] : tensor<4x3xf64>, tensor<7x6xf64>

module {
func.func @main(%arg0: tensor<7x6xf64>, %arg1: tensor<4x3xf64>, %arg2: tensor<3xi64>, %arg3: tensor<4xi64>) -> (tensor<4x3xf64>, tensor<7x6xf64>) {
%c = stablehlo.constant dense<1> : tensor<12x2xi64>
%0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<7x6xf64>) -> tensor<6x7xf64>
%1 = stablehlo.broadcast_in_dim %arg2, dims = [1] : (tensor<3xi64>) -> tensor<4x3xi64>
%2 = stablehlo.broadcast_in_dim %arg3, dims = [0] : (tensor<4xi64>) -> tensor<4x3xi64>
%3 = stablehlo.reshape %2 : (tensor<4x3xi64>) -> tensor<12x1xi64>
%4 = stablehlo.reshape %1 : (tensor<4x3xi64>) -> tensor<12x1xi64>
%5 = stablehlo.concatenate %4, %3, dim = 1 : (tensor<12x1xi64>, tensor<12x1xi64>) -> tensor<12x2xi64>
%6 = stablehlo.reshape %arg1 : (tensor<4x3xf64>) -> tensor<12xf64>
%7 = stablehlo.subtract %5, %c : tensor<12x2xi64>
%8 = "stablehlo.scatter"(%0, %7, %6) <{scatter_dimension_numbers = #stablehlo.scatter<inserted_window_dims = [0, 1], scatter_dims_to_operand_dims = [0, 1], index_vector_dim = 1>, unique_indices = true}> ({
^bb0(%arg4: tensor<f64>, %arg5: tensor<f64>):
stablehlo.return %arg5 : tensor<f64>
}) : (tensor<6x7xf64>, tensor<12x2xi64>, tensor<12xf64>) -> tensor<6x7xf64>
%9 = "stablehlo.gather"(%8, %7) <{dimension_numbers = #stablehlo.gather<collapsed_slice_dims = [0, 1], start_index_map = [0, 1], index_vector_dim = 1>, indices_are_sorted = false, slice_sizes = array<i64: 1, 1>}> : (tensor<6x7xf64>, tensor<12x2xi64>) -> tensor<12xf64>
%10 = stablehlo.reshape %9 : (tensor<12xf64>) -> tensor<4x3xf64>
%11 = stablehlo.transpose %8, dims = [1, 0] : (tensor<6x7xf64>) -> tensor<7x6xf64>
return %10, %11 : tensor<4x3xf64>, tensor<7x6xf64>
}
}

// CHECK: func.func @main
// CHECK: %[[SCATTER:.*]] = "stablehlo.scatter"(%0, %7, %6)
// CHECK-NOT: "stablehlo.gather"
// CHECK: %[[TS:.*]] = stablehlo.transpose %[[SCATTER]]
// CHECK: return %arg1, %[[TS]] : tensor<4x3xf64>, tensor<7x6xf64>
Loading