From bc2a28b74453565066640a31368f98762619020a Mon Sep 17 00:00:00 2001 From: Karol Sobczak Date: Thu, 22 Aug 2019 16:37:44 +0200 Subject: [PATCH 01/16] Remove JoinGraph#removeShallowFrom --- .../iterative/rule/EliminateCrossJoins.java | 2 +- .../optimizations/joins/JoinGraph.java | 57 +------------------ .../rule/TestEliminateCrossJoins.java | 14 ++--- 3 files changed, 11 insertions(+), 62 deletions(-) diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/EliminateCrossJoins.java b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/EliminateCrossJoins.java index 029e010c04c6..d66fff941100 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/EliminateCrossJoins.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/EliminateCrossJoins.java @@ -73,7 +73,7 @@ public boolean isEnabled(Session session) @Override public Result apply(JoinNode node, Captures captures, Context context) { - JoinGraph joinGraph = JoinGraph.buildShallowFrom(node, context.getLookup()); + JoinGraph joinGraph = JoinGraph.buildFrom(node, context.getLookup()); if (joinGraph.size() < 3) { return Result.empty(); } diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/joins/JoinGraph.java b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/joins/JoinGraph.java index 9c28e03add2a..bf4e416a7563 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/joins/JoinGraph.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/joins/JoinGraph.java @@ -27,7 +27,6 @@ import io.prestosql.sql.planner.plan.ProjectNode; import io.prestosql.sql.tree.Expression; -import java.util.ArrayList; import java.util.Collection; import java.util.HashMap; import java.util.List; @@ -52,31 +51,12 @@ public class JoinGraph private final Multimap edges; private final PlanNodeId rootId; - /** - * Builds all (distinct) {@link JoinGraph}-es whole plan tree. - */ - public static List buildFrom(PlanNode plan) - { - return buildFrom(plan, Lookup.noLookup()); - } - /** * Builds {@link JoinGraph} containing {@code plan} node. */ - public static JoinGraph buildShallowFrom(PlanNode plan, Lookup lookup) + public static JoinGraph buildFrom(PlanNode plan, Lookup lookup) { - JoinGraph graph = plan.accept(new Builder(true, lookup), new Context()); - return graph; - } - - private static List buildFrom(PlanNode plan, Lookup lookup) - { - Context context = new Context(); - JoinGraph graph = plan.accept(new Builder(false, lookup), context); - if (graph.size() > 1) { - context.addSubGraph(graph); - } - return context.getGraphs(); + return plan.accept(new Builder(lookup), new Context()); } public JoinGraph(PlanNode node) @@ -127,11 +107,6 @@ public PlanNodeId getRootId() return rootId; } - public JoinGraph withRootId(PlanNodeId rootId) - { - return new JoinGraph(nodes, edges, rootId, filters, assignments); - } - public boolean isEmpty() { return nodes.isEmpty(); @@ -218,29 +193,16 @@ private JoinGraph joinWith(JoinGraph other, List joinCl private static class Builder extends PlanVisitor { - // TODO When io.prestosql.sql.planner.optimizations.EliminateCrossJoins is removed, remove 'shallow' flag - private final boolean shallow; private final Lookup lookup; - private Builder(boolean shallow, Lookup lookup) + private Builder(Lookup lookup) { - this.shallow = shallow; this.lookup = requireNonNull(lookup, "lookup cannot be null"); } @Override protected JoinGraph visitPlan(PlanNode node, Context context) { - if (!shallow) { - for (PlanNode child : node.getSources()) { - JoinGraph graph = child.accept(this, context); - if (graph.size() < 2) { - continue; - } - context.addSubGraph(graph.withRootId(child.getId())); - } - } - for (Symbol symbol : node.getOutputSymbols()) { context.setSymbolSource(symbol, node); } @@ -345,19 +307,11 @@ private static class Context { private final Map symbolSources = new HashMap<>(); - // TODO When io.prestosql.sql.planner.optimizations.EliminateCrossJoins is removed, remove 'joinGraphs' - private final List joinGraphs = new ArrayList<>(); - public void setSymbolSource(Symbol symbol, PlanNode node) { symbolSources.put(symbol, node); } - public void addSubGraph(JoinGraph graph) - { - joinGraphs.add(graph); - } - public boolean containsSymbol(Symbol symbol) { return symbolSources.containsKey(symbol); @@ -368,10 +322,5 @@ public PlanNode getSymbolSource(Symbol symbol) checkState(containsSymbol(symbol)); return symbolSources.get(symbol); } - - public List getGraphs() - { - return joinGraphs; - } } } diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestEliminateCrossJoins.java b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestEliminateCrossJoins.java index e60b095ac13c..a19d93c5dc71 100644 --- a/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestEliminateCrossJoins.java +++ b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestEliminateCrossJoins.java @@ -18,6 +18,7 @@ import io.prestosql.sql.planner.PlanNodeIdAllocator; import io.prestosql.sql.planner.Symbol; import io.prestosql.sql.planner.iterative.GroupReference; +import io.prestosql.sql.planner.iterative.Lookup; import io.prestosql.sql.planner.iterative.rule.test.BaseRuleTest; import io.prestosql.sql.planner.iterative.rule.test.PlanBuilder; import io.prestosql.sql.planner.optimizations.joins.JoinGraph; @@ -38,7 +39,6 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.ImmutableList.toImmutableList; -import static com.google.common.collect.Iterables.getOnlyElement; import static io.prestosql.SystemSessionProperties.JOIN_REORDERING_STRATEGY; import static io.prestosql.sql.planner.assertions.PlanMatchPattern.any; import static io.prestosql.sql.planner.assertions.PlanMatchPattern.join; @@ -115,7 +115,7 @@ public void testJoinOrder() symbol("a"), symbol("c"), symbol("c"), symbol("b")); - JoinGraph joinGraph = getOnlyElement(JoinGraph.buildFrom(plan)); + JoinGraph joinGraph = JoinGraph.buildFrom(plan, Lookup.noLookup()); assertEquals( getJoinOrder(joinGraph), @@ -145,7 +145,7 @@ public void testJoinOrderWithRealCrossJoin() PlanNode plan = joinNode(leftPlan, rightPlan); - JoinGraph joinGraph = getOnlyElement(JoinGraph.buildFrom(plan)); + JoinGraph joinGraph = JoinGraph.buildFrom(plan, Lookup.noLookup()); assertEquals( getJoinOrder(joinGraph), @@ -165,7 +165,7 @@ public void testJoinOrderWithMultipleEdgesBetweenNodes() symbol("c1"), symbol("b1"), symbol("c2"), symbol("b2")); - JoinGraph joinGraph = getOnlyElement(JoinGraph.buildFrom(plan)); + JoinGraph joinGraph = JoinGraph.buildFrom(plan, Lookup.noLookup()); assertEquals( getJoinOrder(joinGraph), @@ -184,7 +184,7 @@ public void testDonNotChangeOrderWithoutCrossJoin() values(symbol("c")), symbol("c"), symbol("b")); - JoinGraph joinGraph = getOnlyElement(JoinGraph.buildFrom(plan)); + JoinGraph joinGraph = JoinGraph.buildFrom(plan, Lookup.noLookup()); assertEquals( getJoinOrder(joinGraph), @@ -202,7 +202,7 @@ public void testDoNotReorderCrossJoins() values(symbol("c")), symbol("c"), symbol("b")); - JoinGraph joinGraph = getOnlyElement(JoinGraph.buildFrom(plan)); + JoinGraph joinGraph = JoinGraph.buildFrom(plan, Lookup.noLookup()); assertEquals( getJoinOrder(joinGraph), @@ -224,7 +224,7 @@ public void testGiveUpOnNonIdentityProjections() symbol("a2"), symbol("c"), symbol("c"), symbol("b")); - assertEquals(JoinGraph.buildFrom(plan).size(), 2); + assertEquals(JoinGraph.buildFrom(plan, Lookup.noLookup()).size(), 2); } private Function crossJoinAndJoin(JoinNode.Type secondJoinType) From 5afaf67c5d4f5ce85f3b0cc84ad7e7e62169ab33 Mon Sep 17 00:00:00 2001 From: Karol Sobczak Date: Fri, 23 Aug 2019 11:20:47 +0200 Subject: [PATCH 02/16] Don't run eliminate cross join optimizer when there are no cross joins --- .../iterative/rule/EliminateCrossJoins.java | 2 +- .../optimizations/joins/JoinGraph.java | 22 +++++++++++++------ .../TestEliminateCrossJoins.java | 18 +++++++++++++++ 3 files changed, 34 insertions(+), 8 deletions(-) diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/EliminateCrossJoins.java b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/EliminateCrossJoins.java index d66fff941100..8e5d44bdee43 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/EliminateCrossJoins.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/EliminateCrossJoins.java @@ -74,7 +74,7 @@ public boolean isEnabled(Session session) public Result apply(JoinNode node, Captures captures, Context context) { JoinGraph joinGraph = JoinGraph.buildFrom(node, context.getLookup()); - if (joinGraph.size() < 3) { + if (joinGraph.size() < 3 || !joinGraph.isContainsCrossJoin()) { return Result.empty(); } diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/joins/JoinGraph.java b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/joins/JoinGraph.java index bf4e416a7563..2f0160229930 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/joins/JoinGraph.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/joins/JoinGraph.java @@ -50,6 +50,7 @@ public class JoinGraph private final List nodes; // nodes in order of their appearance in tree plan (left, right, parent) private final Multimap edges; private final PlanNodeId rootId; + private final boolean containsCrossJoin; /** * Builds {@link JoinGraph} containing {@code plan} node. @@ -61,7 +62,7 @@ public static JoinGraph buildFrom(PlanNode plan, Lookup lookup) public JoinGraph(PlanNode node) { - this(ImmutableList.of(node), ImmutableMultimap.of(), node.getId(), ImmutableList.of(), Optional.empty()); + this(ImmutableList.of(node), ImmutableMultimap.of(), node.getId(), ImmutableList.of(), Optional.empty(), false); } public JoinGraph( @@ -69,18 +70,20 @@ public JoinGraph( Multimap edges, PlanNodeId rootId, List filters, - Optional> assignments) + Optional> assignments, + boolean containsCrossJoin) { this.nodes = nodes; this.edges = edges; this.rootId = rootId; this.filters = filters; this.assignments = assignments; + this.containsCrossJoin = containsCrossJoin; } public JoinGraph withAssignments(Map assignments) { - return new JoinGraph(nodes, edges, rootId, filters, Optional.of(assignments)); + return new JoinGraph(nodes, edges, rootId, filters, Optional.of(assignments), containsCrossJoin); } public Optional> getAssignments() @@ -94,7 +97,7 @@ public JoinGraph withFilter(Expression expression) filters.addAll(this.filters); filters.add(expression); - return new JoinGraph(nodes, edges, rootId, filters.build(), assignments); + return new JoinGraph(nodes, edges, rootId, filters.build(), assignments, containsCrossJoin); } public List getFilters() @@ -132,6 +135,11 @@ public Collection getEdges(PlanNode node) return ImmutableList.copyOf(edges.get(node.getId())); } + public boolean isContainsCrossJoin() + { + return containsCrossJoin; + } + @Override public String toString() { @@ -155,7 +163,7 @@ public String toString() return builder.toString(); } - private JoinGraph joinWith(JoinGraph other, List joinClauses, Context context, PlanNodeId newRoot) + private JoinGraph joinWith(JoinGraph other, List joinClauses, Context context, PlanNodeId newRoot, boolean containsCrossJoin) { for (PlanNode node : other.nodes) { checkState(!edges.containsKey(node.getId()), "Node [%s] appeared in two JoinGraphs", node); @@ -187,7 +195,7 @@ private JoinGraph joinWith(JoinGraph other, List joinCl edges.put(right.getId(), new Edge(left, rightSymbol, leftSymbol)); } - return new JoinGraph(nodes, edges.build(), newRoot, joinedFilters, Optional.empty()); + return new JoinGraph(nodes, edges.build(), newRoot, joinedFilters, Optional.empty(), this.containsCrossJoin || containsCrossJoin); } private static class Builder @@ -227,7 +235,7 @@ public JoinGraph visitJoin(JoinNode node, Context context) JoinGraph left = node.getLeft().accept(this, context); JoinGraph right = node.getRight().accept(this, context); - JoinGraph graph = left.joinWith(right, node.getCriteria(), context, node.getId()); + JoinGraph graph = left.joinWith(right, node.getCriteria(), context, node.getId(), node.isCrossJoin()); if (node.getFilter().isPresent()) { return graph.withFilter(node.getFilter().get()); diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/optimizations/TestEliminateCrossJoins.java b/presto-main/src/test/java/io/prestosql/sql/planner/optimizations/TestEliminateCrossJoins.java index e8f829299c34..0d7b1bbd486d 100644 --- a/presto-main/src/test/java/io/prestosql/sql/planner/optimizations/TestEliminateCrossJoins.java +++ b/presto-main/src/test/java/io/prestosql/sql/planner/optimizations/TestEliminateCrossJoins.java @@ -26,6 +26,7 @@ import static io.prestosql.sql.planner.assertions.PlanMatchPattern.equiJoinClause; import static io.prestosql.sql.planner.assertions.PlanMatchPattern.filter; import static io.prestosql.sql.planner.assertions.PlanMatchPattern.join; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.strictTableScan; import static io.prestosql.sql.planner.assertions.PlanMatchPattern.tableScan; import static io.prestosql.sql.planner.plan.JoinNode.Type.INNER; @@ -75,6 +76,23 @@ public void testEliminateSimpleCrossJoin() anyTree(ORDERS_TABLESCAN)))); } + @Test + public void testDoesNotReorderJoinsWhenNoCrossJoinPresent() + { + assertPlan("SELECT o1.orderkey, o2.custkey, o3.orderstatus, o4.totalprice " + + "FROM (orders o1 JOIN orders o2 ON o1.orderkey = o2.orderkey) " + + "JOIN (orders o3 JOIN orders o4 ON o3.orderkey = o4.orderkey) ON o1.orderkey = o3.orderkey", + anyTree( + join(INNER, ImmutableList.of(equiJoinClause("O1_ORDERKEY", "O3_ORDERKEY")), + join(INNER, ImmutableList.of(equiJoinClause("O1_ORDERKEY", "O2_ORDERKEY")), + anyTree(strictTableScan("orders", ImmutableMap.of("O1_ORDERKEY", "orderkey"))), + anyTree(strictTableScan("orders", ImmutableMap.of("O2_ORDERKEY", "orderkey", "O2_CUSTKEY", "custkey")))), + anyTree( + join(INNER, ImmutableList.of(equiJoinClause("O3_ORDERKEY", "O4_ORDERKEY")), + anyTree(strictTableScan("orders", ImmutableMap.of("O3_ORDERKEY", "orderkey", "O3_ORDERSTATUS", "orderstatus"))), + anyTree(strictTableScan("orders", ImmutableMap.of("O4_ORDERKEY", "orderkey", "O4_totalprice", "totalprice")))))))); + } + @Test public void testGiveUpOnCrossJoin() { From 2633ae2fcc329fa309486e0f1ca9e74320423f46 Mon Sep 17 00:00:00 2001 From: Karol Sobczak Date: Tue, 27 Aug 2019 12:29:30 +0200 Subject: [PATCH 03/16] Use Set instead of Sets.SetView --- .../sql/planner/iterative/rule/InlineProjections.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/InlineProjections.java b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/InlineProjections.java index 5ae6bad95514..6a3291c511ff 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/InlineProjections.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/InlineProjections.java @@ -64,7 +64,7 @@ public Result apply(ProjectNode parent, Captures captures, Context context) { ProjectNode child = captures.get(CHILD); - Sets.SetView targets = extractInliningTargets(parent, child); + Set targets = extractInliningTargets(parent, child); if (targets.isEmpty()) { return Result.empty(); } @@ -122,7 +122,7 @@ private Expression inlineReferences(Expression expression, Assignments assignmen return inlineSymbols(mapping, expression); } - private Sets.SetView extractInliningTargets(ProjectNode parent, ProjectNode child) + private Set extractInliningTargets(ProjectNode parent, ProjectNode child) { // candidates for inlining are // 1. references to simple constants From 894983d50ced2134e8eccc7856f174d18bdff1e9 Mon Sep 17 00:00:00 2001 From: Karol Sobczak Date: Tue, 27 Aug 2019 12:41:24 +0200 Subject: [PATCH 04/16] Eliminate fully inlined child projection --- .../iterative/rule/InlineProjections.java | 26 ++++++++++++------- .../iterative/rule/TestInlineProjections.java | 21 +++++++++++++++ 2 files changed, 38 insertions(+), 9 deletions(-) diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/InlineProjections.java b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/InlineProjections.java index 6a3291c511ff..936ab319e18d 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/InlineProjections.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/InlineProjections.java @@ -22,6 +22,7 @@ import io.prestosql.sql.planner.SymbolsExtractor; import io.prestosql.sql.planner.iterative.Rule; import io.prestosql.sql.planner.plan.Assignments; +import io.prestosql.sql.planner.plan.PlanNode; import io.prestosql.sql.planner.plan.ProjectNode; import io.prestosql.sql.tree.Expression; import io.prestosql.sql.tree.Literal; @@ -79,8 +80,6 @@ public Result apply(ProjectNode parent, Captures captures, Context context) // Synthesize identity assignments for the inputs of expressions that were inlined // to place in the child projection. - // If all assignments end up becoming identity assignments, they'll get pruned by - // other rules Set inputs = child.getAssignments() .entrySet().stream() .filter(entry -> targets.contains(entry.getKey())) @@ -88,23 +87,32 @@ public Result apply(ProjectNode parent, Captures captures, Context context) .flatMap(entry -> SymbolsExtractor.extractAll(entry).stream()) .collect(toSet()); - Assignments.Builder childAssignments = Assignments.builder(); + Assignments.Builder newChildAssignmentsBuilder = Assignments.builder(); for (Map.Entry assignment : child.getAssignments().entrySet()) { if (!targets.contains(assignment.getKey())) { - childAssignments.put(assignment); + newChildAssignmentsBuilder.put(assignment); } } for (Symbol input : inputs) { - childAssignments.putIdentity(input); + newChildAssignmentsBuilder.putIdentity(input); + } + + Assignments newChildAssignments = newChildAssignmentsBuilder.build(); + PlanNode newChild; + if (newChildAssignments.isIdentity()) { + newChild = child.getSource(); + } + else { + newChild = new ProjectNode( + child.getId(), + child.getSource(), + newChildAssignments); } return Result.ofPlanNode( new ProjectNode( parent.getId(), - new ProjectNode( - child.getId(), - child.getSource(), - childAssignments.build()), + newChild, Assignments.copyOf(parentAssignments))); } diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestInlineProjections.java b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestInlineProjections.java index b04ce1cc3bda..b185e81ea8ed 100644 --- a/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestInlineProjections.java +++ b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestInlineProjections.java @@ -67,6 +67,27 @@ public void test() values(ImmutableMap.of("x", 0))))); } + @Test + public void testEliminatesIdentityProjection() + { + tester().assertThat(new InlineProjections()) + .on(p -> + p.project( + Assignments.builder() + .put(p.symbol("single_complex"), expression("complex + 2")) // complex expression referenced only once + .build(), + p.project(Assignments.builder() + .put(p.symbol("complex"), expression("x - 1")) + .build(), + p.values(p.symbol("x"))))) + .matches( + project( + ImmutableMap.builder() + .put("out1", PlanMatchPattern.expression("x - 1 + 2")) + .build(), + values(ImmutableMap.of("x", 0)))); + } + @Test public void testIdentityProjections() { From 25c706ca45a516ba784a9b85ff73fc0fba86ee90 Mon Sep 17 00:00:00 2001 From: Karol Sobczak Date: Fri, 23 Aug 2019 15:30:03 +0200 Subject: [PATCH 05/16] Remove noop TestEliminateCrossJoins#symbol method --- .../rule/TestEliminateCrossJoins.java | 76 +++++++++---------- 1 file changed, 35 insertions(+), 41 deletions(-) diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestEliminateCrossJoins.java b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestEliminateCrossJoins.java index a19d93c5dc71..9f76f1c9733c 100644 --- a/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestEliminateCrossJoins.java +++ b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestEliminateCrossJoins.java @@ -109,11 +109,11 @@ public void testJoinOrder() PlanNode plan = joinNode( joinNode( - values(symbol("a")), - values(symbol("b"))), - values(symbol("c")), - symbol("a"), symbol("c"), - symbol("c"), symbol("b")); + values("a"), + values("b")), + values("c"), + "a", "c", + "c", "b"); JoinGraph joinGraph = JoinGraph.buildFrom(plan, Lookup.noLookup()); @@ -128,20 +128,20 @@ public void testJoinOrderWithRealCrossJoin() PlanNode leftPlan = joinNode( joinNode( - values(symbol("a")), - values(symbol("b"))), - values(symbol("c")), - symbol("a"), symbol("c"), - symbol("c"), symbol("b")); + values("a"), + values("b")), + values("c"), + "a", "c", + "c", "b"); PlanNode rightPlan = joinNode( joinNode( - values(symbol("x")), - values(symbol("y"))), - values(symbol("z")), - symbol("x"), symbol("z"), - symbol("z"), symbol("y")); + values("x"), + values("y")), + values("z"), + "x", "z", + "z", "y"); PlanNode plan = joinNode(leftPlan, rightPlan); @@ -158,12 +158,12 @@ public void testJoinOrderWithMultipleEdgesBetweenNodes() PlanNode plan = joinNode( joinNode( - values(symbol("a")), - values(symbol("b1"), symbol("b2"))), - values(symbol("c1"), symbol("c2")), - symbol("a"), symbol("c1"), - symbol("c1"), symbol("b1"), - symbol("c2"), symbol("b2")); + values("a"), + values("b1", "b2")), + values("c1", "c2"), + "a", "c1", + "c1", "b1", + "c2", "b2"); JoinGraph joinGraph = JoinGraph.buildFrom(plan, Lookup.noLookup()); @@ -178,11 +178,11 @@ public void testDonNotChangeOrderWithoutCrossJoin() PlanNode plan = joinNode( joinNode( - values(symbol("a")), - values(symbol("b")), - symbol("a"), symbol("b")), - values(symbol("c")), - symbol("c"), symbol("b")); + values("a"), + values("b"), + "a", "b"), + values("c"), + "c", "b"); JoinGraph joinGraph = JoinGraph.buildFrom(plan, Lookup.noLookup()); @@ -197,10 +197,10 @@ public void testDoNotReorderCrossJoins() PlanNode plan = joinNode( joinNode( - values(symbol("a")), - values(symbol("b"))), - values(symbol("c")), - symbol("c"), symbol("b")); + values("a"), + values("b")), + values("c"), + "c", "b"); JoinGraph joinGraph = JoinGraph.buildFrom(plan, Lookup.noLookup()); @@ -216,13 +216,12 @@ public void testGiveUpOnNonIdentityProjections() joinNode( projectNode( joinNode( - values(symbol("a1")), - values(symbol("b"))), - symbol("a2"), + values("a1"), + values("b")), + "a2", new ArithmeticUnaryExpression(MINUS, new SymbolReference("a1"))), - values(symbol("c")), - symbol("a2"), symbol("c"), - symbol("c"), symbol("b")); + values("c"), + "a2", "c"); assertEquals(JoinGraph.buildFrom(plan, Lookup.noLookup()).size(), 2); } @@ -254,11 +253,6 @@ private PlanNode projectNode(PlanNode source, String symbol, Expression expressi Assignments.of(new Symbol(symbol), expression)); } - private String symbol(String name) - { - return name; - } - private JoinNode joinNode(PlanNode left, PlanNode right, String... symbols) { checkArgument(symbols.length % 2 == 0); From 1b7e1447af0a6461e96306c558629c015c04f59a Mon Sep 17 00:00:00 2001 From: Karol Sobczak Date: Fri, 23 Aug 2019 13:07:05 +0200 Subject: [PATCH 06/16] Fix typo --- .../sql/planner/iterative/rule/TestEliminateCrossJoins.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestEliminateCrossJoins.java b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestEliminateCrossJoins.java index 9f76f1c9733c..6ce101b7ed82 100644 --- a/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestEliminateCrossJoins.java +++ b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestEliminateCrossJoins.java @@ -173,7 +173,7 @@ public void testJoinOrderWithMultipleEdgesBetweenNodes() } @Test - public void testDonNotChangeOrderWithoutCrossJoin() + public void testDoesNotChangeOrderWithoutCrossJoin() { PlanNode plan = joinNode( From b4738c8fd30ce75728a4214bfbfbcf4ac877d2b0 Mon Sep 17 00:00:00 2001 From: Karol Sobczak Date: Fri, 23 Aug 2019 15:26:12 +0200 Subject: [PATCH 07/16] Add support for projections in eliminate cross joins --- .../iterative/rule/EliminateCrossJoins.java | 2 +- .../iterative/rule/InlineProjections.java | 18 ++- .../rule/PushProjectionThroughJoin.java | 142 ++++++++++++++++++ .../optimizations/joins/JoinGraph.java | 16 +- .../rule/TestEliminateCrossJoins.java | 95 ++++++++++-- .../rule/TestPushProjectionThroughJoin.java | 127 ++++++++++++++++ 6 files changed, 379 insertions(+), 21 deletions(-) create mode 100644 presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushProjectionThroughJoin.java create mode 100644 presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestPushProjectionThroughJoin.java diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/EliminateCrossJoins.java b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/EliminateCrossJoins.java index 8e5d44bdee43..9ff4c06c63da 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/EliminateCrossJoins.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/EliminateCrossJoins.java @@ -73,7 +73,7 @@ public boolean isEnabled(Session session) @Override public Result apply(JoinNode node, Captures captures, Context context) { - JoinGraph joinGraph = JoinGraph.buildFrom(node, context.getLookup()); + JoinGraph joinGraph = JoinGraph.buildFrom(node, context.getLookup(), context.getIdAllocator()); if (joinGraph.size() < 3 || !joinGraph.isContainsCrossJoin()) { return Result.empty(); } diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/InlineProjections.java b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/InlineProjections.java index 936ab319e18d..e8afb662dd34 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/InlineProjections.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/InlineProjections.java @@ -30,6 +30,7 @@ import io.prestosql.sql.util.AstUtils; import java.util.Map; +import java.util.Optional; import java.util.Set; import java.util.function.Function; import java.util.stream.Collectors; @@ -65,9 +66,16 @@ public Result apply(ProjectNode parent, Captures captures, Context context) { ProjectNode child = captures.get(CHILD); + return inlineProjections(parent, child) + .map(Result::ofPlanNode) + .orElse(Result.empty()); + } + + static Optional inlineProjections(ProjectNode parent, ProjectNode child) + { Set targets = extractInliningTargets(parent, child); if (targets.isEmpty()) { - return Result.empty(); + return Optional.empty(); } // inline the expressions @@ -109,14 +117,14 @@ public Result apply(ProjectNode parent, Captures captures, Context context) newChildAssignments); } - return Result.ofPlanNode( + return Optional.of( new ProjectNode( parent.getId(), newChild, Assignments.copyOf(parentAssignments))); } - private Expression inlineReferences(Expression expression, Assignments assignments) + private static Expression inlineReferences(Expression expression, Assignments assignments) { Function mapping = symbol -> { Expression result = assignments.get(symbol); @@ -130,7 +138,7 @@ private Expression inlineReferences(Expression expression, Assignments assignmen return inlineSymbols(mapping, expression); } - private Set extractInliningTargets(ProjectNode parent, ProjectNode child) + private static Set extractInliningTargets(ProjectNode parent, ProjectNode child) { // candidates for inlining are // 1. references to simple constants @@ -170,7 +178,7 @@ private Set extractInliningTargets(ProjectNode parent, ProjectNode child return Sets.union(singletons, constants); } - private Set extractTryArguments(Expression expression) + private static Set extractTryArguments(Expression expression) { return AstUtils.preOrder(expression) .filter(TryExpression.class::isInstance) diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushProjectionThroughJoin.java b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushProjectionThroughJoin.java new file mode 100644 index 000000000000..447335187f78 --- /dev/null +++ b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushProjectionThroughJoin.java @@ -0,0 +1,142 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.sql.planner.iterative.rule; + +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Streams; +import io.prestosql.sql.planner.DeterminismEvaluator; +import io.prestosql.sql.planner.PlanNodeIdAllocator; +import io.prestosql.sql.planner.Symbol; +import io.prestosql.sql.planner.SymbolsExtractor; +import io.prestosql.sql.planner.iterative.Lookup; +import io.prestosql.sql.planner.plan.Assignments; +import io.prestosql.sql.planner.plan.JoinNode; +import io.prestosql.sql.planner.plan.PlanNode; +import io.prestosql.sql.planner.plan.ProjectNode; +import io.prestosql.sql.tree.Expression; + +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; + +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static io.prestosql.sql.planner.SymbolsExtractor.extractUnique; + +/** + * Utility class for pushing simple projections through join so that joins can participate + * in cross join elimination or join reordering. + */ +public final class PushProjectionThroughJoin +{ + public static Optional pushProjectionThroughJoin(ProjectNode projectNode, Lookup lookup, PlanNodeIdAllocator planNodeIdAllocator) + { + if (!projectNode.getAssignments().getExpressions().stream().allMatch(DeterminismEvaluator::isDeterministic)) { + return Optional.empty(); + } + + PlanNode child = lookup.resolve(projectNode.getSource()); + if (!(child instanceof JoinNode)) { + return Optional.empty(); + } + + JoinNode joinNode = (JoinNode) child; + PlanNode leftChild = joinNode.getLeft(); + PlanNode rightChild = joinNode.getRight(); + + Assignments.Builder leftAssignmentsBuilder = Assignments.builder(); + Assignments.Builder rightAssignmentsBuilder = Assignments.builder(); + for (Map.Entry assignment : projectNode.getAssignments().entrySet()) { + Expression expression = assignment.getValue(); + Set symbols = extractUnique(expression); + if (leftChild.getOutputSymbols().containsAll(symbols)) { + // expression is satisfied with left child symbols + leftAssignmentsBuilder.put(assignment.getKey(), expression); + } + else if (rightChild.getOutputSymbols().containsAll(symbols)) { + // expression is satisfied with right child symbols + rightAssignmentsBuilder.put(assignment.getKey(), expression); + } + else { + // expression is using symbols from both join sides + return Optional.empty(); + } + } + + // add projections for symbols required by the join itself + Set joinRequiredSymbols = getJoinRequiredSymbols(joinNode); + for (Symbol requiredSymbol : joinRequiredSymbols) { + if (leftChild.getOutputSymbols().contains(requiredSymbol)) { + leftAssignmentsBuilder.putIdentity(requiredSymbol); + } + else { + checkState(rightChild.getOutputSymbols().contains(requiredSymbol)); + rightAssignmentsBuilder.putIdentity(requiredSymbol); + } + } + + Assignments leftAssignments = leftAssignmentsBuilder.build(); + Assignments rightAssignments = rightAssignmentsBuilder.build(); + List outputSymbols = Streams.concat(leftAssignments.getOutputs().stream(), rightAssignments.getOutputs().stream()) + .filter(ImmutableSet.copyOf(projectNode.getOutputSymbols())::contains) + .collect(toImmutableList()); + + return Optional.of(new JoinNode( + joinNode.getId(), + joinNode.getType(), + inlineProjections( + new ProjectNode(planNodeIdAllocator.getNextId(), leftChild, leftAssignments), + lookup), + inlineProjections( + new ProjectNode(planNodeIdAllocator.getNextId(), rightChild, rightAssignments), + lookup), + joinNode.getCriteria(), + outputSymbols, + joinNode.getFilter(), + joinNode.getLeftHashSymbol(), + joinNode.getRightHashSymbol(), + joinNode.getDistributionType(), + joinNode.isSpillable(), + joinNode.getDynamicFilters())); + } + + private static PlanNode inlineProjections(ProjectNode parentProjection, Lookup lookup) + { + PlanNode child = lookup.resolve(parentProjection.getSource()); + if (!(child instanceof ProjectNode)) { + return parentProjection; + } + ProjectNode childProjection = (ProjectNode) child; + + return InlineProjections.inlineProjections(parentProjection, childProjection) + .map(node -> inlineProjections(node, lookup)) + .orElse(parentProjection); + } + + private static Set getJoinRequiredSymbols(JoinNode node) + { + // extract symbols required by the join itself + return Streams.concat( + node.getCriteria().stream().map(JoinNode.EquiJoinClause::getLeft), + node.getCriteria().stream().map(JoinNode.EquiJoinClause::getRight), + node.getFilter().map(SymbolsExtractor::extractUnique).orElse(ImmutableSet.of()).stream(), + node.getLeftHashSymbol().map(ImmutableSet::of).orElse(ImmutableSet.of()).stream(), + node.getRightHashSymbol().map(ImmutableSet::of).orElse(ImmutableSet.of()).stream()) + .collect(toImmutableSet()); + } + + private PushProjectionThroughJoin() {} +} diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/joins/JoinGraph.java b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/joins/JoinGraph.java index 2f0160229930..e46132975f41 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/joins/JoinGraph.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/joins/JoinGraph.java @@ -16,6 +16,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMultimap; import com.google.common.collect.Multimap; +import io.prestosql.sql.planner.PlanNodeIdAllocator; import io.prestosql.sql.planner.Symbol; import io.prestosql.sql.planner.iterative.GroupReference; import io.prestosql.sql.planner.iterative.Lookup; @@ -35,6 +36,7 @@ import static com.google.common.base.Preconditions.checkState; import static com.google.common.collect.ImmutableList.toImmutableList; +import static io.prestosql.sql.planner.iterative.rule.PushProjectionThroughJoin.pushProjectionThroughJoin; import static io.prestosql.sql.planner.plan.JoinNode.Type.INNER; import static java.util.Objects.requireNonNull; @@ -55,9 +57,9 @@ public class JoinGraph /** * Builds {@link JoinGraph} containing {@code plan} node. */ - public static JoinGraph buildFrom(PlanNode plan, Lookup lookup) + public static JoinGraph buildFrom(PlanNode plan, Lookup lookup, PlanNodeIdAllocator planNodeIdAllocator) { - return plan.accept(new Builder(lookup), new Context()); + return plan.accept(new Builder(lookup, planNodeIdAllocator), new Context()); } public JoinGraph(PlanNode node) @@ -202,10 +204,12 @@ private static class Builder extends PlanVisitor { private final Lookup lookup; + private final PlanNodeIdAllocator planNodeIdAllocator; - private Builder(Lookup lookup) + private Builder(Lookup lookup, PlanNodeIdAllocator planNodeIdAllocator) { this.lookup = requireNonNull(lookup, "lookup cannot be null"); + this.planNodeIdAllocator = requireNonNull(planNodeIdAllocator, "planNodeIdAllocator is null"); } @Override @@ -250,6 +254,12 @@ public JoinGraph visitProject(ProjectNode node, Context context) JoinGraph graph = node.getSource().accept(this, context); return graph.withAssignments(node.getAssignments().getMap()); } + + Optional rewrittenNode = pushProjectionThroughJoin(node, lookup, planNodeIdAllocator); + if (rewrittenNode.isPresent()) { + return rewrittenNode.get().accept(this, context); + } + return visitPlan(node, context); } diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestEliminateCrossJoins.java b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestEliminateCrossJoins.java index 6ce101b7ed82..729ce0c6d46f 100644 --- a/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestEliminateCrossJoins.java +++ b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestEliminateCrossJoins.java @@ -17,8 +17,8 @@ import com.google.common.collect.ImmutableMap; import io.prestosql.sql.planner.PlanNodeIdAllocator; import io.prestosql.sql.planner.Symbol; +import io.prestosql.sql.planner.assertions.PlanMatchPattern; import io.prestosql.sql.planner.iterative.GroupReference; -import io.prestosql.sql.planner.iterative.Lookup; import io.prestosql.sql.planner.iterative.rule.test.BaseRuleTest; import io.prestosql.sql.planner.iterative.rule.test.PlanBuilder; import io.prestosql.sql.planner.optimizations.joins.JoinGraph; @@ -28,6 +28,7 @@ import io.prestosql.sql.planner.plan.PlanNode; import io.prestosql.sql.planner.plan.ProjectNode; import io.prestosql.sql.planner.plan.ValuesNode; +import io.prestosql.sql.tree.ArithmeticBinaryExpression; import io.prestosql.sql.tree.ArithmeticUnaryExpression; import io.prestosql.sql.tree.Expression; import io.prestosql.sql.tree.SymbolReference; @@ -41,11 +42,15 @@ import static com.google.common.collect.ImmutableList.toImmutableList; import static io.prestosql.SystemSessionProperties.JOIN_REORDERING_STRATEGY; import static io.prestosql.sql.planner.assertions.PlanMatchPattern.any; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.expression; import static io.prestosql.sql.planner.assertions.PlanMatchPattern.join; import static io.prestosql.sql.planner.assertions.PlanMatchPattern.node; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.strictProject; +import static io.prestosql.sql.planner.iterative.Lookup.noLookup; import static io.prestosql.sql.planner.iterative.rule.EliminateCrossJoins.getJoinOrder; import static io.prestosql.sql.planner.iterative.rule.EliminateCrossJoins.isOriginalOrder; import static io.prestosql.sql.planner.plan.JoinNode.Type.INNER; +import static io.prestosql.sql.tree.ArithmeticBinaryExpression.Operator.ADD; import static io.prestosql.sql.tree.ArithmeticUnaryExpression.Sign.MINUS; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertFalse; @@ -115,7 +120,7 @@ public void testJoinOrder() "a", "c", "c", "b"); - JoinGraph joinGraph = JoinGraph.buildFrom(plan, Lookup.noLookup()); + JoinGraph joinGraph = JoinGraph.buildFrom(plan, noLookup(), new PlanNodeIdAllocator()); assertEquals( getJoinOrder(joinGraph), @@ -145,7 +150,7 @@ public void testJoinOrderWithRealCrossJoin() PlanNode plan = joinNode(leftPlan, rightPlan); - JoinGraph joinGraph = JoinGraph.buildFrom(plan, Lookup.noLookup()); + JoinGraph joinGraph = JoinGraph.buildFrom(plan, noLookup(), new PlanNodeIdAllocator()); assertEquals( getJoinOrder(joinGraph), @@ -165,7 +170,7 @@ public void testJoinOrderWithMultipleEdgesBetweenNodes() "c1", "b1", "c2", "b2"); - JoinGraph joinGraph = JoinGraph.buildFrom(plan, Lookup.noLookup()); + JoinGraph joinGraph = JoinGraph.buildFrom(plan, noLookup(), new PlanNodeIdAllocator()); assertEquals( getJoinOrder(joinGraph), @@ -184,7 +189,7 @@ public void testDoesNotChangeOrderWithoutCrossJoin() values("c"), "c", "b"); - JoinGraph joinGraph = JoinGraph.buildFrom(plan, Lookup.noLookup()); + JoinGraph joinGraph = JoinGraph.buildFrom(plan, noLookup(), new PlanNodeIdAllocator()); assertEquals( getJoinOrder(joinGraph), @@ -202,7 +207,7 @@ public void testDoNotReorderCrossJoins() values("c"), "c", "b"); - JoinGraph joinGraph = JoinGraph.buildFrom(plan, Lookup.noLookup()); + JoinGraph joinGraph = JoinGraph.buildFrom(plan, noLookup(), new PlanNodeIdAllocator()); assertEquals( getJoinOrder(joinGraph), @@ -210,7 +215,68 @@ public void testDoNotReorderCrossJoins() } @Test - public void testGiveUpOnNonIdentityProjections() + public void testEliminateCrossJoinWithNonIdentityProjections() + { + tester().assertThat(new EliminateCrossJoins()) + .setSystemProperty(JOIN_REORDERING_STRATEGY, "ELIMINATE_CROSS_JOINS") + .on(p -> { + Symbol a1 = p.symbol("a1"); + Symbol a2 = p.symbol("a2"); + Symbol b = p.symbol("b"); + Symbol c = p.symbol("c"); + Symbol d = p.symbol("d"); + Symbol e = p.symbol("e"); + Symbol f = p.symbol("f"); + + return p.join( + INNER, + p.project( + Assignments.of( + a2, new ArithmeticUnaryExpression(MINUS, new SymbolReference("a1")), + f, new SymbolReference("f")), + p.join( + INNER, + p.project( + Assignments.of( + a1, new SymbolReference("a1"), + f, new ArithmeticUnaryExpression(MINUS, new SymbolReference("b"))), + p.join( + INNER, + p.values(a1), + p.values(b))), + p.values(e), + new EquiJoinClause(a1, e))), + p.values(c, d), + new EquiJoinClause(a2, c), + new EquiJoinClause(f, d)); + }) + .matches( + node(ProjectNode.class, + join( + INNER, + ImmutableList.of(aliases -> new EquiJoinClause(new Symbol("d"), new Symbol("f"))), + join( + INNER, + ImmutableList.of(aliases -> new EquiJoinClause(new Symbol("a2"), new Symbol("c"))), + join(INNER, + ImmutableList.of(aliases -> new EquiJoinClause(new Symbol("a1"), new Symbol("e"))), + strictProject( + ImmutableMap.of( + "a2", expression("-a1"), + "a1", expression("a1")), + PlanMatchPattern.values("a1")), + strictProject( + ImmutableMap.of( + "e", expression("e")), + PlanMatchPattern.values("e"))), + any()), + strictProject( + ImmutableMap.of("f", expression("-b")), + PlanMatchPattern.values("b"))))); + } + + @Test + public void testGiveUpOnComplexProjections() { PlanNode plan = joinNode( @@ -219,11 +285,14 @@ public void testGiveUpOnNonIdentityProjections() values("a1"), values("b")), "a2", - new ArithmeticUnaryExpression(MINUS, new SymbolReference("a1"))), + new ArithmeticBinaryExpression(ADD, new SymbolReference("a1"), new SymbolReference("b")), + "b", + new SymbolReference("b")), values("c"), - "a2", "c"); + "a2", "c", + "c", "b"); - assertEquals(JoinGraph.buildFrom(plan, Lookup.noLookup()).size(), 2); + assertEquals(JoinGraph.buildFrom(plan, noLookup(), new PlanNodeIdAllocator()).size(), 2); } private Function crossJoinAndJoin(JoinNode.Type secondJoinType) @@ -245,12 +314,14 @@ private Function crossJoinAndJoin(JoinNode.Type secondJoi }; } - private PlanNode projectNode(PlanNode source, String symbol, Expression expression) + private PlanNode projectNode(PlanNode source, String symbol1, Expression expression1, String symbol2, Expression expression2) { return new ProjectNode( idAllocator.getNextId(), source, - Assignments.of(new Symbol(symbol), expression)); + Assignments.of( + new Symbol(symbol1), expression1, + new Symbol(symbol2), expression2)); } private JoinNode joinNode(PlanNode left, PlanNode right, String... symbols) diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestPushProjectionThroughJoin.java b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestPushProjectionThroughJoin.java new file mode 100644 index 000000000000..3902eb35105a --- /dev/null +++ b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestPushProjectionThroughJoin.java @@ -0,0 +1,127 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.sql.planner.iterative.rule; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.prestosql.sql.planner.Plan; +import io.prestosql.sql.planner.PlanNodeIdAllocator; +import io.prestosql.sql.planner.Symbol; +import io.prestosql.sql.planner.assertions.PlanMatchPattern; +import io.prestosql.sql.planner.iterative.rule.test.PlanBuilder; +import io.prestosql.sql.planner.plan.Assignments; +import io.prestosql.sql.planner.plan.JoinNode; +import io.prestosql.sql.planner.plan.PlanNode; +import io.prestosql.sql.planner.plan.ProjectNode; +import io.prestosql.sql.tree.ArithmeticBinaryExpression; +import io.prestosql.sql.tree.ArithmeticUnaryExpression; +import org.testng.annotations.Test; + +import java.util.Optional; + +import static io.prestosql.cost.PlanNodeStatsEstimate.unknown; +import static io.prestosql.cost.StatsAndCosts.empty; +import static io.prestosql.metadata.AbstractMockMetadata.dummyMetadata; +import static io.prestosql.sql.planner.assertions.PlanAssert.assertPlan; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.expression; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.join; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.strictProject; +import static io.prestosql.sql.planner.iterative.Lookup.noLookup; +import static io.prestosql.sql.planner.iterative.rule.PushProjectionThroughJoin.pushProjectionThroughJoin; +import static io.prestosql.sql.planner.plan.JoinNode.Type.INNER; +import static io.prestosql.sql.tree.ArithmeticBinaryExpression.Operator.ADD; +import static io.prestosql.sql.tree.ArithmeticUnaryExpression.Sign.MINUS; +import static io.prestosql.sql.tree.ArithmeticUnaryExpression.Sign.PLUS; +import static io.prestosql.testing.TestingSession.testSessionBuilder; +import static org.assertj.core.api.Assertions.assertThat; +import static org.testng.Assert.assertTrue; + +public class TestPushProjectionThroughJoin +{ + @Test + public void testPushesProjectionThroughJoin() + { + PlanNodeIdAllocator idAllocator = new PlanNodeIdAllocator(); + PlanBuilder p = new PlanBuilder(idAllocator, dummyMetadata()); + Symbol a0 = p.symbol("a0"); + Symbol a1 = p.symbol("a1"); + Symbol a2 = p.symbol("a2"); + Symbol a3 = p.symbol("a3"); + Symbol b0 = p.symbol("b0"); + Symbol b1 = p.symbol("b1"); + Symbol b2 = p.symbol("b2"); + + ProjectNode planNode = p.project( + Assignments.of( + a3, new ArithmeticUnaryExpression(MINUS, a2.toSymbolReference()), + b2, new ArithmeticUnaryExpression(PLUS, b1.toSymbolReference())), + p.join( + INNER, + // intermediate non-identity projections should be fully inlined + p.project( + Assignments.of( + a2, new ArithmeticUnaryExpression(PLUS, a0.toSymbolReference()), + a1, a1.toSymbolReference()), + p.project( + Assignments.builder() + .putIdentity(a0) + .putIdentity(a1) + .build(), + p.values(a0, a1))), + p.values(b0, b1), + new JoinNode.EquiJoinClause(a1, b1))); + + Optional rewritten = pushProjectionThroughJoin(planNode, noLookup(), idAllocator); + assertTrue(rewritten.isPresent()); + assertPlan( + testSessionBuilder().build(), + dummyMetadata(), + node -> unknown(), + new Plan(rewritten.get(), p.getTypes(), empty()), noLookup(), + join( + INNER, + ImmutableList.of(aliases -> new JoinNode.EquiJoinClause(new Symbol("a1"), new Symbol("b1"))), + strictProject(ImmutableMap.of( + "a3", expression("-(+a0)"), + "a1", expression("a1")), + strictProject(ImmutableMap.of( + "a0", expression("a0"), + "a1", expression("a1")), + PlanMatchPattern.values("a0", "a1"))), + strictProject(ImmutableMap.of( + "b2", expression("+b1"), + "b1", expression("b1")), + PlanMatchPattern.values("b0", "b1"))) + .withExactOutputs("a3", "b2")); + } + + @Test + public void testDoesNotPushStraddlingProjection() + { + PlanBuilder p = new PlanBuilder(new PlanNodeIdAllocator(), dummyMetadata()); + Symbol a = p.symbol("a"); + Symbol b = p.symbol("b"); + Symbol c = p.symbol("c"); + + ProjectNode planNode = p.project( + Assignments.of( + c, new ArithmeticBinaryExpression(ADD, a.toSymbolReference(), b.toSymbolReference())), + p.join( + INNER, + p.values(a), + p.values(b))); + Optional rewritten = pushProjectionThroughJoin(planNode, noLookup(), new PlanNodeIdAllocator()); + assertThat(rewritten).isEmpty(); + } +} From df84db4f31ae0ac880bdfc397fdbd4a087ce17ab Mon Sep 17 00:00:00 2001 From: Karol Sobczak Date: Fri, 23 Aug 2019 16:26:26 +0200 Subject: [PATCH 08/16] Remove unneeded JoinGraph#assignments --- .../iterative/rule/EliminateCrossJoins.java | 9 ------- .../optimizations/joins/JoinGraph.java | 26 +++---------------- 2 files changed, 4 insertions(+), 31 deletions(-) diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/EliminateCrossJoins.java b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/EliminateCrossJoins.java index 9ff4c06c63da..813acc41c147 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/EliminateCrossJoins.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/EliminateCrossJoins.java @@ -24,12 +24,10 @@ import io.prestosql.sql.planner.Symbol; import io.prestosql.sql.planner.iterative.Rule; import io.prestosql.sql.planner.optimizations.joins.JoinGraph; -import io.prestosql.sql.planner.plan.Assignments; import io.prestosql.sql.planner.plan.FilterNode; import io.prestosql.sql.planner.plan.JoinNode; import io.prestosql.sql.planner.plan.PlanNode; import io.prestosql.sql.planner.plan.PlanNodeId; -import io.prestosql.sql.planner.plan.ProjectNode; import io.prestosql.sql.tree.Expression; import java.util.HashMap; @@ -201,13 +199,6 @@ public static PlanNode buildJoinTree(List expectedOutputSymbols, JoinGra filter); } - if (graph.getAssignments().isPresent()) { - result = new ProjectNode( - idAllocator.getNextId(), - result, - Assignments.copyOf(graph.getAssignments().get())); - } - // If needed, introduce a projection to constrain the outputs to what was originally expected // Some nodes are sensitive to what's produced (e.g., DistinctLimit node) return restrictOutputs(idAllocator, result, ImmutableSet.copyOf(expectedOutputSymbols)).orElse(result); diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/joins/JoinGraph.java b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/joins/JoinGraph.java index e46132975f41..c33831cb6914 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/joins/JoinGraph.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/optimizations/joins/JoinGraph.java @@ -47,7 +47,6 @@ */ public class JoinGraph { - private final Optional> assignments; private final List filters; private final List nodes; // nodes in order of their appearance in tree plan (left, right, parent) private final Multimap edges; @@ -64,7 +63,7 @@ public static JoinGraph buildFrom(PlanNode plan, Lookup lookup, PlanNodeIdAlloca public JoinGraph(PlanNode node) { - this(ImmutableList.of(node), ImmutableMultimap.of(), node.getId(), ImmutableList.of(), Optional.empty(), false); + this(ImmutableList.of(node), ImmutableMultimap.of(), node.getId(), ImmutableList.of(), false); } public JoinGraph( @@ -72,34 +71,22 @@ public JoinGraph( Multimap edges, PlanNodeId rootId, List filters, - Optional> assignments, boolean containsCrossJoin) { this.nodes = nodes; this.edges = edges; this.rootId = rootId; this.filters = filters; - this.assignments = assignments; this.containsCrossJoin = containsCrossJoin; } - public JoinGraph withAssignments(Map assignments) - { - return new JoinGraph(nodes, edges, rootId, filters, Optional.of(assignments), containsCrossJoin); - } - - public Optional> getAssignments() - { - return assignments; - } - public JoinGraph withFilter(Expression expression) { ImmutableList.Builder filters = ImmutableList.builder(); filters.addAll(this.filters); filters.add(expression); - return new JoinGraph(nodes, edges, rootId, filters.build(), assignments, containsCrossJoin); + return new JoinGraph(nodes, edges, rootId, filters.build(), containsCrossJoin); } public List getFilters() @@ -197,7 +184,7 @@ private JoinGraph joinWith(JoinGraph other, List joinCl edges.put(right.getId(), new Edge(left, rightSymbol, leftSymbol)); } - return new JoinGraph(nodes, edges.build(), newRoot, joinedFilters, Optional.empty(), this.containsCrossJoin || containsCrossJoin); + return new JoinGraph(nodes, edges.build(), newRoot, joinedFilters, this.containsCrossJoin || containsCrossJoin); } private static class Builder @@ -250,11 +237,6 @@ public JoinGraph visitJoin(JoinNode node, Context context) @Override public JoinGraph visitProject(ProjectNode node, Context context) { - if (node.isIdentity()) { - JoinGraph graph = node.getSource().accept(this, context); - return graph.withAssignments(node.getAssignments().getMap()); - } - Optional rewrittenNode = pushProjectionThroughJoin(node, lookup, planNodeIdAllocator); if (rewrittenNode.isPresent()) { return rewrittenNode.get().accept(this, context); @@ -276,7 +258,7 @@ public JoinGraph visitGroupReference(GroupReference node, Context context) private boolean isTrivialGraph(JoinGraph graph) { - return graph.nodes.size() < 2 && graph.edges.isEmpty() && graph.filters.isEmpty() && !graph.assignments.isPresent(); + return graph.nodes.size() < 2 && graph.edges.isEmpty() && graph.filters.isEmpty(); } private JoinGraph replacementGraph(PlanNode oldNode, PlanNode newNode, Context context) From 9b81c6d2d995464dd046f73f9dab9309564cbf76 Mon Sep 17 00:00:00 2001 From: Karol Sobczak Date: Fri, 30 Aug 2019 14:52:13 +0200 Subject: [PATCH 09/16] Update description --- .../sql/planner/iterative/rule/PushProjectionThroughJoin.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushProjectionThroughJoin.java b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushProjectionThroughJoin.java index 447335187f78..4b26f65ff593 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushProjectionThroughJoin.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushProjectionThroughJoin.java @@ -37,8 +37,8 @@ import static io.prestosql.sql.planner.SymbolsExtractor.extractUnique; /** - * Utility class for pushing simple projections through join so that joins can participate - * in cross join elimination or join reordering. + * Utility class for pushing projections through join so that joins are not separated + * by a project node and can participate in cross join elimination or join reordering. */ public final class PushProjectionThroughJoin { From f5873a35cca28beca39a6d51813b0e7a235860d8 Mon Sep 17 00:00:00 2001 From: Karol Sobczak Date: Thu, 26 Sep 2019 14:10:46 +0200 Subject: [PATCH 10/16] Do not push projections through outer joins --- .../rule/PushProjectionThroughJoin.java | 7 ++- .../rule/TestPushProjectionThroughJoin.java | 20 +++++++ .../java/io/prestosql/sql/query/TestJoin.java | 53 +++++++++++++++++++ 3 files changed, 79 insertions(+), 1 deletion(-) create mode 100644 presto-main/src/test/java/io/prestosql/sql/query/TestJoin.java diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushProjectionThroughJoin.java b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushProjectionThroughJoin.java index 4b26f65ff593..bba966d23880 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushProjectionThroughJoin.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushProjectionThroughJoin.java @@ -35,9 +35,10 @@ import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableSet.toImmutableSet; import static io.prestosql.sql.planner.SymbolsExtractor.extractUnique; +import static io.prestosql.sql.planner.plan.JoinNode.Type.INNER; /** - * Utility class for pushing projections through join so that joins are not separated + * Utility class for pushing projections through inner join so that joins are not separated * by a project node and can participate in cross join elimination or join reordering. */ public final class PushProjectionThroughJoin @@ -57,6 +58,10 @@ public static Optional pushProjectionThroughJoin(ProjectNode projectNo PlanNode leftChild = joinNode.getLeft(); PlanNode rightChild = joinNode.getRight(); + if (joinNode.getType() != INNER) { + return Optional.empty(); + } + Assignments.Builder leftAssignmentsBuilder = Assignments.builder(); Assignments.Builder rightAssignmentsBuilder = Assignments.builder(); for (Map.Entry assignment : projectNode.getAssignments().entrySet()) { diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestPushProjectionThroughJoin.java b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestPushProjectionThroughJoin.java index 3902eb35105a..bb4ea159b1f6 100644 --- a/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestPushProjectionThroughJoin.java +++ b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestPushProjectionThroughJoin.java @@ -40,6 +40,7 @@ import static io.prestosql.sql.planner.iterative.Lookup.noLookup; import static io.prestosql.sql.planner.iterative.rule.PushProjectionThroughJoin.pushProjectionThroughJoin; import static io.prestosql.sql.planner.plan.JoinNode.Type.INNER; +import static io.prestosql.sql.planner.plan.JoinNode.Type.LEFT; import static io.prestosql.sql.tree.ArithmeticBinaryExpression.Operator.ADD; import static io.prestosql.sql.tree.ArithmeticUnaryExpression.Sign.MINUS; import static io.prestosql.sql.tree.ArithmeticUnaryExpression.Sign.PLUS; @@ -124,4 +125,23 @@ c, new ArithmeticBinaryExpression(ADD, a.toSymbolReference(), b.toSymbolReferenc Optional rewritten = pushProjectionThroughJoin(planNode, noLookup(), new PlanNodeIdAllocator()); assertThat(rewritten).isEmpty(); } + + @Test + public void testDoesNotPushProjectionThroughOuterJoin() + { + PlanBuilder p = new PlanBuilder(new PlanNodeIdAllocator(), dummyMetadata()); + Symbol a = p.symbol("a"); + Symbol b = p.symbol("b"); + Symbol c = p.symbol("c"); + + ProjectNode planNode = p.project( + Assignments.of( + c, new ArithmeticUnaryExpression(MINUS, a.toSymbolReference())), + p.join( + LEFT, + p.values(a), + p.values(b))); + Optional rewritten = pushProjectionThroughJoin(planNode, noLookup(), new PlanNodeIdAllocator()); + assertThat(rewritten).isEmpty(); + } } diff --git a/presto-main/src/test/java/io/prestosql/sql/query/TestJoin.java b/presto-main/src/test/java/io/prestosql/sql/query/TestJoin.java new file mode 100644 index 000000000000..174e49db28d8 --- /dev/null +++ b/presto-main/src/test/java/io/prestosql/sql/query/TestJoin.java @@ -0,0 +1,53 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.sql.query; + +import org.testng.annotations.AfterClass; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +public class TestJoin +{ + private QueryAssertions assertions; + + @BeforeClass + public void init() + { + assertions = new QueryAssertions(); + } + + @AfterClass(alwaysRun = true) + public void teardown() + { + assertions.close(); + assertions = null; + } + + @Test + public void testCrossJoinEliminationWithOuterJoin() + { + assertions.assertQuery( + "WITH " + + " a AS (SELECT id FROM (VALUES (1)) AS t(id))," + + " b AS (SELECT id FROM (VALUES (1)) AS t(id))," + + " c AS (SELECT id FROM (VALUES ('1')) AS t(id))," + + " d as (SELECT id FROM (VALUES (1)) AS t(id))" + + "SELECT a.id " + + "FROM a " + + "LEFT JOIN b ON a.id = b.id " + + "JOIN c ON a.id = CAST(c.id AS bigint) " + + "JOIN d ON d.id = a.id", + "VALUES 1"); + } +} From 928e4ccd8875da1ef2949873ba7e53f5f9354322 Mon Sep 17 00:00:00 2001 From: Karol Sobczak Date: Tue, 27 Aug 2019 12:33:05 +0200 Subject: [PATCH 11/16] Add Assignments#isIdentity() method --- .../io/prestosql/sql/planner/plan/Assignments.java | 12 ++++++++++++ .../io/prestosql/sql/planner/plan/ProjectNode.java | 12 +----------- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/plan/Assignments.java b/presto-main/src/main/java/io/prestosql/sql/planner/plan/Assignments.java index 68c219dffa5f..7ea37ff5468b 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/plan/Assignments.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/plan/Assignments.java @@ -128,6 +128,18 @@ public boolean isIdentity(Symbol output) return expression instanceof SymbolReference && ((SymbolReference) expression).getName().equals(output.getName()); } + public boolean isIdentity() + { + for (Map.Entry entry : assignments.entrySet()) { + Expression expression = entry.getValue(); + Symbol symbol = entry.getKey(); + if (!(expression instanceof SymbolReference && ((SymbolReference) expression).getName().equals(symbol.getName()))) { + return false; + } + } + return true; + } + private Collector, Builder, Assignments> toAssignments() { return Collector.of( diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/plan/ProjectNode.java b/presto-main/src/main/java/io/prestosql/sql/planner/plan/ProjectNode.java index 648d17b51998..335b1eb7303e 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/plan/ProjectNode.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/plan/ProjectNode.java @@ -18,13 +18,10 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.Iterables; import io.prestosql.sql.planner.Symbol; -import io.prestosql.sql.tree.Expression; -import io.prestosql.sql.tree.SymbolReference; import javax.annotation.concurrent.Immutable; import java.util.List; -import java.util.Map; import static java.util.Objects.requireNonNull; @@ -76,14 +73,7 @@ public PlanNode getSource() public boolean isIdentity() { - for (Map.Entry entry : assignments.entrySet()) { - Expression expression = entry.getValue(); - Symbol symbol = entry.getKey(); - if (!(expression instanceof SymbolReference && ((SymbolReference) expression).getName().equals(symbol.getName()))) { - return false; - } - } - return true; + return assignments.isIdentity(); } @Override From 3858f5ab908cd30b435b3c172cc6c6f72249fbe1 Mon Sep 17 00:00:00 2001 From: praveenkrishna Date: Sun, 11 Aug 2019 21:13:08 +0530 Subject: [PATCH 12/16] Use lazy blocks for probe side page during Join --- .../operator/LookupJoinPageBuilder.java | 28 ++++++++++++++++--- 1 file changed, 24 insertions(+), 4 deletions(-) diff --git a/presto-main/src/main/java/io/prestosql/operator/LookupJoinPageBuilder.java b/presto-main/src/main/java/io/prestosql/operator/LookupJoinPageBuilder.java index 8d3a4eacbca8..8aa821e86b4a 100644 --- a/presto-main/src/main/java/io/prestosql/operator/LookupJoinPageBuilder.java +++ b/presto-main/src/main/java/io/prestosql/operator/LookupJoinPageBuilder.java @@ -16,6 +16,7 @@ import io.prestosql.spi.Page; import io.prestosql.spi.PageBuilder; import io.prestosql.spi.block.Block; +import io.prestosql.spi.block.LazyBlock; import io.prestosql.spi.type.Type; import it.unimi.dsi.fastutil.ints.IntArrayList; @@ -23,6 +24,7 @@ import static com.google.common.base.MoreObjects.toStringHelper; import static com.google.common.base.Verify.verify; +import static io.prestosql.operator.project.PageProcessor.MAX_BATCH_SIZE; import static io.prestosql.spi.block.PageBuilderStatus.DEFAULT_MAX_PAGE_SIZE_IN_BYTES; import static java.util.Objects.requireNonNull; @@ -48,7 +50,9 @@ public LookupJoinPageBuilder(List buildTypes) public boolean isFull() { - return estimatedProbeBlockBytes + buildPageBuilder.getSizeInBytes() >= DEFAULT_MAX_PAGE_SIZE_IN_BYTES || buildPageBuilder.isFull(); + return estimatedProbeBlockBytes + buildPageBuilder.getSizeInBytes() >= DEFAULT_MAX_PAGE_SIZE_IN_BYTES || + buildPageBuilder.getPositionCount() >= MAX_BATCH_SIZE || + buildPageBuilder.isFull(); } public boolean isEmpty() @@ -104,7 +108,14 @@ public Page build(JoinProbe probe) for (int i = 0; i < probeOutputChannels.length; i++) { Block probeBlock = probe.getPage().getBlock(probeOutputChannels[i]); if (!isSequentialProbeIndices || length == 0) { - blocks[i] = probeBlock.getPositions(probeIndices, 0, probeIndices.length); + if (probeBlock.isLoaded()) { + blocks[i] = probeBlock.getPositions(probeIndices, 0, probeIndices.length); + } + else { + blocks[i] = new LazyBlock( + probeIndices.length, + lazyBlock -> lazyBlock.setBlock(probeBlock.getPositions(probeIndices, 0, probeIndices.length))); + } } else if (length == probeBlock.getPositionCount()) { // probeIndices are a simple covering of the block @@ -115,7 +126,14 @@ else if (length == probeBlock.getPositionCount()) { else { // probeIndices are sequential without holes verify(probeIndices[length - 1] - probeIndices[0] == length - 1); - blocks[i] = probeBlock.getRegion(probeIndices[0], length); + if (probeBlock.isLoaded()) { + blocks[i] = probeBlock.getRegion(probeIndices[0], length); + } + else { + blocks[i] = new LazyBlock( + length, + lazyBlock -> lazyBlock.setBlock(probeBlock.getRegion(probeIndices[0], length))); + } } } @@ -179,7 +197,9 @@ private void appendProbeIndex(JoinProbe probe) for (int index : probe.getOutputChannels()) { Block block = probe.getPage().getBlock(index); // Estimate the size of the current row - estimatedProbeBlockBytes += block.getSizeInBytes() / block.getPositionCount(); + if (!block.isLoaded()) { + estimatedProbeBlockBytes += block.getSizeInBytes() / block.getPositionCount(); + } } } } From 2a6d1899fd26f9c9e8b55fd40e9a537635a466eb Mon Sep 17 00:00:00 2001 From: Piotr Findeisen Date: Mon, 4 Nov 2019 08:18:40 +0100 Subject: [PATCH 13/16] Fix merge compilation error --- .../java/io/prestosql/operator/LookupJoinPageBuilder.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/presto-main/src/main/java/io/prestosql/operator/LookupJoinPageBuilder.java b/presto-main/src/main/java/io/prestosql/operator/LookupJoinPageBuilder.java index 8aa821e86b4a..39e41c1d7d5a 100644 --- a/presto-main/src/main/java/io/prestosql/operator/LookupJoinPageBuilder.java +++ b/presto-main/src/main/java/io/prestosql/operator/LookupJoinPageBuilder.java @@ -114,7 +114,7 @@ public Page build(JoinProbe probe) else { blocks[i] = new LazyBlock( probeIndices.length, - lazyBlock -> lazyBlock.setBlock(probeBlock.getPositions(probeIndices, 0, probeIndices.length))); + () -> probeBlock.getPositions(probeIndices, 0, probeIndices.length)); } } else if (length == probeBlock.getPositionCount()) { @@ -132,7 +132,7 @@ else if (length == probeBlock.getPositionCount()) { else { blocks[i] = new LazyBlock( length, - lazyBlock -> lazyBlock.setBlock(probeBlock.getRegion(probeIndices[0], length))); + () -> probeBlock.getRegion(probeIndices[0], length)); } } } From d762e6939c0b45964da4fd7390fcb9ce53077ed9 Mon Sep 17 00:00:00 2001 From: Karol Sobczak Date: Sun, 3 Nov 2019 22:53:54 +0100 Subject: [PATCH 14/16] Remove unneeded isLoaded check getRegion and getPositions do not load underlying blocks now. --- .../operator/LookupJoinPageBuilder.java | 23 +++---------------- 1 file changed, 3 insertions(+), 20 deletions(-) diff --git a/presto-main/src/main/java/io/prestosql/operator/LookupJoinPageBuilder.java b/presto-main/src/main/java/io/prestosql/operator/LookupJoinPageBuilder.java index 39e41c1d7d5a..3bc696797848 100644 --- a/presto-main/src/main/java/io/prestosql/operator/LookupJoinPageBuilder.java +++ b/presto-main/src/main/java/io/prestosql/operator/LookupJoinPageBuilder.java @@ -16,7 +16,6 @@ import io.prestosql.spi.Page; import io.prestosql.spi.PageBuilder; import io.prestosql.spi.block.Block; -import io.prestosql.spi.block.LazyBlock; import io.prestosql.spi.type.Type; import it.unimi.dsi.fastutil.ints.IntArrayList; @@ -108,14 +107,7 @@ public Page build(JoinProbe probe) for (int i = 0; i < probeOutputChannels.length; i++) { Block probeBlock = probe.getPage().getBlock(probeOutputChannels[i]); if (!isSequentialProbeIndices || length == 0) { - if (probeBlock.isLoaded()) { - blocks[i] = probeBlock.getPositions(probeIndices, 0, probeIndices.length); - } - else { - blocks[i] = new LazyBlock( - probeIndices.length, - () -> probeBlock.getPositions(probeIndices, 0, probeIndices.length)); - } + blocks[i] = probeBlock.getPositions(probeIndices, 0, probeIndices.length); } else if (length == probeBlock.getPositionCount()) { // probeIndices are a simple covering of the block @@ -126,14 +118,7 @@ else if (length == probeBlock.getPositionCount()) { else { // probeIndices are sequential without holes verify(probeIndices[length - 1] - probeIndices[0] == length - 1); - if (probeBlock.isLoaded()) { - blocks[i] = probeBlock.getRegion(probeIndices[0], length); - } - else { - blocks[i] = new LazyBlock( - length, - () -> probeBlock.getRegion(probeIndices[0], length)); - } + blocks[i] = probeBlock.getRegion(probeIndices[0], length); } } @@ -197,9 +182,7 @@ private void appendProbeIndex(JoinProbe probe) for (int index : probe.getOutputChannels()) { Block block = probe.getPage().getBlock(index); // Estimate the size of the current row - if (!block.isLoaded()) { - estimatedProbeBlockBytes += block.getSizeInBytes() / block.getPositionCount(); - } + estimatedProbeBlockBytes += block.getSizeInBytes() / block.getPositionCount(); } } } From fb2bde757368a8cadb600dbc38e4fdb30ffd6a0f Mon Sep 17 00:00:00 2001 From: Karol Sobczak Date: Sun, 3 Nov 2019 22:50:53 +0100 Subject: [PATCH 15/16] Add missing TODO --- .../main/java/io/prestosql/operator/LookupJoinPageBuilder.java | 1 + 1 file changed, 1 insertion(+) diff --git a/presto-main/src/main/java/io/prestosql/operator/LookupJoinPageBuilder.java b/presto-main/src/main/java/io/prestosql/operator/LookupJoinPageBuilder.java index 3bc696797848..11cdce9b4aa2 100644 --- a/presto-main/src/main/java/io/prestosql/operator/LookupJoinPageBuilder.java +++ b/presto-main/src/main/java/io/prestosql/operator/LookupJoinPageBuilder.java @@ -182,6 +182,7 @@ private void appendProbeIndex(JoinProbe probe) for (int index : probe.getOutputChannels()) { Block block = probe.getPage().getBlock(index); // Estimate the size of the current row + // TODO: improve estimation for unloaded blocks by making it similar as in PageProcessor estimatedProbeBlockBytes += block.getSizeInBytes() / block.getPositionCount(); } } From d5dcbe5137e343cb47d85228c6201fcbd9dd656e Mon Sep 17 00:00:00 2001 From: Karol Sobczak Date: Tue, 10 Mar 2020 23:15:52 +0100 Subject: [PATCH 16/16] Precompute probe row size estimate once Don't enumerate probe blocks for every appended join row. --- .../operator/LookupJoinPageBuilder.java | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/presto-main/src/main/java/io/prestosql/operator/LookupJoinPageBuilder.java b/presto-main/src/main/java/io/prestosql/operator/LookupJoinPageBuilder.java index 11cdce9b4aa2..cc65bb8c1c0e 100644 --- a/presto-main/src/main/java/io/prestosql/operator/LookupJoinPageBuilder.java +++ b/presto-main/src/main/java/io/prestosql/operator/LookupJoinPageBuilder.java @@ -39,6 +39,7 @@ public class LookupJoinPageBuilder private final PageBuilder buildPageBuilder; private final int buildOutputChannelCount; private int estimatedProbeBlockBytes; + private int estimatedProbeRowSize = -1; private boolean isSequentialProbeIndices = true; public LookupJoinPageBuilder(List buildTypes) @@ -65,6 +66,7 @@ public void reset() probeIndexBuilder.clear(); buildPageBuilder.reset(); estimatedProbeBlockBytes = 0; + estimatedProbeRowSize = -1; isSequentialProbeIndices = true; } @@ -179,11 +181,24 @@ private void appendProbeIndex(JoinProbe probe) if (previousPosition == position) { return; } + estimatedProbeBlockBytes += getEstimatedProbeRowSize(probe); + } + + private int getEstimatedProbeRowSize(JoinProbe probe) + { + if (estimatedProbeRowSize != -1) { + return estimatedProbeRowSize; + } + + int estimatedProbeRowSize = 0; for (int index : probe.getOutputChannels()) { Block block = probe.getPage().getBlock(index); - // Estimate the size of the current row + // Estimate the size of the probe row // TODO: improve estimation for unloaded blocks by making it similar as in PageProcessor - estimatedProbeBlockBytes += block.getSizeInBytes() / block.getPositionCount(); + estimatedProbeRowSize += block.getSizeInBytes() / block.getPositionCount(); } + + this.estimatedProbeRowSize = estimatedProbeRowSize; + return estimatedProbeRowSize; } }