Skip to content

Commit f0fb575

Browse files
committed
Allow field name declaration in row literal
Add support for `row(a 1, b 2)` instead of the much more complex `cast(row(1, 2) as row(a integer, b integer))`.
1 parent aaffefb commit f0fb575

30 files changed

+353
-128
lines changed

core/trino-grammar/src/main/antlr4/io/trino/grammar/sql/SqlBase.g4

+5-1
Original file line numberDiff line numberDiff line change
@@ -574,7 +574,7 @@ primaryExpression
574574
| QUESTION_MARK #parameter
575575
| POSITION '(' valueExpression IN valueExpression ')' #position
576576
| '(' expression (',' expression)+ ')' #rowConstructor
577-
| ROW '(' expression (',' expression)* ')' #rowConstructor
577+
| ROW '(' fieldConstructor (',' fieldConstructor)* ')' #rowConstructor
578578
| name=LISTAGG '(' setQuantifier? expression (',' string)?
579579
(ON OVERFLOW listAggOverflowBehavior)? ')'
580580
(WITHIN GROUP '(' ORDER BY sortItem (',' sortItem)* ')')
@@ -646,6 +646,10 @@ primaryExpression
646646
')' #jsonArray
647647
;
648648

649+
fieldConstructor
650+
: expression (AS? identifier)?
651+
;
652+
649653
jsonPathInvocation
650654
: jsonValueExpression ',' path=string
651655
(AS pathName=identifier)?

core/trino-main/src/main/java/io/trino/sql/analyzer/AggregationAnalyzer.java

+7-1
Original file line numberDiff line numberDiff line change
@@ -701,10 +701,16 @@ protected Boolean visitTryExpression(TryExpression node, Void context)
701701
@Override
702702
protected Boolean visitRow(Row node, Void context)
703703
{
704-
return node.getItems().stream()
704+
return node.getFields().stream()
705705
.allMatch(item -> process(item, context));
706706
}
707707

