1
+ import itertools
2
+
1
3
import sqlglot
2
4
3
5
from substrait import proto
@@ -19,21 +21,28 @@ def parse_sql_extended_expression(catalog, schema, sql):
19
21
if not isinstance (select , sqlglot .expressions .Select ):
20
22
raise ValueError ("a SELECT statement was expected" )
21
23
22
- invoked_functions_projection , projections = _substrait_projection_from_sqlglot (
23
- catalog , schema , select .expressions
24
- )
24
+ sqlglot_parser = SQLGlotParser (catalog , schema )
25
+
26
+ # Handle the projections in the SELECT statemenent.
27
+ project_expressions = []
28
+ projection_invoked_functions = set ()
29
+ for sqlexpr in select .expressions :
30
+ invoked_functions , output_name , expr = sqlglot_parser .expression_from_sqlglot (sqlexpr )
31
+ projection_invoked_functions .update (invoked_functions )
32
+ project_expressions .append (proto .ExpressionReference (expression = expr , output_names = [output_name ]))
25
33
extension_uris , extensions = catalog .extensions_for_functions (
26
- invoked_functions_projection
34
+ projection_invoked_functions
27
35
)
28
36
projection_extended_expr = proto .ExtendedExpression (
29
37
extension_uris = extension_uris ,
30
38
extensions = extensions ,
31
39
base_schema = schema ,
32
- referred_expr = projections ,
40
+ referred_expr = project_expressions ,
33
41
)
34
42
35
- invoked_functions_filter , filter_expr = _substrait_expression_from_sqlglot (
36
- catalog , schema , select .find (sqlglot .expressions .Where ).this
43
+ # Handle WHERE clause in the SELECT statement.
44
+ invoked_functions_filter , _ , filter_expr = sqlglot_parser .expression_from_sqlglot (
45
+ select .find (sqlglot .expressions .Where ).this
37
46
)
38
47
extension_uris , extensions = catalog .extensions_for_functions (
39
48
invoked_functions_filter
@@ -48,122 +57,89 @@ def parse_sql_extended_expression(catalog, schema, sql):
48
57
return projection_extended_expr , filter_extended_expr
49
58
50
59
51
- def _substrait_projection_from_sqlglot (catalog , schema , expressions ):
52
- if not expressions :
53
- return set (), []
60
+ class SQLGlotParser :
61
+ def __init__ (self , functions_catalog , schema ):
62
+ self ._functions_catalog = functions_catalog
63
+ self ._schema = schema
64
+ self ._counter = itertools .count ()
54
65
55
- # My understanding of ExtendedExpressions is that they are meant to directly
56
- # point to the Expression that ProjectRel would contain, so we don't actually
57
- # need a ProjectRel at all.
58
- """
59
- projection_sub = proto.ProjectRel(
60
- input=proto.Rel(
61
- read=proto.ReadRel(
62
- named_table=proto.ReadRel.NamedTable(names=["__table__"]),
63
- base_schema=schema,
64
- )
65
- ),
66
- expressions=[],
67
- )
68
- """
69
-
70
- substrait_expressions = []
71
- invoked_functions = set ()
72
- for sqlexpr in expressions :
73
- output_names = []
74
- if isinstance (sqlexpr , sqlglot .expressions .Alias ):
75
- output_names = [sqlexpr .output_name ]
76
- sqlexpr = sqlexpr .this
77
- _ , substrait_expr = _parse_expression (
78
- catalog , schema , sqlexpr , invoked_functions
79
- )
80
- substrait_expr_reference = proto .ExpressionReference (
81
- expression = substrait_expr , output_names = output_names
66
+ def expression_from_sqlglot (self , sqlglot_node ):
67
+ invoked_functions = set ()
68
+ output_name , _ , substrait_expr = self ._parse_expression (
69
+ sqlglot_node , invoked_functions
82
70
)
83
- substrait_expressions .append (substrait_expr_reference )
84
-
85
- return invoked_functions , substrait_expressions
86
-
71
+ return invoked_functions , output_name , substrait_expr
87
72
88
- def _substrait_expression_from_sqlglot (catalog , schema , sqlglot_node ):
89
- if not sqlglot_node :
90
- return set (), None
91
-
92
- invoked_functions = set ()
93
- _ , substrait_expr = _parse_expression (
94
- catalog , schema , sqlglot_node , invoked_functions
95
- )
96
- return invoked_functions , substrait_expr
97
-
98
-
99
- def _parse_expression (catalog , schema , expr , invoked_functions ):
100
- # TODO: Propagate up column names (output_names) so that the projections _always_ have an output_name
101
- if isinstance (expr , sqlglot .expressions .Literal ):
102
- if expr .is_string :
103
- return proto .Type (string = proto .Type .String ()), proto .Expression (
104
- literal = proto .Expression .Literal (string = expr .text )
73
+ def _parse_expression (self , expr , invoked_functions ):
74
+ if isinstance (expr , sqlglot .expressions .Literal ):
75
+ if expr .is_string :
76
+ return f"literal_{ next (self ._counter )} " , proto .Type (string = proto .Type .String ()), proto .Expression (
77
+ literal = proto .Expression .Literal (string = expr .text )
78
+ )
79
+ elif expr .is_int :
80
+ return f"literal_{ next (self ._counter )} " , proto .Type (i32 = proto .Type .I32 ()), proto .Expression (
81
+ literal = proto .Expression .Literal (i32 = int (expr .name ))
82
+ )
83
+ elif sqlglot .helper .is_float (expr .name ):
84
+ return f"literal_{ next (self ._counter )} " , proto .Type (fp32 = proto .Type .FP32 ()), proto .Expression (
85
+ literal = proto .Expression .Literal (float = float (expr .name ))
86
+ )
87
+ else :
88
+ raise ValueError (f"Unsupporter literal: { expr .text } " )
89
+ elif isinstance (expr , sqlglot .expressions .Column ):
90
+ column_name = expr .output_name
91
+ schema_field = list (self ._schema .names ).index (column_name )
92
+ schema_type = self ._schema .struct .types [schema_field ]
93
+ return column_name , schema_type , proto .Expression (
94
+ selection = proto .Expression .FieldReference (
95
+ direct_reference = proto .Expression .ReferenceSegment (
96
+ struct_field = proto .Expression .ReferenceSegment .StructField (
97
+ field = schema_field
98
+ )
99
+ )
100
+ )
105
101
)
106
- elif expr .is_int :
107
- return proto .Type (i32 = proto .Type .I32 ()), proto .Expression (
108
- literal = proto .Expression .Literal (i32 = int (expr .name ))
102
+ elif isinstance (expr , sqlglot .expressions .Alias ):
103
+ _ , aliased_type , aliased_expr = self ._parse_expression (expr .this , invoked_functions )
104
+ return expr .output_name , aliased_type , aliased_expr
105
+ elif expr .key in SQL_BINARY_FUNCTIONS :
106
+ left_name , left_type , left = self ._parse_expression (
107
+ expr .left , invoked_functions
109
108
)
110
- elif sqlglot .helper .is_float (expr .name ):
111
- return proto .Type (fp32 = proto .Type .FP32 ()), proto .Expression (
112
- literal = proto .Expression .Literal (float = float (expr .name ))
109
+ right_name , right_type , right = self ._parse_expression (
110
+ expr .right , invoked_functions
113
111
)
112
+ function_name = SQL_BINARY_FUNCTIONS [expr .key ]
113
+ signature , result_type , function_expression = self ._parse_function_invokation (
114
+ function_name , left_type , left , right_type , right
115
+ )
116
+ invoked_functions .add (signature )
117
+ result_name = f"{ left_name } _{ function_name } _{ right_name } _{ next (self ._counter )} "
118
+ return result_name , result_type , function_expression
114
119
else :
115
- raise ValueError (f"Unsupporter literal: { expr .text } " )
116
- elif isinstance (expr , sqlglot .expressions .Column ):
117
- column_name = expr .output_name
118
- schema_field = list (schema .names ).index (column_name )
119
- schema_type = schema .struct .types [schema_field ]
120
- return schema_type , proto .Expression (
121
- selection = proto .Expression .FieldReference (
122
- direct_reference = proto .Expression .ReferenceSegment (
123
- struct_field = proto .Expression .ReferenceSegment .StructField (
124
- field = schema_field
125
- )
126
- )
120
+ raise ValueError (
121
+ f"Unsupported expression in ExtendedExpression: '{ expr .key } ' -> { expr } "
127
122
)
128
- )
129
- elif expr .key in SQL_BINARY_FUNCTIONS :
130
- left_type , left = _parse_expression (
131
- catalog , schema , expr .left , invoked_functions
132
- )
133
- right_type , right = _parse_expression (
134
- catalog , schema , expr .right , invoked_functions
135
- )
136
- function_name = SQL_BINARY_FUNCTIONS [expr .key ]
137
- signature , result_type , function_expression = _parse_function_invokation (
138
- catalog , function_name , left_type , left , right_type , right
139
- )
140
- invoked_functions .add (signature )
141
- return result_type , function_expression
142
- else :
143
- raise ValueError (
144
- f"Unsupported expression in ExtendedExpression: '{ expr .key } ' -> { expr } "
145
- )
146
123
147
-
148
- def _parse_function_invokation (catalog , function_name , left_type , left , right_type , right ):
149
- signature = f"{ function_name } :{ left_type .WhichOneof ('kind' )} _{ right_type .WhichOneof ('kind' )} "
150
- try :
151
- function_anchor = catalog .function_anchor (signature )
152
- except KeyError :
153
- # not function found with the exact types, try any1_any1 version
154
- signature = f"{ function_name } :any1_any1"
155
- function_anchor = catalog .function_anchor (signature )
156
- return (
157
- signature ,
158
- left_type ,
159
- proto .Expression (
160
- scalar_function = proto .Expression .ScalarFunction (
161
- function_reference = function_anchor ,
162
- arguments = [
163
- proto .FunctionArgument (value = left ),
164
- proto .FunctionArgument (value = right ),
165
- ],
166
- )
167
- ),
168
- )
124
+ def _parse_function_invokation (self , function_name , left_type , left , right_type , right ):
125
+ signature = f"{ function_name } :{ left_type .WhichOneof ('kind' )} _{ right_type .WhichOneof ('kind' )} "
126
+ try :
127
+ function_anchor = self ._functions_catalog .function_anchor (signature )
128
+ except KeyError :
129
+ # not function found with the exact types, try any1_any1 version
130
+ signature = f"{ function_name } :any1_any1"
131
+ function_anchor = self ._functions_catalog .function_anchor (signature )
132
+ return (
133
+ signature ,
134
+ left_type ,
135
+ proto .Expression (
136
+ scalar_function = proto .Expression .ScalarFunction (
137
+ function_reference = function_anchor ,
138
+ arguments = [
139
+ proto .FunctionArgument (value = left ),
140
+ proto .FunctionArgument (value = right ),
141
+ ],
142
+ )
143
+ ),
144
+ )
169
145
0 commit comments