Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -449,10 +449,18 @@ OpShardingRuleAttr createOpShardingRule(Operation* op,
if (!conservativePropagation) {
// Only add a factor for spatial dimensions if we are not in
// conservative mode.
for (auto [lhsDim, rhsDim, outDim] :
llvm::zip_equal(dimNums.getInputSpatialDimensions(),
dimNums.getKernelSpatialDimensions(),
dimNums.getOutputSpatialDimensions())) {
std::optional<ArrayRef<bool>> windowReversal =
conv.getWindowReversal();
for (auto [i, dims] : llvm::enumerate(
llvm::zip_equal(dimNums.getInputSpatialDimensions(),
dimNums.getKernelSpatialDimensions(),
dimNums.getOutputSpatialDimensions()))) {
if (windowReversal.has_value() && (*windowReversal)[i]) {
// TODO(b/396724444). Add support for the reversed dimensions when
// we support it in the mesh.
continue;
}
const auto& [lhsDim, rhsDim, outDim] = dims;
// The input spatial dimension can be sharded along either the
// number of windows (corresponds to the output spatial dimension)
// or the window size (corresponds to the kernel spatial dimension),
Expand Down Expand Up @@ -1067,14 +1075,19 @@ OpShardingRuleAttr createOpShardingRule(Operation* op,
return builder.build();
})
.Case<stablehlo::ReverseOp>([](stablehlo::ReverseOp reverse) {
std::function<FactorType(int64_t)> getFactorType = [&](int64_t dim) {
return llvm::is_contained(reverse.getDimensions(), dim)
? FactorType::kPermutation
: FactorType::kPassThrough;
};
return OpShardingRuleBuilder(reverse)
.addPointwise(getTensorShape(reverse.getResult()), getFactorType)
.build();
OpShardingRuleBuilder builder(reverse);
for (const auto& [dim, dimSize] :
llvm::enumerate(getTensorShape(reverse.getResult()))) {
if (llvm::is_contained(reverse.getDimensions(), dim)) {
// TODO(b/396724444). Add support for the reversed dimensions when
// we support it in the mesh.
builder.addFactor(dim, dimSize, FactorType::kPermutation,
/*isBlocked=*/true);
} else {
builder.addFactor(dim, dimSize);
}
}
return builder.build();
})
.Case<stablehlo::RngBitGeneratorOp>(
[](stablehlo::RngBitGeneratorOp rngBitGenerator) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,14 +215,15 @@ func.func @concat_not_all_operands_are_from_slices_of_the_same_tensor(%arg0: ten

// CHECK-LABEL: func @conv_simple
func.func @conv_simple(%arg0 : tensor<2x224x224x192xf32>, %arg1 : tensor<3x3x192x64xf32>) -> tensor<2x112x112x64xf32> {
// CHECK: sdy.sharding_rule = #sdy.op_sharding_rule<([i, jk, lm, n], [k, m, n, o])->([i, j, l, o]) {i=2, j=112, k=2, l=112, m=2, n=192, o=64} reduction={k, m, n} permutation={j, l}>
// CHECK: sdy.sharding_rule = #sdy.op_sharding_rule<([i, n, jk, l], [o, k, l, m])->([i, p, j, m]) {i=2, j=112, k=2, l=192, m=64, n=1, o=1, p=1} reduction={k, l} permutation={j}>
%0 = stablehlo.convolution(%arg0, %arg1)
dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
window = {stride = [2, 2], pad = [[0, 1], [0, 1]]} {
batch_group_count = 1 : i64,
feature_group_count = 1 : i64,
lhs_dilations = dense<1> : tensor<2xi64>,
rhs_dilations = dense<1> : tensor<2xi64>
rhs_dilations = dense<1> : tensor<2xi64>,
window_reversal = array<i1: true, false>
} : (tensor<2x224x224x192xf32>, tensor<3x3x192x64xf32>) -> tensor<2x112x112x64xf32>
return %0 : tensor<2x112x112x64xf32>
}
Expand Down Expand Up @@ -896,7 +897,7 @@ func.func @reshape_split_dim_with_intermediate_one(%arg0: tensor<32xf32>) -> ten

// CHECK-LABEL: func @reverse
func.func @reverse(%arg0: tensor<4x32x8x2xf32>) -> tensor<4x32x8x2xf32> {
// CHECK: sdy.sharding_rule = #sdy.op_sharding_rule<([i, j, k, l])->([i, j, k, l]) {i=4, j=32, k=8, l=2} permutation={j, l}>
// CHECK: sdy.sharding_rule = #sdy.op_sharding_rule<([i, j, k, l])->([i, j, k, l]) {i=4, j=32, k=8, l=2} permutation={j, l} blocked_propagation={j, l}>
%0 = stablehlo.reverse %arg0, dims = [1, 3] : tensor<4x32x8x2xf32>
return %0 : tensor<4x32x8x2xf32>
}
Expand Down
Loading