Skip to content

Add support for CORRESPONDING option in set operations #25260

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -271,9 +271,9 @@ rowCount
;

queryTerm
: queryPrimary #queryTermDefault
| left=queryTerm operator=INTERSECT setQuantifier? right=queryTerm #setOperation
| left=queryTerm operator=(UNION | EXCEPT) setQuantifier? right=queryTerm #setOperation
: queryPrimary #queryTermDefault
| left=queryTerm operator=INTERSECT setQuantifier? corresponding? right=queryTerm #setOperation
| left=queryTerm operator=(UNION | EXCEPT) setQuantifier? corresponding? right=queryTerm #setOperation
;

queryPrimary
Expand All @@ -283,6 +283,10 @@ queryPrimary
| '(' queryNoWith ')' #subquery
;

corresponding
: CORRESPONDING (BY columnAliases)?
;

sortItem
: expression ordering=(ASC | DESC)? (NULLS nullOrdering=(FIRST | LAST))?
;
Expand Down Expand Up @@ -1001,7 +1005,7 @@ nonReserved
// IMPORTANT: this rule must only contain tokens. Nested rules are not supported. See SqlParser.exitNonReserved
: ABSENT | ADD | ADMIN | AFTER | ALL | ANALYZE | ANY | ARRAY | ASC | AT | AUTHORIZATION
| BEGIN | BERNOULLI | BOTH
| CALL | CALLED | CASCADE | CATALOG | CATALOGS | COLUMN | COLUMNS | COMMENT | COMMIT | COMMITTED | CONDITIONAL | COPARTITION | COUNT | CURRENT
| CALL | CALLED | CASCADE | CATALOG | CATALOGS | COLUMN | COLUMNS | COMMENT | COMMIT | COMMITTED | CONDITIONAL | COPARTITION | CORRESPONDING | COUNT | CURRENT
| DATA | DATE | DAY | DECLARE | DEFAULT | DEFINE | DEFINER | DENY | DESC | DESCRIPTOR | DETERMINISTIC | DISTRIBUTED | DO | DOUBLE
| ELSEIF | EMPTY | ENCODING | ERROR | EXCLUDING | EXECUTE | EXPLAIN
| FETCH | FILTER | FINAL | FIRST | FOLLOWING | FORMAT | FUNCTION | FUNCTIONS
Expand Down Expand Up @@ -1062,6 +1066,7 @@ CONDITIONAL: 'CONDITIONAL';
CONSTRAINT: 'CONSTRAINT';
COUNT: 'COUNT';
COPARTITION: 'COPARTITION';
CORRESPONDING: 'CORRESPONDING';
CREATE: 'CREATE';
CROSS: 'CROSS';
CUBE: 'CUBE';
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ public void test()
"CONDITIONAL",
"CONSTRAINT",
"COPARTITION",
"CORRESPONDING",
"COUNT",
"CREATE",
"CROSS",
Expand Down
21 changes: 21 additions & 0 deletions core/trino-main/src/main/java/io/trino/sql/analyzer/Analysis.java
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ public class Analysis
private final Map<NodeRef<Identifier>, LambdaArgumentDeclaration> lambdaArgumentReferences = new LinkedHashMap<>();

private final Map<Field, ColumnHandle> columns = new LinkedHashMap<>();
private final Map<NodeRef<Node>, CorrespondingAnalysis> correspondingAnalysis = new LinkedHashMap<>();

private final Map<NodeRef<SampledRelation>, Double> sampleRatios = new LinkedHashMap<>();

Expand Down Expand Up @@ -767,6 +768,16 @@ public ColumnHandle getColumn(Field field)
return columns.get(field);
}

public CorrespondingAnalysis getCorrespondingAnalysis(Node node)
{
return correspondingAnalysis.get(NodeRef.of(node));
}

public void setCorrespondingAnalysis(Node node, CorrespondingAnalysis correspondingAnalysis)
{
this.correspondingAnalysis.put(NodeRef.of(node), correspondingAnalysis);
}

public Optional<AnalyzeMetadata> getAnalyzeMetadata()
{
return analyzeMetadata;
Expand Down Expand Up @@ -2547,4 +2558,14 @@ public record JsonTableAnalysis(
requireNonNull(orderedOutputColumns, "orderedOutputColumns is null");
}
}

