Skip to content

Commit c93edd0

Browse files
committed
Fix download progress reporting
1 parent 07f6fa2 commit c93edd0

3 files changed

Lines changed: 91 additions & 26 deletions

File tree

src/cpp/include/lemon/backends/backend_utils.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,8 @@ namespace lemon::backends {
7979
static std::string get_therock_install_dir(const std::string& arch, const std::string& version);
8080

8181
/** Download and install TheRock ROCm tarball for the specified architecture (Linux only) */
82-
static void install_therock(const std::string& arch, const std::string& version);
82+
static void install_therock(const std::string& arch, const std::string& version,
83+
DownloadProgressCallback progress_cb = nullptr);
8384

8485
/** Clean up old TheRock versions, keeping only the specified version */
8586
static void cleanup_old_therock_versions(const std::string& current_version);

src/cpp/server/backend_manager.cpp

Lines changed: 55 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -194,8 +194,8 @@ void install_rocm_stable_runtime_if_needed(const std::string& os,
194194
http_progress_cb = [&progress_cb, &filename](size_t downloaded, size_t total) -> bool {
195195
DownloadProgress p;
196196
p.file = filename;
197-
p.file_index = 1;
198-
p.total_files = 1;
197+
p.file_index = 2; // ROCm runtime is the second file
198+
p.total_files = 2; // Backend binary + ROCm runtime
199199
p.bytes_downloaded = downloaded;
200200
p.bytes_total = total;
201201
p.percent = total > 0 ? static_cast<int>((downloaded * 100) / total) : 0;
@@ -223,12 +223,12 @@ void install_rocm_stable_runtime_if_needed(const std::string& os,
223223
if (progress_cb) {
224224
DownloadProgress p;
225225
p.file = filename;
226-
p.file_index = 1;
227-
p.total_files = 1;
226+
p.file_index = 2; // ROCm runtime is the second file
227+
p.total_files = 2; // Backend binary + ROCm runtime
228228
p.bytes_downloaded = download_result.bytes_downloaded;
229229
p.bytes_total = download_result.total_bytes;
230230
p.percent = 100;
231-
p.complete = true;
231+
p.complete = true; // Now we can send the completion event
232232
progress_cb(p);
233233
}
234234
}
@@ -249,32 +249,27 @@ void uninstall_rocm_stable_runtime_if_needed(const std::string& os) {
249249
}
250250
}
251251

252-
void install_therock_if_needed(const std::string& os, const json& backend_versions) {
252+
bool will_install_therock(const std::string& os, const json& backend_versions) {
253253
// TheRock is only needed on Linux for ROCm preview channel.
254254
if (os != "linux") {
255-
return;
255+
return false;
256256
}
257257

258258
// Check if system ROCm is available - if so, don't need TheRock
259259
if (backends::BackendUtils::is_rocm_installed_system_wide()) {
260-
LOG(DEBUG, "BackendManager")
261-
<< "System ROCm detected, skipping TheRock installation" << std::endl;
262-
return;
260+
return false;
263261
}
264262

265263
// Get ROCm architecture
266264
std::string rocm_arch = SystemInfo::get_rocm_arch();
267265
if (rocm_arch.empty()) {
268-
LOG(DEBUG, "BackendManager")
269-
<< "No ROCm architecture detected, skipping TheRock installation" << std::endl;
270-
return;
266+
return false;
271267
}
272268

273269
// Get TheRock version from backend_versions.json
274270
if (!backend_versions.contains("therock") || !backend_versions["therock"].contains("version")) {
275-
throw std::runtime_error("backend_versions.json is missing 'therock.version'");
271+
return false;
276272
}
277-
std::string version = backend_versions["therock"]["version"].get<std::string>();
278273

279274
// Check if this architecture is supported
280275
if (backend_versions["therock"].contains("architectures") &&
@@ -287,14 +282,24 @@ void install_therock_if_needed(const std::string& os, const json& backend_versio
287282
}
288283
}
289284
if (!arch_supported) {
290-
LOG(DEBUG, "BackendManager")
291-
<< "Architecture " << rocm_arch << " not supported by TheRock" << std::endl;
292-
return;
285+
return false;
293286
}
294287
}
295288

289+
return true;
290+
}
291+
292+
void install_therock_if_needed(const std::string& os, const json& backend_versions,
293+
DownloadProgressCallback progress_cb = nullptr) {
294+
if (!will_install_therock(os, backend_versions)) {
295+
return;
296+
}
297+
298+
std::string rocm_arch = SystemInfo::get_rocm_arch();
299+
std::string version = backend_versions["therock"]["version"].get<std::string>();
300+
296301
// Install TheRock for this architecture
297-
backends::BackendUtils::install_therock(rocm_arch, version);
302+
backends::BackendUtils::install_therock(rocm_arch, version, progress_cb);
298303
}
299304

300305
} // namespace
@@ -374,15 +379,42 @@ void BackendManager::install_backend(const std::string& recipe, const std::strin
374379
throw std::runtime_error("[BackendManager] Unknown recipe: " + recipe);
375380
}
376381

382+
// Check if we need to install additional runtime components after the main backend
383+
bool needs_rocm_stable_runtime = (recipe == "llamacpp" || recipe == "sd-cpp") &&
384+
resolved_backend == "rocm-stable" &&
385+
get_current_os() == "linux";
386+
387+
bool needs_therock = (recipe == "llamacpp" || recipe == "sd-cpp") &&
388+
resolved_backend == "rocm-preview" &&
389+
will_install_therock(get_current_os(), backend_versions_);
390+
391+
// Wrap the progress callback to adjust file indices if runtime download follows
392+
DownloadProgressCallback wrapped_progress_cb;
393+
if (progress_cb && (needs_rocm_stable_runtime || needs_therock)) {
394+
wrapped_progress_cb = [progress_cb](const DownloadProgress& p) -> bool {
395+
DownloadProgress adjusted = p;
396+
// Adjust to indicate this is file 1 of 2
397+
adjusted.file_index = 1;
398+
adjusted.total_files = 2;
399+
// Suppress the completion event - we'll send it after the runtime download
400+
if (p.complete) {
401+
adjusted.complete = false;
402+
}
403+
return progress_cb(adjusted);
404+
};
405+
} else {
406+
wrapped_progress_cb = progress_cb;
407+
}
408+
377409
backends::BackendUtils::install_from_github(
378-
*spec, params.version, params.repo, params.filename, resolved_backend, progress_cb);
410+
*spec, params.version, params.repo, params.filename, resolved_backend, wrapped_progress_cb);
379411

