@@ -23,7 +23,7 @@ class FunctionsCatalog:
23
23
"/functions_arithmetic_decimal.yaml" ,
24
24
"/functions_boolean.yaml" ,
25
25
"/functions_comparison.yaml" ,
26
- "/functions_datetime.yaml" ,
26
+ # "/functions_datetime.yaml", for now skip, it has duplicated functions
27
27
"/functions_geometry.yaml" ,
28
28
"/functions_logarithmic.yaml" ,
29
29
"/functions_rounding.yaml" ,
@@ -32,9 +32,10 @@ class FunctionsCatalog:
32
32
)
33
33
34
34
def __init__ (self ):
35
- self ._declarations = {}
36
35
self ._registered_extensions = {}
37
36
self ._functions = {}
37
+ self ._functions_return_type = {}
38
+ self ._register_builtins ()
38
39
39
40
def load_standard_extensions (self , dirpath ):
40
41
for ext in self .STANDARD_EXTENSIONS :
@@ -45,6 +46,7 @@ def load(self, dirpath, filename):
45
46
sections = yaml .safe_load (f )
46
47
47
48
loaded_functions = set ()
49
+ functions_return_type = {}
48
50
for functions in sections .values ():
49
51
for function in functions :
50
52
function_name = function ["name" ]
@@ -55,12 +57,16 @@ def load(self, dirpath, filename):
55
57
signature = function_name
56
58
else :
57
59
signature = f"{ function_name } :{ '_' .join (argtypes )} "
58
- self ._declarations [signature ] = filename
59
60
loaded_functions .add (signature )
61
+ functions_return_type [signature ] = self ._type_from_name (
62
+ impl ["return" ]
63
+ )
60
64
61
- self ._register_extensions (filename , loaded_functions )
65
+ self ._register_extensions (filename , loaded_functions , functions_return_type )
62
66
63
- def _register_extensions (self , extension_uri , loaded_functions ):
67
+ def _register_extensions (
68
+ self , extension_uri , loaded_functions , functions_return_type
69
+ ):
64
70
if extension_uri not in self ._registered_extensions :
65
71
ext_anchor_id = len (self ._registered_extensions ) + 1
66
72
self ._registered_extensions [extension_uri ] = proto .SimpleExtensionURI (
@@ -70,14 +76,12 @@ def _register_extensions(self, extension_uri, loaded_functions):
70
76
for function in loaded_functions :
71
77
if function in self ._functions :
72
78
extensions_by_anchor = self .extension_uris_by_anchor
73
- function = self ._functions [function ]
79
+ existing_function = self ._functions [function ]
74
80
function_extension = extensions_by_anchor [
75
- function .extension_uri_reference
81
+ existing_function .extension_uri_reference
76
82
].uri
77
- # TODO: Support overloading of functions from different extensionUris.
78
- continue
79
83
raise ValueError (
80
- f"Duplicate function definition: { function .name } from { extension_uri } , already loaded from { function_extension } "
84
+ f"Duplicate function definition: { existing_function .name } from { extension_uri } , already loaded from { function_extension } "
81
85
)
82
86
extension_anchor = self ._registered_extensions [
83
87
extension_uri
@@ -90,6 +94,48 @@ def _register_extensions(self, extension_uri, loaded_functions):
90
94
function_anchor = function_anchor ,
91
95
)
92
96
)
97
+ self ._functions_return_type [function ] = functions_return_type [function ]
98
+
99
+ def _register_builtins (self ):
100
+ self ._functions ["not:boolean" ] = (
101
+ proto .SimpleExtensionDeclaration .ExtensionFunction (
102
+ name = "not" ,
103
+ function_anchor = len (self ._functions ) + 1 ,
104
+ )
105
+ )
106
+ self ._functions_return_type ["not:boolean" ] = proto .Type (
107
+ bool = proto .Type .Boolean ()
108
+ )
109
+
110
+ def _type_from_name (self , typename ):
111
+ nullable = False
112
+ if typename .endswith ("?" ):
113
+ nullable = True
114
+
115
+ typename = typename .strip ("?" )
116
+ if typename in ("any" , "any1" ):
117
+ return None
118
+
119
+ if typename == "boolean" :
120
+ # For some reason boolean is an exception to the naming convention
121
+ typename = "bool"
122
+
123
+ try :
124
+ type_descriptor = proto .Type .DESCRIPTOR .fields_by_name [
125
+ typename
126
+ ].message_type
127
+ except KeyError :
128
+ # TODO: improve resolution of complext type like LIST?<any>
129
+ print ("Unsupported type" , typename )
130
+ return None
131
+
132
+ type_class = getattr (proto .Type , type_descriptor .name )
133
+ nullability = (
134
+ proto .Type .Nullability .NULLABILITY_REQUIRED
135
+ if not nullable
136
+ else proto .Type .Nullability .NULLABILITY_NULLABLE
137
+ )
138
+ return proto .Type (** {typename : type_class (nullability = nullability )})
93
139
94
140
@property
95
141
def extension_uris_by_anchor (self ):
@@ -106,14 +152,31 @@ def extension_uris(self):
106
152
def extensions (self ):
107
153
return list (self ._functions .values ())
108
154
155
+ def signature (self , function_name , proto_argtypes ):
156
+ def _normalize_arg_types (argtypes ):
157
+ for argtype in argtypes :
158
+ kind = argtype .WhichOneof ("kind" )
159
+ if kind == "bool" :
160
+ yield "boolean"
161
+ else :
162
+ yield kind
163
+
164
+ return f"{ function_name } :{ '_' .join (_normalize_arg_types (proto_argtypes ))} "
165
+
109
166
def function_anchor (self , function ):
110
167
return self ._functions [function ].function_anchor
111
168
169
+ def function_return_type (self , function ):
170
+ return self ._functions_return_type [function ]
171
+
112
172
def extensions_for_functions (self , functions ):
113
173
uris_anchors = set ()
114
174
extensions = []
115
175
for f in functions :
116
176
ext = self ._functions [f ]
177
+ if not ext .extension_uri_reference :
178
+ # Built-in function
179
+ continue
117
180
uris_anchors .add (ext .extension_uri_reference )
118
181
extensions .append (proto .SimpleExtensionDeclaration (extension_function = ext ))
119
182
0 commit comments