Skip to content

Commit 3097962

Browse files
committed
Register builtin functions and handle return types
1 parent 01d22b9 commit 3097962

File tree

2 files changed

+126
-19
lines changed

2 files changed

+126
-19
lines changed

src/substrait/sql/extended_expression.py

+53-9
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,29 @@
44

55
from substrait import proto
66

7-
7+
SQL_UNARY_FUNCTIONS = {"not": "not"}
88
SQL_BINARY_FUNCTIONS = {
99
# Arithmetic
1010
"add": "add",
1111
"div": "div",
1212
"mul": "mul",
1313
"sub": "sub",
14+
"mod": "modulus",
15+
"bitwiseand": "bitwise_and",
16+
"bitwiseor": "bitwise_or",
17+
"bitwisexor": "bitwise_xor",
18+
"bitwiseor": "bitwise_or",
1419
# Comparisons
1520
"eq": "equal",
21+
"nullsafeeq": "is_not_distinct_from",
22+
"new": "not_equal",
23+
"gt": "gt",
24+
"gte": "gte",
25+
"lt": "lt",
26+
"lte": "lte",
27+
# logical
28+
"and": "and",
29+
"or": "or",
1630
}
1731

1832

@@ -124,6 +138,17 @@ def _parse_expression(self, expr, invoked_functions):
124138
expr.this, invoked_functions
125139
)
126140
return expr.output_name, aliased_type, aliased_expr
141+
elif expr.key in SQL_UNARY_FUNCTIONS:
142+
argument_name, argument_type, argument = self._parse_expression(
143+
expr.this, invoked_functions
144+
)
145+
function_name = SQL_UNARY_FUNCTIONS[expr.key]
146+
signature, result_type, function_expression = (
147+
self._parse_function_invokation(function_name, argument_type, argument)
148+
)
149+
invoked_functions.add(signature)
150+
result_name = f"{function_name}_{argument_name}_{next(self._counter)}"
151+
return result_name, result_type, function_expression
127152
elif expr.key in SQL_BINARY_FUNCTIONS:
128153
left_name, left_type, left = self._parse_expression(
129154
expr.left, invoked_functions
@@ -148,26 +173,45 @@ def _parse_expression(self, expr, invoked_functions):
148173
)
149174

150175
def _parse_function_invokation(
151-
self, function_name, left_type, left, right_type, right
176+
self, function_name, left_type, left, right_type=None, right=None
152177
):
153-
signature = f"{function_name}:{left_type.WhichOneof('kind')}_{right_type.WhichOneof('kind')}"
178+
binary = False
179+
argtypes = [left_type]
180+
if right_type or right:
181+
binary = True
182+
argtypes.append(right_type)
183+
signature = self._functions_catalog.signature(function_name, argtypes)
184+
154185
try:
155186
function_anchor = self._functions_catalog.function_anchor(signature)
156187
except KeyError:
157188
# No function found with the exact types, try any1_any1 version
158189
# TODO: What about cases like i32_any1? What about any instead of any1?
159-
signature = f"{function_name}:any1_any1"
190+
if binary:
191+
signature = f"{function_name}:any1_any1"
192+
else:
193+
signature = f"{function_name}:any1"
160194
function_anchor = self._functions_catalog.function_anchor(signature)
195+
196+
function_return_type = self._functions_catalog.function_return_type(signature)
197+
if function_return_type is None:
198+
print("No return type for", signature)
199+
# TODO: Is this the right way to handle this?
200+
function_return_type = left_type
161201
return (
162202
signature,
163-
left_type, # TODO: Get the actually returned type from the functions catalog.
203+
function_return_type,
164204
proto.Expression(
165205
scalar_function=proto.Expression.ScalarFunction(
166206
function_reference=function_anchor,
167-
arguments=[
168-
proto.FunctionArgument(value=left),
169-
proto.FunctionArgument(value=right),
170-
],
207+
arguments=(
208+
[
209+
proto.FunctionArgument(value=left),
210+
proto.FunctionArgument(value=right),
211+
]
212+
if binary
213+
else [proto.FunctionArgument(value=left)]
214+
),
171215
)
172216
),
173217
)

src/substrait/sql/functions_catalog.py

+73-10
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ class FunctionsCatalog:
2323
"/functions_arithmetic_decimal.yaml",
2424
"/functions_boolean.yaml",
2525
"/functions_comparison.yaml",
26-
"/functions_datetime.yaml",
26+
# "/functions_datetime.yaml", for now skip, it has duplicated functions
2727
"/functions_geometry.yaml",
2828
"/functions_logarithmic.yaml",
2929
"/functions_rounding.yaml",
@@ -32,9 +32,10 @@ class FunctionsCatalog:
3232
)
3333

3434
def __init__(self):
35-
self._declarations = {}
3635
self._registered_extensions = {}
3736
self._functions = {}
37+
self._functions_return_type = {}
38+
self._register_builtins()
3839

