Skip to content

Commit 39dcf7c

Browse files
author
William Grant
committed
Add function complexities and intermediates to cache.
In order to inline cached functions we must store the native and llvm layer intermediates, along with the complexities. This will reduce cache load time but should increase compiled function speed and predictability.
1 parent e2d754a commit 39dcf7c

6 files changed

+139
-33
lines changed

typed_python/compiler/binary_shared_object.py

+44-6
Original file line numberDiff line numberDiff line change
@@ -36,32 +36,64 @@ def __init__(self, binarySharedObject, diskPath, functionPointers, serializedGlo
3636
class BinarySharedObject:
3737
"""Models a shared object library (.so) loadable on linux systems."""
3838

39-
def __init__(self, binaryForm, functionTypes, serializedGlobalVariableDefinitions, globalDependencies):
39+
def __init__(self,
40+
binaryForm,
41+
functionTypes,
42+
serializedGlobalVariableDefinitions,
43+
globalDependencies,
44+
functionComplexities,
45+
functionIRs,
46+
serializedFunctionDefinitions
47+
):
4048
"""
4149
Args:
4250
binaryForm: a bytes object containing the actual compiled code for the module
4351
serializedGlobalVariableDefinitions: a map from name to GlobalVariableDefinition
44-
globalDependencies: a dict from function linkname to the list of global variables it depends on
52+
globalDependencies: a dict from function name to the list of global variables it depends on
53+
functionComplexities: a dict from function name to the total number of llvm instructions in the function (used for inlining)
54+
functionIRs: a dict from function name to the llvm IR Functions (used for inlining)
55+
functionDefinitions: a dict from function name to the native_ast.Functions (used for inlining)
4556
"""
4657
self.binaryForm = binaryForm
4758
self.functionTypes = functionTypes
4859
self.serializedGlobalVariableDefinitions = serializedGlobalVariableDefinitions
4960
self.globalDependencies = globalDependencies
61+
self.functionComplexities = functionComplexities
62+
self.functionIRs = functionIRs
63+
self.serializedFunctionDefinitions = serializedFunctionDefinitions
5064
self.hash = sha_hash(binaryForm)
5165

5266
@property
5367
def definedSymbols(self):
5468
return self.functionTypes.keys()
5569

5670
@staticmethod
57-
def fromDisk(path, serializedGlobalVariableDefinitions, functionNameToType, globalDependencies):
71+
def fromDisk(path,
72+
serializedGlobalVariableDefinitions,
73+
functionNameToType,
74+
globalDependencies,
75+
functionComplexities,
76+
functionIRs,
77+
serializedFunctionDefinitions):
5878
with open(path, "rb") as f:
5979
binaryForm = f.read()
6080

61-
return BinarySharedObject(binaryForm, functionNameToType, serializedGlobalVariableDefinitions, globalDependencies)
81+
return BinarySharedObject(binaryForm,
82+
functionNameToType,
83+
serializedGlobalVariableDefinitions,
84+
globalDependencies,
85+
functionComplexities,
86+
functionIRs,
87+
serializedFunctionDefinitions)
6288

6389
@staticmethod
64-
def fromModule(module, serializedGlobalVariableDefinitions, functionNameToType, globalDependencies):
90+
def fromModule(module,
91+
serializedGlobalVariableDefinitions,
92+
functionNameToType,
93+
globalDependencies,
94+
functionComplexities,
95+
functionIRs,
96+
serializedFunctionDefinitions):
6597
target_triple = llvm.get_process_triple()
6698
target = llvm.Target.from_triple(target_triple)
6799
target_machine_shared_object = target.create_target_machine(reloc='pic', codemodel='default')
@@ -82,7 +114,13 @@ def fromModule(module, serializedGlobalVariableDefinitions, functionNameToType,
82114
)
83115

84116
with open(os.path.join(tf, "module.so"), "rb") as so_file:
85-
return BinarySharedObject(so_file.read(), functionNameToType, serializedGlobalVariableDefinitions, globalDependencies)
117+
return BinarySharedObject(so_file.read(),
118+
functionNameToType,
119+
serializedGlobalVariableDefinitions,
120+
globalDependencies,
121+
functionComplexities,
122+
functionIRs,
123+
serializedFunctionDefinitions)
86124

87125
def load(self, storageDir):
88126
"""Instantiate this .so in temporary storage and return a dict from symbol -> integer function pointer"""

typed_python/compiler/compiler_cache.py

+54-8
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,14 @@
1515
import os
1616
import uuid
1717
import shutil
18+
import llvmlite.ir
1819

1920
from typing import Optional, List
2021

2122
from typed_python.compiler.binary_shared_object import LoadedBinarySharedObject, BinarySharedObject
2223
from typed_python.compiler.directed_graph import DirectedGraph
2324
from typed_python.compiler.typed_call_target import TypedCallTarget
25+
import typed_python.compiler.native_ast as native_ast
2426
from typed_python.SerializationContext import SerializationContext
2527
from typed_python import Dict, ListOf
2628

@@ -67,6 +69,8 @@ def __init__(self, cacheDir):
6769
self.targetsLoaded: Dict[str, TypedCallTarget] = {}
6870
# the set of link_names for functions with linked and validated globals (i.e. ready to be run).
6971
self.targetsValidated = set()
72+
# the total number of instructions for each link_name
73+
self.targetComplexity = Dict(str, int)()
7074
# link_name -> link_name
7175
self.function_dependency_graph = DirectedGraph()
7276
# dict from link_name to list of global names (should be llvm keys in serialisedGlobalDefinitions)
@@ -90,6 +94,21 @@ def getTarget(self, func_name: str) -> TypedCallTarget:
9094
self.loadForSymbol(link_name)
9195
return self.targetsLoaded[link_name]
9296

97+
def getIR(self, func_name: str) -> llvmlite.ir.Function:
98+
if not self.hasSymbol(func_name):
99+
raise ValueError(f'symbol not found for func_name {func_name}')
100+
link_name = self._select_link_name(func_name)
101+
module_hash = self.link_name_to_module_hash[link_name]
102+
return self.loadedBinarySharedObjects[module_hash].binarySharedObject.functionIRs[func_name]
103+
104+
def getDefinition(self, func_name: str) -> native_ast.Function:
105+
if not self.hasSymbol(func_name):
106+
raise ValueError(f'symbol not found for func_name {func_name}')
107+
link_name = self._select_link_name(func_name)
108+
module_hash = self.link_name_to_module_hash[link_name]
109+
serialized_definition = self.loadedBinarySharedObjects[module_hash].binarySharedObject.serializedFunctionDefinitions[func_name]
110+
return SerializationContext().deserialize(serialized_definition)
111+
93112
def _generate_link_name(self, func_name: str, module_hash: str) -> str:
94113
return func_name + "." + module_hash
95114

@@ -126,6 +145,14 @@ def loadForSymbol(self, linkName: str) -> None:
126145
if not self.loadedBinarySharedObjects[moduleHash].validateGlobalVariables(definitionsToLink):
127146
raise RuntimeError('failed to validate globals when loading:', linkName)
128147

148+
def complexityForSymbol(self, func_name: str) -> int:
149+
"""Get the total number of LLVM instructions for a given symbol."""
150+
try:
151+
link_name = self._select_link_name(func_name)
152+
return self.targetComplexity[link_name]
153+
except KeyError as e:
154+
raise ValueError(f'No complexity value cached for {func_name}') from e
155+
129156
def loadModuleByHash(self, moduleHash: str) -> None:
130157
"""Load a module by name.
131158
@@ -139,23 +166,23 @@ def loadModuleByHash(self, moduleHash: str) -> None:
139166

140167
# TODO (Will) - store these names as module consts, use one .dat only
141168
with open(os.path.join(targetDir, "type_manifest.dat"), "rb") as f:
142-
# func_name -> typedcalltarget
143169
callTargets = SerializationContext().deserialize(f.read())
144-
145170
with open(os.path.join(targetDir, "globals_manifest.dat"), "rb") as f:
146171
serializedGlobalVarDefs = SerializationContext().deserialize(f.read())
147-
148172
with open(os.path.join(targetDir, "native_type_manifest.dat"), "rb") as f:
149173
functionNameToNativeType = SerializationContext().deserialize(f.read())
150-
151174
with open(os.path.join(targetDir, "submodules.dat"), "rb") as f:
152175
submodules = SerializationContext().deserialize(f.read(), ListOf(str))
153-
154176
with open(os.path.join(targetDir, "function_dependencies.dat"), "rb") as f:
155177
dependency_edgelist = SerializationContext().deserialize(f.read())
156-
157178
with open(os.path.join(targetDir, "global_dependencies.dat"), "rb") as f:
158179
globalDependencies = SerializationContext().deserialize(f.read())
180+
with open(os.path.join(targetDir, "function_complexities.dat"), "rb") as f:
181+
functionComplexities = SerializationContext().deserialize(f.read())
182+
with open(os.path.join(targetDir, "function_irs.dat"), "rb") as f:
183+
functionIRs = SerializationContext().deserialize(f.read())
184+
with open(os.path.join(targetDir, "function_definitions.dat"), "rb") as f:
185+
functionDefinitions = SerializationContext().deserialize(f.read())
159186

160187
# load the submodules first
161188
for submodule in submodules:
@@ -167,7 +194,10 @@ def loadModuleByHash(self, moduleHash: str) -> None:
167194
modulePath,
168195
serializedGlobalVarDefs,
169196
functionNameToNativeType,
170-
globalDependencies
197+
globalDependencies,
198+
functionComplexities,
199+
functionIRs,
200+
functionDefinitions
171201
).loadFromPath(modulePath)
172202

173203
self.loadedBinarySharedObjects[moduleHash] = loaded
@@ -177,8 +207,11 @@ def loadModuleByHash(self, moduleHash: str) -> None:
177207
assert link_name not in self.targetsLoaded
178208
self.targetsLoaded[link_name] = callTarget
179209

180-
link_name_global_dependencies = {self._generate_link_name(x, moduleHash): y for x, y in globalDependencies.items()}
210+
for func_name, complexity in functionComplexities.items():
211+
link_name = self._generate_link_name(func_name, moduleHash)
212+
self.targetComplexity[link_name] = complexity
181213

214+
link_name_global_dependencies = {self._generate_link_name(x, moduleHash): y for x, y in globalDependencies.items()}
182215
assert not any(key in self.global_dependencies for key in link_name_global_dependencies)
183216

184217
self.global_dependencies.update(link_name_global_dependencies)
@@ -222,6 +255,10 @@ def addModule(self, binarySharedObject, nameToTypedCallTarget, linkDependencies,
222255

223256
path = self.writeModuleToDisk(binarySharedObject, hashToUse, nameToTypedCallTarget, dependentHashes, link_name_dependency_edgelist)
224257

258+
for func_name, complexity in binarySharedObject.functionComplexities.items():
259+
link_name = self._generate_link_name(func_name, hashToUse)
260+
self.targetComplexity[link_name] = complexity
261+
225262
self.loadedBinarySharedObjects[hashToUse] = (
226263
binarySharedObject.loadFromPath(os.path.join(path, "module.so"))
227264
)
@@ -314,6 +351,15 @@ def writeModuleToDisk(self, binarySharedObject, hashToUse, nameToTypedCallTarget
314351
with open(os.path.join(tempTargetDir, "global_dependencies.dat"), "wb") as f:
315352
f.write(SerializationContext().serialize(binarySharedObject.globalDependencies))
316353

354+
with open(os.path.join(tempTargetDir, "function_complexities.dat"), "wb") as f:
355+
f.write(SerializationContext().serialize(binarySharedObject.functionComplexities))
356+
357+
with open(os.path.join(tempTargetDir, "function_irs.dat"), "wb") as f:
358+
f.write(SerializationContext().serialize(binarySharedObject.functionIRs))
359+
360+
with open(os.path.join(tempTargetDir, "function_definitions.dat"), "wb") as f:
361+
f.write(SerializationContext().serialize(binarySharedObject.serializedFunctionDefinitions))
362+
317363
try:
318364
os.rename(tempTargetDir, targetDir)
319365
except IOError:

typed_python/compiler/compiler_cache_test.py

+1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import threading
1717
import os
1818
import pytest
19+
1920
from typed_python.test_util import evaluateExprInFreshProcess
2021

2122
MAIN_MODULE = """

typed_python/compiler/llvm_compiler.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,10 @@ def buildSharedObject(self, functions):
123123
mod,
124124
serializedGlobalVariableDefinitions,
125125
module.functionNameToType,
126-
module.globalDependencies
126+
module.globalDependencies,
127+
{name: self.converter.totalFunctionComplexity(name) for name in functions},
128+
{name: self.converter._functions_by_name[name] for name in functions},
129+
{name: SerializationContext().serialize(self.converter._function_definitions[name]) for name in functions},
127130
)
128131

