Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 49 additions & 1 deletion shardy/dialect/sdy/ir/canonicalization.cc
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,54 @@ class ReduceScatterFusion : public OpRewritePattern<AllSliceOp> {
}
};

class AlltoAllFusion : public OpRewritePattern<AllToAllOp> {
public:
using OpRewritePattern<AllToAllOp>::OpRewritePattern;

private:
LogicalResult matchAndRewrite(AllToAllOp allToAllOp,
PatternRewriter& rewriter) const override {
if (range_size(allToAllOp->getUsers()) != 1) {
return rewriter.notifyMatchFailure(
allToAllOp, "op has multiple users");
}
auto userAllToAllOp = dyn_cast<AllToAllOp>(*allToAllOp->user_begin());
if (!userAllToAllOp) {
return rewriter.notifyMatchFailure(
allToAllOp, "user is not all-to-all");
}
// Combine the params of the two all-to-all ops into one.
SmallVector<AllToAllParamAttr> combinedParams;
combinedParams.reserve(allToAllOp.getParams().size() +
userAllToAllOp.getParams().size());
combinedParams.append(allToAllOp.getParams().begin(),
allToAllOp.getParams().end());
combinedParams.append(userAllToAllOp.getParams().begin(),
userAllToAllOp.getParams().end());
// Check for overlap in the source and target dimensions.
BitVector seenDims(getTensorRank(allToAllOp.getResult()));
for (AllToAllParamAttr param : combinedParams) {
for (int64_t dim : {param.getSrcDim(), param.getTgtDim()}) {
if (seenDims.test(dim)) {
return rewriter.notifyMatchFailure(
allToAllOp, "overlapping dimensions in the combined parameters");
}
seenDims.set(dim);
}
}
llvm::sort(combinedParams,
[](const AllToAllParamAttr& a, const AllToAllParamAttr& b) {
return a.getSrcDim() < b.getSrcDim();
});
rewriter.replaceOpWithNewOp<AllToAllOp>(
userAllToAllOp, userAllToAllOp.getResult().getType(),
allToAllOp.getTensor(), combinedParams,
userAllToAllOp.getOutSharding());
rewriter.eraseOp(allToAllOp);
return success();
}
};

} // namespace

void ManualComputationOp::getCanonicalizationPatterns(
Expand Down Expand Up @@ -345,7 +393,7 @@ void AllReduceOp::getCanonicalizationPatterns(RewritePatternSet& results,

void AllToAllOp::getCanonicalizationPatterns(RewritePatternSet& results,
MLIRContext* context) {
// We don't have patterns for all-to-all for now.
results.add<AlltoAllFusion>(context);
}

void CollectivePermuteOp::getCanonicalizationPatterns(
Expand Down
31 changes: 31 additions & 0 deletions shardy/dialect/sdy/ir/test/collective_canonicalization.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -222,3 +222,34 @@ func.func @reduce_scatter_fusion_no_subaxis_prefix_match(%arg0 : tensor<64x16xf3
%1 = sdy.all_slice [{"r", "x", "z"}, {"y", "q"}] %0 out_sharding=<@mesh2, [{"r", "x", "z"}, {"y", "q"}]> : tensor<64x16xf32>
return %1 : tensor<64x16xf32>
}

// CHECK-LABEL: func @all_to_all_fusion_success
func.func @all_to_all_fusion_success(%arg0 : tensor<64x16x8x8xf32> {sdy.sharding=#sdy.sharding<@mesh, [{"x"}, {"y"}, {}, {}]>}) -> tensor<64x16x8x8xf32> {
// CHECK-NEXT: %0 = sdy.all_to_all [{"x"}: 0->2, {"y"}: 1->3] %arg0 out_sharding=<@mesh, [{}, {}, {"x"}, {"y"}]> : tensor<64x16x8x8xf32>
// CHECK-NEXT: return %0 : tensor<64x16x8x8xf32>
%0 = sdy.all_to_all [{"x"}: 0->2] %arg0 out_sharding=<@mesh, [{}, {"y"}, {"x"}, {}]> : tensor<64x16x8x8xf32>
%1 = sdy.all_to_all [{"y"}: 1->3] %0 out_sharding=<@mesh, [{}, {}, {"x"}, {"y"}]> : tensor<64x16x8x8xf32>
return %1 : tensor<64x16x8x8xf32>
}

// CHECK-LABEL: func @all_to_all_fusion_overlapping_dims
func.func @all_to_all_fusion_overlapping_dims(%arg0 : tensor<64x16x8x8xf32> {sdy.sharding=#sdy.sharding<@mesh, [{"x"}, {"y"}, {}, {}]>}) -> tensor<64x16x8x8xf32> {
// CHECK-NEXT: %0 = sdy.all_to_all [{"x"}: 0->2] %arg0 out_sharding=<@mesh, [{}, {"y"}, {"x"}, {}]> : tensor<64x16x8x8xf32>
// CHECK-NEXT: %1 = sdy.all_to_all [{"y"}: 1->0] %0 out_sharding=<@mesh, [{"y"}, {}, {"x"}, {}]> : tensor<64x16x8x8xf32>
// CHECK-NEXT: return %1 : tensor<64x16x8x8xf32>
%0 = sdy.all_to_all [{"x"}: 0->2] %arg0 out_sharding=<@mesh, [{}, {"y"}, {"x"}, {}]> : tensor<64x16x8x8xf32>
%1 = sdy.all_to_all [{"y"}: 1->0] %0 out_sharding=<@mesh, [{"y"}, {}, {"x"}, {}]> : tensor<64x16x8x8xf32>
return %1 : tensor<64x16x8x8xf32>
}

// CHECK-LABEL: func @all_to_all_fusion_multiple_uses
func.func @all_to_all_fusion_multiple_uses(%arg0 : tensor<64x16x8x8xf32> {sdy.sharding=#sdy.sharding<@mesh, [{"x"}, {"y"}, {}, {}]>}) -> tensor<64x16x8x8xf32> {
// CHECK-NEXT: %0 = sdy.all_to_all [{"x"}: 0->2] %arg0 out_sharding=<@mesh, [{}, {"y"}, {"x"}, {}]> : tensor<64x16x8x8xf32>
// CHECK-NEXT: %1 = sdy.all_to_all [{"y"}: 1->0] %0 out_sharding=<@mesh, [{"y"}, {}, {"x"}, {}]> : tensor<64x16x8x8xf32>
// CHECK-NEXT: %2 = sdy.all_to_all [{"x"}: 2->0] %0 out_sharding=<@mesh, [{"x"}, {"y"}, {}, {}]> : tensor<64x16x8x8xf32>
// CHECK-NEXT: return %2 : tensor<64x16x8x8xf32>
%0 = sdy.all_to_all [{"x"}: 0->2] %arg0 out_sharding=<@mesh, [{}, {"y"}, {"x"}, {}]> : tensor<64x16x8x8xf32>
%1 = sdy.all_to_all [{"y"}: 1->0] %0 out_sharding=<@mesh, [{"y"}, {}, {"x"}, {}]> : tensor<64x16x8x8xf32>
%2 = sdy.all_to_all [{"x"}: 2->0] %0 out_sharding=<@mesh, [{"x"}, {"y"}, {}, {}]> : tensor<64x16x8x8xf32>
return %2 : tensor<64x16x8x8xf32>
}
Loading