Skip to content

Commit e2d754a

Browse files
author
William Grant
committed
Disallow compilation during module import
During the compilation of one Entrypointed function it's possible to import a module which calls a second Entrypointed function. This breaks our model of the compilation process and could cause deadlocks, so instead throw an ImportError.
1 parent ad77e7d commit e2d754a

File tree

2 files changed

+61
-0
lines changed

2 files changed

+61
-0
lines changed

typed_python/compiler/compiler_cache_test.py

+35
Original file line numberDiff line numberDiff line change
@@ -654,3 +654,38 @@ def g(x):
654654
# run twice to check cached code can be retrieved
655655
assert evaluateExprInFreshProcess(MODULE, 'x.getX()', compilerCacheDir) == 1
656656
assert evaluateExprInFreshProcess(MODULE, 'x.getX()', compilerCacheDir) == 1
657+
658+
659+
@pytest.mark.skipif('sys.platform=="darwin"')
660+
def test_compiler_cache_throws_on_import_loop():
661+
"""It is possible, when compiling a module, to attempt to deserialise
662+
a callTarget containing a module import which runs an Entrypointed function.
663+
This results in a 'compilation loop' where one iteration of the conversion
664+
is waiting on another, which currently breaks our model of the compilation
665+
process.
666+
"""
667+
module1 = """
668+
@Entrypoint
669+
def f(x):
670+
return x+1
671+
f(1)
672+
""".replace('\n ', '\n')
673+
module2 = """
674+
@Entrypoint
675+
def g():
676+
import x
677+
""".replace('\n ', '\n')
678+
module3 = """
679+
import y
680+
def rung():
681+
try:
682+
y.g()
683+
except ImportError as e:
684+
return 'ImportError caught'
685+
rung()
686+
""".replace('\n ', '\n')
687+
with tempfile.TemporaryDirectory() as compilerCacheDir:
688+
689+
evaluateExprInFreshProcess({'x.py': module1}, 'x.f(1)', compilerCacheDir)
690+
exception_string = evaluateExprInFreshProcess({'z.py': module3, 'y.py': module2, 'x.py': module1, }, 'z.rung()', compilerCacheDir)
691+
assert exception_string == 'ImportError caught'

typed_python/compiler/runtime.py

+26
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import inspect
1516
import threading
1617
import os
1718
import time
@@ -501,6 +502,14 @@ def Entrypoint(pyFunc):
501502

502503
if not callable(typedFunc):
503504
raise Exception(f"Can only compile functions, not {typedFunc}")
505+
# check if we are already in the middle of the compilation process, due to the Entrypointed
506+
# code being called through a module import, and throw an error if so.
507+
if is_importing():
508+
compiling_func = Runtime.singleton().converter._currentlyConverting
509+
compiling_func_link_name = Runtime.singleton().converter._link_name_for_identity[compiling_func]
510+
error_message = f"Can't import Entrypointed code {pyFunc.__module__}.{pyFunc.__qualname__} \
511+
while {compiling_func_link_name} is being compiled."
512+
raise ImportError(error_message)
504513

505514
typedFunc = Function(typedFunc)
506515

@@ -534,3 +543,20 @@ def Compiled(pyFunc):
534543
f.resultTypeFor(*types)
535544

536545
return f
546+
547+
548+
def is_importing():
549+
"""Walk the stack to check if we are currently importing a module.
550+
551+
In this case, we will have an 'importlib' between two 'typed_python.compiler.runtime' frames.
552+
"""
553+
in_runtime = False
554+
assert __name__ == 'typed_python.compiler.runtime', 'is_importing() should only be called from typed_python.compiler.runtime'
555+
for frame, *_ in inspect.stack()[::-1]:
556+
frame_name = frame.f_globals.get("__name__")
557+
if frame_name == 'typed_python.compiler.runtime':
558+
in_runtime = True
559+
if in_runtime and frame_name == 'importlib':
560+
return True
561+
562+
return False

0 commit comments

Comments
 (0)