Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for CORRESPONDING option in set operations #25260

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -272,8 +272,10 @@ rowCount

queryTerm
: queryPrimary #queryTermDefault
| left=queryTerm operator=INTERSECT setQuantifier? right=queryTerm #setOperation
| left=queryTerm operator=(UNION | EXCEPT) setQuantifier? right=queryTerm #setOperation
| left=queryTerm operator=INTERSECT
setQuantifier? CORRESPONDING? right=queryTerm #setOperation
| left=queryTerm operator=(UNION | EXCEPT)
setQuantifier? CORRESPONDING? right=queryTerm #setOperation
;

queryPrimary
Expand Down Expand Up @@ -1000,7 +1002,7 @@ nonReserved
// IMPORTANT: this rule must only contain tokens. Nested rules are not supported. See SqlParser.exitNonReserved
: ABSENT | ADD | ADMIN | AFTER | ALL | ANALYZE | ANY | ARRAY | ASC | AT | AUTHORIZATION
| BEGIN | BERNOULLI | BOTH
| CALL | CALLED | CASCADE | CATALOG | CATALOGS | COLUMN | COLUMNS | COMMENT | COMMIT | COMMITTED | CONDITIONAL | COPARTITION | COUNT | CURRENT
| CALL | CALLED | CASCADE | CATALOG | CATALOGS | COLUMN | COLUMNS | COMMENT | COMMIT | COMMITTED | CONDITIONAL | COPARTITION | CORRESPONDING | COUNT | CURRENT
| DATA | DATE | DAY | DECLARE | DEFAULT | DEFINE | DEFINER | DENY | DESC | DESCRIPTOR | DETERMINISTIC | DISTRIBUTED | DO | DOUBLE
| ELSEIF | EMPTY | ENCODING | ERROR | EXCLUDING | EXECUTE | EXPLAIN
| FETCH | FILTER | FINAL | FIRST | FOLLOWING | FORMAT | FUNCTION | FUNCTIONS
Expand Down Expand Up @@ -1061,6 +1063,7 @@ CONDITIONAL: 'CONDITIONAL';
CONSTRAINT: 'CONSTRAINT';
COUNT: 'COUNT';
COPARTITION: 'COPARTITION';
CORRESPONDING: 'CORRESPONDING';
CREATE: 'CREATE';
CROSS: 'CROSS';
CUBE: 'CUBE';
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ public void test()
"CONDITIONAL",
"CONSTRAINT",
"COPARTITION",
"CORRESPONDING",
"COUNT",
"CREATE",
"CROSS",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.MoreCollectors.toOptional;
import static java.util.Objects.requireNonNull;

/**
Expand Down Expand Up @@ -81,6 +82,16 @@ public Field getFieldByIndex(int fieldIndex)
return allFields.get(fieldIndex);
}

/**
* Gets the field at the specified name.
*/
public Optional<Field> getFieldByName(String name)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does this handle standard SQL identifier canonicalization and matching?

{
return allFields.stream()
.filter(field -> field.getName().isPresent() && field.getName().get().equals(name))
.collect(toOptional());
}

