From 5e50b6ad037e1d99b80dff29fb79698fc42c9a12 Mon Sep 17 00:00:00 2001 From: Karol Sobczak Date: Wed, 19 Jul 2017 16:24:28 +0200 Subject: [PATCH] Preserve symbols required by join during partial aggregation pushdown Partial aggregation needs to produce symbols requried by join equi-conditions and filter expression --- .../PartialAggregationPushDown.java | 42 +++-- .../presto/tests/AbstractTestQueries.java | 162 ++++++++++-------- 2 files changed, 110 insertions(+), 94 deletions(-) diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PartialAggregationPushDown.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PartialAggregationPushDown.java index cc5a27df9cc83..6b9db42d97996 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PartialAggregationPushDown.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/PartialAggregationPushDown.java @@ -44,12 +44,12 @@ import java.util.Collection; import java.util.Collections; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Optional; import java.util.Set; import java.util.stream.Collectors; -import java.util.stream.Stream; import static com.facebook.presto.SystemSessionProperties.isPushAggregationThroughJoin; import static com.facebook.presto.sql.planner.plan.AggregationNode.Step.FINAL; @@ -60,6 +60,8 @@ import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Verify.verify; import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static com.google.common.collect.Sets.intersection; import static java.util.Objects.requireNonNull; public class PartialAggregationPushDown @@ -192,18 +194,29 @@ else if (allAggregationsOn(node.getAggregations(), child.getRight().getOutputSym private PlanNode pushPartialToLeftChild(AggregationNode node, JoinNode child, RewriteContext context) { - List groupingSet = getPushedDownGroupingSet(node, child, ImmutableSet.copyOf(child.getLeft().getOutputSymbols())); + Set joinLeftChildSymbols = ImmutableSet.copyOf(child.getLeft().getOutputSymbols()); + List groupingSet = getPushedDownGroupingSet(node, joinLeftChildSymbols, intersection(getJoinRequiredSymbols(child), joinLeftChildSymbols)); AggregationNode pushedAggregation = replaceAggregationSource(node, child.getLeft(), child.getCriteria(), groupingSet, context); return pushPartialToJoin(pushedAggregation, child, pushedAggregation, context.rewrite(child.getRight()), child.getRight().getOutputSymbols()); } private PlanNode pushPartialToRightChild(AggregationNode node, JoinNode child, RewriteContext context) { - List groupingSet = getPushedDownGroupingSet(node, child, ImmutableSet.copyOf(child.getRight().getOutputSymbols())); + Set joinRightChildSymbols = ImmutableSet.copyOf(child.getRight().getOutputSymbols()); + List groupingSet = getPushedDownGroupingSet(node, joinRightChildSymbols, intersection(getJoinRequiredSymbols(child), joinRightChildSymbols)); AggregationNode pushedAggregation = replaceAggregationSource(node, child.getRight(), child.getCriteria(), groupingSet, context); return pushPartialToJoin(pushedAggregation, child, context.rewrite(child.getLeft()), pushedAggregation, child.getLeft().getOutputSymbols()); } + private Set getJoinRequiredSymbols(JoinNode node) + { + return ImmutableSet.builder() + .addAll(node.getCriteria().stream().map(EquiJoinClause::getLeft).collect(toImmutableSet())) + .addAll(node.getCriteria().stream().map(EquiJoinClause::getRight).collect(toImmutableSet())) + .addAll(node.getFilter().map(DependencyExtractor::extractUnique).orElse(ImmutableSet.of())) + .build(); + } + private PlanNode pushPartialToJoin( AggregationNode pushedAggregation, JoinNode child, @@ -265,28 +278,21 @@ private boolean allAggregationsOn(Map aggregations, List getPushedDownGroupingSet(AggregationNode aggregation, JoinNode join, Set availableSymbols) + private List getPushedDownGroupingSet(AggregationNode aggregation, Set availableSymbols, Set requiredJoinSymbols) { List groupingSet = Iterables.getOnlyElement(aggregation.getGroupingSets()); - Set joinKeys = Stream.concat( - join.getCriteria().stream().map(EquiJoinClause::getLeft), - join.getCriteria().stream().map(EquiJoinClause::getRight) - ).collect(Collectors.toSet()); - // keep symbols that are either directly from the join's child (availableSymbols) or there is - // an equality in join condition to a symbol for the join child + // keep symbols that are directly from the join's child (availableSymbols) List pushedDownGroupingSet = groupingSet.stream() - .filter(symbol -> joinKeys.contains(symbol) || availableSymbols.contains(symbol)) + .filter(availableSymbols::contains) .collect(Collectors.toList()); - if (pushedDownGroupingSet.size() != groupingSet.size() || pushedDownGroupingSet.isEmpty()) { - // If we dropped some symbol, we have to add all join key columns to the grouping set - Set existingSymbols = ImmutableSet.copyOf(pushedDownGroupingSet); + // add missing required join symbols to grouping set + Set existingSymbols = new HashSet<>(pushedDownGroupingSet); + requiredJoinSymbols.stream() + .filter(existingSymbols::add) + .forEach(pushedDownGroupingSet::add); - join.getCriteria().stream() - .filter(equiJoinClause -> !existingSymbols.contains(equiJoinClause.getLeft()) && !existingSymbols.contains(equiJoinClause.getRight())) - .forEach(joinClause -> pushedDownGroupingSet.add(joinClause.getLeft())); - } return pushedDownGroupingSet; } diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java index 81794c056f8cc..cc5fb93a12fe9 100644 --- a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java +++ b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java @@ -1851,19 +1851,19 @@ public void testGrouping() throws Exception { assertQuery( - "SELECT a, b as t, sum(c), grouping(a, b) + grouping(a) " + - "FROM (VALUES ('h', 'j', 11), ('k', 'l', 7)) AS t (a, b, c) " + - "GROUP BY GROUPING SETS ( (a), (b)) " + - "ORDER BY grouping(b) ASC", - "VALUES (NULL, 'j', 11, 3), (NULL, 'l', 7, 3), ('h', NULL, 11, 1), ('k', NULL, 7, 1)"); + "SELECT a, b as t, sum(c), grouping(a, b) + grouping(a) " + + "FROM (VALUES ('h', 'j', 11), ('k', 'l', 7)) AS t (a, b, c) " + + "GROUP BY GROUPING SETS ( (a), (b)) " + + "ORDER BY grouping(b) ASC", + "VALUES (NULL, 'j', 11, 3), (NULL, 'l', 7, 3), ('h', NULL, 11, 1), ('k', NULL, 7, 1)"); assertQuery( - "SELECT a, sum(b), grouping(a) FROM (VALUES ('h', 11, 0), ('k', 7, 0)) AS t (a, b, c) GROUP BY GROUPING SETS (a)", - "VALUES ('h', 11, 0), ('k', 7, 0)"); + "SELECT a, sum(b), grouping(a) FROM (VALUES ('h', 11, 0), ('k', 7, 0)) AS t (a, b, c) GROUP BY GROUPING SETS (a)", + "VALUES ('h', 11, 0), ('k', 7, 0)"); assertQuery( - "SELECT a, b, sum(c), grouping(a, b) FROM (VALUES ('h', 'j', 11), ('k', 'l', 7) ) AS t (a, b, c) GROUP BY GROUPING SETS ( (a), (b)) HAVING grouping(a, b) > 1 ", - "VALUES (NULL, 'j', 11, 2), (NULL, 'l', 7, 2)"); + "SELECT a, b, sum(c), grouping(a, b) FROM (VALUES ('h', 'j', 11), ('k', 'l', 7) ) AS t (a, b, c) GROUP BY GROUPING SETS ( (a), (b)) HAVING grouping(a, b) > 1 ", + "VALUES (NULL, 'j', 11, 2), (NULL, 'l', 7, 2)"); assertQuery("SELECT a, grouping(a) * 1.0 FROM (VALUES (1) ) AS t (a) GROUP BY a", "VALUES (1, 0.0)"); @@ -1872,7 +1872,7 @@ public void testGrouping() "VALUES (1, 0, 0)"); assertQuery("SELECT grouping(a) FROM (VALUES ('h', 'j', 11), ('k', 'l', 7)) AS t (a, b, c) GROUP BY GROUPING SETS (a,c), c*2", - "VALUES (0), (1), (0), (1)"); + "VALUES (0), (1), (0), (1)"); } @Test @@ -1907,23 +1907,23 @@ public void testGroupingInWindowFunction() throws Exception { assertQuery( - "SELECT orderkey, custkey, sum(totalprice), grouping(orderkey)+grouping(custkey) as g, " + - " rank() OVER (PARTITION BY grouping(orderkey)+grouping(custkey), " + - " CASE WHEN grouping(orderkey) = 0 THEN custkey END ORDER BY orderkey ASC) as r " + - "FROM orders " + - "GROUP BY ROLLUP (orderkey, custkey) " + - "ORDER BY orderkey, custkey " + - "LIMIT 10", - "VALUES (1, 370, 172799.49, 0, 1), " + - " (1, NULL, 172799.49, 1, 1), " + - " (2, 781, 38426.09, 0, 1), " + - " (2, NULL, 38426.09, 1, 2), " + - " (3, 1234, 205654.30, 0, 1), " + - " (3, NULL, 205654.30, 1, 3), " + - " (4, 1369, 56000.91, 0, 1), " + - " (4, NULL, 56000.91, 1, 4), " + - " (5, 445, 105367.67, 0, 1), " + - " (5, NULL, 105367.67, 1, 5)"); + "SELECT orderkey, custkey, sum(totalprice), grouping(orderkey)+grouping(custkey) as g, " + + " rank() OVER (PARTITION BY grouping(orderkey)+grouping(custkey), " + + " CASE WHEN grouping(orderkey) = 0 THEN custkey END ORDER BY orderkey ASC) as r " + + "FROM orders " + + "GROUP BY ROLLUP (orderkey, custkey) " + + "ORDER BY orderkey, custkey " + + "LIMIT 10", + "VALUES (1, 370, 172799.49, 0, 1), " + + " (1, NULL, 172799.49, 1, 1), " + + " (2, 781, 38426.09, 0, 1), " + + " (2, NULL, 38426.09, 1, 2), " + + " (3, 1234, 205654.30, 0, 1), " + + " (3, NULL, 205654.30, 1, 3), " + + " (4, 1369, 56000.91, 0, 1), " + + " (4, NULL, 56000.91, 1, 4), " + + " (5, 445, 105367.67, 0, 1), " + + " (5, NULL, 105367.67, 1, 5)"); } @Test @@ -1938,54 +1938,54 @@ public void testGroupingInTableSubquery() // Inner query has a single GROUP BY and outer query has GROUPING SETS assertQuery( - "SELECT orderkey, custkey, sum(agg_price) as outer_sum, grouping(orderkey, custkey), g " + - "FROM " + - " (SELECT orderkey, custkey, sum(totalprice) as agg_price, grouping(custkey, orderkey) as g " + - " FROM orders " + - " GROUP BY orderkey, custkey " + - " ORDER BY agg_price ASC " + - " LIMIT 5) as t " + - "GROUP BY GROUPING SETS ((orderkey, custkey), g) " + - "ORDER BY outer_sum", - "VALUES (35271, 334, 874.89, 0, NULL), " + - " (28647, 1351, 924.33, 0, NULL), " + - " (58145, 862, 929.03, 0, NULL), " + - " (8354, 634, 974.04, 0, NULL), " + - " (37415, 301, 986.63, 0, NULL), " + - " (NULL, NULL, 4688.92, 3, 0)"); + "SELECT orderkey, custkey, sum(agg_price) as outer_sum, grouping(orderkey, custkey), g " + + "FROM " + + " (SELECT orderkey, custkey, sum(totalprice) as agg_price, grouping(custkey, orderkey) as g " + + " FROM orders " + + " GROUP BY orderkey, custkey " + + " ORDER BY agg_price ASC " + + " LIMIT 5) as t " + + "GROUP BY GROUPING SETS ((orderkey, custkey), g) " + + "ORDER BY outer_sum", + "VALUES (35271, 334, 874.89, 0, NULL), " + + " (28647, 1351, 924.33, 0, NULL), " + + " (58145, 862, 929.03, 0, NULL), " + + " (8354, 634, 974.04, 0, NULL), " + + " (37415, 301, 986.63, 0, NULL), " + + " (NULL, NULL, 4688.92, 3, 0)"); // Inner query has GROUPING SETS and outer query has GROUP BY assertQuery( - "SELECT orderkey, custkey, g, sum(agg_price) as outer_sum, grouping(orderkey, custkey) " + - "FROM " + - " (SELECT orderkey, custkey, sum(totalprice) as agg_price, grouping(custkey, orderkey) as g " + - " FROM orders " + - " GROUP BY GROUPING SETS ((custkey), (orderkey)) " + - " ORDER BY agg_price ASC " + - " LIMIT 5) as t " + - "GROUP BY orderkey, custkey, g", - "VALUES (28647, NULL, 2, 924.33, 0), " + - " (8354, NULL, 2, 974.04, 0), " + - " (37415, NULL, 2, 986.63, 0), " + - " (58145, NULL, 2, 929.03, 0), " + - " (35271, NULL, 2, 874.89, 0)"); + "SELECT orderkey, custkey, g, sum(agg_price) as outer_sum, grouping(orderkey, custkey) " + + "FROM " + + " (SELECT orderkey, custkey, sum(totalprice) as agg_price, grouping(custkey, orderkey) as g " + + " FROM orders " + + " GROUP BY GROUPING SETS ((custkey), (orderkey)) " + + " ORDER BY agg_price ASC " + + " LIMIT 5) as t " + + "GROUP BY orderkey, custkey, g", + "VALUES (28647, NULL, 2, 924.33, 0), " + + " (8354, NULL, 2, 974.04, 0), " + + " (37415, NULL, 2, 986.63, 0), " + + " (58145, NULL, 2, 929.03, 0), " + + " (35271, NULL, 2, 874.89, 0)"); // Inner query has GROUPING SETS but no grouping and outer query has a simple GROUP BY assertQuery( - "SELECT orderkey, custkey, sum(agg_price) as outer_sum, grouping(orderkey, custkey) " + - "FROM " + - " (SELECT orderkey, custkey, sum(totalprice) as agg_price " + - " FROM orders " + - " GROUP BY GROUPING SETS ((custkey), (orderkey)) " + - " ORDER BY agg_price ASC NULLS FIRST) as t " + - "GROUP BY orderkey, custkey " + - "ORDER BY outer_sum ASC NULLS FIRST " + - "LIMIT 5", - "VALUES (35271, NULL, 874.89, 0), " + - " (28647, NULL, 924.33, 0), " + - " (58145, NULL, 929.03, 0), " + - " (8354, NULL, 974.04, 0), " + - " (37415, NULL, 986.63, 0)"); + "SELECT orderkey, custkey, sum(agg_price) as outer_sum, grouping(orderkey, custkey) " + + "FROM " + + " (SELECT orderkey, custkey, sum(totalprice) as agg_price " + + " FROM orders " + + " GROUP BY GROUPING SETS ((custkey), (orderkey)) " + + " ORDER BY agg_price ASC NULLS FIRST) as t " + + "GROUP BY orderkey, custkey " + + "ORDER BY outer_sum ASC NULLS FIRST " + + "LIMIT 5", + "VALUES (35271, NULL, 874.89, 0), " + + " (28647, NULL, 924.33, 0), " + + " (58145, NULL, 929.03, 0), " + + " (8354, NULL, 974.04, 0), " + + " (37415, NULL, 986.63, 0)"); } @Test @@ -6107,11 +6107,11 @@ public void testChainedUnionsWithOrder() public void testUnionWithTopN() { assertQuery("SELECT * FROM (" + - " SELECT regionkey FROM nation " + - " UNION ALL " + - " SELECT nationkey FROM nation" + - ") t(a) " + - "ORDER BY a LIMIT 1", + " SELECT regionkey FROM nation " + + " UNION ALL " + + " SELECT nationkey FROM nation" + + ") t(a) " + + "ORDER BY a LIMIT 1", "SELECT 0"); } @@ -7243,8 +7243,8 @@ public void testCorrelatedScalarSubqueriesWithScalarAggregation() //count in subquery assertQuery("SELECT * " + - "FROM (VALUES (0),( 1), (2), (7)) as v1(c1) " + - "WHERE v1.c1 > (SELECT count(c1) from (VALUES (0),( 1), (2)) as v2(c1) WHERE v1.c1 = v2.c1)", + "FROM (VALUES (0),( 1), (2), (7)) as v1(c1) " + + "WHERE v1.c1 > (SELECT count(c1) from (VALUES (0),( 1), (2)) as v2(c1) WHERE v1.c1 = v2.c1)", "VALUES (2), (7)"); } @@ -8921,6 +8921,16 @@ public void testAggregationPushedBelowOuterJoin() "VALUES 24"); } + @Test + public void testPartialAggregationPushDown() + { + // pushed down aggregation needs to preserve symbols required by join filter and equi-condition + assertQuery("" + + " SELECT orders.custkey AS custkey, orders.orderstatus AS orderstatus " + + " FROM orders JOIN lineitem ON lineitem.orderkey = orders.orderkey AND orders.orderkey = lineitem.partkey AND lineitem.orderkey + 1 = orders.orderkey + 1" + + " GROUP BY orders.custkey, orders.orderstatus"); + } + @Test public void testDefaultDecimalLiteralSwitch() throws Exception