@@ -194,8 +194,8 @@ void install_rocm_stable_runtime_if_needed(const std::string& os,
194194 http_progress_cb = [&progress_cb, &filename](size_t downloaded, size_t total) -> bool {
195195 DownloadProgress p;
196196 p.file = filename;
197- p.file_index = 1 ;
198- p.total_files = 1 ;
197+ p.file_index = 2 ; // ROCm runtime is the second file
198+ p.total_files = 2 ; // Backend binary + ROCm runtime
199199 p.bytes_downloaded = downloaded;
200200 p.bytes_total = total;
201201 p.percent = total > 0 ? static_cast <int >((downloaded * 100 ) / total) : 0 ;
@@ -223,12 +223,12 @@ void install_rocm_stable_runtime_if_needed(const std::string& os,
223223 if (progress_cb) {
224224 DownloadProgress p;
225225 p.file = filename;
226- p.file_index = 1 ;
227- p.total_files = 1 ;
226+ p.file_index = 2 ; // ROCm runtime is the second file
227+ p.total_files = 2 ; // Backend binary + ROCm runtime
228228 p.bytes_downloaded = download_result.bytes_downloaded ;
229229 p.bytes_total = download_result.total_bytes ;
230230 p.percent = 100 ;
231- p.complete = true ;
231+ p.complete = true ; // Now we can send the completion event
232232 progress_cb (p);
233233 }
234234}
@@ -249,32 +249,27 @@ void uninstall_rocm_stable_runtime_if_needed(const std::string& os) {
249249 }
250250}
251251
252- void install_therock_if_needed (const std::string& os, const json& backend_versions) {
252+ bool will_install_therock (const std::string& os, const json& backend_versions) {
253253 // TheRock is only needed on Linux for ROCm preview channel.
254254 if (os != " linux" ) {
255- return ;
255+ return false ;
256256 }
257257
258258 // Check if system ROCm is available - if so, don't need TheRock
259259 if (backends::BackendUtils::is_rocm_installed_system_wide ()) {
260- LOG (DEBUG, " BackendManager" )
261- << " System ROCm detected, skipping TheRock installation" << std::endl;
262- return ;
260+ return false ;
263261 }
264262
265263 // Get ROCm architecture
266264 std::string rocm_arch = SystemInfo::get_rocm_arch ();
267265 if (rocm_arch.empty ()) {
268- LOG (DEBUG, " BackendManager" )
269- << " No ROCm architecture detected, skipping TheRock installation" << std::endl;
270- return ;
266+ return false ;
271267 }
272268
273269 // Get TheRock version from backend_versions.json
274270 if (!backend_versions.contains (" therock" ) || !backend_versions[" therock" ].contains (" version" )) {
275- throw std::runtime_error ( " backend_versions.json is missing 'therock.version' " ) ;
271+ return false ;
276272 }
277- std::string version = backend_versions[" therock" ][" version" ].get <std::string>();
278273
279274 // Check if this architecture is supported
280275 if (backend_versions[" therock" ].contains (" architectures" ) &&
@@ -287,14 +282,24 @@ void install_therock_if_needed(const std::string& os, const json& backend_versio
287282 }
288283 }
289284 if (!arch_supported) {
290- LOG (DEBUG, " BackendManager" )
291- << " Architecture " << rocm_arch << " not supported by TheRock" << std::endl;
292- return ;
285+ return false ;
293286 }
294287 }
295288
289+ return true ;
290+ }
291+
292+ void install_therock_if_needed (const std::string& os, const json& backend_versions,
293+ DownloadProgressCallback progress_cb = nullptr ) {
294+ if (!will_install_therock (os, backend_versions)) {
295+ return ;
296+ }
297+
298+ std::string rocm_arch = SystemInfo::get_rocm_arch ();
299+ std::string version = backend_versions[" therock" ][" version" ].get <std::string>();
300+
296301 // Install TheRock for this architecture
297- backends::BackendUtils::install_therock (rocm_arch, version);
302+ backends::BackendUtils::install_therock (rocm_arch, version, progress_cb );
298303}
299304
300305} // namespace
@@ -374,15 +379,42 @@ void BackendManager::install_backend(const std::string& recipe, const std::strin
374379 throw std::runtime_error (" [BackendManager] Unknown recipe: " + recipe);
375380 }
376381
382+ // Check if we need to install additional runtime components after the main backend
383+ bool needs_rocm_stable_runtime = (recipe == " llamacpp" || recipe == " sd-cpp" ) &&
384+ resolved_backend == " rocm-stable" &&
385+ get_current_os () == " linux" ;
386+
387+ bool needs_therock = (recipe == " llamacpp" || recipe == " sd-cpp" ) &&
388+ resolved_backend == " rocm-preview" &&
389+ will_install_therock (get_current_os (), backend_versions_);
390+
391+ // Wrap the progress callback to adjust file indices if runtime download follows
392+ DownloadProgressCallback wrapped_progress_cb;
393+ if (progress_cb && (needs_rocm_stable_runtime || needs_therock)) {
394+ wrapped_progress_cb = [progress_cb](const DownloadProgress& p) -> bool {
395+ DownloadProgress adjusted = p;
396+ // Adjust to indicate this is file 1 of 2
397+ adjusted.file_index = 1 ;
398+ adjusted.total_files = 2 ;
399+ // Suppress the completion event - we'll send it after the runtime download
400+ if (p.complete ) {
401+ adjusted.complete = false ;
402+ }
403+ return progress_cb (adjusted);
404+ };
405+ } else {
406+ wrapped_progress_cb = progress_cb;
407+ }
408+
377409 backends::BackendUtils::install_from_github (
378- *spec, params.version , params.repo , params.filename , resolved_backend, progress_cb );
410+ *spec, params.version , params.repo , params.filename , resolved_backend, wrapped_progress_cb );
379411
380- if ((recipe == " llamacpp " || recipe == " sd-cpp " ) && resolved_backend == " rocm-stable " ) {
412+ if (needs_rocm_stable_runtime ) {
381413 install_rocm_stable_runtime_if_needed (get_current_os (), *spec, backend_versions_, progress_cb);
382414 }
383415
384- if ((recipe == " llamacpp " || recipe == " sd-cpp " ) && resolved_backend == " rocm-preview " ) {
385- install_therock_if_needed (get_current_os (), backend_versions_);
416+ if (needs_therock ) {
417+ install_therock_if_needed (get_current_os (), backend_versions_, progress_cb );
386418 }
387419}
388420
0 commit comments