/**
* Gets only the visible fields.
* No assumptions should be made about the order of the fields returned from this method.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3189,11 +3189,29 @@ protected Scope visitSubqueryExpression(SubqueryExpression node, Optional<Scope>
@Override
protected Scope visitSetOperation(SetOperation node, Optional<Scope> scope)
{
checkState(node.getRelations().size() >= 2);

List<RelationType> childrenTypes = node.getRelations().stream()
.map(relation -> process(relation, scope).getRelationType().withOnlyVisibleFields())
.collect(toImmutableList());
List<Relation> relations = node.getRelations();
checkState(relations.size() == 2, "relations size must be 2");
boolean corresponding = node.isCorresponding();

List<RelationType> childrenTypes = new ArrayList<>();
childrenTypes.add(process(relations.getFirst(), scope).getRelationType().withOnlyVisibleFields());
if (corresponding) {
RelationType left = childrenTypes.getFirst();
RelationType right = process(relations.getLast(), scope).getRelationType().withOnlyVisibleFields();
checkColumnNames(node, left.getVisibleFields());
checkColumnNames(node, right.getVisibleFields());

List<Field> fields = new ArrayList<>();
for (int i = 0; i < left.getAllFieldCount(); i++) {
Field field = left.getFieldByIndex(i);
String name = field.getName().orElseThrow();
fields.add(right.getFieldByName(name).orElseThrow(() -> semanticException(COLUMN_NOT_FOUND, node, "Column '%s' cannot be resolved", name)));
}
childrenTypes.add(new RelationType(fields).withOnlyVisibleFields());
}
else {
childrenTypes.add(process(relations.getLast(), scope).getRelationType().withOnlyVisibleFields());
}

String setOperationName = node.getClass().getSimpleName().toUpperCase(ENGLISH);
Type[] outputFieldTypes = childrenTypes.get(0).getVisibleFields().stream()
Expand Down Expand Up @@ -3264,8 +3282,8 @@ protected Scope visitSetOperation(SetOperation node, Optional<Scope> scope)
.collect(toImmutableSet()));
}

for (int i = 0; i < node.getRelations().size(); i++) {
Relation relation = node.getRelations().get(i);
for (int i = 0; i < relations.size(); i++) {
Relation relation = relations.get(i);
RelationType relationType = childrenTypes.get(i);
for (int j = 0; j < relationType.getVisibleFields().size(); j++) {
Type outputFieldType = outputFieldTypes[j];
Expand All @@ -3279,6 +3297,17 @@ protected Scope visitSetOperation(SetOperation node, Optional<Scope> scope)
return createAndAssignScope(node, scope, outputDescriptorFields);
}

private static void checkColumnNames(SetOperation node, Collection<Field> fields)
{
Set<String> names = new HashSet<>();
for (Field field : fields) {
String name = field.getName().orElseThrow(() -> semanticException(MISSING_COLUMN_NAME, node, "Anonymous columns are not allowed in set operations with CORRESPONDING"));
if (!names.add(name)) {
throw semanticException(AMBIGUOUS_NAME, node, "Duplicate columns found when using CORRESPONDING in set operations: %s", name);
}
}
}

@Override
protected Scope visitJoin(Join node, Optional<Scope> scope)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1865,9 +1865,31 @@ private SetOperationPlan process(SetOperation node)
ImmutableListMultimap.Builder<Symbol, Symbol> symbolMapping = ImmutableListMultimap.builder();
ImmutableList.Builder<PlanNode> sources = ImmutableList.builder();

for (Relation child : node.getRelations()) {
List<Relation> relations = node.getRelations();
checkArgument(relations.size() == 2, "relations size must be 2");
Relation rightRelation = relations.getLast();
for (Relation child : relations) {
RelationPlan plan = process(child, null);

if (node.isCorresponding() && child.equals(rightRelation)) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The planner should not be doing any mappings based on names. That should have been established during analysis, and any necessary supporting data structure recorded in the Analysis object. E.g., a list of fields to select in the order in which they should be considered + the corresponding relation types.

// Replace right relation's field order to match the output fields of the set operation
Map<String, Symbol> nameToSymbol = new HashMap<>();
RelationType descriptor = plan.getDescriptor();
Collection<Field> visibleFields = outputFields.getVisibleFields();
for (int i = 0; i < visibleFields.size(); i++) {
nameToSymbol.put(descriptor.getFieldByIndex(i).getName().orElseThrow(), plan.getSymbol(i));
}

List<Symbol> fieldMappings = visibleFields.stream()
.map(field -> nameToSymbol.get(field.getName().orElseThrow()))
.collect(toImmutableList());
ProjectNode projectNode = new ProjectNode(
idAllocator.getNextId(),
plan.getRoot(),
Assignments.identity(fieldMappings));
plan = new RelationPlan(projectNode, plan.getScope(), fieldMappings, plan.getOuterContext());
}

NodeAndMappings planAndMappings;
List<Type> types = analysis.getRelationCoercion(child);
if (types == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -293,4 +293,116 @@ public void testIntersectWithEmptyBranches()
.describedAs("INTERSECT DISTINCT with empty branches")
.returnsEmptyResult();
}

@Test
void testExceptCorresponding()
{
assertThat(assertions.query(
"""
SELECT * FROM (VALUES (1, 'alice'), (1, 'alice')) t(x, y)
EXCEPT CORRESPONDING
SELECT * FROM (VALUES ('alice', 1)) t(y, x)
"""))
.returnsEmptyResult();

assertThat(assertions.query(
"""
SELECT * FROM (VALUES (1, 'alice'), (1, 'alice')) t(x, y)
EXCEPT ALL CORRESPONDING
SELECT * FROM (VALUES ('alice', 1)) t(y, x)
"""))
.matches("VALUES (1, 'alice')");

}

@Test
void testUnionCorresponding()
{
assertThat(assertions.query(
"""
SELECT 1 AS x, 'alice' AS y
UNION CORRESPONDING
SELECT * FROM (VALUES ('alice', 1), ('bob', 2)) t(y, x)
"""))
.matches("VALUES (1, 'alice'), (2, 'bob')");

assertThat(assertions.query(
"""
SELECT 1 AS x, 'alice' AS y
UNION ALL CORRESPONDING
SELECT * FROM (VALUES ('alice', 1), ('bob', 2)) t(y, x)
"""))
.matches("VALUES (1, 'alice'), (1, 'alice'), (2, 'bob')");
}

@Test
void testIntersectCorresponding()
{
assertThat(assertions.query(
"""
SELECT * FROM (VALUES (1, 'alice'), (1, 'alice')) t(x, y)
INTERSECT CORRESPONDING
SELECT * FROM (VALUES ('alice', 1), ('alice', 1)) t(y, x)
"""))
.matches("VALUES (1, 'alice')");

assertThat(assertions.query(
"""
SELECT * FROM (VALUES (1, 'alice'), (1, 'alice')) t(x, y)
INTERSECT ALL CORRESPONDING
SELECT * FROM (VALUES ('alice', 1), ('alice', 1)) t(y, x)
"""))
.matches("VALUES (1, 'alice'), (1, 'alice')");
}

@Test
void testCorrespondingDuplicateNames()
{
assertThat(assertions.query("SELECT 1 AS x, 2 AS y EXCEPT CORRESPONDING SELECT 1 AS x, 2 AS x"))
.failure().hasMessage("line 1:23: Duplicate columns found when using CORRESPONDING in set operations: x");
assertThat(assertions.query("SELECT 1 AS x, 2 AS x EXCEPT CORRESPONDING SELECT 1 AS y, 2 AS x"))
.failure().hasMessage("line 1:23: Duplicate columns found when using CORRESPONDING in set operations: x");

assertThat(assertions.query("SELECT 1 AS x, 2 AS y UNION CORRESPONDING SELECT 1 AS x, 2 AS x"))
.failure().hasMessage("line 1:23: Duplicate columns found when using CORRESPONDING in set operations: x");
assertThat(assertions.query("SELECT 1 AS x, 2 AS x UNION CORRESPONDING SELECT 1 AS x, 2 AS y"))
.failure().hasMessage("line 1:23: Duplicate columns found when using CORRESPONDING in set operations: x");

assertThat(assertions.query("SELECT 1 AS x, 2 AS y INTERSECT CORRESPONDING SELECT 1 AS x, 2 AS x"))
.failure().hasMessage("line 1:23: Duplicate columns found when using CORRESPONDING in set operations: x");
assertThat(assertions.query("SELECT 1 AS x, 2 AS x INTERSECT CORRESPONDING SELECT 1 AS x, 2 AS y"))
.failure().hasMessage("line 1:23: Duplicate columns found when using CORRESPONDING in set operations: x");
}

@Test
void testCorrespondingNameMismatch()
{
assertThat(assertions.query("SELECT 1 AS x EXCEPT CORRESPONDING SELECT 2 AS y"))
.failure().hasMessage("line 1:15: Column 'x' cannot be resolved");

assertThat(assertions.query("SELECT 1 AS x UNION CORRESPONDING SELECT 2 AS y"))
.failure().hasMessage("line 1:15: Column 'x' cannot be resolved");

assertThat(assertions.query("SELECT 1 AS x INTERSECT CORRESPONDING SELECT 2 AS y"))
.failure().hasMessage("line 1:15: Column 'x' cannot be resolved");
}

@Test
void testCorrespondingWithAnonymousColumn()
{
assertThat(assertions.query("SELECT 1 EXCEPT CORRESPONDING SELECT 2 AS x"))
.failure().hasMessage("line 1:10: Anonymous columns are not allowed in set operations with CORRESPONDING");
assertThat(assertions.query("SELECT 1 AS x EXCEPT CORRESPONDING SELECT 2"))
.failure().hasMessage("line 1:15: Anonymous columns are not allowed in set operations with CORRESPONDING");

assertThat(assertions.query("SELECT 1 UNION CORRESPONDING SELECT 2 AS x"))
.failure().hasMessage("line 1:10: Anonymous columns are not allowed in set operations with CORRESPONDING");
assertThat(assertions.query("SELECT 1 AS x UNION CORRESPONDING SELECT 2"))
.failure().hasMessage("line 1:15: Anonymous columns are not allowed in set operations with CORRESPONDING");

assertThat(assertions.query("SELECT 1 INTERSECT CORRESPONDING SELECT 2 AS x"))
.failure().hasMessage("line 1:10: Anonymous columns are not allowed in set operations with CORRESPONDING");
assertThat(assertions.query("SELECT 1 AS x INTERSECT CORRESPONDING SELECT 2"))
.failure().hasMessage("line 1:15: Anonymous columns are not allowed in set operations with CORRESPONDING");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1034,6 +1034,9 @@ protected Void visitUnion(Union node, Integer indent)
if (!node.isDistinct()) {
builder.append("ALL ");
}
if (node.isCorresponding()) {
builder.append("CORRESPONDING ");
}
}
}

Expand All @@ -1049,6 +1052,9 @@ protected Void visitExcept(Except node, Integer indent)
if (!node.isDistinct()) {
builder.append("ALL ");
}
if (node.isCorresponding()) {
builder.append("CORRESPONDING ");
}

processRelation(node.getRight(), indent);

Expand All @@ -1068,6 +1074,9 @@ protected Void visitIntersect(Intersect node, Integer indent)
if (!node.isDistinct()) {
builder.append("ALL ");
}
if (node.isCorresponding()) {
builder.append("CORRESPONDING ");
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1383,11 +1383,12 @@ public Node visitSetOperation(SqlBaseParser.SetOperationContext context)
QueryBody right = (QueryBody) visit(context.right);

boolean distinct = context.setQuantifier() == null || context.setQuantifier().DISTINCT() != null;
boolean corresponding = context.CORRESPONDING() != null;

return switch (context.operator.getType()) {
case SqlBaseLexer.UNION -> new Union(getLocation(context.UNION()), ImmutableList.of(left, right), distinct);
case SqlBaseLexer.INTERSECT -> new Intersect(getLocation(context.INTERSECT()), ImmutableList.of(left, right), distinct);
case SqlBaseLexer.EXCEPT -> new Except(getLocation(context.EXCEPT()), left, right, distinct);
case SqlBaseLexer.UNION -> new Union(getLocation(context.UNION()), ImmutableList.of(left, right), distinct, corresponding);
case SqlBaseLexer.INTERSECT -> new Intersect(getLocation(context.INTERSECT()), ImmutableList.of(left, right), distinct, corresponding);
case SqlBaseLexer.EXCEPT -> new Except(getLocation(context.EXCEPT()), left, right, distinct, corresponding);
default -> throw new IllegalArgumentException("Unsupported set operation: " + context.operator.getText());
};
}
Expand Down
14 changes: 9 additions & 5 deletions core/trino-parser/src/main/java/io/trino/sql/tree/Except.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ public class Except
private final Relation left;
private final Relation right;

public Except(NodeLocation location, Relation left, Relation right, boolean distinct)
public Except(NodeLocation location, Relation left, Relation right, boolean distinct, boolean corresponding)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using a boolean is not going to be sufficient for modeling the list of columns in the CORRESPONDING clause. It will require a bigger change later.

Add support for passing in the list of columns in the grammar and AST, even if, for the first version, the analyzer rejects any queries where the list is provided.

{
super(Optional.of(location), distinct);
super(Optional.of(location), distinct, corresponding);
requireNonNull(left, "left is null");
requireNonNull(right, "right is null");

Expand Down Expand Up @@ -73,6 +73,7 @@ public String toString()
.add("left", left)
.add("right", right)
.add("distinct", isDistinct())
.add("corresponding", isCorresponding())
.toString();
}

Expand All @@ -88,13 +89,14 @@ public boolean equals(Object obj)
Except o = (Except) obj;
return Objects.equals(left, o.left) &&
Objects.equals(right, o.right) &&
isDistinct() == o.isDistinct();
isDistinct() == o.isDistinct() &&
isCorresponding() == o.isCorresponding();
}

@Override
public int hashCode()
{
return Objects.hash(left, right, isDistinct());
return Objects.hash(left, right, isDistinct(), isCorresponding());
}

@Override
Expand All @@ -104,6 +106,8 @@ public boolean shallowEquals(Node other)
return false;
}

return this.isDistinct() == ((Except) other).isDistinct();
Except otherExcept = (Except) other;
return this.isDistinct() == otherExcept.isDistinct() &&
this.isCorresponding() == otherExcept.isCorresponding();
}
}
Loading