Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import static com.google.common.base.MoreObjects.toStringHelper;
import static com.google.common.base.Verify.verify;
import static io.prestosql.operator.project.PageProcessor.MAX_BATCH_SIZE;
import static io.prestosql.spi.block.PageBuilderStatus.DEFAULT_MAX_PAGE_SIZE_IN_BYTES;
import static java.util.Objects.requireNonNull;

Expand All @@ -38,6 +39,7 @@ public class LookupJoinPageBuilder
private final PageBuilder buildPageBuilder;
private final int buildOutputChannelCount;
private int estimatedProbeBlockBytes;
private int estimatedProbeRowSize = -1;
private boolean isSequentialProbeIndices = true;

public LookupJoinPageBuilder(List<Type> buildTypes)
Expand All @@ -48,7 +50,9 @@ public LookupJoinPageBuilder(List<Type> buildTypes)

public boolean isFull()
{
return estimatedProbeBlockBytes + buildPageBuilder.getSizeInBytes() >= DEFAULT_MAX_PAGE_SIZE_IN_BYTES || buildPageBuilder.isFull();
return estimatedProbeBlockBytes + buildPageBuilder.getSizeInBytes() >= DEFAULT_MAX_PAGE_SIZE_IN_BYTES ||
buildPageBuilder.getPositionCount() >= MAX_BATCH_SIZE ||
buildPageBuilder.isFull();
}

public boolean isEmpty()
Expand All @@ -62,6 +66,7 @@ public void reset()
probeIndexBuilder.clear();
buildPageBuilder.reset();
estimatedProbeBlockBytes = 0;
estimatedProbeRowSize = -1;
isSequentialProbeIndices = true;
}

Expand Down Expand Up @@ -176,10 +181,24 @@ private void appendProbeIndex(JoinProbe probe)
if (previousPosition == position) {
return;
}
estimatedProbeBlockBytes += getEstimatedProbeRowSize(probe);
}

