diff --git a/shardy/dialect/sdy/transforms/propagation/op_sharding_rule_registry.cc b/shardy/dialect/sdy/transforms/propagation/op_sharding_rule_registry.cc index 12f8d127d..3441a13a0 100644 --- a/shardy/dialect/sdy/transforms/propagation/op_sharding_rule_registry.cc +++ b/shardy/dialect/sdy/transforms/propagation/op_sharding_rule_registry.cc @@ -449,10 +449,18 @@ OpShardingRuleAttr createOpShardingRule(Operation* op, if (!conservativePropagation) { // Only add a factor for spatial dimensions if we are not in // conservative mode. - for (auto [lhsDim, rhsDim, outDim] : - llvm::zip_equal(dimNums.getInputSpatialDimensions(), - dimNums.getKernelSpatialDimensions(), - dimNums.getOutputSpatialDimensions())) { + std::optional> windowReversal = + conv.getWindowReversal(); + for (auto [i, dims] : llvm::enumerate( + llvm::zip_equal(dimNums.getInputSpatialDimensions(), + dimNums.getKernelSpatialDimensions(), + dimNums.getOutputSpatialDimensions()))) { + if (windowReversal.has_value() && (*windowReversal)[i]) { + // TODO(b/396724444). Add support for the reversed dimensions when + // we support it in the mesh. + continue; + } + const auto& [lhsDim, rhsDim, outDim] = dims; // The input spatial dimension can be sharded along either the // number of windows (corresponds to the output spatial dimension) // or the window size (corresponds to the kernel spatial dimension), @@ -1067,14 +1075,19 @@ OpShardingRuleAttr createOpShardingRule(Operation* op, return builder.build(); }) .Case([](stablehlo::ReverseOp reverse) { - std::function getFactorType = [&](int64_t dim) { - return llvm::is_contained(reverse.getDimensions(), dim) - ? FactorType::kPermutation - : FactorType::kPassThrough; - }; - return OpShardingRuleBuilder(reverse) - .addPointwise(getTensorShape(reverse.getResult()), getFactorType) - .build(); + OpShardingRuleBuilder builder(reverse); + for (const auto& [dim, dimSize] : + llvm::enumerate(getTensorShape(reverse.getResult()))) { + if (llvm::is_contained(reverse.getDimensions(), dim)) { + // TODO(b/396724444). Add support for the reversed dimensions when + // we support it in the mesh. + builder.addFactor(dim, dimSize, FactorType::kPermutation, + /*isBlocked=*/true); + } else { + builder.addFactor(dim, dimSize); + } + } + return builder.build(); }) .Case( [](stablehlo::RngBitGeneratorOp rngBitGenerator) { diff --git a/shardy/dialect/sdy/transforms/propagation/test/op_sharding_rule_registry.mlir b/shardy/dialect/sdy/transforms/propagation/test/op_sharding_rule_registry.mlir index 5c3870f52..845f4edc8 100644 --- a/shardy/dialect/sdy/transforms/propagation/test/op_sharding_rule_registry.mlir +++ b/shardy/dialect/sdy/transforms/propagation/test/op_sharding_rule_registry.mlir @@ -215,14 +215,15 @@ func.func @concat_not_all_operands_are_from_slices_of_the_same_tensor(%arg0: ten // CHECK-LABEL: func @conv_simple func.func @conv_simple(%arg0 : tensor<2x224x224x192xf32>, %arg1 : tensor<3x3x192x64xf32>) -> tensor<2x112x112x64xf32> { - // CHECK: sdy.sharding_rule = #sdy.op_sharding_rule<([i, jk, lm, n], [k, m, n, o])->([i, j, l, o]) {i=2, j=112, k=2, l=112, m=2, n=192, o=64} reduction={k, m, n} permutation={j, l}> + // CHECK: sdy.sharding_rule = #sdy.op_sharding_rule<([i, n, jk, l], [o, k, l, m])->([i, p, j, m]) {i=2, j=112, k=2, l=192, m=64, n=1, o=1, p=1} reduction={k, l} permutation={j}> %0 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {stride = [2, 2], pad = [[0, 1], [0, 1]]} { batch_group_count = 1 : i64, feature_group_count = 1 : i64, lhs_dilations = dense<1> : tensor<2xi64>, - rhs_dilations = dense<1> : tensor<2xi64> + rhs_dilations = dense<1> : tensor<2xi64>, + window_reversal = array } : (tensor<2x224x224x192xf32>, tensor<3x3x192x64xf32>) -> tensor<2x112x112x64xf32> return %0 : tensor<2x112x112x64xf32> } @@ -896,7 +897,7 @@ func.func @reshape_split_dim_with_intermediate_one(%arg0: tensor<32xf32>) -> ten // CHECK-LABEL: func @reverse func.func @reverse(%arg0: tensor<4x32x8x2xf32>) -> tensor<4x32x8x2xf32> { - // CHECK: sdy.sharding_rule = #sdy.op_sharding_rule<([i, j, k, l])->([i, j, k, l]) {i=4, j=32, k=8, l=2} permutation={j, l}> + // CHECK: sdy.sharding_rule = #sdy.op_sharding_rule<([i, j, k, l])->([i, j, k, l]) {i=4, j=32, k=8, l=2} permutation={j, l} blocked_propagation={j, l}> %0 = stablehlo.reverse %arg0, dims = [1, 3] : tensor<4x32x8x2xf32> return %0 : tensor<4x32x8x2xf32> }