3940
def load_standard_extensions(self, dirpath):
4041
for ext in self.STANDARD_EXTENSIONS:
@@ -45,6 +46,7 @@ def load(self, dirpath, filename):
4546
sections = yaml.safe_load(f)
4647

4748
loaded_functions = set()
49+
functions_return_type = {}
4850
for functions in sections.values():
4951
for function in functions:
5052
function_name = function["name"]
@@ -55,12 +57,16 @@ def load(self, dirpath, filename):
5557
signature = function_name
5658
else:
5759
signature = f"{function_name}:{'_'.join(argtypes)}"
58-
self._declarations[signature] = filename
5960
loaded_functions.add(signature)
61+
functions_return_type[signature] = self._type_from_name(
62+
impl["return"]
63+
)
6064

61-
self._register_extensions(filename, loaded_functions)
65+
self._register_extensions(filename, loaded_functions, functions_return_type)
6266

63-
def _register_extensions(self, extension_uri, loaded_functions):
67+
def _register_extensions(
68+
self, extension_uri, loaded_functions, functions_return_type
69+
):
6470
if extension_uri not in self._registered_extensions:
6571
ext_anchor_id = len(self._registered_extensions) + 1
6672
self._registered_extensions[extension_uri] = proto.SimpleExtensionURI(
@@ -70,14 +76,12 @@ def _register_extensions(self, extension_uri, loaded_functions):
7076
for function in loaded_functions:
7177
if function in self._functions:
7278
extensions_by_anchor = self.extension_uris_by_anchor
73-
function = self._functions[function]
79+
existing_function = self._functions[function]
7480
function_extension = extensions_by_anchor[
75-
function.extension_uri_reference
81+
existing_function.extension_uri_reference
7682
].uri
77-
# TODO: Support overloading of functions from different extensionUris.
78-
continue
7983
raise ValueError(
80-
f"Duplicate function definition: {function.name} from {extension_uri}, already loaded from {function_extension}"
84+
f"Duplicate function definition: {existing_function.name} from {extension_uri}, already loaded from {function_extension}"
8185
)
8286
extension_anchor = self._registered_extensions[
8387
extension_uri
@@ -90,6 +94,48 @@ def _register_extensions(self, extension_uri, loaded_functions):
9094
function_anchor=function_anchor,
9195
)
9296
)
97+
self._functions_return_type[function] = functions_return_type[function]
98+
99+
def _register_builtins(self):
100+
self._functions["not:boolean"] = (
101+
proto.SimpleExtensionDeclaration.ExtensionFunction(
102+
name="not",
103+
function_anchor=len(self._functions) + 1,
104+
)
105+
)
106+
self._functions_return_type["not:boolean"] = proto.Type(
107+
bool=proto.Type.Boolean()
108+
)
109+
110+
def _type_from_name(self, typename):
111+
nullable = False
112+
if typename.endswith("?"):
113+
nullable = True
114+
115+
typename = typename.strip("?")
116+
if typename in ("any", "any1"):
117+
return None
118+
119+
if typename == "boolean":
120+
# For some reason boolean is an exception to the naming convention
121+
typename = "bool"
122+
123+
try:
124+
type_descriptor = proto.Type.DESCRIPTOR.fields_by_name[
125+
typename
126+
].message_type
127+
except KeyError:
128+
# TODO: improve resolution of complext type like LIST?<any>
129+
print("Unsupported type", typename)
130+
return None
131+
132+
type_class = getattr(proto.Type, type_descriptor.name)
133+
nullability = (
134+
proto.Type.Nullability.NULLABILITY_REQUIRED
135+
if not nullable
136+
else proto.Type.Nullability.NULLABILITY_NULLABLE
137+
)
138+
return proto.Type(**{typename: type_class(nullability=nullability)})
93139

94140
@property
95141
def extension_uris_by_anchor(self):
@@ -106,14 +152,31 @@ def extension_uris(self):
106152
def extensions(self):
107153
return list(self._functions.values())
108154

155+
def signature(self, function_name, proto_argtypes):
156+
def _normalize_arg_types(argtypes):
157+
for argtype in argtypes:
158+
kind = argtype.WhichOneof("kind")
159+
if kind == "bool":
160+
yield "boolean"
161+
else:
162+
yield kind
163+
164+
return f"{function_name}:{'_'.join(_normalize_arg_types(proto_argtypes))}"
165+
109166
def function_anchor(self, function):
110167
return self._functions[function].function_anchor
111168

169+
def function_return_type(self, function):
170+
return self._functions_return_type[function]
171+
112172
def extensions_for_functions(self, functions):
113173
uris_anchors = set()
114174
extensions = []
115175
for f in functions:
116176
ext = self._functions[f]
177+
if not ext.extension_uri_reference:
178+
# Built-in function
179+
continue
117180
uris_anchors.add(ext.extension_uri_reference)
118181
extensions.append(proto.SimpleExtensionDeclaration(extension_function=ext))
119182

0 commit comments

Comments
 (0)