15
15
import os
16
16
import uuid
17
17
import shutil
18
+ import llvmlite .ir
18
19
19
20
from typing import Optional , List
20
21
21
22
from typed_python .compiler .binary_shared_object import LoadedBinarySharedObject , BinarySharedObject
22
23
from typed_python .compiler .directed_graph import DirectedGraph
23
24
from typed_python .compiler .typed_call_target import TypedCallTarget
25
+ import typed_python .compiler .native_ast as native_ast
24
26
from typed_python .SerializationContext import SerializationContext
25
27
from typed_python import Dict , ListOf
26
28
@@ -67,6 +69,8 @@ def __init__(self, cacheDir):
67
69
self .targetsLoaded : Dict [str , TypedCallTarget ] = {}
68
70
# the set of link_names for functions with linked and validated globals (i.e. ready to be run).
69
71
self .targetsValidated = set ()
72
+ # the total number of instructions for each link_name
73
+ self .targetComplexity = Dict (str , int )()
70
74
# link_name -> link_name
71
75
self .function_dependency_graph = DirectedGraph ()
72
76
# dict from link_name to list of global names (should be llvm keys in serialisedGlobalDefinitions)
@@ -90,6 +94,21 @@ def getTarget(self, func_name: str) -> TypedCallTarget:
90
94
self .loadForSymbol (link_name )
91
95
return self .targetsLoaded [link_name ]
92
96
97
+ def getIR (self , func_name : str ) -> llvmlite .ir .Function :
98
+ if not self .hasSymbol (func_name ):
99
+ raise ValueError (f'symbol not found for func_name { func_name } ' )
100
+ link_name = self ._select_link_name (func_name )
101
+ module_hash = self .link_name_to_module_hash [link_name ]
102
+ return self .loadedBinarySharedObjects [module_hash ].binarySharedObject .functionIRs [func_name ]
103
+
104
+ def getDefinition (self , func_name : str ) -> native_ast .Function :
105
+ if not self .hasSymbol (func_name ):
106
+ raise ValueError (f'symbol not found for func_name { func_name } ' )
107
+ link_name = self ._select_link_name (func_name )
108
+ module_hash = self .link_name_to_module_hash [link_name ]
109
+ serialized_definition = self .loadedBinarySharedObjects [module_hash ].binarySharedObject .serializedFunctionDefinitions [func_name ]
110
+ return SerializationContext ().deserialize (serialized_definition )
111
+
93
112
def _generate_link_name (self , func_name : str , module_hash : str ) -> str :
94
113
return func_name + "." + module_hash
95
114
@@ -126,6 +145,14 @@ def loadForSymbol(self, linkName: str) -> None:
126
145
if not self .loadedBinarySharedObjects [moduleHash ].validateGlobalVariables (definitionsToLink ):
127
146
raise RuntimeError ('failed to validate globals when loading:' , linkName )
128
147
148
+ def complexityForSymbol (self , func_name : str ) -> int :
149
+ """Get the total number of LLVM instructions for a given symbol."""
150
+ try :
151
+ link_name = self ._select_link_name (func_name )
152
+ return self .targetComplexity [link_name ]
153
+ except KeyError as e :
154
+ raise ValueError (f'No complexity value cached for { func_name } ' ) from e
155
+
129
156
def loadModuleByHash (self , moduleHash : str ) -> None :
130
157
"""Load a module by name.
131
158
@@ -139,23 +166,23 @@ def loadModuleByHash(self, moduleHash: str) -> None:
139
166
140
167
# TODO (Will) - store these names as module consts, use one .dat only
141
168
with open (os .path .join (targetDir , "type_manifest.dat" ), "rb" ) as f :
142
- # func_name -> typedcalltarget
143
169
callTargets = SerializationContext ().deserialize (f .read ())
144
-
145
170
with open (os .path .join (targetDir , "globals_manifest.dat" ), "rb" ) as f :
146
171
serializedGlobalVarDefs = SerializationContext ().deserialize (f .read ())
147
-
148
172
with open (os .path .join (targetDir , "native_type_manifest.dat" ), "rb" ) as f :
149
173
functionNameToNativeType = SerializationContext ().deserialize (f .read ())
150
-
151
174
with open (os .path .join (targetDir , "submodules.dat" ), "rb" ) as f :
152
175
submodules = SerializationContext ().deserialize (f .read (), ListOf (str ))
153
-
154
176
with open (os .path .join (targetDir , "function_dependencies.dat" ), "rb" ) as f :
155
177
dependency_edgelist = SerializationContext ().deserialize (f .read ())
156
-
157
178
with open (os .path .join (targetDir , "global_dependencies.dat" ), "rb" ) as f :
158
179
globalDependencies = SerializationContext ().deserialize (f .read ())
180
+ with open (os .path .join (targetDir , "function_complexities.dat" ), "rb" ) as f :
181
+ functionComplexities = SerializationContext ().deserialize (f .read ())
182
+ with open (os .path .join (targetDir , "function_irs.dat" ), "rb" ) as f :
183
+ functionIRs = SerializationContext ().deserialize (f .read ())
184
+ with open (os .path .join (targetDir , "function_definitions.dat" ), "rb" ) as f :
185
+ functionDefinitions = SerializationContext ().deserialize (f .read ())
159
186
160
187
# load the submodules first
161
188
for submodule in submodules :
@@ -167,7 +194,10 @@ def loadModuleByHash(self, moduleHash: str) -> None:
167
194
modulePath ,
168
195
serializedGlobalVarDefs ,
169
196
functionNameToNativeType ,
170
- globalDependencies
197
+ globalDependencies ,
198
+ functionComplexities ,
199
+ functionIRs ,
200
+ functionDefinitions
171
201
).loadFromPath (modulePath )
172
202
173
203
self .loadedBinarySharedObjects [moduleHash ] = loaded
@@ -177,8 +207,11 @@ def loadModuleByHash(self, moduleHash: str) -> None:
177
207
assert link_name not in self .targetsLoaded
178
208
self .targetsLoaded [link_name ] = callTarget
179
209
180
- link_name_global_dependencies = {self ._generate_link_name (x , moduleHash ): y for x , y in globalDependencies .items ()}
210
+ for func_name , complexity in functionComplexities .items ():
211
+ link_name = self ._generate_link_name (func_name , moduleHash )
212
+ self .targetComplexity [link_name ] = complexity
181
213
214
+ link_name_global_dependencies = {self ._generate_link_name (x , moduleHash ): y for x , y in globalDependencies .items ()}
182
215
assert not any (key in self .global_dependencies for key in link_name_global_dependencies )
183
216
184
217
self .global_dependencies .update (link_name_global_dependencies )
@@ -222,6 +255,10 @@ def addModule(self, binarySharedObject, nameToTypedCallTarget, linkDependencies,
222
255
223
256
path = self .writeModuleToDisk (binarySharedObject , hashToUse , nameToTypedCallTarget , dependentHashes , link_name_dependency_edgelist )
224
257
258
+ for func_name , complexity in binarySharedObject .functionComplexities .items ():
259
+ link_name = self ._generate_link_name (func_name , hashToUse )
260
+ self .targetComplexity [link_name ] = complexity
261
+
225
262
self .loadedBinarySharedObjects [hashToUse ] = (
226
263
binarySharedObject .loadFromPath (os .path .join (path , "module.so" ))
227
264
)
@@ -314,6 +351,15 @@ def writeModuleToDisk(self, binarySharedObject, hashToUse, nameToTypedCallTarget
314
351
with open (os .path .join (tempTargetDir , "global_dependencies.dat" ), "wb" ) as f :
315
352
f .write (SerializationContext ().serialize (binarySharedObject .globalDependencies ))
316
353
354
+ with open (os .path .join (tempTargetDir , "function_complexities.dat" ), "wb" ) as f :
355
+ f .write (SerializationContext ().serialize (binarySharedObject .functionComplexities ))
356
+
357
+ with open (os .path .join (tempTargetDir , "function_irs.dat" ), "wb" ) as f :
358
+ f .write (SerializationContext ().serialize (binarySharedObject .functionIRs ))
359
+
360
+ with open (os .path .join (tempTargetDir , "function_definitions.dat" ), "wb" ) as f :
361
+ f .write (SerializationContext ().serialize (binarySharedObject .serializedFunctionDefinitions ))
362
+
317
363
try :
318
364
os .rename (tempTargetDir , targetDir )
319
365
except IOError :
0 commit comments