diff --git a/shardy/dialect/sdy/ir/test/manual_computation_verification.mlir b/shardy/dialect/sdy/ir/test/manual_computation_verification.mlir index 2e0c19bb4..63023010d 100644 --- a/shardy/dialect/sdy/ir/test/manual_computation_verification.mlir +++ b/shardy/dialect/sdy/ir/test/manual_computation_verification.mlir @@ -17,9 +17,8 @@ func.func @man_comp_different_meshes(%arg0: tensor<16x32xf32>) -> tensor<16x32xf sdy.mesh @meshA = <["a"=4]> sdy.mesh @meshB = <["b"=4]> -// TODO(b/415376816). This should be an error since we use different meshes in -// the body. func.func @man_comp_different_meshes_in_body(%arg0: tensor<16x32xf32>) -> tensor<16x32xf32> { + // expected-error @+1 {{all shardings must be bound to the same mesh. stablehlo.add is bound to mesh #sdy.mesh<["b"=4]>, while the common mesh is #sdy.mesh<["a"=4]>}} %0 = sdy.manual_computation(%arg0) in_shardings=[<@meshA, [{}, {}]>] out_shardings=[<@meshA, [{}, {}]>] manual_axes={"a"} (%arg1: tensor<16x32xf32>) { %1 = stablehlo.add %arg1, %arg1 {sdy.sharding = #sdy.sharding_per_value<[<@meshB, [{}, {}]>]>} : tensor<16x32xf32> sdy.return %1 : tensor<16x32xf32> diff --git a/shardy/dialect/sdy/ir/verifiers.cc b/shardy/dialect/sdy/ir/verifiers.cc index f5172bd0f..733058bfe 100644 --- a/shardy/dialect/sdy/ir/verifiers.cc +++ b/shardy/dialect/sdy/ir/verifiers.cc @@ -866,8 +866,7 @@ LogicalResult verifyManualComputationValue( SmallVector newDimSizes; auto globalShapedType = mlir::dyn_cast(globalType); if (!globalShapedType) { - // Skipping verification for non-shaped types. This could for example be - // a token type. + // Skipping verification for non-shaped types, e.g., the token type. continue; } for (auto [dimensionSize, dimSharding] : llvm::zip_equal( @@ -970,6 +969,28 @@ LogicalResult ManualComputationOp::verify() { } } + if (mesh) { + for (Operation& op : getBody().getOps()) { + for (auto [idx, sharding] : llvm::enumerate(getShardings(&op))) { + if (!sharding) { + continue; + } + MeshAttr meshForOp = sharding.getMesh(&op); + if (meshForOp.isMaximal()) { + // We allow different meshes if the sharding is maximal, which is + // usually for token types. + continue; + } + if (meshForOp != mesh) { + return emitOpError("all shardings must be bound to the same mesh. ") + << op.getName() << " is bound to mesh " + << sharding.getMesh(&op) << ", while the common mesh is " + << mesh; + } + } + } + } + return success(); }