129132
def function_pointer_by_name(self, name):

typed_python/compiler/native_ast_to_llvm.py

+33-15
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from typed_python.compiler.global_variable_definition import GlobalVariableDefinition
1919
from typed_python.compiler.module_definition import ModuleDefinition
2020
from typing import Dict
21+
2122
llvm_i8ptr = llvmlite.ir.IntType(8).as_pointer()
2223
llvm_i8 = llvmlite.ir.IntType(8)
2324
llvm_i32 = llvmlite.ir.IntType(32)
@@ -642,6 +643,7 @@ def namedCallTargetToLLVM(self, target: native_ast.NamedCallTarget) -> TypedLLVM
642643
2. The function is in function_definitions, in which case we grab the function definition and make an inlining decision.
643644
3. We have a compiler cache, and the function is in it. We add to external_function_references.
644645
"""
646+
assert isinstance(target, native_ast.NamedCallTarget)
645647
if target.external:
646648
if target.name not in self.external_function_references:
647649
func_type = llvmlite.ir.FunctionType(
@@ -673,24 +675,29 @@ def namedCallTargetToLLVM(self, target: native_ast.NamedCallTarget) -> TypedLLVM
673675

674676
func = self.external_function_references[target.name]
675677
else:
676-
# TODO (Will): decide whether to inline cached code
677678
assert self.compilerCache is not None and self.compilerCache.hasSymbol(target.name)
678679
# this function is defined in a shared object that we've loaded from a prior
679-
# invocation
680-
if target.name not in self.external_function_references:
681-
func_type = llvmlite.ir.FunctionType(
682-
type_to_llvm_type(target.output_type),
683-
[type_to_llvm_type(x) for x in target.arg_types],
684-
var_arg=target.varargs
685-
)
680+
# invocation. Again, first make an inlining decision.
681+
if (
682+
self.compilerCache.complexityForSymbol(target.name) < CROSS_MODULE_INLINE_COMPLEXITY
683+
):
684+
self.converter.generateDefinition(target.name)
685+
func = self.converter.repeatFunctionInModule(target.name, self.module)
686+
else:
687+
if target.name not in self.external_function_references:
688+
func_type = llvmlite.ir.FunctionType(
689+
type_to_llvm_type(target.output_type),
690+
[type_to_llvm_type(x) for x in target.arg_types],
691+
var_arg=target.varargs
692+
)
686693

687-
assert target.name not in self.converter._function_definitions, target.name
694+
assert target.name not in self.converter._function_definitions, target.name
688695

689-
self.external_function_references[target.name] = (
690-
llvmlite.ir.Function(self.module, func_type, target.name)
691-
)
696+
self.external_function_references[target.name] = (
697+
llvmlite.ir.Function(self.module, func_type, target.name)
698+
)
692699

693-
func = self.external_function_references[target.name]
700+
func = self.external_function_references[target.name]
694701

695702
return TypedLLVMValue(
696703
func,
@@ -1528,6 +1535,18 @@ def totalFunctionComplexity(self, name):
15281535

15291536
return res
15301537

1538+
def generateDefinition(self, name: str) -> None:
1539+
"""Pull the TypedCallTarget matching `name` from the cache, and use to rebuild
1540+
the function definition. Add to _function_definitions and _functions_by_name.
1541+
"""
1542+
assert self.compilerCache is not None
1543+
1544+
definition = self.compilerCache.getDefinition(name)
1545+
llvm_func = self.compilerCache.getIR(name)
1546+
1547+
self._functions_by_name[name] = llvm_func
1548+
self._function_definitions[name] = definition
1549+
15311550
def repeatFunctionInModule(self, name, module):
15321551
"""Request that the function given by 'name' be inlined into 'module'.
15331552
@@ -1580,7 +1599,6 @@ def add_functions(self, names_to_definitions):
15801599
[type_to_llvm_type(x[1]) for x in function.args]
15811600
)
15821601
self._functions_by_name[name] = llvmlite.ir.Function(module, func_type, name)
1583-
15841602
self._functions_by_name[name].linkage = 'external'
15851603
self._function_definitions[name] = function
15861604

