diff --git a/shardy/dialect/sdy/transforms/propagation/op_priority_propagation.cc b/shardy/dialect/sdy/transforms/propagation/op_priority_propagation.cc index 81d9858d..55510966 100644 --- a/shardy/dialect/sdy/transforms/propagation/op_priority_propagation.cc +++ b/shardy/dialect/sdy/transforms/propagation/op_priority_propagation.cc @@ -87,7 +87,7 @@ bool isOffloadCustomCallOp(Operation* op) { return false; } -PropagationDirection isPassThroughOp(Operation* op, int64_t factorIndex, +PropagationDirection isPassThroughOp(Operation* op, int64_t, bool allowMultiUse) { if (isElementwise(op) || isOffloadCustomCallOp(op) || isa(op)) { @@ -111,10 +111,25 @@ PropagationDirection isPassThroughOpMultiUse(Operation* op, return isPassThroughOp(op, factorIndex, /*allowMultiUse=*/true); } -// NOTE: if the `op` has no sharding rule, then we will assume it uses an -// identity sharding rule. For example, `DataFlowEdgeOp`. +PropagationDirection onlyBatchFactorsExceptBroadcast(Operation* op, + int64_t factorIndex) { + // If the `op` has no sharding rule, then we assume it uses an identity + // sharding rule. For example, `DataFlowEdgeOp`. + if (auto shardingRule = + op->getAttrOfType(kShardingRuleAttr); + shardingRule && !shardingRule.isBatchingFactor(factorIndex)) { + return PropagationDirection::NONE; + } + if (isa(op)) { + return PropagationDirection::NONE; + } + return PropagationDirection::BOTH; +} + PropagationDirection onlyPassThroughFactorsBroadcastBackward( Operation* op, int64_t factorIndex) { + // If the `op` has no sharding rule, then we assume it uses an identity + // sharding rule. For example, `DataFlowEdgeOp`. if (auto shardingRule = op->getAttrOfType(kShardingRuleAttr); shardingRule && !shardingRule.isPassThroughFactor(factorIndex)) { @@ -134,10 +149,10 @@ PropagationDirection propagateAnyExceptBroadcastForward(Operation* op, return PropagationDirection::BOTH; } -constexpr std::array opPropagationSchedule = { - isPassThroughOpSingleUse, isPassThroughOpMultiUse, - onlyPassThroughFactorsBroadcastBackward, propagateAnyExceptBroadcastForward, - propagateAny}; +constexpr std::array opPropagationSchedule = { + isPassThroughOpSingleUse, isPassThroughOpMultiUse, + onlyBatchFactorsExceptBroadcast, onlyPassThroughFactorsBroadcastBackward, + propagateAnyExceptBroadcastForward, propagateAny}; // Returns the direction in which the given operation should be propagated. //