diff --git a/shardy/dialect/sdy/transforms/propagation/op_sharding_rule_registry.cc b/shardy/dialect/sdy/transforms/propagation/op_sharding_rule_registry.cc index 12f8d127..aecaf100 100644 --- a/shardy/dialect/sdy/transforms/propagation/op_sharding_rule_registry.cc +++ b/shardy/dialect/sdy/transforms/propagation/op_sharding_rule_registry.cc @@ -325,6 +325,20 @@ OpShardingRuleAttr createOpShardingRule(Operation* op, } return builder.build(); }) + .Case( + [](stablehlo::BatchNormInferenceOp batchNormInference) { + OpShardingRuleBuilder builder(batchNormInference); + uint64_t featureIndex = batchNormInference.getFeatureIndex(); + for (auto [dim, dimSize] : llvm::enumerate( + batchNormInference.getOperand().getType().getShape())) { + SmallVector opDims( + batchNormInference.getNumOperands(), + /*value=*/dim == featureIndex ? 0 : kNullDim); + opDims[0] = dim; + builder.addFactor(opDims, dim, dimSize); + } + return builder.build(); + }) .Case( [](stablehlo::BroadcastInDimOp broadcast) { OpShardingRuleBuilder builder(broadcast); diff --git a/shardy/dialect/sdy/transforms/propagation/test/op_sharding_rule_registry.mlir b/shardy/dialect/sdy/transforms/propagation/test/op_sharding_rule_registry.mlir index 5c3870f5..497074ae 100644 --- a/shardy/dialect/sdy/transforms/propagation/test/op_sharding_rule_registry.mlir +++ b/shardy/dialect/sdy/transforms/propagation/test/op_sharding_rule_registry.mlir @@ -65,6 +65,13 @@ func.func @all_to_all_same_dimension(%arg0: tensor<2x4xi64>, %arg1: tensor<2x4xi return %0#0, %0#1 : tensor<2x4xi64>, tensor<2x4xi64> } +// CHECK-LABEL: func @batch_norm_inference +func.func @batch_norm_inference(%arg0: tensor<4x8x16x32xf32>, %arg1: tensor<16xf32>, %arg2: tensor<16xf32>, %arg3: tensor<16xf32>, %arg4: tensor<16xf32>) -> tensor<4x8x16x32xf32> { + // CHECK: sdy.sharding_rule = #sdy.op_sharding_rule<([i, j, k, l], [k], [k], [k], [k])->([i, j, k, l]) {i=4, j=8, k=16, l=32}> + %0 = "stablehlo.batch_norm_inference"(%arg0, %arg1, %arg2, %arg3, %arg4) {epsilon = 0.001 : f32, feature_index = 2 : i64} : (tensor<4x8x16x32xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>) -> tensor<4x8x16x32xf32> + func.return %0 : tensor<4x8x16x32xf32> +} + // CHECK-LABEL: func @bitcast_convert_upcast func.func @bitcast_convert_upcast(%arg0: tensor<4x2x2xui32>) -> tensor<4x2xui64> { // CHECK: sdy.sharding_rule = #sdy.op_sharding_rule<([i, j, k])->([i, j]) {i=4, j=2, k=2} need_replication={k}>