diff --git a/shardy/dialect/mpmd/transforms/common/BUILD b/shardy/dialect/mpmd/transforms/common/BUILD index e7c5204a2..2efe64a20 100644 --- a/shardy/dialect/mpmd/transforms/common/BUILD +++ b/shardy/dialect/mpmd/transforms/common/BUILD @@ -31,12 +31,14 @@ cc_library( name = "passes", srcs = [ "absorb_inferred_fragments.cc", + "add_side_effect_to_avoid_cse.cc", "call_rewrites.cc", "copy_constants.cc", "fragment_dce.cc", "fragment_dedup.cc", "merge_fragments.cc", "merge_transfers.cc", + "remove_side_effect_after_cse.cc", "remove_transfer_cycles.cc", "rule_based_merge.cc", "split_bwd_fragments.cc", diff --git a/shardy/dialect/mpmd/transforms/common/add_side_effect_to_avoid_cse.cc b/shardy/dialect/mpmd/transforms/common/add_side_effect_to_avoid_cse.cc new file mode 100644 index 000000000..7d3f3d959 --- /dev/null +++ b/shardy/dialect/mpmd/transforms/common/add_side_effect_to_avoid_cse.cc @@ -0,0 +1,53 @@ +/* Copyright 2025 The MPMD Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "shardy/dialect/mpmd/transforms/common/passes.h" +#include "shardy/dialect/mpmd/transforms/common/utils.h" +#include "stablehlo/dialect/StablehloOps.h" +#include "mlir/Pass/Pass.h" + +namespace mlir::mpmd { + +namespace { + +#define GEN_PASS_DEF_ADDSIDEEFFECTTOAVOIDCSEPASS +#include "shardy/dialect/mpmd/transforms/common/passes.h.inc" + +struct AddSideEffectToAvoidCSEPass + : public impl::AddSideEffectToAvoidCSEPassBase< + AddSideEffectToAvoidCSEPass> { + using impl::AddSideEffectToAvoidCSEPassBase< + AddSideEffectToAvoidCSEPass>::AddSideEffectToAvoidCSEPassBase; + + void runOnOperation() override { + getOperation().walk([](stablehlo::CustomCallOp customCallOp) { + if (customCallOp->hasAttr(kMhloNoCseAttr)) { + customCallOp.setHasSideEffect(true); + } + }); + } +}; + +} // namespace + +std::unique_ptr createAddSideEffectToAvoidCSEPass() { + return std::make_unique(); +} + +} // namespace mlir::mpmd diff --git a/shardy/dialect/mpmd/transforms/common/passes.td b/shardy/dialect/mpmd/transforms/common/passes.td index 7316c85a0..6fd4d9429 100644 --- a/shardy/dialect/mpmd/transforms/common/passes.td +++ b/shardy/dialect/mpmd/transforms/common/passes.td @@ -15,6 +15,30 @@ limitations under the License. include "mlir/Pass/PassBase.td" +def AddSideEffectToAvoidCSEPass : + PassBase<"mpmd-add-side-effect-to-avoid-cse", "OperationPass"> { + let summary = "Adds a side effect attribute to custom_call ops with " + "{mhlo.no_cse} to avoid CSE."; + let description = [{ + For `stablehlo.custom_call` operations that have the `{mhlo.no_cse}` + attribute, this pass adds an `{has_side_effect = true}` attribute. + This prevents MLIR's CSE pass from eliminating these operations, because + CSE skips operations with side effects. + }]; +} + +def RemoveSideEffectAfterCSEPass : + PassBase<"mpmd-remove-side-effect-after-cse", "OperationPass"> { + let summary = "Removes side effect attribute from custom_call ops with " + "{mhlo.no_cse}."; + let description = [{ + For `stablehlo.custom_call` operations that have the `{mhlo.no_cse}` + attribute, this pass removes the `{has_side_effect = true}` attribute if + it exists. This is useful to run after CSE to remove the attribute that + is no longer needed. + }]; +} + // TODO: b/374694825 - This pass is not complete yet. In particular, we also // need to consider: (a) side-ways merging. We need to be careful with this as // it may have performance and jitting time implications. (b) relax the diff --git a/shardy/dialect/mpmd/transforms/common/remove_side_effect_after_cse.cc b/shardy/dialect/mpmd/transforms/common/remove_side_effect_after_cse.cc new file mode 100644 index 000000000..4b32c7887 --- /dev/null +++ b/shardy/dialect/mpmd/transforms/common/remove_side_effect_after_cse.cc @@ -0,0 +1,54 @@ +/* Copyright 2025 The MPMD Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "shardy/dialect/mpmd/transforms/common/passes.h" +#include "shardy/dialect/mpmd/transforms/common/utils.h" +#include "stablehlo/dialect/StablehloOps.h" +#include "mlir/Pass/Pass.h" + +namespace mlir::mpmd { + +namespace { + +#define GEN_PASS_DEF_REMOVESIDEEFFECTAFTERCSEPASS +#include "shardy/dialect/mpmd/transforms/common/passes.h.inc" + +struct RemoveSideEffectAfterCSEPass + : public impl::RemoveSideEffectAfterCSEPassBase< + RemoveSideEffectAfterCSEPass> { + using impl::RemoveSideEffectAfterCSEPassBase< + RemoveSideEffectAfterCSEPass>::RemoveSideEffectAfterCSEPassBase; + + void runOnOperation() override { + getOperation().walk([&](stablehlo::CustomCallOp customCallOp) { + if (customCallOp->hasAttr(kMhloNoCseAttr)) { + customCallOp.setHasSideEffect(std::nullopt); + } + }); + } +}; + +} // namespace + +std::unique_ptr createRemoveSideEffectAfterCSEPass() { + return std::make_unique(); +} + +} // namespace mlir::mpmd diff --git a/shardy/dialect/mpmd/transforms/common/test/add_side_effect_to_avoid_cse.mlir b/shardy/dialect/mpmd/transforms/common/test/add_side_effect_to_avoid_cse.mlir new file mode 100644 index 000000000..ccb7bebe5 --- /dev/null +++ b/shardy/dialect/mpmd/transforms/common/test/add_side_effect_to_avoid_cse.mlir @@ -0,0 +1,30 @@ +// RUN: mpmd_opt %s -mpmd-add-side-effect-to-avoid-cse | FileCheck %s + +// CHECK-LABEL: func @custom_call_with_no_cse_should_add_side_effect +// CHECK-SAME: (%arg0: tensor) -> tensor +func.func @custom_call_with_no_cse_should_add_side_effect(%arg0: tensor) -> tensor { + // CHECK: %[[RES0:.*]] = stablehlo.custom_call @Sharding(%arg0) + // CHECK-SAME: has_side_effect = true + // CHECK-SAME: mhlo.no_cse + // CHECK-SAME: : (tensor) -> tensor + %0 = stablehlo.custom_call @Sharding(%arg0) {mhlo.no_cse} : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: func @custom_call_without_no_cse_should_not_add_side_effect +// CHECK-SAME: (%arg0: tensor) -> tensor +func.func @custom_call_without_no_cse_should_not_add_side_effect(%arg0: tensor) -> tensor { + // CHECK-NOT: has_side_effect + // CHECK: stablehlo.custom_call @Sharding(%arg0) : (tensor) -> tensor + %0 = stablehlo.custom_call @Sharding(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: func @other_op_with_no_cse_should_not_add_side_effect +// CHECK-SAME: (%arg0: tensor) -> tensor +func.func @other_op_with_no_cse_should_not_add_side_effect(%arg0: tensor) -> tensor { + // CHECK-NOT: has_side_effect + // CHECK: stablehlo.add %arg0, %arg0 {mhlo.no_cse} : tensor + %0 = stablehlo.add %arg0, %arg0 {mhlo.no_cse} : tensor + func.return %0 : tensor +} diff --git a/shardy/dialect/mpmd/transforms/common/test/avoid_cse_on_custom_calls_marked_with_no_cse.mlir b/shardy/dialect/mpmd/transforms/common/test/avoid_cse_on_custom_calls_marked_with_no_cse.mlir new file mode 100644 index 000000000..63f280c67 --- /dev/null +++ b/shardy/dialect/mpmd/transforms/common/test/avoid_cse_on_custom_calls_marked_with_no_cse.mlir @@ -0,0 +1,16 @@ +// RUN: mpmd_opt %s -mpmd-add-side-effect-to-avoid-cse -cse -mpmd-remove-side-effect-after-cse | FileCheck %s + +// CHECK-LABEL: func @duplicate_custom_call_with_no_cse_should_be_csed +// CHECK-SAME: (%arg0: tensor) -> (tensor, tensor) +func.func @duplicate_custom_call_with_no_cse_should_be_csed(%arg0: tensor) -> (tensor, tensor) { + // CHECK: %[[RES0:.*]] = stablehlo.custom_call @Sharding(%arg0) + // CHECK-NOT: has_side_effect + // CHECK-SAME: mhlo.no_cse + // CHECK: %[[RES1:.*]] = stablehlo.custom_call @Sharding(%arg0) + // CHECK-NOT: has_side_effect + // CHECK-SAME: mhlo.no_cse + // CHECK: return %[[RES0]], %[[RES1]] + %0 = stablehlo.custom_call @Sharding(%arg0) {mhlo.no_cse} : (tensor) -> tensor + %1 = stablehlo.custom_call @Sharding(%arg0) {mhlo.no_cse} : (tensor) -> tensor + func.return %0, %1 : tensor, tensor +} diff --git a/shardy/dialect/mpmd/transforms/common/test/remove_side_effect_after_cse.mlir b/shardy/dialect/mpmd/transforms/common/test/remove_side_effect_after_cse.mlir new file mode 100644 index 000000000..7fc14072e --- /dev/null +++ b/shardy/dialect/mpmd/transforms/common/test/remove_side_effect_after_cse.mlir @@ -0,0 +1,28 @@ +// RUN: mpmd_opt %s -mpmd-remove-side-effect-after-cse | FileCheck %s + +// CHECK-LABEL: func @custom_call_with_no_cse_should_remove_side_effect +// CHECK-SAME: (%arg0: tensor) -> tensor +func.func @custom_call_with_no_cse_should_remove_side_effect(%arg0: tensor) -> tensor { + // CHECK-NOT: has_side_effect = true + // CHECK: %[[RES0:.*]] = stablehlo.custom_call @Sharding(%arg0) + // CHECK-SAME: mhlo.no_cse + // CHECK-SAME: : (tensor) -> tensor + %0 = stablehlo.custom_call @Sharding(%arg0) {has_side_effect = true,mhlo.no_cse} : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: func @custom_call_without_no_cse_should_do_nothing +// CHECK-SAME: (%arg0: tensor) -> tensor +func.func @custom_call_without_no_cse_should_do_nothing(%arg0: tensor) -> tensor { + // CHECK: stablehlo.custom_call @Sharding(%arg0) {has_side_effect = true} + %0 = stablehlo.custom_call @Sharding(%arg0) {has_side_effect = true}: (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: func @other_op_with_no_cse_should_do_nothing +// CHECK-SAME: (%arg0: tensor) -> tensor +func.func @other_op_with_no_cse_should_do_nothing(%arg0: tensor) -> tensor { + // CHECK: stablehlo.add %arg0, %arg0 {has_side_effect = true, mhlo.no_cse} : tensor + %0 = stablehlo.add %arg0, %arg0 {has_side_effect = true, mhlo.no_cse} : tensor + func.return %0 : tensor +} diff --git a/shardy/dialect/mpmd/transforms/common/utils.h b/shardy/dialect/mpmd/transforms/common/utils.h index eaa85200f..01cfa9a59 100644 --- a/shardy/dialect/mpmd/transforms/common/utils.h +++ b/shardy/dialect/mpmd/transforms/common/utils.h @@ -34,6 +34,9 @@ limitations under the License. namespace mlir::mpmd { +// The attribute to avoid CSE. +inline constexpr StringRef kMhloNoCseAttr = "mhlo.no_cse"; + // The name of the attribute that keeps track of how many times a loop has been // unrolled. constexpr StringRef kUnrollCounterAttrName = "unroll_counter"; diff --git a/shardy/dialect/mpmd/transforms/export/export_pipeline.cc b/shardy/dialect/mpmd/transforms/export/export_pipeline.cc index 3f9bf971a..cb135eae3 100644 --- a/shardy/dialect/mpmd/transforms/export/export_pipeline.cc +++ b/shardy/dialect/mpmd/transforms/export/export_pipeline.cc @@ -79,6 +79,10 @@ void addExportPipeline(OpPassManager& pm, const ExportOptions& options) { // and by DCE'ing the fragment bodies. pm.addNestedPass(createFragmentDcePass()); + // Now all CSE is done, we can remove the side effect from custom calls that + // have the no_cse attribute. + pm.addNestedPass(createRemoveSideEffectAfterCSEPass()); + // Must be applied after the last -mpmd-fragment-dedup, as it may add // duplicated fragment results and after -canonicalize, as it may add // identity fragments, which would be canonicalized away. diff --git a/shardy/dialect/mpmd/transforms/export/test/e2e_pipeline.mlir b/shardy/dialect/mpmd/transforms/export/test/e2e_pipeline.mlir new file mode 100644 index 000000000..7fb7335f0 --- /dev/null +++ b/shardy/dialect/mpmd/transforms/export/test/e2e_pipeline.mlir @@ -0,0 +1,22 @@ +// RUN: mpmd_opt %s -mpmd-import-pipeline='name-to-mesh-assignment=f1@m1,f2@m2' -mpmd-optimize-pipeline -mpmd-sharding-propagation-pipeline -mpmd-export-pipeline 2>&1 | FileCheck %s + +#topology = #mpmd.topology<<"m1": <["x"=2]>>, <"m2": <["x"=2]>>> + +// CHECK-LABEL: func.func @main +func.func @main(%arg0: tensor<4x8xf32>) -> tensor<4x8xf32> attributes { + "topology"=#topology} { + // CHECK: %[[FRAGMENT_CALL:.*]] = mpmd.fragment_call @p0_f1_fwd.main(%arg0) + %1:2 = mpmd.named_computation<"f1"> (%arg0, %arg0) (%arg3: tensor<4x8xf32>, %arg4: tensor<4x8xf32>) { + %2 = stablehlo.custom_call @sdy_testonly(%arg3) {mhlo.no_cse} : (tensor<4x8xf32>) -> tensor<4x8xf32> + %3 = stablehlo.custom_call @sdy_testonly(%arg4) {mhlo.no_cse} : (tensor<4x8xf32>) -> tensor<4x8xf32> + mpmd.return %2, %3 : tensor<4x8xf32>, tensor<4x8xf32> + } : (tensor<4x8xf32>, tensor<4x8xf32>) -> (tensor<4x8xf32>, tensor<4x8xf32>) + func.return %1#0 : tensor<4x8xf32> +} +// CHECK-LABEL: func.func @p0_f1_fwd.main +// CHECK: %[[CUSTOM_CALL:.*]] = stablehlo.custom_call @sdy_testonly +// CHECK-NOT: has_side_effect +// CHECK-SAME: {mhlo.no_cse} +// CHECK-NEXT: %[[CUSTOM_CALL_2:.*]] = stablehlo.custom_call @sdy_testonly +// CHECK-NOT: has_side_effect +// CHECK-SAME: {mhlo.no_cse} diff --git a/shardy/dialect/mpmd/transforms/export/test/export_pipeline.mlir b/shardy/dialect/mpmd/transforms/export/test/export_pipeline.mlir index 2d992f56a..9ffb35139 100644 --- a/shardy/dialect/mpmd/transforms/export/test/export_pipeline.mlir +++ b/shardy/dialect/mpmd/transforms/export/test/export_pipeline.mlir @@ -1,4 +1,4 @@ -// RUN: mpmd_opt %s -mpmd-export-pipeline 2>&1 | FileCheck %s +// RUN: mpmd_opt %s -mpmd-export-pipeline -split-input-file 2>&1 | FileCheck %s !mesh_1_tensor_4_8_f32 = !mpmd.mesh_tensor<"m1", tensor<4x8xf32>> @@ -17,3 +17,28 @@ func.func @main(%arg0: !mesh_1_tensor_4_8_f32 {tf.aliasing_output = 0: i32}, %ar } : (!mesh_1_tensor_4_8_f32, !mesh_1_tensor_4_8_f32) -> (!mesh_1_tensor_4_8_f32) func.return %0 : !mesh_1_tensor_4_8_f32 } + + +// ----- +!mesh_1_tensor_4_8_f32 = !mpmd.mesh_tensor<"m1", tensor<4x8xf32>> + +// CHECK-LABEL: func.func @main +func.func @main(%arg0: !mesh_1_tensor_4_8_f32 {tf.aliasing_output = 0: i32}, %arg1: !mesh_1_tensor_4_8_f32) + -> (!mesh_1_tensor_4_8_f32, !mesh_1_tensor_4_8_f32) attributes { + "topology"=#mpmd.topology< + <"m1": <["x"=2]>>, + <"m2": <["x"=2]>> + >} { +// CHECK: mpmd.fragment_call @[[CALLEE:.*]](%arg0, %arg1) +// CHECK: func.func @[[CALLEE]](%arg0: tensor<4x8xf32> {tf.aliasing_output = 0 : i32}, %arg1: tensor<4x8xf32>) -> (tensor<4x8xf32>, tensor<4x8xf32>) attributes {mesh_shape = #sdy.mesh<["x"=2]>, xla_tpu_user_reserved_hbm_bytes = 0 : i64} +// Note the has_side_effect = true attribute has been dropped from the custom_call. +// CHECK: %[[FISRT_CUSTOM_CALL:.*]] = stablehlo.custom_call @Sharding(%arg0) {mhlo.no_cse} : (tensor<4x8xf32>) -> tensor<4x8xf32> +// CHECK: %[[SECOND_CUSTOM_CALL:.*]] = stablehlo.custom_call @Sharding(%arg1) {mhlo.no_cse} : (tensor<4x8xf32>) -> tensor<4x8xf32> +// CHECK: return %[[FISRT_CUSTOM_CALL]], %[[SECOND_CUSTOM_CALL]] : tensor<4x8xf32>, tensor<4x8xf32> + %0:2 = mpmd.fragment (%arg0, %arg1) (%arg2: tensor<4x8xf32>, %arg3: tensor<4x8xf32>) { + %0 = stablehlo.custom_call @Sharding(%arg2) {has_side_effect = true, mhlo.no_cse} : (tensor<4x8xf32>) -> tensor<4x8xf32> + %1 = stablehlo.custom_call @Sharding(%arg3) {has_side_effect = true, mhlo.no_cse} : (tensor<4x8xf32>) -> tensor<4x8xf32> + mpmd.return %0, %1 : tensor<4x8xf32>, tensor<4x8xf32> + } : (!mesh_1_tensor_4_8_f32, !mesh_1_tensor_4_8_f32) -> (!mesh_1_tensor_4_8_f32, !mesh_1_tensor_4_8_f32) + func.return %0#0, %0#1 : !mesh_1_tensor_4_8_f32, !mesh_1_tensor_4_8_f32 +} diff --git a/shardy/dialect/mpmd/transforms/import/import_pipeline.cc b/shardy/dialect/mpmd/transforms/import/import_pipeline.cc index 2358e79d9..73b05b4c6 100644 --- a/shardy/dialect/mpmd/transforms/import/import_pipeline.cc +++ b/shardy/dialect/mpmd/transforms/import/import_pipeline.cc @@ -29,6 +29,7 @@ limitations under the License. #include "shardy/dialect/mpmd/transforms/import/passes.h" #include "stablehlo/transforms/Passes.h" #include "stablehlo/transforms/optimization/Passes.h" +#include "shardy/dialect/mpmd/transforms/common/passes.h" namespace mlir::mpmd { @@ -39,6 +40,9 @@ void addImportPipeline(OpPassManager& pm, ImportOptions options) { pm.addNestedPass(stablehlo::createChloLegalizeToStablehloPass()); pm.addPass(createInlinerPass()); + + // Add side effect to custom calls with a no_cse attribute to avoid CSE. + pm.addNestedPass(createAddSideEffectToAvoidCSEPass()); pm.addNestedPass(createCSEPass()); // Canonicalization / Target Independent Optimization needed for two things: diff --git a/shardy/dialect/mpmd/transforms/import/test/import_pipeline.mlir b/shardy/dialect/mpmd/transforms/import/test/import_pipeline.mlir index 33520eb12..6aee8e507 100644 --- a/shardy/dialect/mpmd/transforms/import/test/import_pipeline.mlir +++ b/shardy/dialect/mpmd/transforms/import/test/import_pipeline.mlir @@ -108,3 +108,25 @@ func.func private @f(%arg0: tensor<3x5xf32>, %arg1: tensor<3x5xf32>) -> (tensor< return %0, %arg1 : tensor<3x5xf32>, tensor<3x5xf32> } // No error. + +// ----- +// CHECK-LABEL: sdy.mesh @mesh = <["x"=2]> +#topology = #mpmd.topology<<"m1": <["x"=2]>>, <"m2": <["x"=2]>>> + +// Do not CSE on custom calls with no_cse attribute. It should also add side +// effect attribute to the custom call. +// CHECK-LABEL: func @main +func.func @main(%arg0: tensor<4x8xf32>) -> tensor<4x8xf32> attributes { + "topology"=#topology} { +// CHECK-NEXT: %[[FRAG:.*]] = mpmd.fragment (%arg0) (%arg1 +// CHECK-NEXT: %[[CUSTOM_CALL:.*]] = stablehlo.custom_call @Sharding +// CHECK-SAME: {has_side_effect = true, mhlo.no_cse} +// CHECK-NEXT: %[[CUSTOM_CALL_2:.*]] = stablehlo.custom_call @Sharding +// CHECK-SAME: {has_side_effect = true, mhlo.no_cse} + %1:2 = mpmd.named_computation<"f1"> (%arg0, %arg0) (%arg3: tensor<4x8xf32>, %arg4: tensor<4x8xf32>) { + %2 = stablehlo.custom_call @Sharding(%arg3) {mhlo.no_cse} : (tensor<4x8xf32>) -> tensor<4x8xf32> + %3 = stablehlo.custom_call @Sharding(%arg4) {mhlo.no_cse} : (tensor<4x8xf32>) -> tensor<4x8xf32> + mpmd.return %2, %3 : tensor<4x8xf32>, tensor<4x8xf32> + } : (tensor<4x8xf32>, tensor<4x8xf32>) -> (tensor<4x8xf32>, tensor<4x8xf32>) + func.return %1#0 : tensor<4x8xf32> +}