diff --git a/shardy/dialect/sdy/ir/canonicalization.cc b/shardy/dialect/sdy/ir/canonicalization.cc index ce950a3c8..01c808b11 100644 --- a/shardy/dialect/sdy/ir/canonicalization.cc +++ b/shardy/dialect/sdy/ir/canonicalization.cc @@ -314,6 +314,54 @@ class ReduceScatterFusion : public OpRewritePattern { } }; +class AlltoAllFusion : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + private: + LogicalResult matchAndRewrite(AllToAllOp allToAllOp, + PatternRewriter& rewriter) const override { + if (range_size(allToAllOp->getUsers()) != 1) { + return rewriter.notifyMatchFailure( + allToAllOp, "op has multiple users"); + } + auto userAllToAllOp = dyn_cast(*allToAllOp->user_begin()); + if (!userAllToAllOp) { + return rewriter.notifyMatchFailure( + allToAllOp, "user is not all-to-all"); + } + // Combine the params of the two all-to-all ops into one. + SmallVector combinedParams; + combinedParams.reserve(allToAllOp.getParams().size() + + userAllToAllOp.getParams().size()); + combinedParams.append(allToAllOp.getParams().begin(), + allToAllOp.getParams().end()); + combinedParams.append(userAllToAllOp.getParams().begin(), + userAllToAllOp.getParams().end()); + // Check for overlap in the source and target dimensions. + BitVector seenDims(getTensorRank(allToAllOp.getResult())); + for (AllToAllParamAttr param : combinedParams) { + for (int64_t dim : {param.getSrcDim(), param.getTgtDim()}) { + if (seenDims.test(dim)) { + return rewriter.notifyMatchFailure( + allToAllOp, "overlapping dimensions in the combined parameters"); + } + seenDims.set(dim); + } + } + llvm::sort(combinedParams, + [](const AllToAllParamAttr& a, const AllToAllParamAttr& b) { + return a.getSrcDim() < b.getSrcDim(); + }); + rewriter.replaceOpWithNewOp( + userAllToAllOp, userAllToAllOp.getResult().getType(), + allToAllOp.getTensor(), combinedParams, + userAllToAllOp.getOutSharding()); + rewriter.eraseOp(allToAllOp); + return success(); + } +}; + } // namespace void ManualComputationOp::getCanonicalizationPatterns( @@ -345,7 +393,7 @@ void AllReduceOp::getCanonicalizationPatterns(RewritePatternSet& results, void AllToAllOp::getCanonicalizationPatterns(RewritePatternSet& results, MLIRContext* context) { - // We don't have patterns for all-to-all for now. + results.add(context); } void CollectivePermuteOp::getCanonicalizationPatterns( diff --git a/shardy/dialect/sdy/ir/test/collective_canonicalization.mlir b/shardy/dialect/sdy/ir/test/collective_canonicalization.mlir index 0e35d05bc..3a2fa22a4 100644 --- a/shardy/dialect/sdy/ir/test/collective_canonicalization.mlir +++ b/shardy/dialect/sdy/ir/test/collective_canonicalization.mlir @@ -222,3 +222,34 @@ func.func @reduce_scatter_fusion_no_subaxis_prefix_match(%arg0 : tensor<64x16xf3 %1 = sdy.all_slice [{"r", "x", "z"}, {"y", "q"}] %0 out_sharding=<@mesh2, [{"r", "x", "z"}, {"y", "q"}]> : tensor<64x16xf32> return %1 : tensor<64x16xf32> } + +// CHECK-LABEL: func @all_to_all_fusion_success +func.func @all_to_all_fusion_success(%arg0 : tensor<64x16x8x8xf32> {sdy.sharding=#sdy.sharding<@mesh, [{"x"}, {"y"}, {}, {}]>}) -> tensor<64x16x8x8xf32> { + // CHECK-NEXT: %0 = sdy.all_to_all [{"x"}: 0->2, {"y"}: 1->3] %arg0 out_sharding=<@mesh, [{}, {}, {"x"}, {"y"}]> : tensor<64x16x8x8xf32> + // CHECK-NEXT: return %0 : tensor<64x16x8x8xf32> + %0 = sdy.all_to_all [{"x"}: 0->2] %arg0 out_sharding=<@mesh, [{}, {"y"}, {"x"}, {}]> : tensor<64x16x8x8xf32> + %1 = sdy.all_to_all [{"y"}: 1->3] %0 out_sharding=<@mesh, [{}, {}, {"x"}, {"y"}]> : tensor<64x16x8x8xf32> + return %1 : tensor<64x16x8x8xf32> +} + +// CHECK-LABEL: func @all_to_all_fusion_overlapping_dims +func.func @all_to_all_fusion_overlapping_dims(%arg0 : tensor<64x16x8x8xf32> {sdy.sharding=#sdy.sharding<@mesh, [{"x"}, {"y"}, {}, {}]>}) -> tensor<64x16x8x8xf32> { + // CHECK-NEXT: %0 = sdy.all_to_all [{"x"}: 0->2] %arg0 out_sharding=<@mesh, [{}, {"y"}, {"x"}, {}]> : tensor<64x16x8x8xf32> + // CHECK-NEXT: %1 = sdy.all_to_all [{"y"}: 1->0] %0 out_sharding=<@mesh, [{"y"}, {}, {"x"}, {}]> : tensor<64x16x8x8xf32> + // CHECK-NEXT: return %1 : tensor<64x16x8x8xf32> + %0 = sdy.all_to_all [{"x"}: 0->2] %arg0 out_sharding=<@mesh, [{}, {"y"}, {"x"}, {}]> : tensor<64x16x8x8xf32> + %1 = sdy.all_to_all [{"y"}: 1->0] %0 out_sharding=<@mesh, [{"y"}, {}, {"x"}, {}]> : tensor<64x16x8x8xf32> + return %1 : tensor<64x16x8x8xf32> +} + +// CHECK-LABEL: func @all_to_all_fusion_multiple_uses +func.func @all_to_all_fusion_multiple_uses(%arg0 : tensor<64x16x8x8xf32> {sdy.sharding=#sdy.sharding<@mesh, [{"x"}, {"y"}, {}, {}]>}) -> tensor<64x16x8x8xf32> { + // CHECK-NEXT: %0 = sdy.all_to_all [{"x"}: 0->2] %arg0 out_sharding=<@mesh, [{}, {"y"}, {"x"}, {}]> : tensor<64x16x8x8xf32> + // CHECK-NEXT: %1 = sdy.all_to_all [{"y"}: 1->0] %0 out_sharding=<@mesh, [{"y"}, {}, {"x"}, {}]> : tensor<64x16x8x8xf32> + // CHECK-NEXT: %2 = sdy.all_to_all [{"x"}: 2->0] %0 out_sharding=<@mesh, [{"x"}, {"y"}, {}, {}]> : tensor<64x16x8x8xf32> + // CHECK-NEXT: return %2 : tensor<64x16x8x8xf32> + %0 = sdy.all_to_all [{"x"}: 0->2] %arg0 out_sharding=<@mesh, [{}, {"y"}, {"x"}, {}]> : tensor<64x16x8x8xf32> + %1 = sdy.all_to_all [{"y"}: 1->0] %0 out_sharding=<@mesh, [{"y"}, {}, {"x"}, {}]> : tensor<64x16x8x8xf32> + %2 = sdy.all_to_all [{"x"}: 2->0] %0 out_sharding=<@mesh, [{"x"}, {"y"}, {}, {}]> : tensor<64x16x8x8xf32> + return %2 : tensor<64x16x8x8xf32> +}