Skip to content

Commit 6bf9707

Browse files
committed
fix: dont initialize all backends if specific backend is requested
1 parent 9d0aca7 commit 6bf9707

1 file changed

Lines changed: 11 additions & 5 deletions

File tree

src/xla/XLA.jl

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ using EnumX: @enumx
88
using Enzyme: Compiler
99
using Preferences: load_preference
1010
using UUIDs: UUID
11+
using ScopedValues: ScopedValue, with
1112

1213
using Setfield: Setfield, @set!
1314

@@ -48,6 +49,8 @@ include("IFRT/IFRT.jl")
4849

4950
include("CompileOptions.jl")
5051

52+
const BACKENDS_TO_INITIALIZE = ScopedValue{Union{Missing,Set{String}}}(missing)
53+
5154
abstract type AbstractBackendState end
5255

5356
function finalize_backend_state end
@@ -145,7 +148,9 @@ function set_default_backend(backend::AbstractClient)
145148
end
146149

147150
function set_default_backend(backend::String)
148-
global_backend_state.default_client = client(backend)
151+
with(BACKENDS_TO_INITIALIZE => Set{String}([backend])) do
152+
global_backend_state.default_client = client(backend)
153+
end
149154
return nothing
150155
end
151156

@@ -235,6 +240,8 @@ end
235240

236241
for runtime in (:PJRT, :IFRT)
237242
@eval function initialize_default_clients!(state::$(Symbol(runtime, :BackendState)))
243+
backends_to_initialize = BACKENDS_TO_INITIALIZE[]
244+
238245
was_initialized = state.initialized
239246
state.initialized = true
240247
distributed_runtime_client = if global_state.num_processes > 1
@@ -249,10 +256,9 @@ for runtime in (:PJRT, :IFRT)
249256
state,
250257
was_initialized;
251258
allow_initialization=backend -> begin
252-
if Reactant.precompiling()
253-
return backend.platform_name == "cpu"
254-
end
255-
return true
259+
Reactant.precompiling() && return backend.platform_name == "cpu"
260+
backends_to_initialize === missing && return true
261+
return backend.platform_name in backends_to_initialize
256262
end,
257263
node_id=global_state.process_id,
258264
num_nodes=global_state.num_processes,

0 commit comments

Comments
 (0)