diff --git a/shardy/dialect/mpmd/transforms/import/test/import_pipeline_with_heterogeneous_meshes.mlir b/shardy/dialect/mpmd/transforms/import/test/import_pipeline_with_heterogeneous_meshes.mlir index 117deda2..28c104ae 100644 --- a/shardy/dialect/mpmd/transforms/import/test/import_pipeline_with_heterogeneous_meshes.mlir +++ b/shardy/dialect/mpmd/transforms/import/test/import_pipeline_with_heterogeneous_meshes.mlir @@ -32,3 +32,25 @@ module @multiple_input_meshes { return %1 : tensor<16xf32> } } + +// ----- + +module @missing_sdy_mesh { + // CHECK-DAG: sdy.mesh @tpu = <["tpu_x"=2, "tpu_y"=4]> + // CHECK-DAG: sdy.mesh @cpu = <["cpu_z"=8]> + + func.func @main( + %arg0: tensor<16xf32>, %arg1: tensor<16xf32>, %arg2: tensor<16xf32>) + -> (tensor<16xf32>) attributes { + topology = #mpmd.topology<<"tpu" : <["tpu_x"=2, "tpu_y"=4]>>, <"cpu" : <["cpu_z"=8]>>>} { + %0 = mpmd.named_computation<"f1"> (%arg0, %arg2) (%arg3: tensor<16xf32>, %arg4: tensor<16xf32>) { + %2 = stablehlo.add %arg4, %arg3 : tensor<16xf32> + mpmd.return %2 : tensor<16xf32> + } : (tensor<16xf32>, tensor<16xf32>) -> tensor<16xf32> + %1 = mpmd.named_computation<"f2"> (%arg1, %0) (%arg3: tensor<16xf32>, %arg4: tensor<16xf32>) { + %2 = stablehlo.add %arg4, %arg3 : tensor<16xf32> + mpmd.return %2 : tensor<16xf32> + } : (tensor<16xf32>, tensor<16xf32>) -> tensor<16xf32> + return %1 : tensor<16xf32> + } +}