diff --git a/presto-main/src/main/java/io/prestosql/operator/LookupJoinPageBuilder.java b/presto-main/src/main/java/io/prestosql/operator/LookupJoinPageBuilder.java index 8d3a4eacbca8..cc65bb8c1c0e 100644 --- a/presto-main/src/main/java/io/prestosql/operator/LookupJoinPageBuilder.java +++ b/presto-main/src/main/java/io/prestosql/operator/LookupJoinPageBuilder.java @@ -23,6 +23,7 @@ import static com.google.common.base.MoreObjects.toStringHelper; import static com.google.common.base.Verify.verify; +import static io.prestosql.operator.project.PageProcessor.MAX_BATCH_SIZE; import static io.prestosql.spi.block.PageBuilderStatus.DEFAULT_MAX_PAGE_SIZE_IN_BYTES; import static java.util.Objects.requireNonNull; @@ -38,6 +39,7 @@ public class LookupJoinPageBuilder private final PageBuilder buildPageBuilder; private final int buildOutputChannelCount; private int estimatedProbeBlockBytes; + private int estimatedProbeRowSize = -1; private boolean isSequentialProbeIndices = true; public LookupJoinPageBuilder(List buildTypes) @@ -48,7 +50,9 @@ public LookupJoinPageBuilder(List buildTypes) public boolean isFull() { - return estimatedProbeBlockBytes + buildPageBuilder.getSizeInBytes() >= DEFAULT_MAX_PAGE_SIZE_IN_BYTES || buildPageBuilder.isFull(); + return estimatedProbeBlockBytes + buildPageBuilder.getSizeInBytes() >= DEFAULT_MAX_PAGE_SIZE_IN_BYTES || + buildPageBuilder.getPositionCount() >= MAX_BATCH_SIZE || + buildPageBuilder.isFull(); } public boolean isEmpty() @@ -62,6 +66,7 @@ public void reset() probeIndexBuilder.clear(); buildPageBuilder.reset(); estimatedProbeBlockBytes = 0; + estimatedProbeRowSize = -1; isSequentialProbeIndices = true; } @@ -176,10 +181,24 @@ private void appendProbeIndex(JoinProbe probe) if (previousPosition == position) { return; } + estimatedProbeBlockBytes += getEstimatedProbeRowSize(probe); + } + + private int getEstimatedProbeRowSize(JoinProbe probe) + { + if (estimatedProbeRowSize != -1) { + return estimatedProbeRowSize; + } + + int estimatedProbeRowSize = 0; for (int index : probe.getOutputChannels()) { Block block = probe.getPage().getBlock(index); - // Estimate the size of the current row - estimatedProbeBlockBytes += block.getSizeInBytes() / block.getPositionCount(); + // Estimate the size of the probe row + // TODO: improve estimation for unloaded blocks by making it similar as in PageProcessor + estimatedProbeRowSize += block.getSizeInBytes() / block.getPositionCount(); } + + this.estimatedProbeRowSize = estimatedProbeRowSize; + return estimatedProbeRowSize; } } diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/EliminateCrossJoins.java b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/EliminateCrossJoins.java index 029e010c04c6..813acc41c147 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/EliminateCrossJoins.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/EliminateCrossJoins.java @@ -24,12 +24,10 @@ import io.prestosql.sql.planner.Symbol; import io.prestosql.sql.planner.iterative.Rule; import io.prestosql.sql.planner.optimizations.joins.JoinGraph; -import io.prestosql.sql.planner.plan.Assignments; import io.prestosql.sql.planner.plan.FilterNode; import io.prestosql.sql.planner.plan.JoinNode; import io.prestosql.sql.planner.plan.PlanNode; import io.prestosql.sql.planner.plan.PlanNodeId; -import io.prestosql.sql.planner.plan.ProjectNode; import io.prestosql.sql.tree.Expression; import java.util.HashMap; @@ -73,8 +71,8 @@ public boolean isEnabled(Session session) @Override public Result apply(JoinNode node, Captures captures, Context context) { - JoinGraph joinGraph = JoinGraph.buildShallowFrom(node, context.getLookup()); - if (joinGraph.size() < 3) { + JoinGraph joinGraph = JoinGraph.buildFrom(node, context.getLookup(), context.getIdAllocator()); + if (joinGraph.size() < 3 || !joinGraph.isContainsCrossJoin()) { return Result.empty(); } @@ -201,13 +199,6 @@ public static PlanNode buildJoinTree(List expectedOutputSymbols, JoinGra filter); } - if (graph.getAssignments().isPresent()) { - result = new ProjectNode( - idAllocator.getNextId(), - result, - Assignments.copyOf(graph.getAssignments().get())); - } - // If needed, introduce a projection to constrain the outputs to what was originally expected // Some nodes are sensitive to what's produced (e.g., DistinctLimit node) return restrictOutputs(idAllocator, result, ImmutableSet.copyOf(expectedOutputSymbols)).orElse(result); diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/InlineProjections.java b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/InlineProjections.java index 5ae6bad95514..e8afb662dd34 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/InlineProjections.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/InlineProjections.java @@ -22,6 +22,7 @@ import io.prestosql.sql.planner.SymbolsExtractor; import io.prestosql.sql.planner.iterative.Rule; import io.prestosql.sql.planner.plan.Assignments; +import io.prestosql.sql.planner.plan.PlanNode; import io.prestosql.sql.planner.plan.ProjectNode; import io.prestosql.sql.tree.Expression; import io.prestosql.sql.tree.Literal; @@ -29,6 +30,7 @@ import io.prestosql.sql.util.AstUtils; import java.util.Map; +import java.util.Optional; import java.util.Set; import java.util.function.Function; import java.util.stream.Collectors; @@ -64,9 +66,16 @@ public Result apply(ProjectNode parent, Captures captures, Context context) { ProjectNode child = captures.get(CHILD); - Sets.SetView targets = extractInliningTargets(parent, child); + return inlineProjections(parent, child) + .map(Result::ofPlanNode) + .orElse(Result.empty()); + } + + static Optional inlineProjections(ProjectNode parent, ProjectNode child) + { + Set targets = extractInliningTargets(parent, child); if (targets.isEmpty()) { - return Result.empty(); + return Optional.empty(); } // inline the expressions @@ -79,8 +88,6 @@ public Result apply(ProjectNode parent, Captures captures, Context context) // Synthesize identity assignments for the inputs of expressions that were inlined // to place in the child projection. - // If all assignments end up becoming identity assignments, they'll get pruned by - // other rules Set inputs = child.getAssignments() .entrySet().stream() .filter(entry -> targets.contains(entry.getKey())) @@ -88,27 +95,36 @@ public Result apply(ProjectNode parent, Captures captures, Context context) .flatMap(entry -> SymbolsExtractor.extractAll(entry).stream()) .collect(toSet()); - Assignments.Builder childAssignments = Assignments.builder(); + Assignments.Builder newChildAssignmentsBuilder = Assignments.builder(); for (Map.Entry assignment : child.getAssignments().entrySet()) { if (!targets.contains(assignment.getKey())) { - childAssignments.put(assignment); + newChildAssignmentsBuilder.put(assignment); } } for (Symbol input : inputs) { - childAssignments.putIdentity(input); + newChildAssignmentsBuilder.putIdentity(input); + } + + Assignments newChildAssignments = newChildAssignmentsBuilder.build(); + PlanNode newChild; + if (newChildAssignments.isIdentity()) { + newChild = child.getSource(); + } + else { + newChild = new ProjectNode( + child.getId(), + child.getSource(), + newChildAssignments); } - return Result.ofPlanNode( + return Optional.of( new ProjectNode( parent.getId(), - new ProjectNode( - child.getId(), - child.getSource(), - childAssignments.build()), + newChild, Assignments.copyOf(parentAssignments))); } - private Expression inlineReferences(Expression expression, Assignments assignments) + private static Expression inlineReferences(Expression expression, Assignments assignments) { Function mapping = symbol -> { Expression result = assignments.get(symbol); @@ -122,7 +138,7 @@ private Expression inlineReferences(Expression expression, Assignments assignmen return inlineSymbols(mapping, expression); } - private Sets.SetView extractInliningTargets(ProjectNode parent, ProjectNode child) + private static Set extractInliningTargets(ProjectNode parent, ProjectNode child) { // candidates for inlining are // 1. references to simple constants @@ -162,7 +178,7 @@ private Sets.SetView extractInliningTargets(ProjectNode parent, ProjectN return Sets.union(singletons, constants); } - private Set extractTryArguments(Expression expression) + private static Set extractTryArguments(Expression expression) { return AstUtils.preOrder(expression) .filter(TryExpression.class::isInstance) diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushProjectionThroughJoin.java b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushProjectionThroughJoin.java new file mode 100644 index 000000000000..bba966d23880 --- /dev/null +++ b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushProjectionThroughJoin.java @@ -0,0 +1,147 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.sql.planner.iterative.rule; + +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Streams; +import io.prestosql.sql.planner.DeterminismEvaluator; +import io.prestosql.sql.planner.PlanNodeIdAllocator; +import io.prestosql.sql.planner.Symbol; +import io.prestosql.sql.planner.SymbolsExtractor; +import io.prestosql.sql.planner.iterative.Lookup; +import io.prestosql.sql.planner.plan.Assignments; +import io.prestosql.sql.planner.plan.JoinNode; +import io.prestosql.sql.planner.plan.PlanNode; +import io.prestosql.sql.planner.plan.ProjectNode; +import io.prestosql.sql.tree.Expression; + +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; + +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static io.prestosql.sql.planner.SymbolsExtractor.extractUnique; +import static io.prestosql.sql.planner.plan.JoinNode.Type.INNER; + +/** + * Utility class for pushing projections through inner join so that joins are not separated + * by a project node and can participate in cross join elimination or join reordering. + */ +public final class PushProjectionThroughJoin +{ + public static Optional pushProjectionThroughJoin(ProjectNode projectNode, Lookup lookup, PlanNodeIdAllocator planNodeIdAllocator) + { + if (!projectNode.getAssignments().getExpressions().stream().allMatch(DeterminismEvaluator::isDeterministic)) { + return Optional.empty(); + } + + PlanNode child = lookup.resolve(projectNode.getSource()); + if (!(child instanceof JoinNode)) { + return Optional.empty(); + } + + JoinNode joinNode = (JoinNode) child; + PlanNode leftChild = joinNode.getLeft(); + PlanNode rightChild = joinNode.getRight(); + + if (joinNode.getType() != INNER) { + return Optional.empty(); + } + + Assignments.Builder leftAssignmentsBuilder = Assignments.builder(); + Assignments.Builder rightAssignmentsBuilder = Assignments.builder(); + for (Map.Entry assignment : projectNode.getAssignments().entrySet()) { + Expression expression = assignment.getValue(); + Set symbols = extractUnique(expression); + if (leftChild.getOutputSymbols().containsAll(symbols)) { + // expression is satisfied with left child symbols + leftAssignmentsBuilder.put(assignment.getKey(), expression); + } + else if (rightChild.getOutputSymbols().containsAll(symbols)) { + // expression is satisfied with right child symbols + rightAssignmentsBuilder.put(assignment.getKey(), expression); + } + else { + // expression is using symbols from both join sides + return Optional.empty(); + } + } + + // add projections for symbols required by the join itself + Set joinRequiredSymbols = getJoinRequiredSymbols(joinNode); + for (Symbol requiredSymbol : joinRequiredSymbols) { + if (leftChild.getOutputSymbols().contains(requiredSymbol)) { + leftAssignmentsBuilder.putIdentity(requiredSymbol); + } + else { + checkState(rightChild.getOutputSymbols().contains(requiredSymbol)); + rightAssignmentsBuilder.putIdentity(requiredSymbol); + } + } + + Assignments leftAssignments = leftAssignmentsBuilder.build(); + Assignments rightAssignments = rightAssignmentsBuilder.build(); + List outputSymbols = Streams.concat(leftAssignments.getOutputs().stream(), rightAssignments.getOutputs().stream()) + .filter(ImmutableSet.copyOf(projectNode.getOutputSymbols())::contains) + .collect(toImmutableList()); + + return Optional.of(new JoinNode( + joinNode.getId(), + joinNode.getType(), + inlineProjections( + new ProjectNode(planNodeIdAllocator.getNextId(), leftChild, leftAssignments), + lookup), + inlineProjections( + new ProjectNode(planNodeIdAllocator.getNextId(), rightChild, rightAssignments), + lookup), + joinNode.getCriteria(), + outputSymbols, + joinNode.getFilter(), + joinNode.getLeftHashSymbol(), + joinNode.getRightHashSymbol(), + joinNode.getDistributionType(), + joinNode.isSpillable(), + joinNode.getDynamicFilters())); + } + + private static PlanNode inlineProjections(ProjectNode parentProjection, Lookup lookup) + { + PlanNode child = lookup.resolve(parentProjection.getSource()); + if (!(child instanceof ProjectNode)) { + return parentProjection; + } + ProjectNode childProjection = (ProjectNode) child; + + return InlineProjections.inlineProjections(parentProjection, childProjection) + .map(node -> inlineProjections(node, lookup)) + .orElse(parentProjection); + } + + private static Set getJoinRequiredSymbols(JoinNode node) + { + // extract symbols required by the join itself + return Streams.concat( + node.getCriteria().stream().map(JoinNode.EquiJoinClause::getLeft), + node.getCriteria().stream().map(JoinNode.EquiJoinClause::getRight), + node.getFilter().map(SymbolsExtractor::extractUnique).orElse(ImmutableSet.of()).stream(), + node.getLeftHashSymbol().map(ImmutableSet::of).orElse(ImmutableSet.of()).stream(), + node.getRightHashSymbol().map(ImmutableSet::of).orElse(ImmutableSet.of()).stream()) + .collect(toImmutableSet()); + } + + private PushProjectionThroughJoin() {} +} diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/joins/JoinGraph.java b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/joins/JoinGraph.java index 9c28e03add2a..c33831cb6914 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/joins/JoinGraph.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/joins/JoinGraph.java @@ -16,6 +16,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMultimap; import com.google.common.collect.Multimap; +import io.prestosql.sql.planner.PlanNodeIdAllocator; import io.prestosql.sql.planner.Symbol; import io.prestosql.sql.planner.iterative.GroupReference; import io.prestosql.sql.planner.iterative.Lookup; @@ -27,7 +28,6 @@ import io.prestosql.sql.planner.plan.ProjectNode; import io.prestosql.sql.tree.Expression; -import java.util.ArrayList; import java.util.Collection; import java.util.HashMap; import java.util.List; @@ -36,6 +36,7 @@ import static com.google.common.base.Preconditions.checkState; import static com.google.common.collect.ImmutableList.toImmutableList; +import static io.prestosql.sql.planner.iterative.rule.PushProjectionThroughJoin.pushProjectionThroughJoin; import static io.prestosql.sql.planner.plan.JoinNode.Type.INNER; import static java.util.Objects.requireNonNull; @@ -46,42 +47,23 @@ */ public class JoinGraph { - private final Optional> assignments; private final List filters; private final List nodes; // nodes in order of their appearance in tree plan (left, right, parent) private final Multimap edges; private final PlanNodeId rootId; - - /** - * Builds all (distinct) {@link JoinGraph}-es whole plan tree. - */ - public static List buildFrom(PlanNode plan) - { - return buildFrom(plan, Lookup.noLookup()); - } + private final boolean containsCrossJoin; /** * Builds {@link JoinGraph} containing {@code plan} node. */ - public static JoinGraph buildShallowFrom(PlanNode plan, Lookup lookup) - { - JoinGraph graph = plan.accept(new Builder(true, lookup), new Context()); - return graph; - } - - private static List buildFrom(PlanNode plan, Lookup lookup) + public static JoinGraph buildFrom(PlanNode plan, Lookup lookup, PlanNodeIdAllocator planNodeIdAllocator) { - Context context = new Context(); - JoinGraph graph = plan.accept(new Builder(false, lookup), context); - if (graph.size() > 1) { - context.addSubGraph(graph); - } - return context.getGraphs(); + return plan.accept(new Builder(lookup, planNodeIdAllocator), new Context()); } public JoinGraph(PlanNode node) { - this(ImmutableList.of(node), ImmutableMultimap.of(), node.getId(), ImmutableList.of(), Optional.empty()); + this(ImmutableList.of(node), ImmutableMultimap.of(), node.getId(), ImmutableList.of(), false); } public JoinGraph( @@ -89,23 +71,13 @@ public JoinGraph( Multimap edges, PlanNodeId rootId, List filters, - Optional> assignments) + boolean containsCrossJoin) { this.nodes = nodes; this.edges = edges; this.rootId = rootId; this.filters = filters; - this.assignments = assignments; - } - - public JoinGraph withAssignments(Map assignments) - { - return new JoinGraph(nodes, edges, rootId, filters, Optional.of(assignments)); - } - - public Optional> getAssignments() - { - return assignments; + this.containsCrossJoin = containsCrossJoin; } public JoinGraph withFilter(Expression expression) @@ -114,7 +86,7 @@ public JoinGraph withFilter(Expression expression) filters.addAll(this.filters); filters.add(expression); - return new JoinGraph(nodes, edges, rootId, filters.build(), assignments); + return new JoinGraph(nodes, edges, rootId, filters.build(), containsCrossJoin); } public List getFilters() @@ -127,11 +99,6 @@ public PlanNodeId getRootId() return rootId; } - public JoinGraph withRootId(PlanNodeId rootId) - { - return new JoinGraph(nodes, edges, rootId, filters, assignments); - } - public boolean isEmpty() { return nodes.isEmpty(); @@ -157,6 +124,11 @@ public Collection getEdges(PlanNode node) return ImmutableList.copyOf(edges.get(node.getId())); } + public boolean isContainsCrossJoin() + { + return containsCrossJoin; + } + @Override public String toString() { @@ -180,7 +152,7 @@ public String toString() return builder.toString(); } - private JoinGraph joinWith(JoinGraph other, List joinClauses, Context context, PlanNodeId newRoot) + private JoinGraph joinWith(JoinGraph other, List joinClauses, Context context, PlanNodeId newRoot, boolean containsCrossJoin) { for (PlanNode node : other.nodes) { checkState(!edges.containsKey(node.getId()), "Node [%s] appeared in two JoinGraphs", node); @@ -212,35 +184,24 @@ private JoinGraph joinWith(JoinGraph other, List joinCl edges.put(right.getId(), new Edge(left, rightSymbol, leftSymbol)); } - return new JoinGraph(nodes, edges.build(), newRoot, joinedFilters, Optional.empty()); + return new JoinGraph(nodes, edges.build(), newRoot, joinedFilters, this.containsCrossJoin || containsCrossJoin); } private static class Builder extends PlanVisitor { - // TODO When io.prestosql.sql.planner.optimizations.EliminateCrossJoins is removed, remove 'shallow' flag - private final boolean shallow; private final Lookup lookup; + private final PlanNodeIdAllocator planNodeIdAllocator; - private Builder(boolean shallow, Lookup lookup) + private Builder(Lookup lookup, PlanNodeIdAllocator planNodeIdAllocator) { - this.shallow = shallow; this.lookup = requireNonNull(lookup, "lookup cannot be null"); + this.planNodeIdAllocator = requireNonNull(planNodeIdAllocator, "planNodeIdAllocator is null"); } @Override protected JoinGraph visitPlan(PlanNode node, Context context) { - if (!shallow) { - for (PlanNode child : node.getSources()) { - JoinGraph graph = child.accept(this, context); - if (graph.size() < 2) { - continue; - } - context.addSubGraph(graph.withRootId(child.getId())); - } - } - for (Symbol symbol : node.getOutputSymbols()) { context.setSymbolSource(symbol, node); } @@ -265,7 +226,7 @@ public JoinGraph visitJoin(JoinNode node, Context context) JoinGraph left = node.getLeft().accept(this, context); JoinGraph right = node.getRight().accept(this, context); - JoinGraph graph = left.joinWith(right, node.getCriteria(), context, node.getId()); + JoinGraph graph = left.joinWith(right, node.getCriteria(), context, node.getId(), node.isCrossJoin()); if (node.getFilter().isPresent()) { return graph.withFilter(node.getFilter().get()); @@ -276,10 +237,11 @@ public JoinGraph visitJoin(JoinNode node, Context context) @Override public JoinGraph visitProject(ProjectNode node, Context context) { - if (node.isIdentity()) { - JoinGraph graph = node.getSource().accept(this, context); - return graph.withAssignments(node.getAssignments().getMap()); + Optional rewrittenNode = pushProjectionThroughJoin(node, lookup, planNodeIdAllocator); + if (rewrittenNode.isPresent()) { + return rewrittenNode.get().accept(this, context); } + return visitPlan(node, context); } @@ -296,7 +258,7 @@ public JoinGraph visitGroupReference(GroupReference node, Context context) private boolean isTrivialGraph(JoinGraph graph) { - return graph.nodes.size() < 2 && graph.edges.isEmpty() && graph.filters.isEmpty() && !graph.assignments.isPresent(); + return graph.nodes.size() < 2 && graph.edges.isEmpty() && graph.filters.isEmpty(); } private JoinGraph replacementGraph(PlanNode oldNode, PlanNode newNode, Context context) @@ -345,19 +307,11 @@ private static class Context { private final Map symbolSources = new HashMap<>(); - // TODO When io.prestosql.sql.planner.optimizations.EliminateCrossJoins is removed, remove 'joinGraphs' - private final List joinGraphs = new ArrayList<>(); - public void setSymbolSource(Symbol symbol, PlanNode node) { symbolSources.put(symbol, node); } - public void addSubGraph(JoinGraph graph) - { - joinGraphs.add(graph); - } - public boolean containsSymbol(Symbol symbol) { return symbolSources.containsKey(symbol); @@ -368,10 +322,5 @@ public PlanNode getSymbolSource(Symbol symbol) checkState(containsSymbol(symbol)); return symbolSources.get(symbol); } - - public List getGraphs() - { - return joinGraphs; - } } } diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/plan/Assignments.java b/presto-main/src/main/java/io/prestosql/sql/planner/plan/Assignments.java index 68c219dffa5f..7ea37ff5468b 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/plan/Assignments.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/plan/Assignments.java @@ -128,6 +128,18 @@ public boolean isIdentity(Symbol output) return expression instanceof SymbolReference && ((SymbolReference) expression).getName().equals(output.getName()); } + public boolean isIdentity() + { + for (Map.Entry entry : assignments.entrySet()) { + Expression expression = entry.getValue(); + Symbol symbol = entry.getKey(); + if (!(expression instanceof SymbolReference && ((SymbolReference) expression).getName().equals(symbol.getName()))) { + return false; + } + } + return true; + } + private Collector, Builder, Assignments> toAssignments() { return Collector.of( diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/plan/ProjectNode.java b/presto-main/src/main/java/io/prestosql/sql/planner/plan/ProjectNode.java index 648d17b51998..335b1eb7303e 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/plan/ProjectNode.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/plan/ProjectNode.java @@ -18,13 +18,10 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.Iterables; import io.prestosql.sql.planner.Symbol; -import io.prestosql.sql.tree.Expression; -import io.prestosql.sql.tree.SymbolReference; import javax.annotation.concurrent.Immutable; import java.util.List; -import java.util.Map; import static java.util.Objects.requireNonNull; @@ -76,14 +73,7 @@ public PlanNode getSource() public boolean isIdentity() { - for (Map.Entry entry : assignments.entrySet()) { - Expression expression = entry.getValue(); - Symbol symbol = entry.getKey(); - if (!(expression instanceof SymbolReference && ((SymbolReference) expression).getName().equals(symbol.getName()))) { - return false; - } - } - return true; + return assignments.isIdentity(); } @Override diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestEliminateCrossJoins.java b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestEliminateCrossJoins.java index e60b095ac13c..729ce0c6d46f 100644 --- a/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestEliminateCrossJoins.java +++ b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestEliminateCrossJoins.java @@ -17,6 +17,7 @@ import com.google.common.collect.ImmutableMap; import io.prestosql.sql.planner.PlanNodeIdAllocator; import io.prestosql.sql.planner.Symbol; +import io.prestosql.sql.planner.assertions.PlanMatchPattern; import io.prestosql.sql.planner.iterative.GroupReference; import io.prestosql.sql.planner.iterative.rule.test.BaseRuleTest; import io.prestosql.sql.planner.iterative.rule.test.PlanBuilder; @@ -27,6 +28,7 @@ import io.prestosql.sql.planner.plan.PlanNode; import io.prestosql.sql.planner.plan.ProjectNode; import io.prestosql.sql.planner.plan.ValuesNode; +import io.prestosql.sql.tree.ArithmeticBinaryExpression; import io.prestosql.sql.tree.ArithmeticUnaryExpression; import io.prestosql.sql.tree.Expression; import io.prestosql.sql.tree.SymbolReference; @@ -38,14 +40,17 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.ImmutableList.toImmutableList; -import static com.google.common.collect.Iterables.getOnlyElement; import static io.prestosql.SystemSessionProperties.JOIN_REORDERING_STRATEGY; import static io.prestosql.sql.planner.assertions.PlanMatchPattern.any; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.expression; import static io.prestosql.sql.planner.assertions.PlanMatchPattern.join; import static io.prestosql.sql.planner.assertions.PlanMatchPattern.node; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.strictProject; +import static io.prestosql.sql.planner.iterative.Lookup.noLookup; import static io.prestosql.sql.planner.iterative.rule.EliminateCrossJoins.getJoinOrder; import static io.prestosql.sql.planner.iterative.rule.EliminateCrossJoins.isOriginalOrder; import static io.prestosql.sql.planner.plan.JoinNode.Type.INNER; +import static io.prestosql.sql.tree.ArithmeticBinaryExpression.Operator.ADD; import static io.prestosql.sql.tree.ArithmeticUnaryExpression.Sign.MINUS; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertFalse; @@ -109,13 +114,13 @@ public void testJoinOrder() PlanNode plan = joinNode( joinNode( - values(symbol("a")), - values(symbol("b"))), - values(symbol("c")), - symbol("a"), symbol("c"), - symbol("c"), symbol("b")); + values("a"), + values("b")), + values("c"), + "a", "c", + "c", "b"); - JoinGraph joinGraph = getOnlyElement(JoinGraph.buildFrom(plan)); + JoinGraph joinGraph = JoinGraph.buildFrom(plan, noLookup(), new PlanNodeIdAllocator()); assertEquals( getJoinOrder(joinGraph), @@ -128,24 +133,24 @@ public void testJoinOrderWithRealCrossJoin() PlanNode leftPlan = joinNode( joinNode( - values(symbol("a")), - values(symbol("b"))), - values(symbol("c")), - symbol("a"), symbol("c"), - symbol("c"), symbol("b")); + values("a"), + values("b")), + values("c"), + "a", "c", + "c", "b"); PlanNode rightPlan = joinNode( joinNode( - values(symbol("x")), - values(symbol("y"))), - values(symbol("z")), - symbol("x"), symbol("z"), - symbol("z"), symbol("y")); + values("x"), + values("y")), + values("z"), + "x", "z", + "z", "y"); PlanNode plan = joinNode(leftPlan, rightPlan); - JoinGraph joinGraph = getOnlyElement(JoinGraph.buildFrom(plan)); + JoinGraph joinGraph = JoinGraph.buildFrom(plan, noLookup(), new PlanNodeIdAllocator()); assertEquals( getJoinOrder(joinGraph), @@ -158,14 +163,14 @@ public void testJoinOrderWithMultipleEdgesBetweenNodes() PlanNode plan = joinNode( joinNode( - values(symbol("a")), - values(symbol("b1"), symbol("b2"))), - values(symbol("c1"), symbol("c2")), - symbol("a"), symbol("c1"), - symbol("c1"), symbol("b1"), - symbol("c2"), symbol("b2")); + values("a"), + values("b1", "b2")), + values("c1", "c2"), + "a", "c1", + "c1", "b1", + "c2", "b2"); - JoinGraph joinGraph = getOnlyElement(JoinGraph.buildFrom(plan)); + JoinGraph joinGraph = JoinGraph.buildFrom(plan, noLookup(), new PlanNodeIdAllocator()); assertEquals( getJoinOrder(joinGraph), @@ -173,18 +178,18 @@ public void testJoinOrderWithMultipleEdgesBetweenNodes() } @Test - public void testDonNotChangeOrderWithoutCrossJoin() + public void testDoesNotChangeOrderWithoutCrossJoin() { PlanNode plan = joinNode( joinNode( - values(symbol("a")), - values(symbol("b")), - symbol("a"), symbol("b")), - values(symbol("c")), - symbol("c"), symbol("b")); + values("a"), + values("b"), + "a", "b"), + values("c"), + "c", "b"); - JoinGraph joinGraph = getOnlyElement(JoinGraph.buildFrom(plan)); + JoinGraph joinGraph = JoinGraph.buildFrom(plan, noLookup(), new PlanNodeIdAllocator()); assertEquals( getJoinOrder(joinGraph), @@ -197,12 +202,12 @@ public void testDoNotReorderCrossJoins() PlanNode plan = joinNode( joinNode( - values(symbol("a")), - values(symbol("b"))), - values(symbol("c")), - symbol("c"), symbol("b")); + values("a"), + values("b")), + values("c"), + "c", "b"); - JoinGraph joinGraph = getOnlyElement(JoinGraph.buildFrom(plan)); + JoinGraph joinGraph = JoinGraph.buildFrom(plan, noLookup(), new PlanNodeIdAllocator()); assertEquals( getJoinOrder(joinGraph), @@ -210,21 +215,84 @@ public void testDoNotReorderCrossJoins() } @Test - public void testGiveUpOnNonIdentityProjections() + public void testEliminateCrossJoinWithNonIdentityProjections() + { + tester().assertThat(new EliminateCrossJoins()) + .setSystemProperty(JOIN_REORDERING_STRATEGY, "ELIMINATE_CROSS_JOINS") + .on(p -> { + Symbol a1 = p.symbol("a1"); + Symbol a2 = p.symbol("a2"); + Symbol b = p.symbol("b"); + Symbol c = p.symbol("c"); + Symbol d = p.symbol("d"); + Symbol e = p.symbol("e"); + Symbol f = p.symbol("f"); + + return p.join( + INNER, + p.project( + Assignments.of( + a2, new ArithmeticUnaryExpression(MINUS, new SymbolReference("a1")), + f, new SymbolReference("f")), + p.join( + INNER, + p.project( + Assignments.of( + a1, new SymbolReference("a1"), + f, new ArithmeticUnaryExpression(MINUS, new SymbolReference("b"))), + p.join( + INNER, + p.values(a1), + p.values(b))), + p.values(e), + new EquiJoinClause(a1, e))), + p.values(c, d), + new EquiJoinClause(a2, c), + new EquiJoinClause(f, d)); + }) + .matches( + node(ProjectNode.class, + join( + INNER, + ImmutableList.of(aliases -> new EquiJoinClause(new Symbol("d"), new Symbol("f"))), + join( + INNER, + ImmutableList.of(aliases -> new EquiJoinClause(new Symbol("a2"), new Symbol("c"))), + join(INNER, + ImmutableList.of(aliases -> new EquiJoinClause(new Symbol("a1"), new Symbol("e"))), + strictProject( + ImmutableMap.of( + "a2", expression("-a1"), + "a1", expression("a1")), + PlanMatchPattern.values("a1")), + strictProject( + ImmutableMap.of( + "e", expression("e")), + PlanMatchPattern.values("e"))), + any()), + strictProject( + ImmutableMap.of("f", expression("-b")), + PlanMatchPattern.values("b"))))); + } + + @Test + public void testGiveUpOnComplexProjections() { PlanNode plan = joinNode( projectNode( joinNode( - values(symbol("a1")), - values(symbol("b"))), - symbol("a2"), - new ArithmeticUnaryExpression(MINUS, new SymbolReference("a1"))), - values(symbol("c")), - symbol("a2"), symbol("c"), - symbol("c"), symbol("b")); - - assertEquals(JoinGraph.buildFrom(plan).size(), 2); + values("a1"), + values("b")), + "a2", + new ArithmeticBinaryExpression(ADD, new SymbolReference("a1"), new SymbolReference("b")), + "b", + new SymbolReference("b")), + values("c"), + "a2", "c", + "c", "b"); + + assertEquals(JoinGraph.buildFrom(plan, noLookup(), new PlanNodeIdAllocator()).size(), 2); } private Function crossJoinAndJoin(JoinNode.Type secondJoinType) @@ -246,17 +314,14 @@ private Function crossJoinAndJoin(JoinNode.Type secondJoi }; } - private PlanNode projectNode(PlanNode source, String symbol, Expression expression) + private PlanNode projectNode(PlanNode source, String symbol1, Expression expression1, String symbol2, Expression expression2) { return new ProjectNode( idAllocator.getNextId(), source, - Assignments.of(new Symbol(symbol), expression)); - } - - private String symbol(String name) - { - return name; + Assignments.of( + new Symbol(symbol1), expression1, + new Symbol(symbol2), expression2)); } private JoinNode joinNode(PlanNode left, PlanNode right, String... symbols) diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestInlineProjections.java b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestInlineProjections.java index b04ce1cc3bda..b185e81ea8ed 100644 --- a/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestInlineProjections.java +++ b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestInlineProjections.java @@ -67,6 +67,27 @@ public void test() values(ImmutableMap.of("x", 0))))); } + @Test + public void testEliminatesIdentityProjection() + { + tester().assertThat(new InlineProjections()) + .on(p -> + p.project( + Assignments.builder() + .put(p.symbol("single_complex"), expression("complex + 2")) // complex expression referenced only once + .build(), + p.project(Assignments.builder() + .put(p.symbol("complex"), expression("x - 1")) + .build(), + p.values(p.symbol("x"))))) + .matches( + project( + ImmutableMap.builder() + .put("out1", PlanMatchPattern.expression("x - 1 + 2")) + .build(), + values(ImmutableMap.of("x", 0)))); + } + @Test public void testIdentityProjections() { diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestPushProjectionThroughJoin.java b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestPushProjectionThroughJoin.java new file mode 100644 index 000000000000..bb4ea159b1f6 --- /dev/null +++ b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestPushProjectionThroughJoin.java @@ -0,0 +1,147 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.sql.planner.iterative.rule; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.prestosql.sql.planner.Plan; +import io.prestosql.sql.planner.PlanNodeIdAllocator; +import io.prestosql.sql.planner.Symbol; +import io.prestosql.sql.planner.assertions.PlanMatchPattern; +import io.prestosql.sql.planner.iterative.rule.test.PlanBuilder; +import io.prestosql.sql.planner.plan.Assignments; +import io.prestosql.sql.planner.plan.JoinNode; +import io.prestosql.sql.planner.plan.PlanNode; +import io.prestosql.sql.planner.plan.ProjectNode; +import io.prestosql.sql.tree.ArithmeticBinaryExpression; +import io.prestosql.sql.tree.ArithmeticUnaryExpression; +import org.testng.annotations.Test; + +import java.util.Optional; + +import static io.prestosql.cost.PlanNodeStatsEstimate.unknown; +import static io.prestosql.cost.StatsAndCosts.empty; +import static io.prestosql.metadata.AbstractMockMetadata.dummyMetadata; +import static io.prestosql.sql.planner.assertions.PlanAssert.assertPlan; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.expression; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.join; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.strictProject; +import static io.prestosql.sql.planner.iterative.Lookup.noLookup; +import static io.prestosql.sql.planner.iterative.rule.PushProjectionThroughJoin.pushProjectionThroughJoin; +import static io.prestosql.sql.planner.plan.JoinNode.Type.INNER; +import static io.prestosql.sql.planner.plan.JoinNode.Type.LEFT; +import static io.prestosql.sql.tree.ArithmeticBinaryExpression.Operator.ADD; +import static io.prestosql.sql.tree.ArithmeticUnaryExpression.Sign.MINUS; +import static io.prestosql.sql.tree.ArithmeticUnaryExpression.Sign.PLUS; +import static io.prestosql.testing.TestingSession.testSessionBuilder; +import static org.assertj.core.api.Assertions.assertThat; +import static org.testng.Assert.assertTrue; + +public class TestPushProjectionThroughJoin +{ + @Test + public void testPushesProjectionThroughJoin() + { + PlanNodeIdAllocator idAllocator = new PlanNodeIdAllocator(); + PlanBuilder p = new PlanBuilder(idAllocator, dummyMetadata()); + Symbol a0 = p.symbol("a0"); + Symbol a1 = p.symbol("a1"); + Symbol a2 = p.symbol("a2"); + Symbol a3 = p.symbol("a3"); + Symbol b0 = p.symbol("b0"); + Symbol b1 = p.symbol("b1"); + Symbol b2 = p.symbol("b2"); + + ProjectNode planNode = p.project( + Assignments.of( + a3, new ArithmeticUnaryExpression(MINUS, a2.toSymbolReference()), + b2, new ArithmeticUnaryExpression(PLUS, b1.toSymbolReference())), + p.join( + INNER, + // intermediate non-identity projections should be fully inlined + p.project( + Assignments.of( + a2, new ArithmeticUnaryExpression(PLUS, a0.toSymbolReference()), + a1, a1.toSymbolReference()), + p.project( + Assignments.builder() + .putIdentity(a0) + .putIdentity(a1) + .build(), + p.values(a0, a1))), + p.values(b0, b1), + new JoinNode.EquiJoinClause(a1, b1))); + + Optional rewritten = pushProjectionThroughJoin(planNode, noLookup(), idAllocator); + assertTrue(rewritten.isPresent()); + assertPlan( + testSessionBuilder().build(), + dummyMetadata(), + node -> unknown(), + new Plan(rewritten.get(), p.getTypes(), empty()), noLookup(), + join( + INNER, + ImmutableList.of(aliases -> new JoinNode.EquiJoinClause(new Symbol("a1"), new Symbol("b1"))), + strictProject(ImmutableMap.of( + "a3", expression("-(+a0)"), + "a1", expression("a1")), + strictProject(ImmutableMap.of( + "a0", expression("a0"), + "a1", expression("a1")), + PlanMatchPattern.values("a0", "a1"))), + strictProject(ImmutableMap.of( + "b2", expression("+b1"), + "b1", expression("b1")), + PlanMatchPattern.values("b0", "b1"))) + .withExactOutputs("a3", "b2")); + } + + @Test + public void testDoesNotPushStraddlingProjection() + { + PlanBuilder p = new PlanBuilder(new PlanNodeIdAllocator(), dummyMetadata()); + Symbol a = p.symbol("a"); + Symbol b = p.symbol("b"); + Symbol c = p.symbol("c"); + + ProjectNode planNode = p.project( + Assignments.of( + c, new ArithmeticBinaryExpression(ADD, a.toSymbolReference(), b.toSymbolReference())), + p.join( + INNER, + p.values(a), + p.values(b))); + Optional rewritten = pushProjectionThroughJoin(planNode, noLookup(), new PlanNodeIdAllocator()); + assertThat(rewritten).isEmpty(); + } + + @Test + public void testDoesNotPushProjectionThroughOuterJoin() + { + PlanBuilder p = new PlanBuilder(new PlanNodeIdAllocator(), dummyMetadata()); + Symbol a = p.symbol("a"); + Symbol b = p.symbol("b"); + Symbol c = p.symbol("c"); + + ProjectNode planNode = p.project( + Assignments.of( + c, new ArithmeticUnaryExpression(MINUS, a.toSymbolReference())), + p.join( + LEFT, + p.values(a), + p.values(b))); + Optional rewritten = pushProjectionThroughJoin(planNode, noLookup(), new PlanNodeIdAllocator()); + assertThat(rewritten).isEmpty(); + } +} diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/optimizations/TestEliminateCrossJoins.java b/presto-main/src/test/java/io/prestosql/sql/planner/optimizations/TestEliminateCrossJoins.java index e8f829299c34..0d7b1bbd486d 100644 --- a/presto-main/src/test/java/io/prestosql/sql/planner/optimizations/TestEliminateCrossJoins.java +++ b/presto-main/src/test/java/io/prestosql/sql/planner/optimizations/TestEliminateCrossJoins.java @@ -26,6 +26,7 @@ import static io.prestosql.sql.planner.assertions.PlanMatchPattern.equiJoinClause; import static io.prestosql.sql.planner.assertions.PlanMatchPattern.filter; import static io.prestosql.sql.planner.assertions.PlanMatchPattern.join; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.strictTableScan; import static io.prestosql.sql.planner.assertions.PlanMatchPattern.tableScan; import static io.prestosql.sql.planner.plan.JoinNode.Type.INNER; @@ -75,6 +76,23 @@ public void testEliminateSimpleCrossJoin() anyTree(ORDERS_TABLESCAN)))); } + @Test + public void testDoesNotReorderJoinsWhenNoCrossJoinPresent() + { + assertPlan("SELECT o1.orderkey, o2.custkey, o3.orderstatus, o4.totalprice " + + "FROM (orders o1 JOIN orders o2 ON o1.orderkey = o2.orderkey) " + + "JOIN (orders o3 JOIN orders o4 ON o3.orderkey = o4.orderkey) ON o1.orderkey = o3.orderkey", + anyTree( + join(INNER, ImmutableList.of(equiJoinClause("O1_ORDERKEY", "O3_ORDERKEY")), + join(INNER, ImmutableList.of(equiJoinClause("O1_ORDERKEY", "O2_ORDERKEY")), + anyTree(strictTableScan("orders", ImmutableMap.of("O1_ORDERKEY", "orderkey"))), + anyTree(strictTableScan("orders", ImmutableMap.of("O2_ORDERKEY", "orderkey", "O2_CUSTKEY", "custkey")))), + anyTree( + join(INNER, ImmutableList.of(equiJoinClause("O3_ORDERKEY", "O4_ORDERKEY")), + anyTree(strictTableScan("orders", ImmutableMap.of("O3_ORDERKEY", "orderkey", "O3_ORDERSTATUS", "orderstatus"))), + anyTree(strictTableScan("orders", ImmutableMap.of("O4_ORDERKEY", "orderkey", "O4_totalprice", "totalprice")))))))); + } + @Test public void testGiveUpOnCrossJoin() { diff --git a/presto-main/src/test/java/io/prestosql/sql/query/TestJoin.java b/presto-main/src/test/java/io/prestosql/sql/query/TestJoin.java new file mode 100644 index 000000000000..174e49db28d8 --- /dev/null +++ b/presto-main/src/test/java/io/prestosql/sql/query/TestJoin.java @@ -0,0 +1,53 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.sql.query; + +import org.testng.annotations.AfterClass; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +public class TestJoin +{ + private QueryAssertions assertions; + + @BeforeClass + public void init() + { + assertions = new QueryAssertions(); + } + + @AfterClass(alwaysRun = true) + public void teardown() + { + assertions.close(); + assertions = null; + } + + @Test + public void testCrossJoinEliminationWithOuterJoin() + { + assertions.assertQuery( + "WITH " + + " a AS (SELECT id FROM (VALUES (1)) AS t(id))," + + " b AS (SELECT id FROM (VALUES (1)) AS t(id))," + + " c AS (SELECT id FROM (VALUES ('1')) AS t(id))," + + " d as (SELECT id FROM (VALUES (1)) AS t(id))" + + "SELECT a.id " + + "FROM a " + + "LEFT JOIN b ON a.id = b.id " + + "JOIN c ON a.id = CAST(c.id AS bigint) " + + "JOIN d ON d.id = a.id", + "VALUES 1"); + } +}