Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ bool isOffloadCustomCallOp(Operation* op) {
return false;
}

PropagationDirection isPassThroughOp(Operation* op, int64_t factorIndex,
PropagationDirection isPassThroughOp(Operation* op, int64_t,
bool allowMultiUse) {
if (isElementwise(op) || isOffloadCustomCallOp(op) ||
isa<stablehlo::ReshapeOp, stablehlo::TransposeOp, DataFlowEdgeOp>(op)) {
Expand All @@ -111,10 +111,25 @@ PropagationDirection isPassThroughOpMultiUse(Operation* op,
return isPassThroughOp(op, factorIndex, /*allowMultiUse=*/true);
}

// NOTE: if the `op` has no sharding rule, then we will assume it uses an
// identity sharding rule. For example, `DataFlowEdgeOp`.
PropagationDirection onlyBatchFactorsExceptBroadcast(Operation* op,
int64_t factorIndex) {
// If the `op` has no sharding rule, then we assume it uses an identity
// sharding rule. For example, `DataFlowEdgeOp`.
if (auto shardingRule =
op->getAttrOfType<OpShardingRuleAttr>(kShardingRuleAttr);
shardingRule && !shardingRule.isBatchingFactor(factorIndex)) {
return PropagationDirection::NONE;
}
if (isa<stablehlo::BroadcastInDimOp>(op)) {
return PropagationDirection::NONE;
}
return PropagationDirection::BOTH;
}

PropagationDirection onlyPassThroughFactorsBroadcastBackward(
Operation* op, int64_t factorIndex) {
// If the `op` has no sharding rule, then we assume it uses an identity
// sharding rule. For example, `DataFlowEdgeOp`.
if (auto shardingRule =
op->getAttrOfType<OpShardingRuleAttr>(kShardingRuleAttr);
shardingRule && !shardingRule.isPassThroughFactor(factorIndex)) {
Expand All @@ -134,10 +149,10 @@ PropagationDirection propagateAnyExceptBroadcastForward(Operation* op,
return PropagationDirection::BOTH;
}

constexpr std::array<GetDirectionToPropagateFnPtr, 5> opPropagationSchedule = {
isPassThroughOpSingleUse, isPassThroughOpMultiUse,
onlyPassThroughFactorsBroadcastBackward, propagateAnyExceptBroadcastForward,
propagateAny};
constexpr std::array<GetDirectionToPropagateFnPtr, 6> opPropagationSchedule = {
isPassThroughOpSingleUse, isPassThroughOpMultiUse,
onlyBatchFactorsExceptBroadcast, onlyPassThroughFactorsBroadcastBackward,
propagateAnyExceptBroadcastForward, propagateAny};

// Returns the direction in which the given operation should be propagated.
//
Expand Down
Loading