Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 9986d09

Browse files
committedApr 17, 2024·
Refactor parsing of expressions and propagate column names
1 parent d7522de commit 9986d09

File tree

2 files changed

+97
-118
lines changed

2 files changed

+97
-118
lines changed
 

‎src/substrait/sql/__main__.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@
88
catalog.load_standard_extensions(
99
pathlib.Path(__file__).parent.parent.parent.parent / "third_party" / "substrait" / "extensions",
1010
)
11+
12+
# TODO: Turn this into a command line tool to test more queries.
13+
# We can probably have a quick way to declare schema using command line args.
14+
# like first_name=String,surname=String,age=I32 etc...
1115
schema = proto.NamedStruct(
1216
names=["first_name", "surname", "age"],
1317
struct=proto.Type.Struct(
@@ -38,5 +42,4 @@
3842
print("---- PROJECTION ----")
3943
print(projection_expr)
4044
print("---- FILTER ----")
41-
print(filter_expr)
42-
# parse_extended_expression("INSERT INTO table VALUES(1, 2, 3)")
45+
print(filter_expr)
+92-116
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import itertools
2+
13
import sqlglot
24

35
from substrait import proto
@@ -19,21 +21,28 @@ def parse_sql_extended_expression(catalog, schema, sql):
1921
if not isinstance(select, sqlglot.expressions.Select):
2022
raise ValueError("a SELECT statement was expected")
2123

22-
invoked_functions_projection, projections = _substrait_projection_from_sqlglot(
23-
catalog, schema, select.expressions
24-
)
24+
sqlglot_parser = SQLGlotParser(catalog, schema)
25+
26+
# Handle the projections in the SELECT statemenent.
27+
project_expressions = []
28+
projection_invoked_functions = set()
29+
for sqlexpr in select.expressions:
30+
invoked_functions, output_name, expr = sqlglot_parser.expression_from_sqlglot(sqlexpr)
31+
projection_invoked_functions.update(invoked_functions)
32+
project_expressions.append(proto.ExpressionReference(expression=expr, output_names=[output_name]))
2533
extension_uris, extensions = catalog.extensions_for_functions(
26-
invoked_functions_projection
34+
projection_invoked_functions
2735
)
2836
projection_extended_expr = proto.ExtendedExpression(
2937
extension_uris=extension_uris,
3038
extensions=extensions,
3139
base_schema=schema,
32-
referred_expr=projections,
40+
referred_expr=project_expressions,
3341
)
3442

35-
invoked_functions_filter, filter_expr = _substrait_expression_from_sqlglot(
36-
catalog, schema, select.find(sqlglot.expressions.Where).this
43+
# Handle WHERE clause in the SELECT statement.
44+
invoked_functions_filter, _, filter_expr = sqlglot_parser.expression_from_sqlglot(
45+
select.find(sqlglot.expressions.Where).this
3746
)
3847
extension_uris, extensions = catalog.extensions_for_functions(
3948
invoked_functions_filter
@@ -48,122 +57,89 @@ def parse_sql_extended_expression(catalog, schema, sql):
4857
return projection_extended_expr, filter_extended_expr
4958

5059

51-
def _substrait_projection_from_sqlglot(catalog, schema, expressions):
52-
if not expressions:
53-
return set(), []
60+
class SQLGlotParser:
61+
def __init__(self, functions_catalog, schema):
62+
self._functions_catalog = functions_catalog
63+
self._schema = schema
64+
self._counter = itertools.count()
5465

55-
# My understanding of ExtendedExpressions is that they are meant to directly
56-
# point to the Expression that ProjectRel would contain, so we don't actually
57-
# need a ProjectRel at all.
58-
"""
59-
projection_sub = proto.ProjectRel(
60-
input=proto.Rel(
61-
read=proto.ReadRel(
62-
named_table=proto.ReadRel.NamedTable(names=["__table__"]),
63-
base_schema=schema,
64-
)
65-
),
66-
expressions=[],
67-
)
68-
"""
69-
70-
substrait_expressions = []
71-
invoked_functions = set()
72-
for sqlexpr in expressions:
73-
output_names = []
74-
if isinstance(sqlexpr, sqlglot.expressions.Alias):
75-
output_names = [sqlexpr.output_name]
76-
sqlexpr = sqlexpr.this
77-
_, substrait_expr = _parse_expression(
78-
catalog, schema, sqlexpr, invoked_functions
79-
)
80-
substrait_expr_reference = proto.ExpressionReference(
81-
expression=substrait_expr, output_names=output_names
66+
def expression_from_sqlglot(self, sqlglot_node):
67+
invoked_functions = set()
68+
output_name, _, substrait_expr = self._parse_expression(
69+
sqlglot_node, invoked_functions
8270
)
83-
substrait_expressions.append(substrait_expr_reference)
84-
85-
return invoked_functions, substrait_expressions
86-
71+
return invoked_functions, output_name, substrait_expr
8772

88-
def _substrait_expression_from_sqlglot(catalog, schema, sqlglot_node):
89-
if not sqlglot_node:
90-
return set(), None
91-
92-
invoked_functions = set()
93-
_, substrait_expr = _parse_expression(
94-
catalog, schema, sqlglot_node, invoked_functions
95-
)
96-
return invoked_functions, substrait_expr
97-
98-
99-
def _parse_expression(catalog, schema, expr, invoked_functions):
100-
# TODO: Propagate up column names (output_names) so that the projections _always_ have an output_name
101-
if isinstance(expr, sqlglot.expressions.Literal):
102-
if expr.is_string:
103-
return proto.Type(string=proto.Type.String()), proto.Expression(
104-
literal=proto.Expression.Literal(string=expr.text)
73+
def _parse_expression(self, expr, invoked_functions):
74+
if isinstance(expr, sqlglot.expressions.Literal):
75+
if expr.is_string:
76+
return f"literal_{next(self._counter)}", proto.Type(string=proto.Type.String()), proto.Expression(
77+
literal=proto.Expression.Literal(string=expr.text)
78+
)
79+
elif expr.is_int:
80+
return f"literal_{next(self._counter)}", proto.Type(i32=proto.Type.I32()), proto.Expression(
81+
literal=proto.Expression.Literal(i32=int(expr.name))
82+
)
83+
elif sqlglot.helper.is_float(expr.name):
84+
return f"literal_{next(self._counter)}", proto.Type(fp32=proto.Type.FP32()), proto.Expression(
85+
literal=proto.Expression.Literal(float=float(expr.name))
86+
)
87+
else:
88+
raise ValueError(f"Unsupporter literal: {expr.text}")
89+
elif isinstance(expr, sqlglot.expressions.Column):
90+
column_name = expr.output_name
91+
schema_field = list(self._schema.names).index(column_name)
92+
schema_type = self._schema.struct.types[schema_field]
93+
return column_name, schema_type, proto.Expression(
94+
selection=proto.Expression.FieldReference(
95+
direct_reference=proto.Expression.ReferenceSegment(
96+
struct_field=proto.Expression.ReferenceSegment.StructField(
97+
field=schema_field
98+
)
99+
)
100+
)
105101
)
106-
elif expr.is_int:
107-
return proto.Type(i32=proto.Type.I32()), proto.Expression(
108-
literal=proto.Expression.Literal(i32=int(expr.name))
102+
elif isinstance(expr, sqlglot.expressions.Alias):
103+
_, aliased_type, aliased_expr = self._parse_expression(expr.this, invoked_functions)
104+
return expr.output_name, aliased_type, aliased_expr
105+
elif expr.key in SQL_BINARY_FUNCTIONS:
106+
left_name, left_type, left = self._parse_expression(
107+
expr.left, invoked_functions
109108
)
110-
elif sqlglot.helper.is_float(expr.name):
111-
return proto.Type(fp32=proto.Type.FP32()), proto.Expression(
112-
literal=proto.Expression.Literal(float=float(expr.name))
109+
right_name, right_type, right = self._parse_expression(
110+
expr.right, invoked_functions
113111
)
112+
function_name = SQL_BINARY_FUNCTIONS[expr.key]
113+
signature, result_type, function_expression = self._parse_function_invokation(
114+
function_name, left_type, left, right_type, right
115+
)
116+
invoked_functions.add(signature)
117+
result_name = f"{left_name}_{function_name}_{right_name}_{next(self._counter)}"
118+
return result_name, result_type, function_expression
114119
else:
115-
raise ValueError(f"Unsupporter literal: {expr.text}")
116-
elif isinstance(expr, sqlglot.expressions.Column):
117-
column_name = expr.output_name
118-
schema_field = list(schema.names).index(column_name)
119-
schema_type = schema.struct.types[schema_field]
120-
return schema_type, proto.Expression(
121-
selection=proto.Expression.FieldReference(
122-
direct_reference=proto.Expression.ReferenceSegment(
123-
struct_field=proto.Expression.ReferenceSegment.StructField(
124-
field=schema_field
125-
)
126-
)
120+
raise ValueError(
121+
f"Unsupported expression in ExtendedExpression: '{expr.key}' -> {expr}"
127122
)
128-
)
129-
elif expr.key in SQL_BINARY_FUNCTIONS:
130-
left_type, left = _parse_expression(
131-
catalog, schema, expr.left, invoked_functions
132-
)
133-
right_type, right = _parse_expression(
134-
catalog, schema, expr.right, invoked_functions
135-
)
136-
function_name = SQL_BINARY_FUNCTIONS[expr.key]
137-
signature, result_type, function_expression = _parse_function_invokation(
138-
catalog, function_name, left_type, left, right_type, right
139-
)
140-
invoked_functions.add(signature)
141-
return result_type, function_expression
142-
else:
143-
raise ValueError(
144-
f"Unsupported expression in ExtendedExpression: '{expr.key}' -> {expr}"
145-
)
146123

147-
148-
def _parse_function_invokation(catalog, function_name, left_type, left, right_type, right):
149-
signature = f"{function_name}:{left_type.WhichOneof('kind')}_{right_type.WhichOneof('kind')}"
150-
try:
151-
function_anchor = catalog.function_anchor(signature)
152-
except KeyError:
153-
# not function found with the exact types, try any1_any1 version
154-
signature = f"{function_name}:any1_any1"
155-
function_anchor = catalog.function_anchor(signature)
156-
return (
157-
signature,
158-
left_type,
159-
proto.Expression(
160-
scalar_function=proto.Expression.ScalarFunction(
161-
function_reference=function_anchor,
162-
arguments=[
163-
proto.FunctionArgument(value=left),
164-
proto.FunctionArgument(value=right),
165-
],
166-
)
167-
),
168-
)
124+
def _parse_function_invokation(self, function_name, left_type, left, right_type, right):
125+
signature = f"{function_name}:{left_type.WhichOneof('kind')}_{right_type.WhichOneof('kind')}"
126+
try:
127+
function_anchor = self._functions_catalog.function_anchor(signature)
128+
except KeyError:
129+
# not function found with the exact types, try any1_any1 version
130+
signature = f"{function_name}:any1_any1"
131+
function_anchor = self._functions_catalog.function_anchor(signature)
132+
return (
133+
signature,
134+
left_type,
135+
proto.Expression(
136+
scalar_function=proto.Expression.ScalarFunction(
137+
function_reference=function_anchor,
138+
arguments=[
139+
proto.FunctionArgument(value=left),
140+
proto.FunctionArgument(value=right),
141+
],
142+
)
143+
),
144+
)
169145

0 commit comments

Comments
 (0)
Please sign in to comment.