708+
@Override
709+
protected Boolean visitRowField(Row.Field node, Void context)
710+
{
711+
return process(node.getExpression(), context);
712+
}
713+
708714
@Override
709715
protected Boolean visitParameter(Parameter node, Void context)
710716
{

core/trino-main/src/main/java/io/trino/sql/analyzer/ExpressionAnalyzer.java

+3-3
Original file line numberDiff line numberDiff line change
@@ -679,11 +679,11 @@ public Type process(Node node, @Nullable Context context)
679679
@Override
680680
protected Type visitRow(Row node, Context context)
681681
{
682-
List<Type> types = node.getItems().stream()
683-
.map(child -> process(child, context))
682+
List<RowType.Field> fields = node.getFields().stream()
683+
.map(field -> new RowType.Field(field.getName().map(Identifier::getCanonicalValue), process(field.getExpression(), context)))
684684
.collect(toImmutableList());
685685

686-
Type type = RowType.anonymous(types);
686+
Type type = RowType.from(fields);
687687
return setExpressionType(node, type);
688688
}
689689

core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -3929,7 +3929,7 @@ protected Scope visitValues(Values node, Optional<Scope> scope)
39293929
// TODO coerce the whole Row and add an Optimizer rule that converts CAST(ROW(...) AS ...) into ROW(CAST(...), CAST(...), ...).
39303930
// The rule would also handle Row-type expressions that were specified as CAST(ROW). It should support multiple casts over a ROW.
39313931
for (int i = 0; i < actualType.getTypeParameters().size(); i++) {
3932-
Expression item = ((Row) row).getItems().get(i);
3932+
Expression item = ((Row) row).getFields().get(i).getExpression();
39333933
Type actualItemType = actualType.getTypeParameters().get(i);
39343934
Type expectedItemType = commonSuperType.getTypeParameters().get(i);
39353935
if (!actualItemType.equals(expectedItemType)) {

core/trino-main/src/main/java/io/trino/sql/ir/ExpressionFormatter.java

+14-3
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import com.google.common.base.CharMatcher;
1717
import com.google.common.base.Joiner;
1818
import com.google.common.collect.ImmutableList;
19+
import io.trino.spi.type.RowType;
1920
import io.trino.sql.planner.Symbol;
2021

2122
import java.util.List;
@@ -67,9 +68,19 @@ protected String visitArray(Array node, Void context)
6768
@Override
6869
protected String visitRow(Row node, Void context)
6970
{
70-
return node.items().stream()
71-
.map(child -> process(child, context))
72-
.collect(joining(", ", "ROW (", ")"));
71+
List<RowType.Field> fieldTypes = ((RowType) node.type()).getFields();
72+
73+
StringBuilder builder = new StringBuilder();
74+
builder.append("ROW (");
75+
for (int i = 0; i < fieldTypes.size(); i++) {
76+
if (i > 0) {
77+
builder.append(", ");
78+
}
79+
builder.append(node.items().get(i).accept(this, context));
80+
fieldTypes.get(i).getName().ifPresent(name -> builder.append(" AS ").append(name));
81+
}
82+
builder.append(")");
83+
return builder.toString();
7384
}
7485

7586
@Override

core/trino-main/src/main/java/io/trino/sql/ir/ExpressionTreeRewriter.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ protected Expression visitRow(Row node, Context<C> context)
110110
List<Expression> items = rewrite(node.items(), context);
111111

112112
if (!sameElements(node.items(), items)) {
113-
return new Row(items);
113+
return new Row(items, node.type());
114114
}
115115

116116
return node;

core/trino-main/src/main/java/io/trino/sql/ir/Row.java

+25-5
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,15 @@
1616
import com.fasterxml.jackson.databind.annotation.JsonSerialize;
1717
import com.google.common.collect.ImmutableList;
1818
import io.trino.spi.type.RowType;
19-
import io.trino.spi.type.Type;
2019

2120
import java.util.List;
21+
import java.util.Optional;
2222
import java.util.stream.Collectors;
2323

2424
import static java.util.Objects.requireNonNull;
2525

2626
@JsonSerialize
27-
public record Row(List<Expression> items)
27+
public record Row(List<Expression> items, RowType type)
2828
implements Expression
2929
{
3030
public Row
@@ -33,10 +33,9 @@ public record Row(List<Expression> items)
3333
items = ImmutableList.copyOf(items);
3434
}
3535

36-
@Override
37-
public Type type()
36+
public Row(List<Expression> items)
3837
{
39-
return RowType.anonymous(items.stream().map(Expression::type).collect(Collectors.toList()));
38+
this(items, RowType.anonymous(items.stream().map(Expression::type).toList()));
4039
}
4140

4241
@Override
@@ -60,4 +59,25 @@ public String toString()
6059
.collect(Collectors.joining(", ")) +
6160
")";
6261
}
62+
63+
@JsonSerialize
64+
public record Field(Optional<String> name, Expression value)
65+
{
66+
public Field
67+
{
68+
requireNonNull(name, "name is null");
69+
requireNonNull(value, "value is null");
70+
}
71+
72+
public static Field anonymousField(Expression value)
73+
{
74+
return new Field(Optional.empty(), value);
75+
}
76+
77+
@Override
78+
public String toString()
79+
{
80+
return name.map(n -> n + " " + value).orElseGet(value::toString);
81+
}
82+
}
6383
}

core/trino-main/src/main/java/io/trino/sql/ir/optimizer/IrExpressionOptimizer.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ private Optional<Expression> processChildren(Expression expression, Session sess
192192
case Logical logical -> process(logical.terms(), session, bindings).map(arguments -> new Logical(logical.operator(), arguments));
193193
case Call call -> process(call.arguments(), session, bindings).map(arguments -> new Call(call.function(), arguments));
194194
case Array array -> process(array.elements(), session, bindings).map(elements -> new Array(array.elementType(), elements));
195-
case Row row -> process(row.items(), session, bindings).map(fields -> new Row(fields));
195+
case Row row -> process(row.items(), session, bindings).map(fields -> new Row(fields, row.type()));
196196
case Between between -> {
197197
Optional<Expression> value = process(between.value(), session, bindings);
198198
Optional<Expression> min = process(between.min(), session, bindings);

core/trino-main/src/main/java/io/trino/sql/ir/optimizer/rule/EvaluateRow.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ public class EvaluateRow
3737
@Override
3838
public Optional<Expression> apply(Expression expression, Session session, Map<Symbol, Expression> bindings)
3939
{
40-
if (!(expression instanceof Row(List<Expression> fields)) || !fields.stream().allMatch(Constant.class::isInstance)) {
40+
if (!(expression instanceof Row(List<Expression> fields, RowType _)) || !fields.stream().allMatch(Constant.class::isInstance)) {
4141
return Optional.empty();
4242
}
4343

core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java

+10-4
Original file line numberDiff line numberDiff line change
@@ -1764,10 +1764,16 @@ protected RelationPlan visitValues(Values node, Void context)
17641764

17651765
ImmutableList.Builder<Expression> rows = ImmutableList.builder();
17661766
for (io.trino.sql.tree.Expression row : node.getRows()) {
1767-
if (row instanceof io.trino.sql.tree.Row) {
1768-
rows.add(new Row(((io.trino.sql.tree.Row) row).getItems().stream()
1769-
.map(item -> coerceIfNecessary(analysis, item, translationMap.rewrite(item)))
1770-
.collect(toImmutableList())));
1767+
if (row instanceof io.trino.sql.tree.Row astRow) {
1768+
ImmutableList.Builder<Expression> fields = ImmutableList.builder();
1769+
ImmutableList.Builder<RowType.Field> typeFields = ImmutableList.builder();
1770+
for (int i = 0; i < astRow.getFields().size(); i++) {
1771+
io.trino.sql.tree.Row.Field astField = astRow.getFields().get(i);
1772+
Expression expression = coerceIfNecessary(analysis, astField.getExpression(), translationMap.rewrite(astField.getExpression()));
1773+
fields.add(expression);
1774+
typeFields.add(new RowType.Field(astField.getName().map(Identifier::getCanonicalValue), expression.type()));
1775+
}
1776+
rows.add(new Row(fields.build(), RowType.from(typeFields.build())));
17711777
}
17721778
else if (analysis.getType(row) instanceof RowType) {
17731779
rows.add(coerceIfNecessary(analysis, row, translationMap.rewrite(row)));

core/trino-main/src/main/java/io/trino/sql/planner/TranslationMap.java

+11-4
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@
150150
import static io.trino.sql.ir.Comparison.Operator.NOT_EQUAL;
151151
import static io.trino.sql.ir.IrExpressions.ifExpression;
152152
import static io.trino.sql.ir.IrExpressions.not;
153+
import static io.trino.sql.planner.QueryPlanner.coerceIfNecessary;
153154
import static io.trino.sql.planner.ScopeAware.scopeAwareKey;
154155
import static io.trino.sql.tree.JsonQuery.EmptyOrErrorBehavior.ERROR;
155156
import static io.trino.sql.tree.JsonQuery.QuotesBehavior.KEEP;
@@ -545,11 +546,17 @@ private io.trino.sql.ir.Expression translate(NotExpression expression)
545546
return not(plannerContext.getMetadata(), translateExpression(expression.getValue()));
546547
}
547548

548-
private io.trino.sql.ir.Expression translate(Row expression)
549+
private io.trino.sql.ir.Expression translate(Row row)
549550
{
550-
return new io.trino.sql.ir.Row(expression.getItems().stream()
551-
.map(this::translateExpression)
552-
.collect(toImmutableList()));
551+
ImmutableList.Builder<io.trino.sql.ir.Expression> fields = ImmutableList.builder();
552+
ImmutableList.Builder<RowType.Field> typeFields = ImmutableList.builder();
553+
for (int i = 0; i < row.getFields().size(); i++) {
554+
io.trino.sql.tree.Row.Field field = row.getFields().get(i);
555+
io.trino.sql.ir.Expression expression = coerceIfNecessary(analysis, field.getExpression(), translateExpression(field.getExpression()));
556+
fields.add(expression);
557+
typeFields.add(new RowType.Field(field.getName().map(Identifier::getCanonicalValue), expression.type()));
558+
}
559+
return new io.trino.sql.ir.Row(fields.build(), RowType.from(typeFields.build()));
553560
}
554561

555562
private io.trino.sql.ir.Expression translate(ComparisonExpression expression)

core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ExpressionRewriteRuleSet.java

+9-7
Original file line numberDiff line numberDiff line change
@@ -300,18 +300,20 @@ public Result apply(ValuesNode valuesNode, Captures captures, Context context)
300300

301301
boolean anyRewritten = false;
302302
ImmutableList.Builder<Expression> rows = ImmutableList.builder();
303-
for (Expression row : valuesNode.getRows().get()) {
303+
for (Expression original : valuesNode.getRows().get()) {
304304
Expression rewritten;
305-
if (row instanceof Row) {
305+
if (original instanceof Row row) {
306306
// preserve the structure of row
307-
rewritten = new Row(((Row) row).items().stream()
308-
.map(item -> rewriter.rewrite(item, context))
309-
.collect(toImmutableList()));
307+
rewritten = new Row(
308+
row.items().stream()
309+
.map(item -> rewriter.rewrite(item, context))
310+
.collect(toImmutableList()),
311+
row.type());
310312
}
311313
else {
312-
rewritten = rewriter.rewrite(row, context);
314+
rewritten = rewriter.rewrite(original, context);
313315
}
314-
if (!row.equals(rewritten)) {
316+
if (!original.equals(rewritten)) {
315317
anyRewritten = true;
316318
}
317319
rows.add(rewritten);

core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushCastIntoRow.java

+21-37
Original file line numberDiff line numberDiff line change
@@ -20,73 +20,57 @@
2020
import io.trino.sql.ir.Expression;
2121
import io.trino.sql.ir.ExpressionTreeRewriter;
2222
import io.trino.sql.ir.Row;
23-
import io.trino.type.UnknownType;
2423

2524
/**
2625
* Transforms expressions of the form
2726
*
2827
* <pre>
2928
* CAST(
30-
* CAST(
31-
* ROW(x, y)
32-
* AS row(f1 type1, f2 type2))
33-
* AS row(g1 type3, g2 type4))
29+
* ROW(x, y)
30+
* AS row(f1 type1, f2 type2))
3431
* </pre>
3532
*
3633
* to
3734
*
3835
* <pre>
39-
* CAST(
40-
* ROW(
41-
* CAST(x AS type1),
42-
* CAST(y AS type2))
43-
* AS row(g1 type3, g2 type4))
36+
* ROW(
37+
* CAST(x AS type1) as f1,
38+
* CAST(y AS type2) as f2)
4439
* </pre>
45-
*
46-
* Note: it preserves the top-level CAST if the row type has field names because the names are needed by the ROW to JSON cast
47-
* TODO: ideally, the types involved in ROW to JSON cast should be captured at analysis time and
48-
* remain fixed for the duration of the optimization process so as to have flexibility in terms
49-
* of removing field names, which are irrelevant in the IR
5040
*/
5141
public class PushCastIntoRow
5242
extends ExpressionRewriteRuleSet
5343
{
5444
public PushCastIntoRow()
5545
{
56-
super((expression, context) -> ExpressionTreeRewriter.rewriteWith(new Rewriter(), expression, false));
46+
super((expression, context) -> ExpressionTreeRewriter.rewriteWith(new Rewriter(), expression, null));
5747
}
5848

5949
private static class Rewriter
60-
extends io.trino.sql.ir.ExpressionRewriter<Boolean>
50+
extends io.trino.sql.ir.ExpressionRewriter<Void>
6151
{
6252
@Override
63-
public Expression rewriteCast(Cast node, Boolean inRowCast, ExpressionTreeRewriter<Boolean> treeRewriter)
53+
public Expression rewriteCast(Cast node, Void context, ExpressionTreeRewriter<Void> treeRewriter)
6454
{
65-
if (!(node.type() instanceof RowType type)) {
66-
return treeRewriter.defaultRewrite(node, false);
55+
if (!(node.type() instanceof RowType castToType)) {
56+
return treeRewriter.defaultRewrite(node, null);
6757
}
6858

69-
// if inRowCast == true or row is anonymous, we're free to push Cast into Row. An enclosing CAST(... AS ROW) will take care of preserving field names
70-
// otherwise, apply recursively with inRowCast == true and don't push this one
71-
72-
if (inRowCast || type.getFields().stream().allMatch(field -> field.getName().isEmpty())) {
73-
Expression value = treeRewriter.rewrite(node.expression(), true);
74-
75-
if (value instanceof Row row) {
76-
ImmutableList.Builder<Expression> items = ImmutableList.builder();
77-
for (int i = 0; i < row.items().size(); i++) {
78-
Expression item = row.items().get(i);
79-
Type itemType = type.getFields().get(i).getType();
80-
if (!(itemType instanceof UnknownType)) {
81-
item = new Cast(item, itemType);
82-
}
83-
items.add(item);
59+
Expression value = treeRewriter.rewrite(node.expression(), null);
60+
if (value instanceof Row(java.util.List<Expression> expressions, RowType type)) {
61+
ImmutableList.Builder<Expression> items = ImmutableList.builder();
62+
for (int i = 0; i < expressions.size(); i++) {
63+
Expression fieldValue = expressions.get(i);
64+
Type fieldType = castToType.getFields().get(i).getType();
65+
if (!fieldValue.type().equals(fieldType)) {
66+
fieldValue = new Cast(fieldValue, fieldType);
8467
}
85-
return new Row(items.build());
68+
items.add(fieldValue);
8669
}
70+
return new Row(items.build(), castToType);
8771
}
8872

89-
return treeRewriter.defaultRewrite(node, true);
73+
return treeRewriter.defaultRewrite(node, null);
9074
}
9175
}
9276
}

core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedSingleRowSubqueryToProject.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ public Result apply(CorrelatedJoinNode parent, Captures captures, Context contex
9393
.putIdentities(parent.getInput().getOutputSymbols());
9494
forEachPair(
9595
values.getOutputSymbols().stream(),
96-
row.items().stream(),
96+
row.children().stream(),
9797
assignments::put);
9898
return Result.ofPlanNode(projectNode(parent.getInput(), assignments.build(), context));
9999
}

0 commit comments

Comments
 (0)