|
20 | 20 | import io.trino.sql.ir.Expression;
|
21 | 21 | import io.trino.sql.ir.ExpressionTreeRewriter;
|
22 | 22 | import io.trino.sql.ir.Row;
|
23 |
| -import io.trino.type.UnknownType; |
24 | 23 |
|
25 | 24 | /**
|
26 | 25 | * Transforms expressions of the form
|
27 | 26 | *
|
28 | 27 | * <pre>
|
29 | 28 | * CAST(
|
30 |
| - * CAST( |
31 |
| - * ROW(x, y) |
32 |
| - * AS row(f1 type1, f2 type2)) |
33 |
| - * AS row(g1 type3, g2 type4)) |
| 29 | + * ROW(x, y) |
| 30 | + * AS row(f1 type1, f2 type2)) |
34 | 31 | * </pre>
|
35 | 32 | *
|
36 | 33 | * to
|
37 | 34 | *
|
38 | 35 | * <pre>
|
39 |
| - * CAST( |
40 |
| - * ROW( |
41 |
| - * CAST(x AS type1), |
42 |
| - * CAST(y AS type2)) |
43 |
| - * AS row(g1 type3, g2 type4)) |
| 36 | + * ROW( |
| 37 | + * CAST(x AS type1) as f1, |
| 38 | + * CAST(y AS type2) as f2) |
44 | 39 | * </pre>
|
45 |
| - * |
46 |
| - * Note: it preserves the top-level CAST if the row type has field names because the names are needed by the ROW to JSON cast |
47 |
| - * TODO: ideally, the types involved in ROW to JSON cast should be captured at analysis time and |
48 |
| - * remain fixed for the duration of the optimization process so as to have flexibility in terms |
49 |
| - * of removing field names, which are irrelevant in the IR |
50 | 40 | */
|
51 | 41 | public class PushCastIntoRow
|
52 | 42 | extends ExpressionRewriteRuleSet
|
53 | 43 | {
|
54 | 44 | public PushCastIntoRow()
|
55 | 45 | {
|
56 |
| - super((expression, context) -> ExpressionTreeRewriter.rewriteWith(new Rewriter(), expression, false)); |
| 46 | + super((expression, context) -> ExpressionTreeRewriter.rewriteWith(new Rewriter(), expression, null)); |
57 | 47 | }
|
58 | 48 |
|
59 | 49 | private static class Rewriter
|
60 |
| - extends io.trino.sql.ir.ExpressionRewriter<Boolean> |
| 50 | + extends io.trino.sql.ir.ExpressionRewriter<Void> |
61 | 51 | {
|
62 | 52 | @Override
|
63 |
| - public Expression rewriteCast(Cast node, Boolean inRowCast, ExpressionTreeRewriter<Boolean> treeRewriter) |
| 53 | + public Expression rewriteCast(Cast node, Void context, ExpressionTreeRewriter<Void> treeRewriter) |
64 | 54 | {
|
65 |
| - if (!(node.type() instanceof RowType type)) { |
66 |
| - return treeRewriter.defaultRewrite(node, false); |
| 55 | + if (!(node.type() instanceof RowType castToType)) { |
| 56 | + return treeRewriter.defaultRewrite(node, null); |
67 | 57 | }
|
68 | 58 |
|
69 |
| - // if inRowCast == true or row is anonymous, we're free to push Cast into Row. An enclosing CAST(... AS ROW) will take care of preserving field names |
70 |
| - // otherwise, apply recursively with inRowCast == true and don't push this one |
71 |
| - |
72 |
| - if (inRowCast || type.getFields().stream().allMatch(field -> field.getName().isEmpty())) { |
73 |
| - Expression value = treeRewriter.rewrite(node.expression(), true); |
74 |
| - |
75 |
| - if (value instanceof Row row) { |
76 |
| - ImmutableList.Builder<Expression> items = ImmutableList.builder(); |
77 |
| - for (int i = 0; i < row.items().size(); i++) { |
78 |
| - Expression item = row.items().get(i); |
79 |
| - Type itemType = type.getFields().get(i).getType(); |
80 |
| - if (!(itemType instanceof UnknownType)) { |
81 |
| - item = new Cast(item, itemType); |
82 |
| - } |
83 |
| - items.add(item); |
| 59 | + Expression value = treeRewriter.rewrite(node.expression(), null); |
| 60 | + if (value instanceof Row(java.util.List<Expression> expressions, RowType type)) { |
| 61 | + ImmutableList.Builder<Expression> items = ImmutableList.builder(); |
| 62 | + for (int i = 0; i < expressions.size(); i++) { |
| 63 | + Expression fieldValue = expressions.get(i); |
| 64 | + Type fieldType = castToType.getFields().get(i).getType(); |
| 65 | + if (!fieldValue.type().equals(fieldType)) { |
| 66 | + fieldValue = new Cast(fieldValue, fieldType); |
84 | 67 | }
|
85 |
| - return new Row(items.build()); |
| 68 | + items.add(fieldValue); |
86 | 69 | }
|
| 70 | + return new Row(items.build(), castToType); |
87 | 71 | }
|
88 | 72 |
|
89 |
| - return treeRewriter.defaultRewrite(node, true); |
| 73 | + return treeRewriter.defaultRewrite(node, null); |
90 | 74 | }
|
91 | 75 | }
|
92 | 76 | }
|
0 commit comments