Skip to content

Commit db85867

Browse files
leofangkkraus14
andauthored
Make populating the internal symbol table thread-safe (#835)
* protect cuPythonInit in driver * add lock for all modules * fixes * fix identation, make consistent * relocate setting __cuPythonInit to avoid deadlock since we use cuGetProcAddress in the init function... * move init check inside lock * make cuPythonInit reentrant + ensure GIL is released when calling underlying C APIs * fix indentation --------- Co-authored-by: Keith Kraus <[email protected]>
1 parent 8f1dd40 commit db85867

File tree

7 files changed

+8872
-8854
lines changed

7 files changed

+8872
-8854
lines changed

cuda_bindings/cuda/bindings/_bindings/cydriver.pyx.in

Lines changed: 8198 additions & 8189 deletions
Large diffs are not rendered by default.

cuda_bindings/cuda/bindings/_bindings/cynvrtc.pyx.in

Lines changed: 108 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@ from libc.stdint cimport uintptr_t
1111
{{endif}}
1212
from cuda.pathfinder import load_nvidia_dynamic_lib
1313
from libc.stdint cimport intptr_t
14+
import threading
1415

16+
cdef object __symbol_lock = threading.Lock()
1517
cdef bint __cuPythonInit = False
1618
{{if 'nvrtcGetErrorString' in found_functions}}cdef void *__nvrtcGetErrorString = NULL{{endif}}
1719
{{if 'nvrtcVersion' in found_functions}}cdef void *__nvrtcVersion = NULL{{endif}}
@@ -42,21 +44,12 @@ cdef int cuPythonInit() except -1 nogil:
4244
global __cuPythonInit
4345
if __cuPythonInit:
4446
return 0
45-
__cuPythonInit = True
4647

47-
# Load library
48-
{{if 'Windows' == platform.system()}}
49-
with gil:
48+
with gil, __symbol_lock:
49+
{{if 'Windows' == platform.system()}}
5050
handle = load_nvidia_dynamic_lib("nvrtc")._handle_uint
51-
{{else}}
52-
with gil:
53-
handle = <void*><uintptr_t>load_nvidia_dynamic_lib("nvrtc")._handle_uint
54-
{{endif}}
5551

56-
57-
# Load function
58-
{{if 'Windows' == platform.system()}}
59-
with gil:
52+
# Load function
6053
{{if 'nvrtcGetErrorString' in found_functions}}
6154
try:
6255
global __nvrtcGetErrorString
@@ -226,105 +219,110 @@ cdef int cuPythonInit() except -1 nogil:
226219
pass
227220
{{endif}}
228221

229-
{{else}}
230-
{{if 'nvrtcGetErrorString' in found_functions}}
231-
global __nvrtcGetErrorString
232-
__nvrtcGetErrorString = dlfcn.dlsym(handle, 'nvrtcGetErrorString')
233-
{{endif}}
234-
{{if 'nvrtcVersion' in found_functions}}
235-
global __nvrtcVersion
236-
__nvrtcVersion = dlfcn.dlsym(handle, 'nvrtcVersion')
237-
{{endif}}
238-
{{if 'nvrtcGetNumSupportedArchs' in found_functions}}
239-
global __nvrtcGetNumSupportedArchs
240-
__nvrtcGetNumSupportedArchs = dlfcn.dlsym(handle, 'nvrtcGetNumSupportedArchs')
241-
{{endif}}
242-
{{if 'nvrtcGetSupportedArchs' in found_functions}}
243-
global __nvrtcGetSupportedArchs
244-
__nvrtcGetSupportedArchs = dlfcn.dlsym(handle, 'nvrtcGetSupportedArchs')
245-
{{endif}}
246-
{{if 'nvrtcCreateProgram' in found_functions}}
247-
global __nvrtcCreateProgram
248-
__nvrtcCreateProgram = dlfcn.dlsym(handle, 'nvrtcCreateProgram')
249-
{{endif}}
250-
{{if 'nvrtcDestroyProgram' in found_functions}}
251-
global __nvrtcDestroyProgram
252-
__nvrtcDestroyProgram = dlfcn.dlsym(handle, 'nvrtcDestroyProgram')
253-
{{endif}}
254-
{{if 'nvrtcCompileProgram' in found_functions}}
255-
global __nvrtcCompileProgram
256-
__nvrtcCompileProgram = dlfcn.dlsym(handle, 'nvrtcCompileProgram')
257-
{{endif}}
258-
{{if 'nvrtcGetPTXSize' in found_functions}}
259-
global __nvrtcGetPTXSize
260-
__nvrtcGetPTXSize = dlfcn.dlsym(handle, 'nvrtcGetPTXSize')
261-
{{endif}}
262-
{{if 'nvrtcGetPTX' in found_functions}}
263-
global __nvrtcGetPTX
264-
__nvrtcGetPTX = dlfcn.dlsym(handle, 'nvrtcGetPTX')
265-
{{endif}}
266-
{{if 'nvrtcGetCUBINSize' in found_functions}}
267-
global __nvrtcGetCUBINSize
268-
__nvrtcGetCUBINSize = dlfcn.dlsym(handle, 'nvrtcGetCUBINSize')
269-
{{endif}}
270-
{{if 'nvrtcGetCUBIN' in found_functions}}
271-
global __nvrtcGetCUBIN
272-
__nvrtcGetCUBIN = dlfcn.dlsym(handle, 'nvrtcGetCUBIN')
273-
{{endif}}
274-
{{if 'nvrtcGetLTOIRSize' in found_functions}}
275-
global __nvrtcGetLTOIRSize
276-
__nvrtcGetLTOIRSize = dlfcn.dlsym(handle, 'nvrtcGetLTOIRSize')
277-
{{endif}}
278-
{{if 'nvrtcGetLTOIR' in found_functions}}
279-
global __nvrtcGetLTOIR
280-
__nvrtcGetLTOIR = dlfcn.dlsym(handle, 'nvrtcGetLTOIR')
281-
{{endif}}
282-
{{if 'nvrtcGetOptiXIRSize' in found_functions}}
283-
global __nvrtcGetOptiXIRSize
284-
__nvrtcGetOptiXIRSize = dlfcn.dlsym(handle, 'nvrtcGetOptiXIRSize')
285-
{{endif}}
286-
{{if 'nvrtcGetOptiXIR' in found_functions}}
287-
global __nvrtcGetOptiXIR
288-
__nvrtcGetOptiXIR = dlfcn.dlsym(handle, 'nvrtcGetOptiXIR')
289-
{{endif}}
290-
{{if 'nvrtcGetProgramLogSize' in found_functions}}
291-
global __nvrtcGetProgramLogSize
292-
__nvrtcGetProgramLogSize = dlfcn.dlsym(handle, 'nvrtcGetProgramLogSize')
293-
{{endif}}
294-
{{if 'nvrtcGetProgramLog' in found_functions}}
295-
global __nvrtcGetProgramLog
296-
__nvrtcGetProgramLog = dlfcn.dlsym(handle, 'nvrtcGetProgramLog')
297-
{{endif}}
298-
{{if 'nvrtcAddNameExpression' in found_functions}}
299-
global __nvrtcAddNameExpression
300-
__nvrtcAddNameExpression = dlfcn.dlsym(handle, 'nvrtcAddNameExpression')
301-
{{endif}}
302-
{{if 'nvrtcGetLoweredName' in found_functions}}
303-
global __nvrtcGetLoweredName
304-
__nvrtcGetLoweredName = dlfcn.dlsym(handle, 'nvrtcGetLoweredName')
305-
{{endif}}
306-
{{if 'nvrtcGetPCHHeapSize' in found_functions}}
307-
global __nvrtcGetPCHHeapSize
308-
__nvrtcGetPCHHeapSize = dlfcn.dlsym(handle, 'nvrtcGetPCHHeapSize')
309-
{{endif}}
310-
{{if 'nvrtcSetPCHHeapSize' in found_functions}}
311-
global __nvrtcSetPCHHeapSize
312-
__nvrtcSetPCHHeapSize = dlfcn.dlsym(handle, 'nvrtcSetPCHHeapSize')
313-
{{endif}}
314-
{{if 'nvrtcGetPCHCreateStatus' in found_functions}}
315-
global __nvrtcGetPCHCreateStatus
316-
__nvrtcGetPCHCreateStatus = dlfcn.dlsym(handle, 'nvrtcGetPCHCreateStatus')
317-
{{endif}}
318-
{{if 'nvrtcGetPCHHeapSizeRequired' in found_functions}}
319-
global __nvrtcGetPCHHeapSizeRequired
320-
__nvrtcGetPCHHeapSizeRequired = dlfcn.dlsym(handle, 'nvrtcGetPCHHeapSizeRequired')
321-
{{endif}}
322-
{{if 'nvrtcSetFlowCallback' in found_functions}}
323-
global __nvrtcSetFlowCallback
324-
__nvrtcSetFlowCallback = dlfcn.dlsym(handle, 'nvrtcSetFlowCallback')
325-
{{endif}}
222+
{{else}}
223+
handle = <void*><uintptr_t>(load_nvidia_dynamic_lib("nvrtc")._handle_uint)
326224

327-
{{endif}}
225+
# Load function
226+
{{if 'nvrtcGetErrorString' in found_functions}}
227+
global __nvrtcGetErrorString
228+
__nvrtcGetErrorString = dlfcn.dlsym(handle, 'nvrtcGetErrorString')
229+
{{endif}}
230+
{{if 'nvrtcVersion' in found_functions}}
231+
global __nvrtcVersion
232+
__nvrtcVersion = dlfcn.dlsym(handle, 'nvrtcVersion')
233+
{{endif}}
234+
{{if 'nvrtcGetNumSupportedArchs' in found_functions}}
235+
global __nvrtcGetNumSupportedArchs
236+
__nvrtcGetNumSupportedArchs = dlfcn.dlsym(handle, 'nvrtcGetNumSupportedArchs')
237+
{{endif}}
238+
{{if 'nvrtcGetSupportedArchs' in found_functions}}
239+
global __nvrtcGetSupportedArchs
240+
__nvrtcGetSupportedArchs = dlfcn.dlsym(handle, 'nvrtcGetSupportedArchs')
241+
{{endif}}
242+
{{if 'nvrtcCreateProgram' in found_functions}}
243+
global __nvrtcCreateProgram
244+
__nvrtcCreateProgram = dlfcn.dlsym(handle, 'nvrtcCreateProgram')
245+
{{endif}}
246+
{{if 'nvrtcDestroyProgram' in found_functions}}
247+
global __nvrtcDestroyProgram
248+
__nvrtcDestroyProgram = dlfcn.dlsym(handle, 'nvrtcDestroyProgram')
249+
{{endif}}
250+
{{if 'nvrtcCompileProgram' in found_functions}}
251+
global __nvrtcCompileProgram
252+
__nvrtcCompileProgram = dlfcn.dlsym(handle, 'nvrtcCompileProgram')
253+
{{endif}}
254+
{{if 'nvrtcGetPTXSize' in found_functions}}
255+
global __nvrtcGetPTXSize
256+
__nvrtcGetPTXSize = dlfcn.dlsym(handle, 'nvrtcGetPTXSize')
257+
{{endif}}
258+
{{if 'nvrtcGetPTX' in found_functions}}
259+
global __nvrtcGetPTX
260+
__nvrtcGetPTX = dlfcn.dlsym(handle, 'nvrtcGetPTX')
261+
{{endif}}
262+
{{if 'nvrtcGetCUBINSize' in found_functions}}
263+
global __nvrtcGetCUBINSize
264+
__nvrtcGetCUBINSize = dlfcn.dlsym(handle, 'nvrtcGetCUBINSize')
265+
{{endif}}
266+
{{if 'nvrtcGetCUBIN' in found_functions}}
267+
global __nvrtcGetCUBIN
268+
__nvrtcGetCUBIN = dlfcn.dlsym(handle, 'nvrtcGetCUBIN')
269+
{{endif}}
270+
{{if 'nvrtcGetLTOIRSize' in found_functions}}
271+
global __nvrtcGetLTOIRSize
272+
__nvrtcGetLTOIRSize = dlfcn.dlsym(handle, 'nvrtcGetLTOIRSize')
273+
{{endif}}
274+
{{if 'nvrtcGetLTOIR' in found_functions}}
275+
global __nvrtcGetLTOIR
276+
__nvrtcGetLTOIR = dlfcn.dlsym(handle, 'nvrtcGetLTOIR')
277+
{{endif}}
278+
{{if 'nvrtcGetOptiXIRSize' in found_functions}}
279+
global __nvrtcGetOptiXIRSize
280+
__nvrtcGetOptiXIRSize = dlfcn.dlsym(handle, 'nvrtcGetOptiXIRSize')
281+
{{endif}}
282+
{{if 'nvrtcGetOptiXIR' in found_functions}}
283+
global __nvrtcGetOptiXIR
284+
__nvrtcGetOptiXIR = dlfcn.dlsym(handle, 'nvrtcGetOptiXIR')
285+
{{endif}}
286+
{{if 'nvrtcGetProgramLogSize' in found_functions}}
287+
global __nvrtcGetProgramLogSize
288+
__nvrtcGetProgramLogSize = dlfcn.dlsym(handle, 'nvrtcGetProgramLogSize')
289+
{{endif}}
290+
{{if 'nvrtcGetProgramLog' in found_functions}}
291+
global __nvrtcGetProgramLog
292+
__nvrtcGetProgramLog = dlfcn.dlsym(handle, 'nvrtcGetProgramLog')
293+
{{endif}}
294+
{{if 'nvrtcAddNameExpression' in found_functions}}
295+
global __nvrtcAddNameExpression
296+
__nvrtcAddNameExpression = dlfcn.dlsym(handle, 'nvrtcAddNameExpression')
297+
{{endif}}
298+
{{if 'nvrtcGetLoweredName' in found_functions}}
299+
global __nvrtcGetLoweredName
300+
__nvrtcGetLoweredName = dlfcn.dlsym(handle, 'nvrtcGetLoweredName')
301+
{{endif}}
302+
{{if 'nvrtcGetPCHHeapSize' in found_functions}}
303+
global __nvrtcGetPCHHeapSize
304+
__nvrtcGetPCHHeapSize = dlfcn.dlsym(handle, 'nvrtcGetPCHHeapSize')
305+
{{endif}}
306+
{{if 'nvrtcSetPCHHeapSize' in found_functions}}
307+
global __nvrtcSetPCHHeapSize
308+
__nvrtcSetPCHHeapSize = dlfcn.dlsym(handle, 'nvrtcSetPCHHeapSize')
309+
{{endif}}
310+
{{if 'nvrtcGetPCHCreateStatus' in found_functions}}
311+
global __nvrtcGetPCHCreateStatus
312+
__nvrtcGetPCHCreateStatus = dlfcn.dlsym(handle, 'nvrtcGetPCHCreateStatus')
313+
{{endif}}
314+
{{if 'nvrtcGetPCHHeapSizeRequired' in found_functions}}
315+
global __nvrtcGetPCHHeapSizeRequired
316+
__nvrtcGetPCHHeapSizeRequired = dlfcn.dlsym(handle, 'nvrtcGetPCHHeapSizeRequired')
317+
{{endif}}
318+
{{if 'nvrtcSetFlowCallback' in found_functions}}
319+
global __nvrtcSetFlowCallback
320+
__nvrtcSetFlowCallback = dlfcn.dlsym(handle, 'nvrtcSetFlowCallback')
321+
{{endif}}
322+
{{endif}}
323+
324+
__cuPythonInit = True
325+
return 0
328326

329327
{{if 'nvrtcGetErrorString' in found_functions}}
330328

0 commit comments

Comments
 (0)