diff --git a/src/backends/onnx.js b/src/backends/onnx.js index 38cd71337..00267fe4d 100644 --- a/src/backends/onnx.js +++ b/src/backends/onnx.js @@ -152,7 +152,12 @@ export async function createInferenceSession(buffer, session_options, session_co // so we wait for it to resolve before creating this new session. await wasmInitPromise; } - + if (ONNX_ENV?.webgpu) { + const adapter = await navigator.gpu.requestAdapter({ powerPreference: 'high-performance' }) + if (adapter) { + ONNX_ENV.webgpu.device = await adapter.requestDevice(); + } + } const sessionPromise = InferenceSession.create(buffer, session_options); wasmInitPromise ??= sessionPromise; const session = await sessionPromise; @@ -194,10 +199,6 @@ if (ONNX_ENV?.wasm) { } } -if (ONNX_ENV?.webgpu) { - ONNX_ENV.webgpu.powerPreference = 'high-performance'; -} - /** * Check if ONNX's WASM backend is being proxied. * @returns {boolean} Whether ONNX's WASM backend is being proxied.