private int getEstimatedProbeRowSize(JoinProbe probe)
{
if (estimatedProbeRowSize != -1) {
return estimatedProbeRowSize;
}

int estimatedProbeRowSize = 0;
for (int index : probe.getOutputChannels()) {
Block block = probe.getPage().getBlock(index);
// Estimate the size of the current row
estimatedProbeBlockBytes += block.getSizeInBytes() / block.getPositionCount();
// Estimate the size of the probe row
// TODO: improve estimation for unloaded blocks by making it similar as in PageProcessor
estimatedProbeRowSize += block.getSizeInBytes() / block.getPositionCount();
}

this.estimatedProbeRowSize = estimatedProbeRowSize;
return estimatedProbeRowSize;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,10 @@
import io.prestosql.sql.planner.Symbol;
import io.prestosql.sql.planner.iterative.Rule;
import io.prestosql.sql.planner.optimizations.joins.JoinGraph;
import io.prestosql.sql.planner.plan.Assignments;
import io.prestosql.sql.planner.plan.FilterNode;
import io.prestosql.sql.planner.plan.JoinNode;
import io.prestosql.sql.planner.plan.PlanNode;
import io.prestosql.sql.planner.plan.PlanNodeId;
import io.prestosql.sql.planner.plan.ProjectNode;
import io.prestosql.sql.tree.Expression;

import java.util.HashMap;
Expand Down Expand Up @@ -73,8 +71,8 @@ public boolean isEnabled(Session session)
@Override
public Result apply(JoinNode node, Captures captures, Context context)
{
JoinGraph joinGraph = JoinGraph.buildShallowFrom(node, context.getLookup());
if (joinGraph.size() < 3) {
JoinGraph joinGraph = JoinGraph.buildFrom(node, context.getLookup(), context.getIdAllocator());
if (joinGraph.size() < 3 || !joinGraph.isContainsCrossJoin()) {
return Result.empty();
}

Expand Down Expand Up @@ -201,13 +199,6 @@ public static PlanNode buildJoinTree(List<Symbol> expectedOutputSymbols, JoinGra
filter);
}

if (graph.getAssignments().isPresent()) {
result = new ProjectNode(
idAllocator.getNextId(),
result,
Assignments.copyOf(graph.getAssignments().get()));
}

// If needed, introduce a projection to constrain the outputs to what was originally expected
// Some nodes are sensitive to what's produced (e.g., DistinctLimit node)
return restrictOutputs(idAllocator, result, ImmutableSet.copyOf(expectedOutputSymbols)).orElse(result);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,15 @@
import io.prestosql.sql.planner.SymbolsExtractor;
import io.prestosql.sql.planner.iterative.Rule;
import io.prestosql.sql.planner.plan.Assignments;
import io.prestosql.sql.planner.plan.PlanNode;
import io.prestosql.sql.planner.plan.ProjectNode;
import io.prestosql.sql.tree.Expression;
import io.prestosql.sql.tree.Literal;
import io.prestosql.sql.tree.TryExpression;
import io.prestosql.sql.util.AstUtils;

import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -64,9 +66,16 @@ public Result apply(ProjectNode parent, Captures captures, Context context)
{
ProjectNode child = captures.get(CHILD);

Sets.SetView<Symbol> targets = extractInliningTargets(parent, child);
return inlineProjections(parent, child)
.map(Result::ofPlanNode)
.orElse(Result.empty());
}

static Optional<ProjectNode> inlineProjections(ProjectNode parent, ProjectNode child)
{
Set<Symbol> targets = extractInliningTargets(parent, child);
if (targets.isEmpty()) {
return Result.empty();
return Optional.empty();
}

// inline the expressions
Expand All @@ -79,36 +88,43 @@ public Result apply(ProjectNode parent, Captures captures, Context context)

// Synthesize identity assignments for the inputs of expressions that were inlined
// to place in the child projection.
// If all assignments end up becoming identity assignments, they'll get pruned by
// other rules
Set<Symbol> inputs = child.getAssignments()
.entrySet().stream()
.filter(entry -> targets.contains(entry.getKey()))
.map(Map.Entry::getValue)
.flatMap(entry -> SymbolsExtractor.extractAll(entry).stream())
.collect(toSet());

Assignments.Builder childAssignments = Assignments.builder();
Assignments.Builder newChildAssignmentsBuilder = Assignments.builder();
for (Map.Entry<Symbol, Expression> assignment : child.getAssignments().entrySet()) {
if (!targets.contains(assignment.getKey())) {
childAssignments.put(assignment);
newChildAssignmentsBuilder.put(assignment);
}
}
for (Symbol input : inputs) {
childAssignments.putIdentity(input);
newChildAssignmentsBuilder.putIdentity(input);
}

Assignments newChildAssignments = newChildAssignmentsBuilder.build();
PlanNode newChild;
if (newChildAssignments.isIdentity()) {
newChild = child.getSource();
}
else {
newChild = new ProjectNode(
child.getId(),
child.getSource(),
newChildAssignments);
}

return Result.ofPlanNode(
return Optional.of(
new ProjectNode(
parent.getId(),
new ProjectNode(
child.getId(),
child.getSource(),
childAssignments.build()),
newChild,
Assignments.copyOf(parentAssignments)));
}

private Expression inlineReferences(Expression expression, Assignments assignments)
private static Expression inlineReferences(Expression expression, Assignments assignments)
{
Function<Symbol, Expression> mapping = symbol -> {
Expression result = assignments.get(symbol);
Expand All @@ -122,7 +138,7 @@ private Expression inlineReferences(Expression expression, Assignments assignmen
return inlineSymbols(mapping, expression);
}

private Sets.SetView<Symbol> extractInliningTargets(ProjectNode parent, ProjectNode child)
private static Set<Symbol> extractInliningTargets(ProjectNode parent, ProjectNode child)
{
// candidates for inlining are
// 1. references to simple constants
Expand Down Expand Up @@ -162,7 +178,7 @@ private Sets.SetView<Symbol> extractInliningTargets(ProjectNode parent, ProjectN
return Sets.union(singletons, constants);
}

private Set<Symbol> extractTryArguments(Expression expression)
private static Set<Symbol> extractTryArguments(Expression expression)
{
return AstUtils.preOrder(expression)
.filter(TryExpression.class::isInstance)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.prestosql.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Streams;
import io.prestosql.sql.planner.DeterminismEvaluator;
import io.prestosql.sql.planner.PlanNodeIdAllocator;
import io.prestosql.sql.planner.Symbol;
import io.prestosql.sql.planner.SymbolsExtractor;
import io.prestosql.sql.planner.iterative.Lookup;
import io.prestosql.sql.planner.plan.Assignments;
import io.prestosql.sql.planner.plan.JoinNode;
import io.prestosql.sql.planner.plan.PlanNode;
import io.prestosql.sql.planner.plan.ProjectNode;
import io.prestosql.sql.tree.Expression;

import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;

import static com.google.common.base.Preconditions.checkState;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static io.prestosql.sql.planner.SymbolsExtractor.extractUnique;
import static io.prestosql.sql.planner.plan.JoinNode.Type.INNER;

/**
* Utility class for pushing projections through inner join so that joins are not separated
* by a project node and can participate in cross join elimination or join reordering.
*/
public final class PushProjectionThroughJoin
{
public static Optional<PlanNode> pushProjectionThroughJoin(ProjectNode projectNode, Lookup lookup, PlanNodeIdAllocator planNodeIdAllocator)
{
if (!projectNode.getAssignments().getExpressions().stream().allMatch(DeterminismEvaluator::isDeterministic)) {
return Optional.empty();
}

PlanNode child = lookup.resolve(projectNode.getSource());
if (!(child instanceof JoinNode)) {
return Optional.empty();
}

JoinNode joinNode = (JoinNode) child;
PlanNode leftChild = joinNode.getLeft();
PlanNode rightChild = joinNode.getRight();

if (joinNode.getType() != INNER) {
return Optional.empty();
}

Assignments.Builder leftAssignmentsBuilder = Assignments.builder();
Assignments.Builder rightAssignmentsBuilder = Assignments.builder();
for (Map.Entry<Symbol, Expression> assignment : projectNode.getAssignments().entrySet()) {
Expression expression = assignment.getValue();
Set<Symbol> symbols = extractUnique(expression);
if (leftChild.getOutputSymbols().containsAll(symbols)) {
// expression is satisfied with left child symbols
leftAssignmentsBuilder.put(assignment.getKey(), expression);
}
else if (rightChild.getOutputSymbols().containsAll(symbols)) {
// expression is satisfied with right child symbols
rightAssignmentsBuilder.put(assignment.getKey(), expression);
}
else {
// expression is using symbols from both join sides
return Optional.empty();
}
}

// add projections for symbols required by the join itself
Set<Symbol> joinRequiredSymbols = getJoinRequiredSymbols(joinNode);
for (Symbol requiredSymbol : joinRequiredSymbols) {
if (leftChild.getOutputSymbols().contains(requiredSymbol)) {
leftAssignmentsBuilder.putIdentity(requiredSymbol);
}
else {
checkState(rightChild.getOutputSymbols().contains(requiredSymbol));
rightAssignmentsBuilder.putIdentity(requiredSymbol);
}
}

Assignments leftAssignments = leftAssignmentsBuilder.build();
Assignments rightAssignments = rightAssignmentsBuilder.build();
List<Symbol> outputSymbols = Streams.concat(leftAssignments.getOutputs().stream(), rightAssignments.getOutputs().stream())
.filter(ImmutableSet.copyOf(projectNode.getOutputSymbols())::contains)
.collect(toImmutableList());

return Optional.of(new JoinNode(
joinNode.getId(),
joinNode.getType(),
inlineProjections(
new ProjectNode(planNodeIdAllocator.getNextId(), leftChild, leftAssignments),
lookup),
inlineProjections(
new ProjectNode(planNodeIdAllocator.getNextId(), rightChild, rightAssignments),
lookup),
joinNode.getCriteria(),
outputSymbols,
joinNode.getFilter(),
joinNode.getLeftHashSymbol(),
joinNode.getRightHashSymbol(),
joinNode.getDistributionType(),
joinNode.isSpillable(),
joinNode.getDynamicFilters()));
}

private static PlanNode inlineProjections(ProjectNode parentProjection, Lookup lookup)
{
PlanNode child = lookup.resolve(parentProjection.getSource());
if (!(child instanceof ProjectNode)) {
return parentProjection;
}
ProjectNode childProjection = (ProjectNode) child;

return InlineProjections.inlineProjections(parentProjection, childProjection)
.map(node -> inlineProjections(node, lookup))
.orElse(parentProjection);
}

private static Set<Symbol> getJoinRequiredSymbols(JoinNode node)
{
// extract symbols required by the join itself
return Streams.concat(
node.getCriteria().stream().map(JoinNode.EquiJoinClause::getLeft),
node.getCriteria().stream().map(JoinNode.EquiJoinClause::getRight),
node.getFilter().map(SymbolsExtractor::extractUnique).orElse(ImmutableSet.of()).stream(),
node.getLeftHashSymbol().map(ImmutableSet::of).orElse(ImmutableSet.of()).stream(),
node.getRightHashSymbol().map(ImmutableSet::of).orElse(ImmutableSet.of()).stream())
.collect(toImmutableSet());
}

private PushProjectionThroughJoin() {}
}
Loading