diff --git a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp index b9179e9936..309b6a6d26 100644 --- a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp +++ b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp @@ -28764,6 +28764,53 @@ struct FuseReshapeCollapseOrExpandDimsIntoReduce final } }; +struct GatherOfScatterSimplify final + : CheckedOpRewritePattern { + using CheckedOpRewritePattern::CheckedOpRewritePattern; + + LogicalResult matchAndRewriteImpl(stablehlo::GatherOp gatherOp, + PatternRewriter &rewriter) { + auto input = gatherOp.getOperand(); + auto scatterOp = input.getDefiningOp(); + + if (!scatterOp || + scatterOp.getScatterIndices() != gatherOp.getStartIndices() || + computeGatherSliceSizes(scatterOp) != gatherOp.getSliceSizes() || + getGatherDims(scatterOp->getContext(), + scatterOp.getScatterDimensionNumbersAttr()) != + gatherOp.getDimensionNumbersAttr()) { + return failure(); + } + + auto opResult = cast(input); + auto opNum = opResult.getResultNumber(); + + SplatElementsAttr constSetIndexValue; + if (!detectConstantSetindexScatterOp( + scatterOp, true, [](auto input) { return true; }, + constSetIndexValue) + .ok()) { + return failure(); + } + + if (constSetIndexValue) { + auto constResult = stablehlo::ConstantOp::create( + rewriter, gatherOp.getLoc(), + constSetIndexValue.resizeSplat(cast(gatherOp.getType()))); + rewriter.replaceOp(gatherOp, constResult); + return success(); + } + + if (!scatterOp.getUniqueIndices()) { + return failure(); + } + + auto newResult = scatterOp.getUpdates()[opNum]; + rewriter.replaceOp(gatherOp, newResult); + return success(); + } +}; + /////////////// End Imported from stablehlo // clang-format off @@ -29477,7 +29524,8 @@ struct EnzymeHLOOptPass DeleteDimsReduce, ReduceDeleteDims, DotGeneralInsertDimContractionSimplification, - FuseReshapeCollapseOrExpandDimsIntoReduce + FuseReshapeCollapseOrExpandDimsIntoReduce, + GatherOfScatterSimplify >(context); patterns.add(true, true, context); diff --git a/src/enzyme_ad/jax/TransformOps/TransformOps.td b/src/enzyme_ad/jax/TransformOps/TransformOps.td index c192bfa2dd..b8cc08adb5 100644 --- a/src/enzyme_ad/jax/TransformOps/TransformOps.td +++ b/src/enzyme_ad/jax/TransformOps/TransformOps.td @@ -2721,3 +2721,8 @@ def ApplyWhileElementwiseReductionToReducePatterns : EnzymeHLOPatternOp< "while_elementwise_reduction_to_reduce"> { let patterns = ["WhileElementwiseReductionToReduce"]; } + +def ApplyGatherOfScatterSimplifyPatterns : EnzymeHLOPatternOp< + "gather_of_scatter_simplify"> { + let patterns = ["GatherOfScatterSimplify"]; +} diff --git a/src/enzyme_ad/jax/primitives.py b/src/enzyme_ad/jax/primitives.py index 7e5b5b1ccd..de8dcc22b9 100644 --- a/src/enzyme_ad/jax/primitives.py +++ b/src/enzyme_ad/jax/primitives.py @@ -472,6 +472,7 @@ def optimization_passes( f"scatter_const_fold({max_constant_threshold})", "cse_gather", "cse_scatter", + "gather_of_scatter_simplify", ] if enable_pad_optimization_passes: diff --git a/test/lit_tests/gather_scatter.mlir b/test/lit_tests/gather_scatter.mlir new file mode 100644 index 0000000000..a95682cd78 --- /dev/null +++ b/test/lit_tests/gather_scatter.mlir @@ -0,0 +1,78 @@ +// RUN: enzymexlamlir-opt %s --enzyme-hlo-opt | FileCheck %s + +module { + func.func @main(%arg0: tensor<7x6xf64>) -> (tensor<4x3xf64>, tensor<7x6xf64>) { + %c = stablehlo.constant dense<1> : tensor<3x1xi64> + %cst = stablehlo.constant dense<1.000000e+00> : tensor<3x4xf64> + %cst_1 = stablehlo.constant dense<1.000000e+00> : tensor + %c_0 = stablehlo.constant dense<[[1], [3], [2]]> : tensor<3x1xi64> + %0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<7x6xf64>) -> tensor<6x7xf64> + %1 = stablehlo.subtract %c_0, %c : tensor<3x1xi64> + %2 = "stablehlo.scatter"(%0, %1, %cst) <{scatter_dimension_numbers = #stablehlo.scatter}> ({ + ^bb0(%arg1: tensor, %arg2: tensor): + stablehlo.return %cst_1 : tensor + }) : (tensor<6x7xf64>, tensor<3x1xi64>, tensor<3x4xf64>) -> tensor<6x7xf64> + %3 = "stablehlo.gather"(%2, %1) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<6x7xf64>, tensor<3x1xi64>) -> tensor<3x4xf64> + %4 = stablehlo.transpose %3, dims = [1, 0] : (tensor<3x4xf64>) -> tensor<4x3xf64> + %5 = stablehlo.transpose %2, dims = [1, 0] : (tensor<6x7xf64>) -> tensor<7x6xf64> + return %4, %5 : tensor<4x3xf64>, tensor<7x6xf64> + } +} + +// CHECK: func.func @main +// CHECK: %[[SCATTER:.*]] = "stablehlo.scatter"(%0, %c, %cst_0) +// CHECK-NOT: "stablehlo.gather" +// CHECK: %[[TS:.*]] = stablehlo.transpose %[[SCATTER]] +// CHECK: return %cst, %[[TS]] : tensor<4x3xf64>, tensor<7x6xf64> + +module { + func.func @main(%arg0: tensor<7x6xf64>, %arg1: tensor<4x3xf64>) -> (tensor<4x3xf64>, tensor<7x6xf64>) { + %c = stablehlo.constant dense<1> : tensor<3x1xi64> + %c_0 = stablehlo.constant dense<[[1], [3], [2]]> : tensor<3x1xi64> + %0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<7x6xf64>) -> tensor<6x7xf64> + %1 = stablehlo.transpose %arg1, dims = [1, 0] : (tensor<4x3xf64>) -> tensor<3x4xf64> + %2 = stablehlo.subtract %c_0, %c : tensor<3x1xi64> + %3 = "stablehlo.scatter"(%0, %2, %1) <{scatter_dimension_numbers = #stablehlo.scatter, unique_indices = true}> ({ + ^bb0(%arg2: tensor, %arg3: tensor): + stablehlo.return %arg3 : tensor + }) : (tensor<6x7xf64>, tensor<3x1xi64>, tensor<3x4xf64>) -> tensor<6x7xf64> + %4 = "stablehlo.gather"(%3, %2) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<6x7xf64>, tensor<3x1xi64>) -> tensor<3x4xf64> + %5 = stablehlo.transpose %4, dims = [1, 0] : (tensor<3x4xf64>) -> tensor<4x3xf64> + %6 = stablehlo.transpose %3, dims = [1, 0] : (tensor<6x7xf64>) -> tensor<7x6xf64> + return %5, %6 : tensor<4x3xf64>, tensor<7x6xf64> + } +} + +// CHECK: func.func @main +// CHECK: %[[SCATTER:.*]] = "stablehlo.scatter"(%0, %c, %1) +// CHECK-NOT: "stablehlo.gather" +// CHECK: %[[TS:.*]] = stablehlo.transpose %[[SCATTER]] +// CHECK: return %arg1, %[[TS]] : tensor<4x3xf64>, tensor<7x6xf64> + +module { + func.func @main(%arg0: tensor<7x6xf64>, %arg1: tensor<4x3xf64>, %arg2: tensor<3xi64>, %arg3: tensor<4xi64>) -> (tensor<4x3xf64>, tensor<7x6xf64>) { + %c = stablehlo.constant dense<1> : tensor<12x2xi64> + %0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<7x6xf64>) -> tensor<6x7xf64> + %1 = stablehlo.broadcast_in_dim %arg2, dims = [1] : (tensor<3xi64>) -> tensor<4x3xi64> + %2 = stablehlo.broadcast_in_dim %arg3, dims = [0] : (tensor<4xi64>) -> tensor<4x3xi64> + %3 = stablehlo.reshape %2 : (tensor<4x3xi64>) -> tensor<12x1xi64> + %4 = stablehlo.reshape %1 : (tensor<4x3xi64>) -> tensor<12x1xi64> + %5 = stablehlo.concatenate %4, %3, dim = 1 : (tensor<12x1xi64>, tensor<12x1xi64>) -> tensor<12x2xi64> + %6 = stablehlo.reshape %arg1 : (tensor<4x3xf64>) -> tensor<12xf64> + %7 = stablehlo.subtract %5, %c : tensor<12x2xi64> + %8 = "stablehlo.scatter"(%0, %7, %6) <{scatter_dimension_numbers = #stablehlo.scatter, unique_indices = true}> ({ + ^bb0(%arg4: tensor, %arg5: tensor): + stablehlo.return %arg5 : tensor + }) : (tensor<6x7xf64>, tensor<12x2xi64>, tensor<12xf64>) -> tensor<6x7xf64> + %9 = "stablehlo.gather"(%8, %7) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = false, slice_sizes = array}> : (tensor<6x7xf64>, tensor<12x2xi64>) -> tensor<12xf64> + %10 = stablehlo.reshape %9 : (tensor<12xf64>) -> tensor<4x3xf64> + %11 = stablehlo.transpose %8, dims = [1, 0] : (tensor<6x7xf64>) -> tensor<7x6xf64> + return %10, %11 : tensor<4x3xf64>, tensor<7x6xf64> + } +} + +// CHECK: func.func @main +// CHECK: %[[SCATTER:.*]] = "stablehlo.scatter"(%0, %7, %6) +// CHECK-NOT: "stablehlo.gather" +// CHECK: %[[TS:.*]] = stablehlo.transpose %[[SCATTER]] +// CHECK: return %arg1, %[[TS]] : tensor<4x3xf64>, tensor<7x6xf64>