Skip to content

Commit

Permalink
Fix multiple sources build side in AdaptiveReorderPartitionedJoin
Browse files Browse the repository at this point in the history
  • Loading branch information
gaurav8297 authored and losipiuk committed Sep 20, 2024
1 parent 550544f commit 6c25ebf
Show file tree
Hide file tree
Showing 3 changed files with 455 additions and 63 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
import io.airlift.units.DataSize;
Expand All @@ -25,36 +26,43 @@
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.iterative.Lookup;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.optimizations.PlanNodeSearcher;
import io.trino.sql.planner.optimizations.StreamPreferredProperties;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.ExchangeNode;
import io.trino.sql.planner.plan.JoinNode;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.PlanNodeId;
import io.trino.sql.planner.plan.SimplePlanRewriter;

import java.util.List;
import java.util.Optional;
import java.util.Set;

import static com.google.common.base.Verify.verify;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static io.trino.SystemSessionProperties.getFaultTolerantExecutionAdaptiveJoinReorderingMinSizeThreshold;
import static io.trino.SystemSessionProperties.getFaultTolerantExecutionAdaptiveJoinReorderingSizeDifferenceRatio;
import static io.trino.SystemSessionProperties.getRetryPolicy;
import static io.trino.SystemSessionProperties.isFaultTolerantExecutionAdaptiveJoinReorderingEnabled;
import static io.trino.cost.PlanNodeStatsEstimateMath.getFirstKnownOutputSizeInBytes;
import static io.trino.operator.RetryPolicy.TASK;
import static io.trino.sql.planner.SystemPartitioningHandle.FIXED_ARBITRARY_DISTRIBUTION;
import static io.trino.sql.planner.optimizations.StreamPreferredProperties.partitionedOn;
import static io.trino.sql.planner.optimizations.StreamPropertyDerivations.StreamProperties;
import static io.trino.sql.planner.optimizations.StreamPropertyDerivations.deriveStreamPropertiesWithoutActualProperties;
import static io.trino.sql.planner.plan.AggregationNode.Step.PARTIAL;
import static io.trino.sql.planner.plan.ChildReplacer.replaceChildren;
import static io.trino.sql.planner.plan.ExchangeNode.Scope.LOCAL;
import static io.trino.sql.planner.plan.ExchangeNode.Type.GATHER;
import static io.trino.sql.planner.plan.ExchangeNode.partitionedExchange;
import static io.trino.sql.planner.plan.ExchangeNode.roundRobinExchange;
import static io.trino.sql.planner.plan.JoinNode.DistributionType.PARTITIONED;
import static io.trino.sql.planner.plan.Patterns.Join.right;
import static io.trino.sql.planner.plan.Patterns.aggregation;
import static io.trino.sql.planner.plan.Patterns.exchange;
import static io.trino.sql.planner.plan.Patterns.join;
import static io.trino.sql.planner.plan.Patterns.source;
import static io.trino.sql.planner.plan.SimplePlanRewriter.rewriteWith;
import static java.util.Objects.requireNonNull;

