diff --git a/shardy/dialect/mpmd/ir/dialect.cc b/shardy/dialect/mpmd/ir/dialect.cc index 5fc8c90a9..8577252b5 100644 --- a/shardy/dialect/mpmd/ir/dialect.cc +++ b/shardy/dialect/mpmd/ir/dialect.cc @@ -296,7 +296,9 @@ ShapedType MeshTensorType::cloneWith(std::optional> shape, } bool MeshTensorType::isOnHost() { - return getMemoryKind() && getMemoryKind().getValue() == kMemoryKindPinnedHost; + return getMemoryKind() && + (getMemoryKind().getValue() == kMemoryKindPinnedHost || + getMemoryKind().getValue() == kMemoryKindUnpinnedHost); } //===----------------------------------------------------------------------===// @@ -1031,6 +1033,11 @@ StringAttr FindMemoryKindInAttributes(Value value, FuncOp func) { return nullptr; } +bool IsValidMemoryKind(StringRef memory_kind) { + return memory_kind == kMemoryKindPinnedHost || + memory_kind == kMemoryKindUnpinnedHost || + memory_kind == kMemoryKindDevice; +} } // namespace LogicalResult TransferOp::verify() { @@ -1049,20 +1056,22 @@ LogicalResult TransferOp::verify() { } if (StringAttr in_memory_kind = mesh_type_in.getMemoryKind()) { - if (in_memory_kind.getValue() != kMemoryKindPinnedHost && - in_memory_kind.getValue() != kMemoryKindDevice) { + StringRef in_memory_kind_value = in_memory_kind.getValue(); + if (!IsValidMemoryKind(in_memory_kind_value)) { return emitError("memory kind must be either '") - << kMemoryKindPinnedHost << "' or '" << kMemoryKindDevice - << "'. Found '" << in_memory_kind.getValue() << "'."; + << kMemoryKindPinnedHost << "' or '" << kMemoryKindUnpinnedHost + << "' or '" << kMemoryKindDevice << "'. Found '" + << in_memory_kind.getValue() << "'."; } } if (StringAttr out_memory_kind = mesh_type_out.getMemoryKind()) { - if (out_memory_kind.getValue() != kMemoryKindPinnedHost && - out_memory_kind.getValue() != kMemoryKindDevice) { + StringRef out_memory_kind_value = out_memory_kind.getValue(); + if (!IsValidMemoryKind(out_memory_kind_value)) { return emitError("memory kind must be either '") - << kMemoryKindPinnedHost << "' or '" << kMemoryKindDevice - << "'. Found '" << out_memory_kind.getValue() << "'."; + << kMemoryKindPinnedHost << "' or '" << kMemoryKindUnpinnedHost + << "' or '" << kMemoryKindDevice << "'. Found '" + << out_memory_kind.getValue() << "'."; } } diff --git a/shardy/dialect/mpmd/ir/test/transfer_w_memory_kind_verify.mlir b/shardy/dialect/mpmd/ir/test/transfer_w_memory_kind_verify.mlir index 6a9680d5a..83425314b 100644 --- a/shardy/dialect/mpmd/ir/test/transfer_w_memory_kind_verify.mlir +++ b/shardy/dialect/mpmd/ir/test/transfer_w_memory_kind_verify.mlir @@ -10,7 +10,7 @@ func.func @f(%arg0 : !m_device) { %t1 = mpmd.transfer %arg0 : (!m_device) -> !m_undefined // No error. %t2 = mpmd.transfer %t1 : (!m_undefined) -> !m_host // No error. - // expected-error@+1 {{memory kind must be either 'pinned_host' or 'device'. Found 'qwerty'.}} + // expected-error@+1 {{memory kind must be either 'pinned_host' or 'unpinned_host' or 'device'. Found 'qwerty'.}} %t3 = mpmd.transfer %t2 : (!m_host) -> !m_invalid func.return } diff --git a/shardy/dialect/mpmd/ir/utils.h b/shardy/dialect/mpmd/ir/utils.h index 556d00a66..f8a84e175 100644 --- a/shardy/dialect/mpmd/ir/utils.h +++ b/shardy/dialect/mpmd/ir/utils.h @@ -68,8 +68,10 @@ constexpr StringRef kCpuMeshSuffix = "/cpu"; // Attr on func args and results to indicate whether the value lives on host or // device. If not present, it means it lives on device. inline constexpr StringRef kMemoryKindAttr = "mhlo.memory_kind"; -// Attr value to indicate whether the value is on the host. +// Attr value to indicate whether the value is pinned on the host. inline constexpr StringRef kMemoryKindPinnedHost = "pinned_host"; +// Attr value to indicate whether the value is unpinned on the host. +inline constexpr StringRef kMemoryKindUnpinnedHost = "unpinned_host"; // Attr value to indicate whether the value is on the device. inline constexpr StringRef kMemoryKindDevice = "device"; diff --git a/shardy/dialect/mpmd/transforms/export/mark_aliasing_and_donation.cc b/shardy/dialect/mpmd/transforms/export/mark_aliasing_and_donation.cc index 0735de345..3e097bb01 100644 --- a/shardy/dialect/mpmd/transforms/export/mark_aliasing_and_donation.cc +++ b/shardy/dialect/mpmd/transforms/export/mark_aliasing_and_donation.cc @@ -60,7 +60,7 @@ std::optional FindAliasingOutput( // 2. the output has not been aliased yet, and // 3. the output is not on host memory. // 4. the input and output have the same layout. - if (output.getType() != input_type || IsResultOnHost(output) || + if (output.getType() != input_type || GetMemoryKindIfResultOnHost(output) || !IsInputOutputLayoutMatch(op, input_index, output_index)) { continue; } diff --git a/shardy/dialect/mpmd/transforms/export/mark_fragment_reserved_memory.cc b/shardy/dialect/mpmd/transforms/export/mark_fragment_reserved_memory.cc index 69a49a0dd..54d7d25c3 100644 --- a/shardy/dialect/mpmd/transforms/export/mark_fragment_reserved_memory.cc +++ b/shardy/dialect/mpmd/transforms/export/mark_fragment_reserved_memory.cc @@ -151,7 +151,7 @@ class MarkFragmentReservedMemoryPass } // Add sizes of the results which are now live. for (OpResult result : op.getResults()) { - if (!IsResultOnHost(result) && !result.use_empty()) { + if (!GetMemoryKindIfResultOnHost(result).has_value() && !result.use_empty()) { MeshTensorType type = cast(result.getType()); AddLiveValue(type, current_memory_usage_per_mesh, &op); } diff --git a/shardy/dialect/mpmd/transforms/export/mark_offloaded_input_output.cc b/shardy/dialect/mpmd/transforms/export/mark_offloaded_input_output.cc index de4e17057..ded0cf247 100644 --- a/shardy/dialect/mpmd/transforms/export/mark_offloaded_input_output.cc +++ b/shardy/dialect/mpmd/transforms/export/mark_offloaded_input_output.cc @@ -141,8 +141,8 @@ Value WalkBackwardThroughOffloadCompatibleResult(OpResult res) { // Returns whether a value is a result, and whether the result is stored on the // host memory via an annotate custom call on the source of the result, or if // the result was computed on host. -bool IsResultAndOnHostMemory(Value val) { - auto res = dyn_cast_if_present(val); +bool IsResultAndOnHostMemory(Value val, mlir::StringRef memory_kind) { + auto res = mlir::dyn_cast_if_present(val); if (!res) { return false; } @@ -155,10 +155,20 @@ bool IsResultAndOnHostMemory(Value val) { } } if (Value operand = WalkBackwardThroughOffloadCompatibleResult(res)) { - return IsResultAndOnHostMemory(operand); + return IsResultAndOnHostMemory(operand, memory_kind); } - return GetOffloadValueIfExists(res.getOwner()) == kMemoryKindPinnedHost; + return GetOffloadValueIfExists(res.getOwner()) == memory_kind; +} + +std::optional GetOnHostMemoryKindIfResult(Value val) { + if (IsResultAndOnHostMemory(val, kMemoryKindPinnedHost)) { + return kMemoryKindPinnedHost; + } + if (IsResultAndOnHostMemory(val, kMemoryKindUnpinnedHost)) { + return kMemoryKindUnpinnedHost; + } + return std::nullopt; } // Gets memory kind from user. @@ -299,20 +309,27 @@ class MarkOffloadedInputOutputPass // - Propagating host annotation through ops implicitly computed on host. // // We don't propagate device memory kinds, because that's the default. + void PropagateHostMemoryKindOnFragments(FragmentOp frag, FuncOp parent) { SmallVector arg_attrs = GetArgAttrsOrCreateDefault(frag); for (OpOperand& operand : frag->getOpOperands()) { - if (auto result = dyn_cast(operand.get()); - result && isa(result.getOwner()) && - IsResultOnHost(result)) { - InsertAttr(arg_attrs[operand.getOperandNumber()], kMemoryKindAttr, - StringAttr::get(frag.getContext(), kMemoryKindPinnedHost)); - continue; + if (auto result = mlir::dyn_cast(operand.get()); + result && mlir::isa(result.getOwner())) { + if (std::optional memory_kind = + GetMemoryKindIfResultOnHost(result)) { + mlir::mpmd::InsertAttr( + arg_attrs[operand.getOperandNumber()], kMemoryKindAttr, + mlir::StringAttr::get(frag.getContext(), memory_kind.value())); + continue; + } } - if (auto block_arg = dyn_cast(operand.get()); - block_arg && IsArgOnHost(parent, block_arg.getArgNumber())) { - InsertAttr(arg_attrs[operand.getOperandNumber()], kMemoryKindAttr, - StringAttr::get(frag.getContext(), kMemoryKindPinnedHost)); + if (auto block_arg = mlir::dyn_cast(operand.get())) { + if (std::optional memory_kind = + GetMemoryKindIfArgOnHost(parent, block_arg.getArgNumber())) { + mlir::mpmd::InsertAttr( + arg_attrs[operand.getOperandNumber()], kMemoryKindAttr, + mlir::StringAttr::get(frag.getContext(), memory_kind.value())); + } continue; } } @@ -321,9 +338,11 @@ class MarkOffloadedInputOutputPass SmallVector res_attrs = GetResAttrsOrCreateDefault(frag); for (auto [idx, return_operand] : llvm::enumerate( frag.getRegion().front().getTerminator()->getOperands())) { - if (IsResultAndOnHostMemory(return_operand)) { - InsertAttr(res_attrs[idx], kMemoryKindAttr, - StringAttr::get(frag.getContext(), kMemoryKindPinnedHost)); + if (std::optional memory_kind = + GetOnHostMemoryKindIfResult(return_operand)) { + mlir::mpmd::InsertAttr( + res_attrs[idx], kMemoryKindAttr, + mlir::StringAttr::get(frag.getContext(), memory_kind.value())); } } SetResAttrs(frag, res_attrs); diff --git a/shardy/dialect/mpmd/transforms/export/test/mark_offloaded_input_output.mlir b/shardy/dialect/mpmd/transforms/export/test/mark_offloaded_input_output.mlir index ba10359d5..72a2c274d 100644 --- a/shardy/dialect/mpmd/transforms/export/test/mark_offloaded_input_output.mlir +++ b/shardy/dialect/mpmd/transforms/export/test/mark_offloaded_input_output.mlir @@ -32,6 +32,33 @@ func.func @simple(%func_arg: !m1_16 {mhlo.memory_kind = "pinned_host"}) -> func.return %f : !m1_16 } +// CHECK-LABEL: func @simple_unpinned_host(%arg0: {{.*}} {mhlo.memory_kind = "unpinned_host"}) -> +// CHECK-SAME: {mhlo.memory_kind = "unpinned_host"}) +func.func @simple_unpinned_host(%func_arg: !m1_16 {mhlo.memory_kind = "unpinned_host"}) -> + (!m1_16 {mhlo.memory_kind = "unpinned_host"}) + attributes {topology=#topology} { + + // CHECK-NEXT: fragment + // CHECK-SAME: {arg_attrs = [{mhlo.memory_kind = "unpinned_host"}], res_attrs = [{mhlo.memory_kind = "unpinned_host"}]} + %f = mpmd.fragment (%func_arg) + (%arg0: tensor<16xf32>) { + %7 = stablehlo.custom_call @annotate_device_placement(%arg0) { + backend_config = "", has_side_effect = true, + mhlo.frontend_attributes = {_xla_buffer_placement = "device"} + } : (tensor<16xf32>) -> tensor<16xf32> + + %8 = stablehlo.add %7, %7 : tensor<16xf32> + + %9 = stablehlo.custom_call @annotate_device_placement(%8) { + backend_config = "", has_side_effect = true, + mhlo.frontend_attributes = {_xla_buffer_placement = "unpinned_host"} + } : (tensor<16xf32>) -> tensor<16xf32> + mpmd.return %9 : tensor<16xf32> + } : (!m1_16) -> !m1_16 + + func.return %f : !m1_16 +} + // CHECK-LABEL: func @with_optimization_barrier(%arg0 // CHECK-SAME: %arg1: {{.*}} {mhlo.memory_kind = "pinned_host"}) -> func.func @with_optimization_barrier(%func_arg0: !m1_16, @@ -173,9 +200,9 @@ func.func @place_host_with_incompatible_reshape_and_custom_call(%func_arg0: !m1_ func.return %f#0, %f#1 : !m1_4x4, !m1_16 } -// CHECK-LABEL: func @func_arg_multiple_matching_users +// CHECK-LABEL: func @func_arg_multiple_matching_users_pinned_host // CHECK-SAME: (%arg0: {{.*}} {mhlo.memory_kind = "pinned_host"}) -func.func @func_arg_multiple_matching_users( +func.func @func_arg_multiple_matching_users_pinned_host( %func_arg: !m1_16 {mhlo.memory_kind = "pinned_host"}) -> (!m1_16, !m1_16) attributes {topology=#topology} { @@ -205,6 +232,38 @@ func.func @func_arg_multiple_matching_users( func.return %f1, %f2 : !m1_16, !m1_16 } +// CHECK-LABEL: func @func_arg_multiple_matching_users_unpinned_host +// CHECK-SAME: (%arg0: {{.*}} {mhlo.memory_kind = "unpinned_host"}) +func.func @func_arg_multiple_matching_users_unpinned_host( + %func_arg: !m1_16 {mhlo.memory_kind = "unpinned_host"}) -> (!m1_16, !m1_16) + attributes {topology=#topology} { + + // CHECK: fragment{{.*}} origin=["f"] + // CHECK-SAME: {arg_attrs = [{mhlo.memory_kind = "unpinned_host"}] + %f1 = mpmd.fragment (%func_arg) + (%arg0: tensor<16xf32>) { + %7 = stablehlo.custom_call @annotate_device_placement(%arg0) { + backend_config = "", has_side_effect = true, + mhlo.frontend_attributes = {_xla_buffer_placement = "device"} + } : (tensor<16xf32>) -> tensor<16xf32> + + mpmd.return %7 : tensor<16xf32> + } : (!m1_16) -> !m1_16 + + // CHECK: fragment{{.*}} origin=["g"] + // CHECK-SAME: {arg_attrs = [{mhlo.memory_kind = "unpinned_host"}] + %f2 = mpmd.fragment (%func_arg) + (%arg0: tensor<16xf32>) { + %7 = stablehlo.custom_call @annotate_device_placement(%arg0) { + backend_config = "", has_side_effect = true, + mhlo.frontend_attributes = {_xla_buffer_placement = "device"} + } : (tensor<16xf32>) -> tensor<16xf32> + + mpmd.return %7 : tensor<16xf32> + } : (!m1_16) -> !m1_16 + func.return %f1, %f2 : !m1_16, !m1_16 +} + // CHECK-LABEL: func @func_arg_match // CHECK-SAME: (%arg0: {{.*}} {mhlo.memory_kind = "pinned_host"}) func.func @func_arg_match(%func_arg: !m1_16 {mhlo.memory_kind = "pinned_host"}) -> !m1_16 diff --git a/shardy/dialect/mpmd/transforms/export/utils.h b/shardy/dialect/mpmd/transforms/export/utils.h index fd893a217..6a0dfc36d 100644 --- a/shardy/dialect/mpmd/transforms/export/utils.h +++ b/shardy/dialect/mpmd/transforms/export/utils.h @@ -63,26 +63,53 @@ DenseSet GetDonatedBlockArguments( DenseMap> OperandsForDeletionMapping(func::FuncOp main_func); +inline bool IsMemoryKindOnHost(mlir::StringAttr memory_kind) { + if (!memory_kind) { + return false; + } + mlir::StringRef memory_kind_val = memory_kind.getValue(); + return memory_kind_val == mpmd::kMemoryKindPinnedHost || + memory_kind_val == mpmd::kMemoryKindUnpinnedHost; +} + + // Checks the arg attrs of the op to see if the arg is on the host. -inline bool IsArgOnHost(Operation* op, int index) { - return GetArgAttr(op, index, kMemoryKindAttr) == - StringAttr::get(op->getContext(), kMemoryKindPinnedHost); +inline bool IsArgOnHost(mlir::Operation* op, int index) { + mlir::StringAttr memory_kind_attr = mlir::dyn_cast_or_null( + mlir::mpmd::GetArgAttr(op, index, mlir::mpmd::kMemoryKindAttr)); + return IsMemoryKindOnHost(memory_kind_attr); } -// Checks the arg of the function is on the host. -inline bool IsArgOnHost(func::FuncOp func, int index) { - return func.getArgAttrOfType(index, - kMemoryKindAttr) == - StringAttr::get(func.getContext(), kMemoryKindPinnedHost); +inline std::optional GetMemoryKindIfArgOnHost( + mlir::func::FuncOp func, int index) { + mlir::StringAttr memory_kind_attr = func.getArgAttrOfType( + index, mlir::mpmd::kMemoryKindAttr); + if (!memory_kind_attr) { + return std::nullopt; + } + mlir::StringRef memory_kind_val = memory_kind_attr.getValue(); + if (memory_kind_val == mlir::mpmd::kMemoryKindPinnedHost || + memory_kind_val == mlir::mpmd::kMemoryKindUnpinnedHost) { + return memory_kind_val; + } + return std::nullopt; } -// Checks the result attrs of the op to see if the result is on the host. -inline bool IsResultOnHost(OpResult op_result) { - return GetResAttr(op_result.getOwner(), - op_result.getResultNumber(), - kMemoryKindAttr) == - StringAttr::get(op_result.getContext(), - kMemoryKindPinnedHost); +inline std::optional GetMemoryKindIfResultOnHost( + mlir::OpResult op_result) { + auto memory_kind = dyn_cast_or_null( + mlir::mpmd::GetResAttr(op_result.getOwner(), op_result.getResultNumber(), + mlir::mpmd::kMemoryKindAttr)); + + if (!memory_kind) { + return std::nullopt; + } + mlir::StringRef memory_kind_val = memory_kind.getValue(); + if (memory_kind_val == mlir::mpmd::kMemoryKindPinnedHost || + memory_kind_val == mlir::mpmd::kMemoryKindUnpinnedHost) { + return memory_kind_val; + } + return std::nullopt; } // Checks if the layout of the input and output match.