diff --git a/Cargo.lock b/Cargo.lock index 5ca2a48..61d3476 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -346,13 +346,26 @@ dependencies = [ "burn-central-fleet", "burn-central-inference", "burn-central-macros", - "burn-central-registry", "burn-central-runtime", ] [[package]] -name = "burn-central-client" +name = "burn-central-artifact" version = "0.5.0" +dependencies = [ + "burn-central-core", + "crossbeam", + "reqwest", + "serde", + "serde_json", + "sha2", + "thiserror 2.0.18", +] + +[[package]] +name = "burn-central-client" +version = "0.6.0" +source = "git+https://github.com/tracel-ai/burn-central-client.git?branch=feat%2Ffleets#ffda122ef0a1ea2c8dba86955d36e2f41caac4c3" dependencies = [ "reqwest", "serde", @@ -390,9 +403,9 @@ version = "0.5.0" dependencies = [ "arc-swap", "burn", + "burn-central-artifact", "burn-central-client", "burn-central-inference", - "burn-central-registry", "chrono", "crossbeam-queue", "directories", @@ -429,22 +442,6 @@ dependencies = [ "syn-serde", ] -[[package]] -name = "burn-central-registry" -version = "0.5.0" -dependencies = [ - "burn-central-client", - "burn-central-core", - "crossbeam", - "directories", - "reqwest", - "serde", - "serde_json", - "sha2", - "thiserror 2.0.18", - "url", -] - [[package]] name = "burn-central-runtime" version = "0.5.0" diff --git a/Cargo.toml b/Cargo.toml index 8001045..c993e49 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -57,10 +57,9 @@ tracing-subscriber = "0.3.22" opentelemetry = { version = "0.31.0", features = ["metrics"] } ## External Crate -burn-central-client = { version = "0.5.0", path = "../burn-central-client/burn-central-client" } +burn-central-client = { version = "0.6.0" } # burn-central-client = "0.5.0" -burn-central-registry = { path = "crates/burn-central-registry", version = "0.5.0" } ## Crate burn-central-core = { path = "crates/burn-central-core", version = "0.5.0" } @@ -68,6 +67,7 @@ burn-central-runtime = { path = "crates/burn-central-runtime", version = "0.5.0" burn-central-macros = { path = "crates/burn-central-macros", version = "0.5.0" } burn-central-inference = { path = "crates/burn-central-inference", version = "0.5.0" } burn-central-fleet = { path = "crates/burn-central-fleet", version = "0.5.0" } +burn-central-artifact = { path = "crates/burn-central-artifact", version = "0.5.0" } ### For xtask crate ### tracel-xtask = "4.5.0" @@ -76,4 +76,4 @@ tracel-xtask = "4.5.0" debug = 0 # Speed up compilation time and not necessary. [patch.crates-io] -burn-central-client = { path = "../burn-central-client/burn-central-client" } +burn-central-client = { git = "https://github.com/tracel-ai/burn-central-client.git", branch = "feat/fleets" } diff --git a/crates/burn-central-artifact/Cargo.toml b/crates/burn-central-artifact/Cargo.toml new file mode 100644 index 0000000..cdd29db --- /dev/null +++ b/crates/burn-central-artifact/Cargo.toml @@ -0,0 +1,20 @@ +[package] +name = "burn-central-artifact" +edition.workspace = true +version.workspace = true +readme.workspace = true +license.workspace = true +rust-version.workspace = true +authors.workspace = true +repository.workspace = true +keywords.workspace = true +categories.workspace = true + +[dependencies] +crossbeam = { workspace = true } +reqwest = { version = "0.13.2", features = ["blocking"] } +serde = { workspace = true, features = ["derive"] } +sha2 = { workspace = true } +serde_json = { workspace = true } +burn-central-core = { workspace = true } +thiserror = { workspace = true } diff --git a/crates/burn-central-artifact/src/artifact_download.rs b/crates/burn-central-artifact/src/artifact_download.rs new file mode 100644 index 0000000..e208db3 --- /dev/null +++ b/crates/burn-central-artifact/src/artifact_download.rs @@ -0,0 +1,80 @@ +use std::fs::{self, File}; +use std::io::BufWriter; +use std::path::{Path, PathBuf}; + +use burn_central_core::bundle::normalize_bundle_path; + +use crate::download::{DownloadError, DownloadTask, download_tasks}; +use crate::tools::path::safe_join; + +/// Generic download descriptor for any model artifact file. +#[derive(Debug, Clone)] +pub struct ArtifactDownloadFile { + pub rel_path: String, + pub url: String, + pub size_bytes: u64, + pub checksum: String, +} + +/// Download artifact files into a destination directory, validating size and checksum. +pub fn download_artifacts_to_dir( + dest_root: &Path, + files: &[ArtifactDownloadFile], +) -> Result<(), DownloadError> { + fs::create_dir_all(dest_root)?; + + if files.is_empty() { + return Ok(()); + } + + let mut tmps = Vec::with_capacity(files.len()); + let mut tasks = Vec::with_capacity(files.len()); + for file in files { + let rel_path = normalize_bundle_path(&file.rel_path); + let dest = safe_join(dest_root, &rel_path) + .map_err(|e| DownloadError::InvalidPath(e.to_string()))?; + + if let Some(parent) = dest.parent() { + fs::create_dir_all(parent)?; + } + let tmp = temp_path(&dest)?; + tmps.push((dest.clone(), tmp.clone())); + + let dest_file = File::create(dest)?; + let writer = BufWriter::new(dest_file); + + tasks.push(DownloadTask { + rel_path: rel_path.clone(), + url: file.url.clone(), + writer, + expected_size: file.size_bytes, + expected_checksum: file.checksum.clone(), + }); + } + + let parallelism = std::thread::available_parallelism() + .map(|n| n.get()) + .unwrap_or(4); + let http = reqwest::blocking::Client::new(); + let res = download_tasks(&http, tasks, parallelism); + + for (tmp_dest, tmp) in tmps { + if tmp_dest.exists() { + fs::remove_file(&tmp_dest)?; + } + if tmp.exists() { + fs::rename(tmp, tmp_dest)?; + } + } + + res +} + +/// Generate a temporary file path for downloads. +fn temp_path(dest: &Path) -> Result { + let file_name = dest + .file_name() + .ok_or_else(|| DownloadError::InvalidPath("missing file name".to_string()))? + .to_string_lossy(); + Ok(dest.with_file_name(format!(".{file_name}.partial"))) +} diff --git a/crates/burn-central-artifact/src/download.rs b/crates/burn-central-artifact/src/download.rs new file mode 100644 index 0000000..0cde48b --- /dev/null +++ b/crates/burn-central-artifact/src/download.rs @@ -0,0 +1,165 @@ +use std::io::{Read, Write}; + +use crossbeam::channel; +use reqwest::blocking::Client as HttpClient; +use sha2::Digest; + +use crate::tools::validation::normalize_checksum; + +#[derive(Debug, thiserror::Error)] +pub enum DownloadError { + #[error("failed to download {path}: {details}")] + DownloadFailed { path: String, details: String }, + #[error("size mismatch for {path}: expected {expected} bytes, got {actual} bytes")] + SizeMismatch { + path: String, + expected: u64, + actual: u64, + }, + #[error("checksum mismatch for {path}: expected {expected}, got {actual}")] + ChecksumMismatch { + path: String, + expected: String, + actual: String, + }, + #[error("invalid checksum: {0}")] + InvalidChecksum(String), + #[error("writer error: {0}")] + WriterError(#[from] std::io::Error), + #[error("invalid path: {0}")] + InvalidPath(String), +} + +/// A single file download task. +#[derive(Clone)] +pub struct DownloadTask { + pub rel_path: String, + pub url: String, + pub writer: W, + pub expected_size: u64, + pub expected_checksum: String, +} + +/// Download multiple files in parallel. +pub fn download_tasks( + http: &HttpClient, + tasks: Vec>, + max_parallel: usize, +) -> Result<(), DownloadError> { + if tasks.is_empty() { + return Ok(()); + } + + if max_parallel <= 1 || tasks.len() == 1 { + for mut task in tasks { + download_one(http, &mut task)?; + } + return Ok(()); + } + + let (tx, rx) = channel::unbounded::>(); + for task in tasks { + tx.send(task).expect("channel open"); + } + drop(tx); + + crossbeam::scope(|scope| { + let mut handles = Vec::new(); + let worker_count = max_parallel.min(rx.len().max(1)); + for _ in 0..worker_count { + let rx = rx.clone(); + let http = http.clone(); + handles.push(scope.spawn(move |_| { + for mut task in rx.iter() { + download_one(&http, &mut task)?; + } + Ok::<(), DownloadError>(()) + })); + } + + for handle in handles { + handle.join().expect("thread panicked")?; + } + + Ok(()) + }) + .expect("scope failed") +} + +/// Download a single file with checksum verification. +fn download_one( + http: &HttpClient, + task: &mut DownloadTask, +) -> Result<(), DownloadError> { + // if let Some(parent) = task.dest.parent() { + // fs::create_dir_all(parent)?; + // } + + // let tmp = temp_path(&task.dest)?; + + let mut resp = http + .get(&task.url) + .send() + .map_err(|e| DownloadError::DownloadFailed { + path: task.rel_path.clone(), + details: e.to_string(), + })?; + + if !resp.status().is_success() { + return Err(DownloadError::DownloadFailed { + path: task.rel_path.clone(), + details: format!("HTTP {}", resp.status()), + }); + } + + let sink = &mut task.writer; + let mut hasher = sha2::Sha256::new(); + let mut buf = [0u8; 1024 * 64]; + let mut total = 0u64; + + loop { + let read = resp.read(&mut buf)?; + if read == 0 { + break; + } + sink.write_all(&buf[..read])?; + hasher.update(&buf[..read]); + total += read as u64; + } + + let digest = format!("{:x}", hasher.finalize()); + let expected_checksum = + normalize_checksum(&task.expected_checksum).map_err(DownloadError::InvalidChecksum)?; + + if total != task.expected_size { + return Err(DownloadError::SizeMismatch { + path: task.rel_path.clone(), + expected: task.expected_size, + actual: total, + }); + } + if digest != expected_checksum { + return Err(DownloadError::ChecksumMismatch { + path: task.rel_path.clone(), + expected: expected_checksum, + actual: digest, + }); + } + + // if task.dest.exists() { + // fs::remove_file(&task.dest)?; + // } + + // fs::rename(tmp, &task.dest)?; + + Ok(()) +} + +// /// Generate a temporary file path for downloads. +// fn temp_path(dest: &Path) -> Result { +// let file_name = dest +// .file_name() +// .ok_or_else(|| RegistryError::InvalidPath("missing file name".to_string()))? +// .to_string_lossy(); +// Ok(dest.with_file_name(format!(".{file_name}.partial"))) +// } diff --git a/crates/burn-central-artifact/src/lib.rs b/crates/burn-central-artifact/src/lib.rs new file mode 100644 index 0000000..fac1aa3 --- /dev/null +++ b/crates/burn-central-artifact/src/lib.rs @@ -0,0 +1,8 @@ +//! This crate centralizes traits, structures and utilities for handling artifacts and models in Burn Central. + +mod artifact_download; +mod download; +mod tools; + +pub use artifact_download::{ArtifactDownloadFile, download_artifacts_to_dir}; +pub use download::DownloadError; diff --git a/crates/burn-central-artifact/src/tools/mod.rs b/crates/burn-central-artifact/src/tools/mod.rs new file mode 100644 index 0000000..5fed2d8 --- /dev/null +++ b/crates/burn-central-artifact/src/tools/mod.rs @@ -0,0 +1,2 @@ +pub mod path; +pub mod validation; diff --git a/crates/burn-central-artifact/src/tools/path.rs b/crates/burn-central-artifact/src/tools/path.rs new file mode 100644 index 0000000..0a2661b --- /dev/null +++ b/crates/burn-central-artifact/src/tools/path.rs @@ -0,0 +1,28 @@ +use std::path::{Path, PathBuf}; + +use burn_central_core::bundle::normalize_bundle_path; + +/// Sanitize a relative path to prevent directory traversal attacks. +pub fn sanitize_rel_path(path: &str) -> Result { + let normalized = normalize_bundle_path(path); + let rel = Path::new(&normalized); + for component in rel.components() { + use std::path::Component; + match component { + Component::ParentDir | Component::RootDir | Component::Prefix(_) => { + return Err(format!("invalid path component: {path}")); + } + Component::CurDir => { + return Err(format!("invalid path component: {path}")); + } + Component::Normal(_) => {} + } + } + Ok(PathBuf::from(normalized)) +} + +/// Safely join a root path with a relative path. +pub fn safe_join(root: &Path, rel: &str) -> Result { + let rel = sanitize_rel_path(rel)?; + Ok(root.join(rel)) +} diff --git a/crates/burn-central-artifact/src/tools/validation.rs b/crates/burn-central-artifact/src/tools/validation.rs new file mode 100644 index 0000000..0795ea9 --- /dev/null +++ b/crates/burn-central-artifact/src/tools/validation.rs @@ -0,0 +1,15 @@ +/// Normalize a checksum string (strip prefixes, lowercase). +pub fn normalize_checksum(value: &str) -> Result { + let trimmed = value.trim(); + if trimmed.is_empty() { + return Err("checksum is empty".to_string()); + } + let lower = trimmed.to_ascii_lowercase(); + if let Some(rest) = lower.strip_prefix("sha256:") { + return Ok(rest.to_string()); + } + if lower.contains(':') { + return Err(format!("unsupported checksum format: {trimmed}")); + } + Ok(lower) +} diff --git a/crates/burn-central-fleet/Cargo.toml b/crates/burn-central-fleet/Cargo.toml index ba18379..90df3c5 100644 --- a/crates/burn-central-fleet/Cargo.toml +++ b/crates/burn-central-fleet/Cargo.toml @@ -15,7 +15,7 @@ rust-version.workspace = true burn.workspace = true burn-central-client.workspace = true burn-central-inference.workspace = true -burn-central-registry.workspace = true +burn-central-artifact.workspace = true thiserror.workspace = true serde.workspace = true serde_json.workspace = true diff --git a/crates/burn-central-fleet/src/model.rs b/crates/burn-central-fleet/src/model.rs index a35406e..b7e409a 100644 --- a/crates/burn-central-fleet/src/model.rs +++ b/crates/burn-central-fleet/src/model.rs @@ -3,8 +3,8 @@ use std::io; use std::path::Path; use std::path::PathBuf; +use burn_central_artifact::{ArtifactDownloadFile, DownloadError, download_artifacts_to_dir}; use burn_central_client::fleet::response::FleetModelDownloadResponse; -use burn_central_registry::{ArtifactDownloadFile, download_artifacts_to_dir}; use serde::{Deserialize, Serialize}; #[derive(Debug, thiserror::Error)] @@ -20,7 +20,7 @@ pub enum ModelCacheError { #[error("cached model file missing: {0}")] MissingCachedFile(String), #[error(transparent)] - Registry(#[from] burn_central_registry::RegistryError), + Registry(#[from] DownloadError), } /// Source information for loading an assigned model. diff --git a/crates/burn-central-fleet/src/session.rs b/crates/burn-central-fleet/src/session.rs index cfa6650..dc9fa13 100644 --- a/crates/burn-central-fleet/src/session.rs +++ b/crates/burn-central-fleet/src/session.rs @@ -183,10 +183,9 @@ impl FleetDeviceSession { self.state .set_auth_token(auth_response.access_token, auth_response.expires_in_seconds); - let mut telemetry_auth_token = self - .telemetry_auth_token - .write() - .map_err(|_| FleetError::SyncFailed("telemetry auth token lock poisoned".to_string()))?; + let mut telemetry_auth_token = self.telemetry_auth_token.write().map_err(|_| { + FleetError::SyncFailed("telemetry auth token lock poisoned".to_string()) + })?; *telemetry_auth_token = self.state.auth_token().map(|auth| auth.token().to_string()); Ok(()) diff --git a/crates/burn-central-fleet/src/telemetry/pipeline/mod.rs b/crates/burn-central-fleet/src/telemetry/pipeline/mod.rs index 2fb5272..e6399c5 100644 --- a/crates/burn-central-fleet/src/telemetry/pipeline/mod.rs +++ b/crates/burn-central-fleet/src/telemetry/pipeline/mod.rs @@ -4,9 +4,8 @@ use crossbeam_queue::SegQueue; use std::{ path::{Path, PathBuf}, sync::{ - Arc, + Arc, RwLock, atomic::{AtomicUsize, Ordering}, - RwLock, }, time::Duration, }; @@ -145,7 +144,9 @@ impl TelemetryPipeline { let shipper_handle = shipper::start( outbox, - Arc::new(shipper::BurnCentralFleetShipperTransport::new(auth_token, client)), + Arc::new(shipper::BurnCentralFleetShipperTransport::new( + auth_token, client, + )), Duration::from_secs(5), ); diff --git a/crates/burn-central-registry/Cargo.toml b/crates/burn-central-registry/Cargo.toml deleted file mode 100644 index 4b59a07..0000000 --- a/crates/burn-central-registry/Cargo.toml +++ /dev/null @@ -1,24 +0,0 @@ -[package] -name = "burn-central-registry" -authors.workspace = true -categories.workspace = true -description = "Burn Central local registry/cache for model artifacts." -keywords.workspace = true -repository.workspace = true -edition.workspace = true -license.workspace = true -readme.workspace = true -version.workspace = true -rust-version.workspace = true - -[dependencies] -burn-central-client.workspace = true -burn-central-core.workspace = true -serde.workspace = true -serde_json.workspace = true -thiserror.workspace = true -sha2.workspace = true -crossbeam.workspace = true -url.workspace = true -reqwest = { version = "0.13.2", features = ["blocking"] } -directories = "6.0.0" diff --git a/crates/burn-central-registry/src/artifact_download.rs b/crates/burn-central-registry/src/artifact_download.rs deleted file mode 100644 index 2279f7d..0000000 --- a/crates/burn-central-registry/src/artifact_download.rs +++ /dev/null @@ -1,52 +0,0 @@ -use std::path::Path; - -use burn_central_core::bundle::normalize_bundle_path; - -use crate::RegistryError; -use crate::cache::safe_join; -use crate::download::{DownloadTask, download_tasks}; -use crate::manifest::ManifestFile; - -/// Generic download descriptor for any model artifact file. -#[derive(Debug, Clone)] -pub struct ArtifactDownloadFile { - pub rel_path: String, - pub url: String, - pub size_bytes: u64, - pub checksum: String, -} - -/// Download artifact files into a destination directory, validating size and checksum. -pub fn download_artifacts_to_dir( - dest_root: &Path, - files: &[ArtifactDownloadFile], -) -> Result<(), RegistryError> { - std::fs::create_dir_all(dest_root)?; - - if files.is_empty() { - return Ok(()); - } - - let mut tasks = Vec::with_capacity(files.len()); - for file in files { - let rel_path = normalize_bundle_path(&file.rel_path); - let dest = safe_join(dest_root, &rel_path)?; - - tasks.push(DownloadTask { - rel_path: rel_path.clone(), - url: file.url.clone(), - dest, - expected: ManifestFile { - rel_path, - size_bytes: file.size_bytes, - checksum: file.checksum.clone(), - }, - }); - } - - let parallelism = std::thread::available_parallelism() - .map(|n| n.get()) - .unwrap_or(4); - let http = reqwest::blocking::Client::new(); - download_tasks(&http, tasks, parallelism) -} diff --git a/crates/burn-central-registry/src/builder.rs b/crates/burn-central-registry/src/builder.rs deleted file mode 100644 index 26dc43c..0000000 --- a/crates/burn-central-registry/src/builder.rs +++ /dev/null @@ -1,63 +0,0 @@ -use std::path::PathBuf; - -use burn_central_client::{BurnCentralCredentials, Client, Env}; -use directories::{BaseDirs, ProjectDirs}; - -use crate::error::RegistryError; -use crate::registry::Registry; - -/// Builder for a registry client. -#[derive(Debug, Clone)] -pub struct RegistryBuilder { - credentials: BurnCentralCredentials, - env: Option, - cache_dir: Option, -} - -impl RegistryBuilder { - /// Create a new registry builder with the given credentials. - pub fn new(credentials: impl Into) -> Self { - Self { - credentials: credentials.into(), - env: None, - cache_dir: None, - } - } - - /// Use a specific environment (production by default). - pub fn with_env(mut self, env: Env) -> Self { - self.env = Some(env); - self - } - - /// Build the registry client. - pub fn build(self) -> Result { - let client = { - let env = self.env.unwrap_or(Env::Production); - Client::new(env, &self.credentials)? - }; - - let cache_dir = match self.cache_dir { - Some(dir) => dir, - None => default_cache_dir(client.get_env())?, - }; - - Ok(Registry::new(client, cache_dir)) - } -} - -/// Get the default cache directory for a given environment. -fn default_cache_dir(env: &Env) -> Result { - let registry_subdir = match env { - Env::Production => "registry".to_string(), - Env::Staging(version) => format!("registry-staging-{}", version), - Env::Development => "registry-dev".to_string(), - }; - if let Some(project) = ProjectDirs::from("ai", "tracel", "burn-central") { - return Ok(project.cache_dir().join(registry_subdir)); - } - if let Some(base) = BaseDirs::new() { - return Ok(base.cache_dir().join("burn-central").join(registry_subdir)); - } - Err(RegistryError::CacheDirUnavailable) -} diff --git a/crates/burn-central-registry/src/cache.rs b/crates/burn-central-registry/src/cache.rs deleted file mode 100644 index 566fcc5..0000000 --- a/crates/burn-central-registry/src/cache.rs +++ /dev/null @@ -1,126 +0,0 @@ -use std::fs::{self, File}; -use std::io::Read; -use std::path::{Path, PathBuf}; - -use burn_central_core::bundle::normalize_bundle_path; -use sha2::Digest; - -use crate::error::RegistryError; -use crate::manifest::{ManifestFile, ModelManifest}; - -/// Check if a cached model version is valid. -pub fn cache_is_valid( - root: &Path, - manifest: &ModelManifest, - verify_checksums: bool, -) -> Result { - for file in &manifest.files { - let path = safe_join(root, &file.rel_path)?; - if !path.exists() { - return Ok(false); - } - let metadata = fs::metadata(&path)?; - if metadata.len() != file.size_bytes { - return Ok(false); - } - if verify_checksums { - let (digest, _) = sha256_file(&path)?; - let expected = normalize_checksum(&file.checksum)?; - if digest != expected { - return Ok(false); - } - } - } - - Ok(true) -} - -/// Check if a single file is valid according to its manifest entry. -pub fn file_is_valid( - path: &Path, - expected: &ManifestFile, - verify_checksums: bool, -) -> Result { - if !path.exists() { - return Ok(false); - } - let metadata = fs::metadata(path)?; - if metadata.len() != expected.size_bytes { - return Ok(false); - } - if verify_checksums { - let (digest, _) = sha256_file(path)?; - let expected_checksum = normalize_checksum(&expected.checksum)?; - if digest != expected_checksum { - return Ok(false); - } - } - Ok(true) -} - -/// Compute SHA256 checksum of a file. -pub fn sha256_file(path: &Path) -> Result<(String, u64), RegistryError> { - let mut file = File::open(path)?; - let mut hasher = sha2::Sha256::new(); - let mut buf = [0u8; 1024 * 64]; - let mut total = 0u64; - loop { - let read = file.read(&mut buf)?; - if read == 0 { - break; - } - hasher.update(&buf[..read]); - total += read as u64; - } - let digest = format!("{:x}", hasher.finalize()); - Ok((digest, total)) -} - -/// Normalize a checksum string (strip prefixes, lowercase). -pub fn normalize_checksum(value: &str) -> Result { - let trimmed = value.trim(); - if trimmed.is_empty() { - return Err(RegistryError::InvalidManifest( - "checksum is empty".to_string(), - )); - } - let lower = trimmed.to_ascii_lowercase(); - if let Some(rest) = lower.strip_prefix("sha256:") { - return Ok(rest.to_string()); - } - if lower.contains(':') { - return Err(RegistryError::InvalidManifest(format!( - "unsupported checksum format: {trimmed}" - ))); - } - Ok(lower) -} - -/// Sanitize a relative path to prevent directory traversal attacks. -pub fn sanitize_rel_path(path: &str) -> Result { - let normalized = normalize_bundle_path(path); - let rel = Path::new(&normalized); - for component in rel.components() { - use std::path::Component; - match component { - Component::ParentDir | Component::RootDir | Component::Prefix(_) => { - return Err(RegistryError::InvalidPath(format!( - "invalid path component: {path}" - ))); - } - Component::CurDir => { - return Err(RegistryError::InvalidPath(format!( - "invalid path component: {path}" - ))); - } - Component::Normal(_) => {} - } - } - Ok(PathBuf::from(normalized)) -} - -/// Safely join a root path with a relative path. -pub fn safe_join(root: &Path, rel: &str) -> Result { - let rel = sanitize_rel_path(rel)?; - Ok(root.join(rel)) -} diff --git a/crates/burn-central-registry/src/diagnostics.rs b/crates/burn-central-registry/src/diagnostics.rs deleted file mode 100644 index 6abd852..0000000 --- a/crates/burn-central-registry/src/diagnostics.rs +++ /dev/null @@ -1,25 +0,0 @@ -use std::collections::BTreeMap; - -use crate::cache::{safe_join, sha256_file}; -use crate::error::RegistryError; -use crate::model::CachedModel; - -/// Helper to check file hashes in the cache (useful for debugging). -#[derive(Debug, Clone)] -pub struct CacheDiagnostics { - /// Files and their computed checksums. - pub files: BTreeMap, -} - -impl CacheDiagnostics { - /// Compute checksums for a cached model version. - pub fn from_cached(model: &CachedModel) -> Result { - let mut files = BTreeMap::new(); - for file in &model.manifest().files { - let path = safe_join(model.path(), &file.rel_path)?; - let (digest, _) = sha256_file(&path)?; - files.insert(file.rel_path.clone(), digest); - } - Ok(Self { files }) - } -} diff --git a/crates/burn-central-registry/src/download.rs b/crates/burn-central-registry/src/download.rs deleted file mode 100644 index 92b0a14..0000000 --- a/crates/burn-central-registry/src/download.rs +++ /dev/null @@ -1,140 +0,0 @@ -use std::fs::{self, File}; -use std::io::{Read, Write}; -use std::path::{Path, PathBuf}; - -use crossbeam::channel; -use reqwest::blocking::Client as HttpClient; -use sha2::Digest; - -use crate::cache::normalize_checksum; -use crate::error::RegistryError; -use crate::manifest::ManifestFile; - -/// A single file download task. -#[derive(Clone)] -pub struct DownloadTask { - pub rel_path: String, - pub url: String, - pub dest: PathBuf, - pub expected: ManifestFile, -} - -/// Download multiple files in parallel. -pub fn download_tasks( - http: &HttpClient, - tasks: Vec, - max_parallel: usize, -) -> Result<(), RegistryError> { - if tasks.is_empty() { - return Ok(()); - } - - if max_parallel <= 1 || tasks.len() == 1 { - for task in tasks { - download_one(http, &task)?; - } - return Ok(()); - } - - let (tx, rx) = channel::unbounded::(); - for task in tasks { - tx.send(task).expect("channel open"); - } - drop(tx); - - crossbeam::scope(|scope| { - let mut handles = Vec::new(); - let worker_count = max_parallel.min(rx.len().max(1)); - for _ in 0..worker_count { - let rx = rx.clone(); - let http = http.clone(); - handles.push(scope.spawn(move |_| { - for task in rx.iter() { - download_one(&http, &task)?; - } - Ok::<(), RegistryError>(()) - })); - } - - for handle in handles { - handle.join().expect("thread panicked")?; - } - - Ok(()) - }) - .expect("scope failed") -} - -/// Download a single file with checksum verification. -fn download_one(http: &HttpClient, task: &DownloadTask) -> Result<(), RegistryError> { - if let Some(parent) = task.dest.parent() { - fs::create_dir_all(parent)?; - } - - let tmp = temp_path(&task.dest)?; - - let mut resp = http - .get(&task.url) - .send() - .map_err(|e| RegistryError::DownloadFailed { - path: task.rel_path.clone(), - details: e.to_string(), - })?; - - if !resp.status().is_success() { - return Err(RegistryError::DownloadFailed { - path: task.rel_path.clone(), - details: format!("HTTP {}", resp.status()), - }); - } - - let mut file = File::create(&tmp)?; - let mut hasher = sha2::Sha256::new(); - let mut buf = [0u8; 1024 * 64]; - let mut total = 0u64; - - loop { - let read = resp.read(&mut buf)?; - if read == 0 { - break; - } - file.write_all(&buf[..read])?; - hasher.update(&buf[..read]); - total += read as u64; - } - - let digest = format!("{:x}", hasher.finalize()); - let expected_checksum = normalize_checksum(&task.expected.checksum)?; - - if total != task.expected.size_bytes { - return Err(RegistryError::SizeMismatch { - path: task.rel_path.clone(), - expected: task.expected.size_bytes, - actual: total, - }); - } - if digest != expected_checksum { - return Err(RegistryError::ChecksumMismatch { - path: task.rel_path.clone(), - expected: expected_checksum, - actual: digest, - }); - } - - if task.dest.exists() { - fs::remove_file(&task.dest)?; - } - - fs::rename(tmp, &task.dest)?; - - Ok(()) -} - -/// Generate a temporary file path for downloads. -fn temp_path(dest: &Path) -> Result { - let file_name = dest - .file_name() - .ok_or_else(|| RegistryError::InvalidPath("missing file name".to_string()))? - .to_string_lossy(); - Ok(dest.with_file_name(format!(".{file_name}.partial"))) -} diff --git a/crates/burn-central-registry/src/error.rs b/crates/burn-central-registry/src/error.rs deleted file mode 100644 index 43b2b01..0000000 --- a/crates/burn-central-registry/src/error.rs +++ /dev/null @@ -1,44 +0,0 @@ -use burn_central_client::ClientError; - -/// Errors returned by the registry. -#[derive(Debug, thiserror::Error)] -pub enum RegistryError { - /// Errors returned by the Burn Central HTTP client. - #[error("Client error: {0}")] - Client(#[from] ClientError), - /// IO errors. - #[error("IO error: {0}")] - Io(#[from] std::io::Error), - /// HTTP download error. - #[error("Download failed for {path}: {details}")] - DownloadFailed { path: String, details: String }, - /// Manifest is missing or invalid. - #[error("Invalid manifest: {0}")] - InvalidManifest(String), - /// Cache directory could not be resolved. - #[error("Cache directory unavailable")] - CacheDirUnavailable, - /// Invalid path. - #[error("Invalid path: {0}")] - InvalidPath(String), - /// A file is missing from the cache. - #[error("Missing file: {0}")] - MissingFile(String), - /// A file checksum does not match the manifest. - #[error("Checksum mismatch for {path}: expected {expected}, got {actual}")] - ChecksumMismatch { - path: String, - expected: String, - actual: String, - }, - /// A file size does not match the manifest. - #[error("Size mismatch for {path}: expected {expected} bytes, got {actual} bytes")] - SizeMismatch { - path: String, - expected: u64, - actual: u64, - }, - /// Error while decoding a bundle from the cache. - #[error("Decode error: {0}")] - Decode(String), -} diff --git a/crates/burn-central-registry/src/lib.rs b/crates/burn-central-registry/src/lib.rs deleted file mode 100644 index 1894de4..0000000 --- a/crates/burn-central-registry/src/lib.rs +++ /dev/null @@ -1,46 +0,0 @@ -//! Local registry/cache for Burn Central model artifacts. -//! -//! This crate provides a client for downloading and caching model artifacts from Burn Central. -//! It handles parallel downloads, checksum verification, and local caching of model files. -//! -//! # Example -//! -//! ```no_run -//! use burn_central_registry::{RegistryBuilder, CachedModel}; -//! use burn_central_client::BurnCentralCredentials; -//! -//! # fn example() -> Result<(), Box> { -//! // Create a registry client -//! let credentials = BurnCentralCredentials::from_env()?; -//! let registry = RegistryBuilder::new(credentials).build()?; -//! -//! // Get a model handle -//! let model = registry.model("namespace", "project", "model")?; -//! -//! // Ensure the model is cached locally -//! let cached = model.ensure(1)?; -//! -//! // Access the cached model files -//! let path = cached.path(); -//! # Ok(()) -//! # } -//! ``` - -mod artifact_download; -mod builder; -mod cache; -mod diagnostics; -mod download; -mod error; -mod manifest; -mod model; -mod registry; - -// Public API exports -pub use artifact_download::{ArtifactDownloadFile, download_artifacts_to_dir}; -pub use builder::RegistryBuilder; -pub use diagnostics::CacheDiagnostics; -pub use error::RegistryError; -pub use manifest::{ManifestFile, ModelManifest}; -pub use model::{CachedModel, ModelHandle, ModelRef, ModelVersion, ModelVersionSelector}; -pub use registry::Registry; diff --git a/crates/burn-central-registry/src/manifest.rs b/crates/burn-central-registry/src/manifest.rs deleted file mode 100644 index dd5fd69..0000000 --- a/crates/burn-central-registry/src/manifest.rs +++ /dev/null @@ -1,89 +0,0 @@ -use std::collections::{HashMap, HashSet}; -use std::fs::{self, File}; -use std::io::Write; -use std::path::Path; - -use burn_central_core::bundle::normalize_bundle_path; -use serde::{Deserialize, Serialize}; - -use crate::cache::sanitize_rel_path; -use crate::error::RegistryError; - -/// Model version manifest (subset of backend schema). -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ModelManifest { - /// Files stored in the bundle. - pub files: Vec, -} - -/// Model version file descriptor. -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ManifestFile { - /// Path within the bundle. - pub rel_path: String, - /// Size in bytes. - pub size_bytes: u64, - /// Checksum (sha256). - pub checksum: String, -} - -const MANIFEST_FILE: &str = "manifest.json"; - -/// Load manifest from a cached version directory. -pub fn load_manifest(version_dir: &Path) -> Result { - let path = version_dir.join(MANIFEST_FILE); - let bytes = fs::read(path)?; - serde_json::from_slice::(&bytes) - .map_err(|e| RegistryError::InvalidManifest(e.to_string())) -} - -/// Write manifest to a cached version directory. -pub fn write_manifest(version_dir: &Path, manifest: &ModelManifest) -> Result<(), RegistryError> { - let path = version_dir.join(MANIFEST_FILE); - let mut file = File::create(path)?; - let bytes = serde_json::to_vec_pretty(manifest) - .map_err(|e| RegistryError::InvalidManifest(e.to_string()))?; - file.write_all(&bytes)?; - Ok(()) -} - -/// Parse and validate a manifest from JSON. -pub fn parse_manifest(value: serde_json::Value) -> Result { - let mut manifest: ModelManifest = - serde_json::from_value(value).map_err(|e| RegistryError::InvalidManifest(e.to_string()))?; - - let mut seen = HashSet::new(); - for file in &mut manifest.files { - file.rel_path = normalize_bundle_path(&file.rel_path); - sanitize_rel_path(&file.rel_path)?; - if file.rel_path.is_empty() { - return Err(RegistryError::InvalidManifest( - "manifest file path is empty".to_string(), - )); - } - if !seen.insert(file.rel_path.clone()) { - return Err(RegistryError::InvalidManifest(format!( - "duplicate file path in manifest: {}", - file.rel_path - ))); - } - } - - Ok(manifest) -} - -/// Create a hashmap of manifest files by relative path. -pub fn manifest_map( - manifest: &ModelManifest, -) -> Result, RegistryError> { - let mut map = HashMap::new(); - for file in &manifest.files { - if map.insert(file.rel_path.clone(), file.clone()).is_some() { - return Err(RegistryError::InvalidManifest(format!( - "duplicate file path in manifest: {}", - file.rel_path - ))); - } - } - Ok(map) -} diff --git a/crates/burn-central-registry/src/model.rs b/crates/burn-central-registry/src/model.rs deleted file mode 100644 index 969733f..0000000 --- a/crates/burn-central-registry/src/model.rs +++ /dev/null @@ -1,250 +0,0 @@ -use std::fs; -use std::path::{Path, PathBuf}; - -use burn_central_core::bundle::{BundleDecode, FsBundleReader, normalize_bundle_path}; - -use crate::cache::{cache_is_valid, file_is_valid, safe_join}; -use crate::download::{DownloadTask, download_tasks}; -use crate::error::RegistryError; -use crate::manifest::{ModelManifest, load_manifest, manifest_map, parse_manifest, write_manifest}; -use crate::registry::Registry; - -/// Selector for which model version to load. -#[derive(Debug, Clone, PartialEq, Eq, Hash, Default)] -pub enum ModelVersionSelector { - #[default] - Latest, - Version(u64), -} - -/// Type alias for model version numbers. -pub type ModelVersion = u64; - -impl From for ModelVersionSelector { - fn from(value: u32) -> Self { - ModelVersionSelector::Version(value as u64) - } -} - -impl From for ModelVersionSelector { - fn from(value: u64) -> Self { - ModelVersionSelector::Version(value) - } -} - -/// Reference to a model in the registry. -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub struct ModelRef { - namespace: String, - project: String, - model: String, -} - -impl ModelRef { - pub fn new( - namespace: impl Into, - project: impl Into, - model: impl Into, - ) -> Result { - let namespace = namespace.into(); - let project = project.into(); - let model = model.into(); - validate_path_component(&namespace, "namespace")?; - validate_path_component(&project, "project")?; - validate_path_component(&model, "model")?; - Ok(Self { - namespace, - project, - model, - }) - } - - fn version_dir(&self, root: &Path, version: ModelVersion) -> PathBuf { - root.join(&self.namespace) - .join(&self.project) - .join("models") - .join(&self.model) - .join("versions") - .join(version.to_string()) - } - - pub fn namespace(&self) -> &str { - &self.namespace - } - - pub fn project(&self) -> &str { - &self.project - } - - pub fn model(&self) -> &str { - &self.model - } -} - -/// Validate that a path component doesn't contain directory traversal characters. -fn validate_path_component(value: &str, label: &str) -> Result<(), RegistryError> { - if value.is_empty() { - return Err(RegistryError::InvalidPath(format!( - "{label} must not be empty" - ))); - } - let path = Path::new(value); - for component in path.components() { - use std::path::Component; - match component { - Component::ParentDir | Component::RootDir | Component::Prefix(_) => { - return Err(RegistryError::InvalidPath(format!( - "{label} contains invalid path segments" - ))); - } - Component::Normal(_) => {} - Component::CurDir => { - return Err(RegistryError::InvalidPath(format!( - "{label} contains invalid path segments" - ))); - } - } - } - Ok(()) -} - -/// Handle for downloading and caching a specific model. -#[derive(Clone)] -pub struct ModelHandle { - registry: Registry, - model: ModelRef, -} - -impl ModelHandle { - pub fn new(registry: Registry, model: ModelRef) -> Self { - Self { registry, model } - } - - /// Ensure a model version is cached locally. - pub fn ensure( - &self, - version: impl Into, - ) -> Result { - let version = match version.into() { - ModelVersionSelector::Latest => { - let info = self.registry.client().get_model( - self.model.namespace(), - self.model.project(), - self.model.model(), - )?; - info.version_count - 1 - } - ModelVersionSelector::Version(v) => v, - }; - let version_dir = self.model.version_dir(self.registry.cache_dir(), version); - - if let Ok(manifest) = load_manifest(&version_dir) { - if cache_is_valid(&version_dir, &manifest, true)? { - return Ok(CachedModel::new(version_dir, manifest)); - } - } - - self.download_version(version, &version_dir) - } - - /// Download and decode a model version using the bundle decoder. - pub fn load( - &self, - version: impl Into, - settings: &T::Settings, - ) -> Result { - let cached = self.ensure(version)?; - T::decode(&cached.reader(), settings) - .map_err(|e| RegistryError::Decode(e.into().to_string())) - } - - fn download_version( - &self, - version: ModelVersion, - version_dir: &Path, - ) -> Result { - fs::create_dir_all(version_dir)?; - - let version_info = self.registry.client().get_model_version( - self.model.namespace(), - self.model.project(), - self.model.model(), - version as _, - )?; - - let manifest = parse_manifest(version_info.manifest)?; - let download = self.registry.client().presign_model_download( - self.model.namespace(), - self.model.project(), - self.model.model(), - version as _, - )?; - - let manifest_map = manifest_map(&manifest)?; - let mut tasks = Vec::new(); - - for file in download.files { - let rel_path = normalize_bundle_path(&file.rel_path); - let expected = manifest_map.get(&rel_path).ok_or_else(|| { - RegistryError::InvalidManifest(format!( - "download file {rel_path} missing from manifest" - )) - })?; - - let dest = safe_join(version_dir, &rel_path)?; - let needs_download = !file_is_valid(&dest, expected, true)?; - - if needs_download { - tasks.push(DownloadTask { - rel_path, - url: file.url, - dest, - expected: expected.clone(), - }); - } - } - - let parallelism = std::thread::available_parallelism() - .map(|n| n.get()) - .unwrap_or(4); - download_tasks(self.registry.http(), tasks, parallelism)?; - - write_manifest(version_dir, &manifest)?; - Ok(CachedModel::new(version_dir.to_path_buf(), manifest)) - } -} - -/// Cached model version stored on disk. -#[derive(Debug, Clone)] -pub struct CachedModel { - root: PathBuf, - manifest: ModelManifest, -} - -impl CachedModel { - pub fn new(root: PathBuf, manifest: ModelManifest) -> Self { - Self { root, manifest } - } - - /// Root directory of the cached model version. - pub fn path(&self) -> &Path { - &self.root - } - - /// Manifest for the cached model version. - pub fn manifest(&self) -> &ModelManifest { - &self.manifest - } - - /// Build a file-backed bundle reader. - pub fn reader(&self) -> FsBundleReader { - FsBundleReader::new( - self.root.clone(), - self.manifest - .files - .iter() - .map(|f| f.rel_path.clone()) - .collect(), - ) - } -} diff --git a/crates/burn-central-registry/src/registry.rs b/crates/burn-central-registry/src/registry.rs deleted file mode 100644 index a3904aa..0000000 --- a/crates/burn-central-registry/src/registry.rs +++ /dev/null @@ -1,49 +0,0 @@ -use std::path::{Path, PathBuf}; - -use burn_central_client::Client; -use reqwest::blocking::Client as HttpClient; - -use crate::error::RegistryError; -use crate::model::{ModelHandle, ModelRef}; - -/// Registry client for downloading and caching model artifacts. -#[derive(Clone)] -pub struct Registry { - client: Client, - http: HttpClient, - cache_dir: PathBuf, -} - -impl Registry { - /// Create a registry client from a Burn Central HTTP client and config. - pub fn new(client: Client, cache_dir: PathBuf) -> Self { - Self { - client, - http: HttpClient::new(), - cache_dir, - } - } - - /// Create a model handle scoped to a project. - pub fn model( - &self, - namespace: impl Into, - project: impl Into, - model: impl Into, - ) -> Result { - let model = ModelRef::new(namespace, project, model)?; - Ok(ModelHandle::new(self.clone(), model)) - } - - pub(crate) fn client(&self) -> &Client { - &self.client - } - - pub(crate) fn http(&self) -> &HttpClient { - &self.http - } - - pub(crate) fn cache_dir(&self) -> &Path { - &self.cache_dir - } -} diff --git a/crates/burn-central/Cargo.toml b/crates/burn-central/Cargo.toml index e0d68aa..cb41feb 100644 --- a/crates/burn-central/Cargo.toml +++ b/crates/burn-central/Cargo.toml @@ -15,6 +15,5 @@ rust-version.workspace = true burn-central-core.workspace = true burn-central-macros.workspace = true burn-central-runtime.workspace = true -burn-central-registry.workspace = true burn-central-inference.workspace = true burn-central-fleet.workspace = true diff --git a/crates/burn-central/src/lib.rs b/crates/burn-central/src/lib.rs index 82e35ea..6f38b50 100644 --- a/crates/burn-central/src/lib.rs +++ b/crates/burn-central/src/lib.rs @@ -64,7 +64,3 @@ pub use burn_central_inference as inference; /// On-device fleet synchronization helpers. #[doc(inline)] pub use burn_central_fleet as fleet; - -/// Local registry/cache helpers for downloading models. -#[doc(inline)] -pub use burn_central_registry as registry;