Skip to content

Commit

Permalink
Change transfers to do pulls on multi transfer (#19931)
Browse files Browse the repository at this point in the history
Scatter-gather reduction operations are preferable due to non-parallel
data transfers. To adjust for this we should only push tensors when they
are mono-transfers. When distributing to multiple devices it is
preferable to have the device fetch as the case of 1...8 -> 1 -> 1...8
we would prefer each device to handle its own data transfers.
  • Loading branch information
rsuderman authored Feb 18, 2025
1 parent 4e10bb5 commit 33a770e
Show file tree
Hide file tree
Showing 15 changed files with 285 additions and 99 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,8 @@ partitionStreamableOpsReference(IREE::Stream::PartitioningConfigAttr config,
// region - we must still block on loads though.
LLVM_DEBUG(llvm::dbgs() << "(ignoring global store)\n");
continue;
} else if (!isa<IREE::Stream::StreamableOpInterface>(op)) {
} else if (!isa<IREE::Stream::StreamableOpInterface>(op) &&
!isa<IREE::Stream::AsyncBarrierOp>(op)) {
// Not a streamable op. If it has side-effects then we force a hazard on
// all builders so that we don't move ops across it.
if (!mlir::wouldOpBeTriviallyDead(&op)) {
Expand All @@ -169,12 +170,18 @@ partitionStreamableOpsReference(IREE::Stream::PartitioningConfigAttr config,

// Synchronizing operations should join with their producers if the producer
// is streamable.
if (dyn_cast<IREE::Stream::AsyncBarrierOp>(op) ||
dyn_cast<IREE::Stream::AsyncTransferOp>(op)) {
if (dyn_cast<IREE::Stream::AsyncTransferOp>(op)) {
auto producer = op.getOperand(0).getDefiningOp();
auto streamable =
dyn_cast_or_null<IREE::Stream::StreamableOpInterface>(producer);
if (streamable) {

auto srcAffinity =
dyn_cast_or_null<IREE::Stream::AffinityOpInterface>(producer);
auto opAffinity = dyn_cast_or_null<IREE::Stream::AffinityOpInterface>(op);

if (streamable && srcAffinity && srcAffinity.getAffinityAttr() &&
IREE::Stream::AffinityAttr::canExecuteTogether(
opAffinity.getAffinityAttr(), srcAffinity.getAffinityAttr())) {
if (!syncOps.contains(producer))
syncOps[producer] = llvm::SmallVector<Operation *>();
syncOps[producer].push_back(&op);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ struct ConvertTensorBarrierOp
auto barrierOp = rewriter.create<IREE::Stream::AsyncBarrierOp>(
op.getLoc(), operand.resource.getType(), operand.resource,
operand.resourceSize,
/*affinity=*/executionAffinityAttr);
/*affinity=*/cast<IREE::Stream::AffinityAttr>(op.getTargetAttr()));
rewriter.replaceOpWithMultiple(op, {{barrierOp, operand.resourceSize}});
return success();
}
Expand All @@ -272,7 +272,8 @@ struct ConvertTensorTransferOp
op.getLoc(), unknownType, operand.resource, operand.resourceSize,
operand.resourceSize,
/*source_affinity=*/operand.affinity,
/*result_affinity=*/executionAffinityAttr);
/*result_affinity=*/
cast<IREE::Stream::AffinityAttr>(op.getTargetAttr()));
rewriter.replaceOpWithMultiple(op, {{transferOp, operand.resourceSize}});
return success();
}
Expand Down
62 changes: 6 additions & 56 deletions compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2500,8 +2500,6 @@ void AsyncCollectiveOp::getAsyncAccessRanges(
// stream.async.barrier
//===----------------------------------------------------------------------===//

bool AsyncBarrierOp::isMetadata() { return true; }

LogicalResult AsyncBarrierOp::verify() { return success(); }

Value AsyncBarrierOp::getTiedResult(unsigned resultIndex) {
Expand Down Expand Up @@ -2530,60 +2528,12 @@ LogicalResult AsyncTransferOp::verify() {
return success();
}

IREE::Stream::AffinityAttr AsyncTransferOp::getAffinityAttr() {
auto sourceType = cast<IREE::Stream::ResourceType>(getSource().getType());
auto resultType = cast<IREE::Stream::ResourceType>(getResult().getType());
if (sourceType.getLifetime() == IREE::Stream::Lifetime::Staging &&
resultType.getLifetime() == IREE::Stream::Lifetime::Staging) {
// TODO(multi-device): figure out how to model staging->staging transfers.
return getSourceAffinityAttr();
} else if (sourceType.getLifetime() == IREE::Stream::Lifetime::External ||
sourceType.getLifetime() == IREE::Stream::Lifetime::Staging) {
// If source is staging then the op should execute on the consumer.
return getResultAffinityAttr();
} else if (resultType.getLifetime() == IREE::Stream::Lifetime::External ||
resultType.getLifetime() == IREE::Stream::Lifetime::Staging) {
// If result is staging then the op should execute on the producer.
return getSourceAffinityAttr();
} else {
// Default to result affinity.
return getSourceAffinityAttr();
}
}

void AsyncTransferOp::setAffinityAttr(IREE::Stream::AffinityAttr value) {
auto sourceType = cast<IREE::Stream::ResourceType>(getSource().getType());
auto resultType = cast<IREE::Stream::ResourceType>(getResult().getType());
if (sourceType.getLifetime() == IREE::Stream::Lifetime::Staging &&
resultType.getLifetime() == IREE::Stream::Lifetime::Staging) {
// TODO(multi-device): figure out how to model staging->staging transfers.
if (value) {
setSourceAffinityAttr(value);
} else {
removeSourceAffinityAttr();
}
} else if (sourceType.getLifetime() == IREE::Stream::Lifetime::Staging) {
// If source is staging then the op should execute on the consumer.
if (value) {
setResultAffinityAttr(value);
} else {
removeResultAffinityAttr();
}
} else if (resultType.getLifetime() == IREE::Stream::Lifetime::Staging) {
// If result is staging then the op should execute on the producer.
if (value) {
setSourceAffinityAttr(value);
} else {
removeSourceAffinityAttr();
}
} else {
// Default to result affinity.
if (value) {
setResultAffinityAttr(value);
} else {
removeResultAffinityAttr();
}
}
void AsyncTransferOp::build(OpBuilder &builder, OperationState &state,
Type type, Value source, Value source_size,
Value result_size, AffinityAttr source_attr,
AffinityAttr result_attr) {
build(builder, state, type, source, source_size, result_size, source_attr,
result_attr, nullptr);
}

void AsyncTransferOp::getAsyncAccessRanges(
Expand Down
22 changes: 14 additions & 8 deletions compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2294,9 +2294,6 @@ def Stream_AsyncBarrierOp : Stream_Op<"async.barrier", [
AllTypesMatch<["source", "result"]>,
Stream_AffinityOp,
Stream_AsyncPhaseOp,
DeclareOpInterfaceMethods<Stream_StreamableOp, [
"isMetadata",
]>,
Util_SizeAwareOp,
DeclareOpInterfaceMethods<Util_TiedOpInterface, [
"getTiedResult",
Expand Down Expand Up @@ -2344,10 +2341,7 @@ def Stream_AsyncBarrierOp : Stream_Op<"async.barrier", [
}

def Stream_AsyncTransferOp : Stream_Op<"async.transfer", [
DeclareOpInterfaceMethods<Stream_AffinityOp, [
"getAffinityAttr",
"setAffinityAttr",
]>,
DeclareOpInterfaceMethods<Stream_AffinityOp>,
Stream_AsyncPhaseOp,
Stream_StreamableOp,
DeclareOpInterfaceMethods<Stream_AsyncAccessOp, [
Expand All @@ -2371,7 +2365,8 @@ def Stream_AsyncTransferOp : Stream_Op<"async.transfer", [
Stream_Size:$source_size,
Stream_Size:$result_size,
OptionalAttr<Stream_AffinityAttr>:$source_affinity,
OptionalAttr<Stream_AffinityAttr>:$result_affinity
OptionalAttr<Stream_AffinityAttr>:$result_affinity,
OptionalAttr<Stream_AffinityAttr>:$affinity
);
let results = (outs
AnyTypeOf<[
Expand All @@ -2383,6 +2378,7 @@ def Stream_AsyncTransferOp : Stream_Op<"async.transfer", [
let assemblyFormat = [{
$source `:` type($source)
`` `{` $source_size `}`
(`on` `(` $affinity^ `)`)?
(`from` `(` $source_affinity^ `)`)?
`->`
(`to` `(` $result_affinity^ `)`)?
Expand All @@ -2399,6 +2395,16 @@ def Stream_AsyncTransferOp : Stream_Op<"async.transfer", [

let hasCanonicalizer = 1;
let hasFolder = 1;

let builders = [
OpBuilder<(ins
"Type":$type,
"Value":$source,
"Value":$source_size,
"Value":$result_size,
"AffinityAttr":$source_affinity,
"AffinityAttr":$result_affinity)>,
];
}

def Stream_AsyncLoadOp : Stream_PureOp<"async.load", [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ iree_compiler_cc_library(
"ElideTimepoints.cpp",
"EmplaceAllocations.cpp",
"EncodeTensors.cpp",
"ExecutionPlacement.cpp",
"FoldUniformOperands.cpp",
"FuseDispatchBindings.cpp",
"LayoutSlices.cpp",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ iree_cc_library(
"ElideTimepoints.cpp"
"EmplaceAllocations.cpp"
"EncodeTensors.cpp"
"ExecutionPlacement.cpp"
"FoldUniformOperands.cpp"
"FuseDispatchBindings.cpp"
"LayoutSlices.cpp"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
// Copyright 2025 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Dialect/Stream/Analysis/Partitioning.h"
#include "iree/compiler/Dialect/Stream/IR/StreamDialect.h"
#include "iree/compiler/Dialect/Stream/IR/StreamOps.h"
#include "iree/compiler/Dialect/Stream/IR/StreamTypes.h"
#include "iree/compiler/Dialect/Stream/Transforms/Passes.h"
#include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
#include "llvm/ADT/EquivalenceClasses.h"
#include "llvm/Support/Debug.h"
#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Dominance.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

#define DEBUG_TYPE "iree-stream-execution-placement"

namespace mlir::iree_compiler::IREE::Stream {

#define GEN_PASS_DEF_EXECUTIONPLACEMENTPASS
#include "iree/compiler/Dialect/Stream/Transforms/Passes.h.inc"

namespace {

struct ExecutionPlacementPass
: public IREE::Stream::impl::ExecutionPlacementPassBase<
ExecutionPlacementPass> {
void runOnOperation() override {

getOperation()->walk([](IREE::Stream::AsyncTransferOp transfer) {
if (transfer.getAffinityAttr())
return;

auto operand = transfer.getSource();
auto producer = operand.getDefiningOp();
auto streamable =
dyn_cast_or_null<IREE::Stream::StreamableOpInterface>(producer);
auto srcAffinity =
dyn_cast_or_null<IREE::Stream::AffinityOpInterface>(producer);

bool hasOneUse = operand.hasOneUse();
if (hasOneUse && streamable && srcAffinity &&
srcAffinity.getAffinityAttr()) {
transfer.setAffinityAttr(srcAffinity.getAffinityAttr());
return;
}

if (transfer.getResultAffinityAttr()) {
transfer.setAffinityAttr(transfer.getResultAffinityAttr());
return;
}

if (transfer.getSourceAffinityAttr()) {
transfer.setAffinityAttr(transfer.getSourceAffinityAttr());
return;
}

transfer->emitOpError("Unknown src/dest affinity");
});
}
};

} // namespace
} // namespace mlir::iree_compiler::IREE::Stream
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,8 @@ void buildStreamAsyncPassPipeline(OpPassManager &passManager,
//----------------------------------------------------------------------------

FunctionLikeNest(passManager)
// Analyze and assign execution placement.
.addPass(IREE::Stream::createExecutionPlacementPass)
// Combine async work into execution regions.
.addPass(IREE::Stream::createScheduleExecutionPass)
// Group concurrently executable work into waves.
Expand Down
13 changes: 13 additions & 0 deletions compiler/src/iree/compiler/Dialect/Stream/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,19 @@ def RefineUsagePass :
// Stream formation and scheduling
//===----------------------------------------------------------------------===//

def ExecutionPlacementPass :
InterfacePass<"iree-stream-execution-placement", "mlir::CallableOpInterface"> {
let summary = "Runs an analysis and placement for stream executions.";
let description = [{
Handles placement analysis for `stream.async` operators that have a preferable
placement. This is so that more complex analsysis can be separated from the
execution scheduling pass.
}];
let dependentDialects = [
"IREE::Stream::StreamDialect",
];
}

def ScheduleExecutionPass :
InterfacePass<"iree-stream-schedule-execution", "mlir::CallableOpInterface"> {
let summary = "Identifies and groups asynchronous operations into executable regions within function-like regions.";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -668,19 +668,8 @@ applyAsyncCollectiveOp(IREE::Stream::AsyncCollectiveOp asyncOp,
static LogicalResult applyAsyncBarrierOp(IREE::Stream::AsyncBarrierOp barrierOp,
AllocationScope &scope,
OpBuilder builder) {
// TODO: barriers are being treated as copies, they should just be metadata
// operations but currently it's causing failures to be removed.
auto sourceRange = scope.lookupResourceRange(barrierOp.getSource());
auto targetRange = scope.lookupResourceRange(barrierOp.getResult());

// Perform the copy.
builder.create<IREE::Stream::CmdCopyOp>(
barrierOp.getLoc(), sourceRange.resource, sourceRange.resourceSize,
sourceRange.offset, targetRange.resource, targetRange.resourceSize,
targetRange.offset, sourceRange.length);

barrierOp.erase();
return success();
barrierOp->emitError("Async barrier should not longer exist");
return failure();
}

static LogicalResult applyAsyncTransferOp(IREE::Stream::AsyncTransferOp asyncOp,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ iree_lit_test_suite(
"encode_host_tensors_encoding.mlir",
"encode_host_tensors_packing.mlir",
"encode_host_tensors_packing_i1_experimental_clopt.mlir",
"execution_placement.mlir",
"fold_globals.mlir",
"fold_uniform_operands.mlir",
"fuse_dispatch_bindings.mlir",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ iree_lit_test_suite(
"encode_host_tensors_encoding.mlir"
"encode_host_tensors_packing.mlir"
"encode_host_tensors_packing_i1_experimental_clopt.mlir"
"execution_placement.mlir"
"fold_globals.mlir"
"fold_uniform_operands.mlir"
"fuse_dispatch_bindings.mlir"
Expand Down
Loading

0 comments on commit 33a770e

Please sign in to comment.