@@ -8,6 +8,7 @@ using EnumX: @enumx
88using Enzyme: Compiler
99using Preferences: load_preference
1010using UUIDs: UUID
11+ using ScopedValues: ScopedValue, with
1112
1213using Setfield: Setfield, @set!
1314
@@ -48,6 +49,8 @@ include("IFRT/IFRT.jl")
4849
4950include (" CompileOptions.jl" )
5051
52+ const BACKENDS_TO_INITIALIZE = ScopedValue {Union{Missing,Set{String}}} (missing )
53+
5154abstract type AbstractBackendState end
5255
5356function finalize_backend_state end
@@ -145,7 +148,9 @@ function set_default_backend(backend::AbstractClient)
145148end
146149
147150function set_default_backend (backend:: String )
148- global_backend_state. default_client = client (backend)
151+ with (BACKENDS_TO_INITIALIZE => Set {String} ([backend])) do
152+ global_backend_state. default_client = client (backend)
153+ end
149154 return nothing
150155end
151156
235240
236241for runtime in (:PJRT , :IFRT )
237242 @eval function initialize_default_clients! (state:: $ (Symbol (runtime, :BackendState )))
243+ backends_to_initialize = BACKENDS_TO_INITIALIZE[]
244+
238245 was_initialized = state. initialized
239246 state. initialized = true
240247 distributed_runtime_client = if global_state. num_processes > 1
@@ -249,10 +256,9 @@ for runtime in (:PJRT, :IFRT)
249256 state,
250257 was_initialized;
251258 allow_initialization= backend -> begin
252- if Reactant. precompiling ()
253- return backend. platform_name == " cpu"
254- end
255- return true
259+ Reactant. precompiling () && return backend. platform_name == " cpu"
260+ backends_to_initialize === missing && return true
261+ return backend. platform_name in backends_to_initialize
256262 end ,
257263 node_id= global_state. process_id,
258264 num_nodes= global_state. num_processes,
0 commit comments