/**
Expand Down Expand Up @@ -138,12 +146,9 @@ public Result apply(JoinNode joinNode, Captures captures, Context context)
return Result.empty();
}

// Remove local exchange from build side
JoinNode joinNodeWithoutExchanges = removeLocalExchangeFromBuildSide(joinNode, localExchangeNode, context);
JoinNode flippedJoin = flipJoinBasedOnStats(joinNodeWithoutExchanges, context);
if (flippedJoin != joinNodeWithoutExchanges) {
// Add local exchange back to build side after flipping
return Result.ofPlanNode(addLocalExchangeToBuildSideIfNeeded(flippedJoin, metadata, context));
boolean flipJoin = flipJoinBasedOnStats(joinNode, context);
if (flipJoin) {
return Result.ofPlanNode(flipJoinAndFixLocalExchanges(joinNode, localExchangeNode.getId(), metadata, context));
}
return Result.empty();
}
Expand All @@ -157,70 +162,59 @@ private static boolean isBuildSideLocalExchangeNode(ExchangeNode exchangeNode, S
&& exchangeNode.getPartitioningScheme().getHashColumn().isEmpty();
}

private static JoinNode removeLocalExchangeFromBuildSide(JoinNode joinNode, ExchangeNode localExchangeNode, Context context)
private static JoinNode flipJoinAndFixLocalExchanges(
JoinNode joinNode,
PlanNodeId buildSideLocalExchangeId,
Metadata metadata,
Context context)
{
PlanNode rightNodeWithoutLocalExchange = PlanNodeSearcher
.searchFrom(joinNode.getRight(), context.getLookup())
.where(node -> node.getId().equals(localExchangeNode.getId()))
.removeFirst();
JoinNode flippedJoinNode = joinNode.flipChildren();

return new JoinNode(
joinNode.getId(),
joinNode.getType(),
joinNode.getLeft(),
rightNodeWithoutLocalExchange,
joinNode.getCriteria(),
joinNode.getLeftOutputSymbols(),
joinNode.getRightOutputSymbols(),
joinNode.isMaySkipOutputDuplicates(),
joinNode.getFilter(),
joinNode.getLeftHashSymbol(),
joinNode.getRightHashSymbol(),
joinNode.getDistributionType(),
joinNode.isSpillable(),
joinNode.getDynamicFilters(),
joinNode.getReorderJoinStatsAndCost());
}
// Fix local exchange on probe side
BuildToProbeLocalExchangeRewriter buildToProbeLocalExchangeRewriter = new BuildToProbeLocalExchangeRewriter(
buildSideLocalExchangeId,
context);
PlanNode probeSide = rewriteWith(
buildToProbeLocalExchangeRewriter,
context.getLookup().resolve(flippedJoinNode.getLeft()));

private static JoinNode addLocalExchangeToBuildSideIfNeeded(JoinNode joinNode, Metadata metadata, Context context)
{
// Fix local exchange on build side
PlanNode buildSide = flippedJoinNode.getRight();
StreamProperties rightProperties = deriveStreamPropertiesRecursively(
joinNode.getRight(),
buildSide,
metadata,
context.getLookup(),
context.getSession());
List<Symbol> buildSymbols = Lists.transform(joinNode.getCriteria(), JoinNode.EquiJoinClause::getRight);
List<Symbol> buildSymbols = Lists.transform(flippedJoinNode.getCriteria(), JoinNode.EquiJoinClause::getRight);
StreamPreferredProperties expectedRightProperties = partitionedOn(buildSymbols);
if (expectedRightProperties.isSatisfiedBy(rightProperties)) {
return joinNode;
// Do not add local exchange if the partitioning properties are already satisfied
if (!expectedRightProperties.isSatisfiedBy(rightProperties)) {
ProbeToBuildLocalExchangeRewriter probeToBuildLocalExchangeRewriter = new ProbeToBuildLocalExchangeRewriter(
buildSymbols,
context);
// Rewrite build side with local exchange
buildSide = rewriteWith(probeToBuildLocalExchangeRewriter, context.getLookup().resolve(buildSide));
}

// Add local exchange to build side
ExchangeNode exchangeNode = partitionedExchange(
context.getIdAllocator().getNextId(),
LOCAL,
joinNode.getRight(),
buildSymbols,
joinNode.getRightHashSymbol());
return new JoinNode(
joinNode.getId(),
joinNode.getType(),
joinNode.getLeft(),
exchangeNode,
joinNode.getCriteria(),
joinNode.getLeftOutputSymbols(),
joinNode.getRightOutputSymbols(),
joinNode.isMaySkipOutputDuplicates(),
joinNode.getFilter(),
joinNode.getLeftHashSymbol(),
joinNode.getRightHashSymbol(),
joinNode.getDistributionType(),
joinNode.isSpillable(),
joinNode.getDynamicFilters(),
joinNode.getReorderJoinStatsAndCost());
flippedJoinNode.getId(),
flippedJoinNode.getType(),
probeSide,
buildSide,
flippedJoinNode.getCriteria(),
flippedJoinNode.getLeftOutputSymbols(),
flippedJoinNode.getRightOutputSymbols(),
flippedJoinNode.isMaySkipOutputDuplicates(),
flippedJoinNode.getFilter(),
flippedJoinNode.getLeftHashSymbol(),
flippedJoinNode.getRightHashSymbol(),
flippedJoinNode.getDistributionType(),
flippedJoinNode.isSpillable(),
flippedJoinNode.getDynamicFilters(),
flippedJoinNode.getReorderJoinStatsAndCost());
}

private static JoinNode flipJoinBasedOnStats(JoinNode joinNode, Context context)
private static boolean flipJoinBasedOnStats(JoinNode joinNode, Context context)
{
double leftOutputSizeInBytes = getFirstKnownOutputSizeInBytes(
joinNode.getLeft(),
Expand All @@ -232,11 +226,8 @@ private static JoinNode flipJoinBasedOnStats(JoinNode joinNode, Context context)
context.getStatsProvider());
DataSize minSizeThreshold = getFaultTolerantExecutionAdaptiveJoinReorderingMinSizeThreshold(context.getSession());
double sizeDifferenceRatio = getFaultTolerantExecutionAdaptiveJoinReorderingSizeDifferenceRatio(context.getSession());
if (rightOutputSizeInBytes > minSizeThreshold.toBytes()
&& rightOutputSizeInBytes > sizeDifferenceRatio * leftOutputSizeInBytes) {
return joinNode.flipChildren();
}
return joinNode;
return rightOutputSizeInBytes > minSizeThreshold.toBytes()
&& rightOutputSizeInBytes > sizeDifferenceRatio * leftOutputSizeInBytes;
}

private static StreamProperties deriveStreamPropertiesRecursively(PlanNode node, Metadata metadata, Lookup lookup, Session session)
Expand All @@ -247,4 +238,104 @@ private static StreamProperties deriveStreamPropertiesRecursively(PlanNode node,
.collect(toImmutableList());
return deriveStreamPropertiesWithoutActualProperties(resolvedNode, inputProperties, metadata, session);
}

private static class BuildToProbeLocalExchangeRewriter
extends SimplePlanRewriter<Void>
{
private final PlanNodeId localExchangeNodeId;
private final Context context;

private BuildToProbeLocalExchangeRewriter(PlanNodeId localExchangeNodeId, Context context)
{
this.localExchangeNodeId = requireNonNull(localExchangeNodeId, "localExchangeNodeId is null");
this.context = requireNonNull(context, "context is null");
}

@Override
public PlanNode visitPlan(PlanNode node, RewriteContext<Void> context)
{
throw new UnsupportedOperationException("Unexpected plan node: " + node.getClass().getSimpleName());
}

@Override
public PlanNode visitAggregation(AggregationNode node, RewriteContext<Void> ctx)
{
// Other than partial aggregation is not possible in this rule since pattern matches either
// (partial aggregation + local exchange) or (local exchange) on build side
verify(node.getStep() == PARTIAL, "Unexpected aggregation step: %s", node.getStep());
// Skip partial aggregation and rewrite sources which contains build side local exchange
return rewriteSources(this, node, context);
}

@Override
public PlanNode visitExchange(ExchangeNode node, RewriteContext<Void> ctx)
{
verify(
node.getScope().equals(LOCAL) && node.getId().equals(localExchangeNodeId),
"Unexpected exchange node: %s", node.getId());
// Remove local exchange if there is only one source since we are converting build side
// to probe side
if (node.getSources().size() == 1) {
return node.getSources().getFirst();
}

// Add RoundRobinExchange to replace the partitioned local exchange if there are multiple sources
return roundRobinExchange(
context.getIdAllocator().getNextId(),
LOCAL,
node.getSources(),
node.getOutputSymbols());
}
}

private static class ProbeToBuildLocalExchangeRewriter
extends SimplePlanRewriter<Void>
{
private final Context context;
private final List<Symbol> buildSymbols;

private ProbeToBuildLocalExchangeRewriter(List<Symbol> buildSymbols, Context context)
{
this.buildSymbols = requireNonNull(buildSymbols, "buildSymbols is null");
this.context = requireNonNull(context, "context is null");
}

@Override
public PlanNode visitPlan(PlanNode node, RewriteContext<Void> ctx)
{
// Add partitioned local exchange to the probe side which is now the build side since we have
// flipped the join.
return partitionedExchange(
context.getIdAllocator().getNextId(),
LOCAL,
node,
buildSymbols,
Optional.empty());
}

@Override
public PlanNode visitExchange(ExchangeNode node, RewriteContext<Void> ctx)
{
// if there are multiple sources with round-robin exchange, replace it with partitioned exchange
// instead of adding one.
if (node.getScope().equals(LOCAL)
&& node.getSources().size() > 1
&& node.getPartitioningScheme().getPartitioning().getHandle().equals(FIXED_ARBITRARY_DISTRIBUTION)) {
return partitionedExchange(
context.getIdAllocator().getNextId(),
LOCAL,
node.getSources(),
buildSymbols,
node.getOutputSymbols());
}
return visitPlan(node, ctx);
}
}

private static PlanNode rewriteSources(SimplePlanRewriter<Void> rewriter, PlanNode node, Context context)
{
ImmutableList.Builder<PlanNode> children = ImmutableList.builderWithExpectedSize(node.getSources().size());
node.getSources().forEach(source -> children.add(rewriteWith(rewriter, context.getLookup().resolve(source))));
return replaceChildren(node, children.build());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import java.util.Optional;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static io.trino.sql.planner.SystemPartitioningHandle.FIXED_ARBITRARY_DISTRIBUTION;
import static io.trino.sql.planner.SystemPartitioningHandle.FIXED_BROADCAST_DISTRIBUTION;
import static io.trino.sql.planner.SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION;
Expand Down Expand Up @@ -150,6 +151,23 @@ public static ExchangeNode partitionedExchange(PlanNodeId id, Scope scope, PlanN
Optional.empty());
}

public static ExchangeNode partitionedExchange(PlanNodeId id, Scope scope, List<PlanNode> sources, List<Symbol> partitioningColumns, List<Symbol> outputSymbols)
{
List<List<Symbol>> sourceInputs = sources.stream()
.map(PlanNode::getOutputSymbols)
.collect(toImmutableList());
return new ExchangeNode(
id,
ExchangeNode.Type.REPARTITION,
scope,
new PartitioningScheme(
Partitioning.create(FIXED_HASH_DISTRIBUTION, partitioningColumns),
outputSymbols),
sources,
sourceInputs,
Optional.empty());
}

public static ExchangeNode replicatedExchange(PlanNodeId id, Scope scope, PlanNode child)
{
return new ExchangeNode(
Expand Down Expand Up @@ -183,6 +201,21 @@ public static ExchangeNode roundRobinExchange(PlanNodeId id, Scope scope, PlanNo
new PartitioningScheme(Partitioning.create(FIXED_ARBITRARY_DISTRIBUTION, ImmutableList.of()), child.getOutputSymbols()));
}

public static ExchangeNode roundRobinExchange(PlanNodeId id, Scope scope, List<PlanNode> sources, List<Symbol> outputSymbols)
{
List<List<Symbol>> sourceInputs = sources.stream()
.map(PlanNode::getOutputSymbols)
.collect(toImmutableList());
return new ExchangeNode(
id,
ExchangeNode.Type.REPARTITION,
scope,
new PartitioningScheme(Partitioning.create(FIXED_ARBITRARY_DISTRIBUTION, ImmutableList.of()), outputSymbols),
sources,
sourceInputs,
Optional.empty());
}

public static ExchangeNode mergingExchange(PlanNodeId id, Scope scope, PlanNode child, OrderingScheme orderingScheme)
{
PartitioningHandle partitioningHandle = scope == LOCAL ? FIXED_PASSTHROUGH_DISTRIBUTION : SINGLE_DISTRIBUTION;
Expand Down
Loading

0 comments on commit 6c25ebf

Please sign in to comment.