380-
if ((recipe == "llamacpp" || recipe == "sd-cpp") && resolved_backend == "rocm-stable") {
412+
if (needs_rocm_stable_runtime) {
381413
install_rocm_stable_runtime_if_needed(get_current_os(), *spec, backend_versions_, progress_cb);
382414
}
383415

384-
if ((recipe == "llamacpp" || recipe == "sd-cpp") && resolved_backend == "rocm-preview") {
385-
install_therock_if_needed(get_current_os(), backend_versions_);
416+
if (needs_therock) {
417+
install_therock_if_needed(get_current_os(), backend_versions_, progress_cb);
386418
}
387419
}
388420

src/cpp/server/backends/backend_utils.cpp

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -483,7 +483,8 @@ namespace lemon::backends {
483483
#endif
484484
}
485485

486-
void BackendUtils::install_therock(const std::string& arch, const std::string& version) {
486+
void BackendUtils::install_therock(const std::string& arch, const std::string& version,
487+
DownloadProgressCallback progress_cb) {
487488
#ifndef __linux__
488489
throw std::runtime_error("TheRock is only supported on Linux");
489490
#else
@@ -524,10 +525,28 @@ namespace lemon::backends {
524525
LOG(DEBUG, "BackendUtils") << "Downloading TheRock from: " << url << std::endl;
525526
LOG(DEBUG, "BackendUtils") << "Downloading to: " << tarball_path << std::endl;
526527

528+
// Create progress callback for download
529+
utils::ProgressCallback http_progress_cb;
530+
if (progress_cb) {
531+
http_progress_cb = [&progress_cb, &filename](size_t downloaded, size_t total) -> bool {
532+
DownloadProgress p;
533+
p.file = filename;
534+
p.file_index = 1;
535+
p.total_files = 1;
536+
p.bytes_downloaded = downloaded;
537+
p.bytes_total = total;
538+
p.percent = total > 0 ? static_cast<int>((downloaded * 100) / total) : 0;
539+
p.complete = false;
540+
return progress_cb(p);
541+
};
542+
} else {
543+
http_progress_cb = utils::create_throttled_progress_callback();
544+
}
545+
527546
auto download_result = utils::HttpClient::download_file(
528547
url,
529548
tarball_path,
530-
utils::create_throttled_progress_callback()
549+
http_progress_cb
531550
);
532551

533552
if (!download_result.success) {
@@ -567,6 +586,19 @@ namespace lemon::backends {
567586
fs::remove(tarball_path);
568587
cleanup_old_therock_versions(version);
569588

589+
// Send completion notification
590+
if (progress_cb) {
591+
DownloadProgress p;
592+
p.file = filename;
593+
p.file_index = 1;
594+
p.total_files = 1;
595+
p.bytes_downloaded = download_result.bytes_downloaded;
596+
p.bytes_total = download_result.total_bytes;
597+
p.percent = 100;
598+
p.complete = true;
599+
progress_cb(p);
600+
}
601+
570602
LOG(INFO, "BackendUtils") << "TheRock installation complete" << std::endl;
571603
#endif
572604
}

0 commit comments

Comments
 (0)