public record CorrespondingAnalysis(List<Integer> indexes, List<Field> fields)
{
public CorrespondingAnalysis
{
indexes = ImmutableList.copyOf(indexes);
fields = ImmutableList.copyOf(fields);
checkArgument(indexes.size() == fields.size(), "indexes and fields must have the same size");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@
import io.trino.spi.type.VarcharType;
import io.trino.sql.InterpretedFunctionInvoker;
import io.trino.sql.PlannerContext;
import io.trino.sql.analyzer.Analysis.CorrespondingAnalysis;
import io.trino.sql.analyzer.Analysis.GroupingSetAnalysis;
import io.trino.sql.analyzer.Analysis.JsonTableAnalysis;
import io.trino.sql.analyzer.Analysis.MergeAnalysis;
Expand Down Expand Up @@ -133,6 +134,7 @@
import io.trino.sql.tree.ColumnDefinition;
import io.trino.sql.tree.Comment;
import io.trino.sql.tree.Commit;
import io.trino.sql.tree.Corresponding;
import io.trino.sql.tree.CreateCatalog;
import io.trino.sql.tree.CreateMaterializedView;
import io.trino.sql.tree.CreateSchema;
Expand Down Expand Up @@ -282,6 +284,7 @@
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
Expand Down Expand Up @@ -3191,9 +3194,54 @@ protected Scope visitSetOperation(SetOperation node, Optional<Scope> scope)
{
checkState(node.getRelations().size() >= 2);

List<RelationType> childrenTypes = node.getRelations().stream()
.map(relation -> process(relation, scope).getRelationType().withOnlyVisibleFields())
.collect(toImmutableList());
List<Relation> relations = node.getRelations();

List<RelationType> childrenTypes = new ArrayList<>();
if (node.getCorresponding().isPresent()) {
checkState(relations.size() == 2, "CORRESPONDING requires 2 relations");

Corresponding corresponding = node.getCorresponding().get();
if (!corresponding.getColumns().isEmpty()) {
throw semanticException(NOT_SUPPORTED, node, "CORRESPONDING with columns is unsupported");
}

RelationType leftRelation = process(relations.getFirst(), scope).getRelationType().withOnlyVisibleFields();
RelationType rightRelation = process(relations.getLast(), scope).getRelationType().withOnlyVisibleFields();

Map<String, Integer> leftFieldsByName = buildNameToIndex(node, leftRelation);
Map<String, Integer> rightFieldsByName = buildNameToIndex(node, rightRelation);

List<String> correspondingColumns = leftFieldsByName.keySet().stream()
.filter(rightFieldsByName::containsKey)
.collect(toImmutableList());

if (correspondingColumns.isEmpty()) {
throw semanticException(MISMATCHED_COLUMN_ALIASES, node, "No corresponding columns");
}

ImmutableList.Builder<Integer> leftColumnIndexes = ImmutableList.builderWithExpectedSize(correspondingColumns.size());
ImmutableList.Builder<Integer> rightColumnIndexes = ImmutableList.builderWithExpectedSize(correspondingColumns.size());
ImmutableList.Builder<Field> leftRequiredFields = ImmutableList.builderWithExpectedSize(correspondingColumns.size());
ImmutableList.Builder<Field> rightRequiredFields = ImmutableList.builderWithExpectedSize(correspondingColumns.size());
for (String correspondingColumn : correspondingColumns) {
int leftFieldIndex = leftFieldsByName.get(correspondingColumn);
int rightFieldIndex = rightFieldsByName.get(correspondingColumn);
leftColumnIndexes.add(leftFieldIndex);
rightColumnIndexes.add(rightFieldIndex);
leftRequiredFields.add(leftRelation.getFieldByIndex(leftFieldIndex));
rightRequiredFields.add(rightRelation.getFieldByIndex(rightFieldIndex));
}

analysis.setCorrespondingAnalysis(node.getRelations().getFirst(), new CorrespondingAnalysis(leftColumnIndexes.build(), leftRequiredFields.build()));
analysis.setCorrespondingAnalysis(node.getRelations().getLast(), new CorrespondingAnalysis(rightColumnIndexes.build(), rightRequiredFields.build()));

childrenTypes.add(new RelationType(leftRequiredFields.build()).withOnlyVisibleFields());
childrenTypes.add(new RelationType(rightRequiredFields.build()).withOnlyVisibleFields());
}
else {
childrenTypes.add(process(relations.getFirst(), scope).getRelationType().withOnlyVisibleFields());
childrenTypes.add(process(relations.getLast(), scope).getRelationType().withOnlyVisibleFields());
}

String setOperationName = node.getClass().getSimpleName().toUpperCase(ENGLISH);
Type[] outputFieldTypes = childrenTypes.get(0).getVisibleFields().stream()
Expand Down Expand Up @@ -3264,8 +3312,8 @@ protected Scope visitSetOperation(SetOperation node, Optional<Scope> scope)
.collect(toImmutableSet()));
}

for (int i = 0; i < node.getRelations().size(); i++) {
Relation relation = node.getRelations().get(i);
for (int i = 0; i < relations.size(); i++) {
Relation relation = relations.get(i);
RelationType relationType = childrenTypes.get(i);
for (int j = 0; j < relationType.getVisibleFields().size(); j++) {
Type outputFieldType = outputFieldTypes[j];
Expand All @@ -3279,6 +3327,22 @@ protected Scope visitSetOperation(SetOperation node, Optional<Scope> scope)
return createAndAssignScope(node, scope, outputDescriptorFields);
}

private static Map<String, Integer> buildNameToIndex(SetOperation node, RelationType relationType)
{
Map<String, Integer> nameToIndex = new LinkedHashMap<>();
for (int i = 0; i < relationType.getAllFieldCount(); i++) {
Field field = relationType.getFieldByIndex(i);
String name = field.getName()
.orElseThrow(() -> semanticException(MISSING_COLUMN_NAME, node, "Anonymous columns are not allowed in set operations with CORRESPONDING"))
// TODO https://github.com/trinodb/trino/issues/17 Add support for case sensitive identifiers
.toLowerCase(ENGLISH);
if (nameToIndex.put(name, i) != null) {
throw semanticException(AMBIGUOUS_NAME, node, "Duplicate columns found when using CORRESPONDING in set operations: %s", name);
}
}
return ImmutableMap.copyOf(nameToIndex);
}

@Override
protected Scope visitJoin(Join node, Optional<Scope> scope)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import io.trino.spi.type.Type;
import io.trino.sql.PlannerContext;
import io.trino.sql.analyzer.Analysis;
import io.trino.sql.analyzer.Analysis.CorrespondingAnalysis;
import io.trino.sql.analyzer.Analysis.JsonTableAnalysis;
import io.trino.sql.analyzer.Analysis.TableArgumentAnalysis;
import io.trino.sql.analyzer.Analysis.TableFunctionInvocationAnalysis;
Expand Down Expand Up @@ -1865,9 +1866,33 @@ private SetOperationPlan process(SetOperation node)
ImmutableListMultimap.Builder<Symbol, Symbol> symbolMapping = ImmutableListMultimap.builder();
ImmutableList.Builder<PlanNode> sources = ImmutableList.builder();

for (Relation child : node.getRelations()) {
List<Relation> relations = node.getRelations();
checkArgument(relations.size() == 2, "relations size must be 2");
for (Relation child : relations) {
RelationPlan plan = process(child, null);

if (node.getCorresponding().isPresent()) {
int[] fieldIndexForVisibleColumn = new int[plan.getDescriptor().getVisibleFieldCount()];
int visibleColumn = 0;
for (int i = 0; i < plan.getDescriptor().getAllFieldCount(); i++) {
if (!plan.getDescriptor().getFieldByIndex(i).isHidden()) {
fieldIndexForVisibleColumn[visibleColumn] = i;
visibleColumn++;
}
}

CorrespondingAnalysis correspondingAnalysis = analysis.getCorrespondingAnalysis(child);
List<Symbol> requiredColumns = correspondingAnalysis.indexes().stream()
.filter(column -> column < fieldIndexForVisibleColumn.length)
.map(column -> fieldIndexForVisibleColumn[column])
.map(plan::getSymbol)
.collect(toImmutableList());

ProjectNode projectNode = new ProjectNode(idAllocator.getNextId(), plan.getRoot(), Assignments.identity(requiredColumns));
Scope scope = Scope.builder().withRelationType(plan.getScope().getRelationId(), new RelationType(correspondingAnalysis.fields())).build();
plan = new RelationPlan(projectNode, scope, requiredColumns, plan.getOuterContext());
}

NodeAndMappings planAndMappings;
List<Type> types = analysis.getRelationCoercion(child);
if (types == null) {
Expand Down
Loading