@@ -1664,6 +1682,7 @@ def add_functions(self, names_to_definitions):
16641682
# want to repeat its definition in this particular module.
16651683
for name in self._inlineRequests:
16661684
names_to_definitions[name] = self._function_definitions[name]
1685+
16671686
self._inlineRequests.clear()
16681687

16691688
# define a function that accepts a pointer and fills it out with a table of pointer values
@@ -1674,7 +1693,6 @@ def add_functions(self, names_to_definitions):
16741693
output=native_ast.Void,
16751694
args=[native_ast.Void.pointer().pointer()]
16761695
)
1677-
16781696
return ModuleDefinition(
16791697
str(module),
16801698
functionTypes,

typed_python/compiler/python_to_native_converter.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
from typed_python.hash import Hash
1919
from types import ModuleType
20-
from typing import Dict
20+
from typing import Dict, Optional
2121
from typed_python import Class
2222
import typed_python.python_ast as python_ast
2323
import typed_python._types as _types
@@ -288,10 +288,10 @@ def deleteTarget(self, linkName):
288288
self._targets.pop(linkName)
289289

290290
def setTarget(self, linkName, target):
291-
assert(isinstance(target, TypedCallTarget))
291+
assert (isinstance(target, TypedCallTarget))
292292
self._targets[linkName] = target
293293

294-
def getTarget(self, linkName) -> TypedCallTarget:
294+
def getTarget(self, linkName) -> Optional[TypedCallTarget]:
295295
if linkName in self._targets:
296296
return self._targets[linkName]
297297

0 commit comments

Comments
 (0)