Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 18 additions & 9 deletions shardy/dialect/mpmd/ir/dialect.cc
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,9 @@ ShapedType MeshTensorType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
}

bool MeshTensorType::isOnHost() {
return getMemoryKind() && getMemoryKind().getValue() == kMemoryKindPinnedHost;
return getMemoryKind() &&
(getMemoryKind().getValue() == kMemoryKindPinnedHost ||
getMemoryKind().getValue() == kMemoryKindUnpinnedHost);
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1031,6 +1033,11 @@ StringAttr FindMemoryKindInAttributes(Value value, FuncOp func) {
return nullptr;
}

bool IsValidMemoryKind(StringRef memory_kind) {
return memory_kind == kMemoryKindPinnedHost ||
memory_kind == kMemoryKindUnpinnedHost ||
memory_kind == kMemoryKindDevice;
}
} // namespace

LogicalResult TransferOp::verify() {
Expand All @@ -1049,20 +1056,22 @@ LogicalResult TransferOp::verify() {
}

if (StringAttr in_memory_kind = mesh_type_in.getMemoryKind()) {
if (in_memory_kind.getValue() != kMemoryKindPinnedHost &&
in_memory_kind.getValue() != kMemoryKindDevice) {
StringRef in_memory_kind_value = in_memory_kind.getValue();
if (!IsValidMemoryKind(in_memory_kind_value)) {
return emitError("memory kind must be either '")
<< kMemoryKindPinnedHost << "' or '" << kMemoryKindDevice
<< "'. Found '" << in_memory_kind.getValue() << "'.";
<< kMemoryKindPinnedHost << "' or '" << kMemoryKindUnpinnedHost
<< "' or '" << kMemoryKindDevice << "'. Found '"
<< in_memory_kind.getValue() << "'.";
}
}

if (StringAttr out_memory_kind = mesh_type_out.getMemoryKind()) {
if (out_memory_kind.getValue() != kMemoryKindPinnedHost &&
out_memory_kind.getValue() != kMemoryKindDevice) {
StringRef out_memory_kind_value = out_memory_kind.getValue();
if (!IsValidMemoryKind(out_memory_kind_value)) {
return emitError("memory kind must be either '")
<< kMemoryKindPinnedHost << "' or '" << kMemoryKindDevice
<< "'. Found '" << out_memory_kind.getValue() << "'.";
<< kMemoryKindPinnedHost << "' or '" << kMemoryKindUnpinnedHost
<< "' or '" << kMemoryKindDevice << "'. Found '"
<< out_memory_kind.getValue() << "'.";
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ func.func @f(%arg0 : !m_device)
{
%t1 = mpmd.transfer %arg0 : (!m_device) -> !m_undefined // No error.
%t2 = mpmd.transfer %t1 : (!m_undefined) -> !m_host // No error.
// expected-error@+1 {{memory kind must be either 'pinned_host' or 'device'. Found 'qwerty'.}}
// expected-error@+1 {{memory kind must be either 'pinned_host' or 'unpinned_host' or 'device'. Found 'qwerty'.}}
%t3 = mpmd.transfer %t2 : (!m_host) -> !m_invalid
func.return
}
Expand Down
4 changes: 3 additions & 1 deletion shardy/dialect/mpmd/ir/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,10 @@ constexpr StringRef kCpuMeshSuffix = "/cpu";
// Attr on func args and results to indicate whether the value lives on host or
// device. If not present, it means it lives on device.
inline constexpr StringRef kMemoryKindAttr = "mhlo.memory_kind";
// Attr value to indicate whether the value is on the host.
// Attr value to indicate whether the value is pinned on the host.
inline constexpr StringRef kMemoryKindPinnedHost = "pinned_host";
// Attr value to indicate whether the value is unpinned on the host.
inline constexpr StringRef kMemoryKindUnpinnedHost = "unpinned_host";
// Attr value to indicate whether the value is on the device.
inline constexpr StringRef kMemoryKindDevice = "device";

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ std::optional<unsigned int> FindAliasingOutput(
// 2. the output has not been aliased yet, and
// 3. the output is not on host memory.
// 4. the input and output have the same layout.
if (output.getType() != input_type || IsResultOnHost(output) ||
if (output.getType() != input_type || GetMemoryKindIfResultOnHost(output) ||
!IsInputOutputLayoutMatch(op, input_index, output_index)) {
continue;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ class MarkFragmentReservedMemoryPass
}
// Add sizes of the results which are now live.
for (OpResult result : op.getResults()) {
if (!IsResultOnHost(result) && !result.use_empty()) {
if (!GetMemoryKindIfResultOnHost(result).has_value() && !result.use_empty()) {
MeshTensorType type = cast<MeshTensorType>(result.getType());
AddLiveValue(type, current_memory_usage_per_mesh, &op);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,8 @@ Value WalkBackwardThroughOffloadCompatibleResult(OpResult res) {
// Returns whether a value is a result, and whether the result is stored on the
// host memory via an annotate custom call on the source of the result, or if
// the result was computed on host.
bool IsResultAndOnHostMemory(Value val) {
auto res = dyn_cast_if_present<OpResult>(val);
bool IsResultAndOnHostMemory(Value val, mlir::StringRef memory_kind) {
auto res = mlir::dyn_cast_if_present<OpResult>(val);
if (!res) {
return false;
}
Expand All @@ -155,10 +155,20 @@ bool IsResultAndOnHostMemory(Value val) {
}
}
if (Value operand = WalkBackwardThroughOffloadCompatibleResult(res)) {
return IsResultAndOnHostMemory(operand);
return IsResultAndOnHostMemory(operand, memory_kind);
}

return GetOffloadValueIfExists(res.getOwner()) == kMemoryKindPinnedHost;
return GetOffloadValueIfExists(res.getOwner()) == memory_kind;
}

std::optional<StringRef> GetOnHostMemoryKindIfResult(Value val) {
if (IsResultAndOnHostMemory(val, kMemoryKindPinnedHost)) {
return kMemoryKindPinnedHost;
}
if (IsResultAndOnHostMemory(val, kMemoryKindUnpinnedHost)) {
return kMemoryKindUnpinnedHost;
}
return std::nullopt;
}

// Gets memory kind from user.
Expand Down Expand Up @@ -299,20 +309,27 @@ class MarkOffloadedInputOutputPass
// - Propagating host annotation through ops implicitly computed on host.
//
// We don't propagate device memory kinds, because that's the default.

void PropagateHostMemoryKindOnFragments(FragmentOp frag, FuncOp parent) {
SmallVector<Attribute> arg_attrs = GetArgAttrsOrCreateDefault(frag);
for (OpOperand& operand : frag->getOpOperands()) {
if (auto result = dyn_cast<OpResult>(operand.get());
result && isa<FragmentOp>(result.getOwner()) &&
IsResultOnHost(result)) {
InsertAttr(arg_attrs[operand.getOperandNumber()], kMemoryKindAttr,
StringAttr::get(frag.getContext(), kMemoryKindPinnedHost));
continue;
if (auto result = mlir::dyn_cast<OpResult>(operand.get());
result && mlir::isa<FragmentOp>(result.getOwner())) {
if (std::optional<mlir::StringRef> memory_kind =
GetMemoryKindIfResultOnHost(result)) {
mlir::mpmd::InsertAttr(
arg_attrs[operand.getOperandNumber()], kMemoryKindAttr,
mlir::StringAttr::get(frag.getContext(), memory_kind.value()));
continue;
}
}
if (auto block_arg = dyn_cast<BlockArgument>(operand.get());
block_arg && IsArgOnHost(parent, block_arg.getArgNumber())) {
InsertAttr(arg_attrs[operand.getOperandNumber()], kMemoryKindAttr,
StringAttr::get(frag.getContext(), kMemoryKindPinnedHost));
if (auto block_arg = mlir::dyn_cast<BlockArgument>(operand.get())) {
if (std::optional<mlir::StringRef> memory_kind =
GetMemoryKindIfArgOnHost(parent, block_arg.getArgNumber())) {
mlir::mpmd::InsertAttr(
arg_attrs[operand.getOperandNumber()], kMemoryKindAttr,
mlir::StringAttr::get(frag.getContext(), memory_kind.value()));
}
continue;
}
}
Expand All @@ -321,9 +338,11 @@ class MarkOffloadedInputOutputPass
SmallVector<Attribute> res_attrs = GetResAttrsOrCreateDefault(frag);
for (auto [idx, return_operand] : llvm::enumerate(
frag.getRegion().front().getTerminator()->getOperands())) {
if (IsResultAndOnHostMemory(return_operand)) {
InsertAttr(res_attrs[idx], kMemoryKindAttr,
StringAttr::get(frag.getContext(), kMemoryKindPinnedHost));
if (std::optional<StringRef> memory_kind =
GetOnHostMemoryKindIfResult(return_operand)) {
mlir::mpmd::InsertAttr(
res_attrs[idx], kMemoryKindAttr,
mlir::StringAttr::get(frag.getContext(), memory_kind.value()));
}
}
SetResAttrs(frag, res_attrs);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,33 @@ func.func @simple(%func_arg: !m1_16 {mhlo.memory_kind = "pinned_host"}) ->
func.return %f : !m1_16
}

// CHECK-LABEL: func @simple_unpinned_host(%arg0: {{.*}} {mhlo.memory_kind = "unpinned_host"}) ->
// CHECK-SAME: {mhlo.memory_kind = "unpinned_host"})
func.func @simple_unpinned_host(%func_arg: !m1_16 {mhlo.memory_kind = "unpinned_host"}) ->
(!m1_16 {mhlo.memory_kind = "unpinned_host"})
attributes {topology=#topology} {

// CHECK-NEXT: fragment
// CHECK-SAME: {arg_attrs = [{mhlo.memory_kind = "unpinned_host"}], res_attrs = [{mhlo.memory_kind = "unpinned_host"}]}
%f = mpmd.fragment<mesh="m1", origin=[]> (%func_arg)
(%arg0: tensor<16xf32>) {
%7 = stablehlo.custom_call @annotate_device_placement(%arg0) {
backend_config = "", has_side_effect = true,
mhlo.frontend_attributes = {_xla_buffer_placement = "device"}
} : (tensor<16xf32>) -> tensor<16xf32>

%8 = stablehlo.add %7, %7 : tensor<16xf32>

%9 = stablehlo.custom_call @annotate_device_placement(%8) {
backend_config = "", has_side_effect = true,
mhlo.frontend_attributes = {_xla_buffer_placement = "unpinned_host"}
} : (tensor<16xf32>) -> tensor<16xf32>
mpmd.return %9 : tensor<16xf32>
} : (!m1_16) -> !m1_16

func.return %f : !m1_16
}

// CHECK-LABEL: func @with_optimization_barrier(%arg0
// CHECK-SAME: %arg1: {{.*}} {mhlo.memory_kind = "pinned_host"}) ->
func.func @with_optimization_barrier(%func_arg0: !m1_16,
Expand Down Expand Up @@ -173,9 +200,9 @@ func.func @place_host_with_incompatible_reshape_and_custom_call(%func_arg0: !m1_
func.return %f#0, %f#1 : !m1_4x4, !m1_16
}

// CHECK-LABEL: func @func_arg_multiple_matching_users
// CHECK-LABEL: func @func_arg_multiple_matching_users_pinned_host
// CHECK-SAME: (%arg0: {{.*}} {mhlo.memory_kind = "pinned_host"})
func.func @func_arg_multiple_matching_users(
func.func @func_arg_multiple_matching_users_pinned_host(
%func_arg: !m1_16 {mhlo.memory_kind = "pinned_host"}) -> (!m1_16, !m1_16)
attributes {topology=#topology} {

Expand Down Expand Up @@ -205,6 +232,38 @@ func.func @func_arg_multiple_matching_users(
func.return %f1, %f2 : !m1_16, !m1_16
}

// CHECK-LABEL: func @func_arg_multiple_matching_users_unpinned_host
// CHECK-SAME: (%arg0: {{.*}} {mhlo.memory_kind = "unpinned_host"})
func.func @func_arg_multiple_matching_users_unpinned_host(
%func_arg: !m1_16 {mhlo.memory_kind = "unpinned_host"}) -> (!m1_16, !m1_16)
attributes {topology=#topology} {

// CHECK: fragment{{.*}} origin=["f"]
// CHECK-SAME: {arg_attrs = [{mhlo.memory_kind = "unpinned_host"}]
%f1 = mpmd.fragment<mesh="m1", origin=["f"]> (%func_arg)
(%arg0: tensor<16xf32>) {
%7 = stablehlo.custom_call @annotate_device_placement(%arg0) {
backend_config = "", has_side_effect = true,
mhlo.frontend_attributes = {_xla_buffer_placement = "device"}
} : (tensor<16xf32>) -> tensor<16xf32>

mpmd.return %7 : tensor<16xf32>
} : (!m1_16) -> !m1_16

// CHECK: fragment{{.*}} origin=["g"]
// CHECK-SAME: {arg_attrs = [{mhlo.memory_kind = "unpinned_host"}]
%f2 = mpmd.fragment<mesh="m1", origin=["g"]> (%func_arg)
(%arg0: tensor<16xf32>) {
%7 = stablehlo.custom_call @annotate_device_placement(%arg0) {
backend_config = "", has_side_effect = true,
mhlo.frontend_attributes = {_xla_buffer_placement = "device"}
} : (tensor<16xf32>) -> tensor<16xf32>

mpmd.return %7 : tensor<16xf32>
} : (!m1_16) -> !m1_16
func.return %f1, %f2 : !m1_16, !m1_16
}

// CHECK-LABEL: func @func_arg_match
// CHECK-SAME: (%arg0: {{.*}} {mhlo.memory_kind = "pinned_host"})
func.func @func_arg_match(%func_arg: !m1_16 {mhlo.memory_kind = "pinned_host"}) -> !m1_16
Expand Down
57 changes: 42 additions & 15 deletions shardy/dialect/mpmd/transforms/export/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,26 +63,53 @@ DenseSet<BlockArgument> GetDonatedBlockArguments(
DenseMap<Operation*, SmallVector<unsigned int>>
OperandsForDeletionMapping(func::FuncOp main_func);

inline bool IsMemoryKindOnHost(mlir::StringAttr memory_kind) {
if (!memory_kind) {
return false;
}
mlir::StringRef memory_kind_val = memory_kind.getValue();
return memory_kind_val == mpmd::kMemoryKindPinnedHost ||
memory_kind_val == mpmd::kMemoryKindUnpinnedHost;
}


// Checks the arg attrs of the op to see if the arg is on the host.
inline bool IsArgOnHost(Operation* op, int index) {
return GetArgAttr(op, index, kMemoryKindAttr) ==
StringAttr::get(op->getContext(), kMemoryKindPinnedHost);
inline bool IsArgOnHost(mlir::Operation* op, int index) {
mlir::StringAttr memory_kind_attr = mlir::dyn_cast_or_null<mlir::StringAttr>(
mlir::mpmd::GetArgAttr(op, index, mlir::mpmd::kMemoryKindAttr));
return IsMemoryKindOnHost(memory_kind_attr);
}

// Checks the arg of the function is on the host.
inline bool IsArgOnHost(func::FuncOp func, int index) {
return func.getArgAttrOfType<StringAttr>(index,
kMemoryKindAttr) ==
StringAttr::get(func.getContext(), kMemoryKindPinnedHost);
inline std::optional<mlir::StringRef> GetMemoryKindIfArgOnHost(
mlir::func::FuncOp func, int index) {
mlir::StringAttr memory_kind_attr = func.getArgAttrOfType<mlir::StringAttr>(
index, mlir::mpmd::kMemoryKindAttr);
if (!memory_kind_attr) {
return std::nullopt;
}
mlir::StringRef memory_kind_val = memory_kind_attr.getValue();
if (memory_kind_val == mlir::mpmd::kMemoryKindPinnedHost ||
memory_kind_val == mlir::mpmd::kMemoryKindUnpinnedHost) {
return memory_kind_val;
}
return std::nullopt;
}

// Checks the result attrs of the op to see if the result is on the host.
inline bool IsResultOnHost(OpResult op_result) {
return GetResAttr(op_result.getOwner(),
op_result.getResultNumber(),
kMemoryKindAttr) ==
StringAttr::get(op_result.getContext(),
kMemoryKindPinnedHost);
inline std::optional<mlir::StringRef> GetMemoryKindIfResultOnHost(
mlir::OpResult op_result) {
auto memory_kind = dyn_cast_or_null<mlir::StringAttr>(
mlir::mpmd::GetResAttr(op_result.getOwner(), op_result.getResultNumber(),
mlir::mpmd::kMemoryKindAttr));

if (!memory_kind) {
return std::nullopt;
}
mlir::StringRef memory_kind_val = memory_kind.getValue();
if (memory_kind_val == mlir::mpmd::kMemoryKindPinnedHost ||
memory_kind_val == mlir::mpmd::kMemoryKindUnpinnedHost) {
return memory_kind_val;
}
return std::nullopt;
}

// Checks if the layout of the input and output match.
Expand Down
Loading