diff --git a/Cargo.lock b/Cargo.lock index b7c43af2..c168d8f1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -353,12 +353,12 @@ dependencies = [ name = "burn-central-artifact" version = "0.5.0" dependencies = [ - "burn-central-core", "crossbeam", "reqwest", "serde", "serde_json", "sha2", + "tempfile", "thiserror 2.0.18", ] @@ -383,14 +383,13 @@ name = "burn-central-core" version = "0.5.0" dependencies = [ "burn", + "burn-central-artifact", "burn-central-client", "crossbeam", - "reqwest", "serde", "serde_json", "sha2", "strum", - "tempfile", "thiserror 2.0.18", "tracing", "tracing-core", @@ -449,6 +448,7 @@ version = "0.5.0" dependencies = [ "anyhow", "burn", + "burn-central-artifact", "burn-central-core", "burn-central-fleet", "burn-central-inference", @@ -2536,9 +2536,9 @@ checksum = "2c4a545a15244c7d945065b5d392b2d2d7f21526fba56ce51467b06ed445e8f7" [[package]] name = "libc" -version = "0.2.180" +version = "0.2.182" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bcc35a38544a891a5f7c865aca548a982ccb3b8650a5b06d0fd33a10283c56fc" +checksum = "6800badb6cb2082ffd7b6a67e6125bb39f18782f793520caee8cb8846be06112" [[package]] name = "libloading" @@ -2568,9 +2568,9 @@ dependencies = [ [[package]] name = "linux-raw-sys" -version = "0.11.0" +version = "0.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "df1d3c3b53da64cf5760482273a98e575c651a67eec7f77df96b5b642de8f039" +checksum = "32a66949e030da00e8c7d4434b251670a91556f4144941d37452769c25d58a53" [[package]] name = "litemap" @@ -3667,9 +3667,9 @@ dependencies = [ [[package]] name = "rustix" -version = "1.1.3" +version = "1.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "146c9e247ccc180c1f61615433868c99f3de3ae256a30a43b49f67c2d9171f34" +checksum = "b6fe4565b9518b83ef4f91bb47ce29620ca828bd32cb7e408f0062e9930ba190" dependencies = [ "bitflags", "errno", @@ -4142,9 +4142,9 @@ dependencies = [ [[package]] name = "tempfile" -version = "3.24.0" +version = "3.26.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "655da9c7eb6305c55742045d5a8d2037996d61d8de95806335c7c86ce0f82e9c" +checksum = "82a72c767771b47409d2345987fda8628641887d5466101319899796367354a0" dependencies = [ "fastrand", "getrandom 0.3.4", diff --git a/crates/burn-central-artifact/Cargo.toml b/crates/burn-central-artifact/Cargo.toml index cdd29db2..26d76073 100644 --- a/crates/burn-central-artifact/Cargo.toml +++ b/crates/burn-central-artifact/Cargo.toml @@ -16,5 +16,5 @@ reqwest = { version = "0.13.2", features = ["blocking"] } serde = { workspace = true, features = ["derive"] } sha2 = { workspace = true } serde_json = { workspace = true } -burn-central-core = { workspace = true } thiserror = { workspace = true } +tempfile = "3.26.0" diff --git a/crates/burn-central-artifact/src/artifact_download.rs b/crates/burn-central-artifact/src/artifact_download.rs deleted file mode 100644 index 942f41f5..00000000 --- a/crates/burn-central-artifact/src/artifact_download.rs +++ /dev/null @@ -1,92 +0,0 @@ -use std::fs::{self, File}; -use std::io::BufWriter; -use std::path::{Path, PathBuf}; - -use burn_central_core::bundle::normalize_bundle_path; - -use crate::download::{DownloadError, DownloadTask, download_tasks}; -use crate::tools::path::safe_join; - -/// Generic download descriptor for any model artifact file. -#[derive(Debug, Clone)] -pub struct ArtifactDownloadFile { - pub rel_path: String, - pub url: String, - pub size_bytes: u64, - pub checksum: String, -} - -/// Download artifact files into a destination directory, validating size and checksum. -pub fn download_artifacts_to_dir( - dest_root: &Path, - files: &[ArtifactDownloadFile], -) -> Result<(), DownloadError> { - fs::create_dir_all(dest_root)?; - - if files.is_empty() { - return Ok(()); - } - - let mut staged_files = Vec::with_capacity(files.len()); - let mut tasks = Vec::with_capacity(files.len()); - for file in files { - let rel_path = normalize_bundle_path(&file.rel_path); - let dest = safe_join(dest_root, &rel_path) - .map_err(|e| DownloadError::InvalidPath(e.to_string()))?; - - if let Some(parent) = dest.parent() { - fs::create_dir_all(parent)?; - } - let tmp = temp_path(&dest)?; - if tmp.exists() { - fs::remove_file(&tmp)?; - } - staged_files.push((dest.clone(), tmp.clone(), rel_path.clone())); - - let tmp_file = File::create(tmp)?; - let writer = BufWriter::new(tmp_file); - - tasks.push(DownloadTask { - rel_path: rel_path.clone(), - url: file.url.clone(), - writer, - expected_size: file.size_bytes, - expected_checksum: file.checksum.clone(), - }); - } - - let parallelism = std::thread::available_parallelism() - .map(|n| n.get()) - .unwrap_or(4); - let http = reqwest::blocking::Client::new(); - if let Err(err) = download_tasks(&http, tasks, parallelism) { - for (_, tmp, _) in staged_files { - let _ = fs::remove_file(tmp); - } - return Err(err); - } - - for (dest, tmp, rel_path) in staged_files { - if !tmp.exists() { - return Err(DownloadError::DownloadFailed { - path: rel_path, - details: "temporary downloaded file is missing".to_string(), - }); - } - if dest.exists() { - fs::remove_file(&dest)?; - } - fs::rename(tmp, dest)?; - } - - Ok(()) -} - -/// Generate a temporary file path for downloads. -fn temp_path(dest: &Path) -> Result { - let file_name = dest - .file_name() - .ok_or_else(|| DownloadError::InvalidPath("missing file name".to_string()))? - .to_string_lossy(); - Ok(dest.with_file_name(format!(".{file_name}.partial"))) -} diff --git a/crates/burn-central-artifact/src/bundle/fs.rs b/crates/burn-central-artifact/src/bundle/fs.rs new file mode 100644 index 00000000..932ad7a0 --- /dev/null +++ b/crates/burn-central-artifact/src/bundle/fs.rs @@ -0,0 +1,284 @@ +//! File-backed bundle implementation for Burn Central artifacts. +//! +//! This module provides an implementation of the `BundleSink` and `BundleSource` traits that uses the local filesystem to store artifact files. +//! It supports both temporary bundles (which clean up after themselves) and persistent bundles rooted at a specified directory. +//! The implementation ensures that file paths are sanitized to prevent directory traversal, and that concurrent writes to the same path are handled safely using temporary files and atomic renames. + +use std::collections::HashSet; +use std::fs::{self, File}; +use std::io::{Read, Seek, SeekFrom, Write}; +use std::path::{Path, PathBuf}; + +use sha2::Digest; +use tempfile::TempDir; + +use crate::bundle::{BundleSink, BundleSource}; +use crate::tools::path::{safe_join, sanitize_rel_path}; +use crate::upload::{MultipartUploadSource, UploadError}; + +/// File-backed bundle that can both read and write artifact files. +#[derive(Debug)] +pub struct FsBundle { + root: PathBuf, + files: Vec, + seen: HashSet, + _temp: Option, +} + +impl FsBundle { + /// Create a writable bundle rooted at the provided directory. + pub fn create(root: impl Into) -> Result { + let root = root.into(); + fs::create_dir_all(&root)?; + Ok(Self { + root, + files: Vec::new(), + seen: HashSet::new(), + _temp: None, + }) + } + + /// Create a temporary writable bundle that cleans up on drop. + pub fn temp() -> Result { + let temp = TempDir::new()?; + let root = temp.path().to_path_buf(); + Ok(Self { + root, + files: Vec::new(), + seen: HashSet::new(), + _temp: Some(temp), + }) + } + + /// Create a read-oriented bundle backed by an existing root + file list. + pub fn with_files(root: PathBuf, files: Vec) -> Result { + let mut bundle = Self { + root, + files: Vec::new(), + seen: HashSet::new(), + _temp: None, + }; + for path in files { + bundle.register_file(&path)?; + } + Ok(bundle) + } + + /// Root directory for the bundle files. + pub fn root(&self) -> &Path { + &self.root + } + + /// Files written into the bundle. + pub fn files(&self) -> &[FsBundleFile] { + &self.files + } + + /// Relative file paths currently indexed by this bundle. + pub fn file_paths(&self) -> Vec { + self.files.iter().map(|f| f.rel_path.clone()).collect() + } + + /// Register an existing file path in the bundle, ensuring it is valid and not duplicated. + fn register_file(&mut self, path: &str) -> Result<(), String> { + let rel = sanitize_rel_path(path)?.to_string_lossy().to_string(); + if rel.is_empty() || !self.seen.insert(rel.clone()) { + return Err(format!("Duplicate bundle path: {rel}")); + } + + self.files.push(FsBundleFile { + rel_path: rel.clone(), + abs_path: self.root.join(&rel), + size_bytes: None, + checksum: None, + }); + + Ok(()) + } + + /// Delete all files in this bundle from the filesystem. This is idempotent and can be used for cleanup. + pub fn delete(self) -> Result<(), std::io::Error> { + for file in &self.files { + let path = safe_join(&self.root, &file.rel_path); + if let Ok(path) = path { + match fs::remove_file(path) { + Ok(()) => {} + Err(err) if err.kind() == std::io::ErrorKind::NotFound => {} + Err(err) => return Err(err), + } + } + } + Ok(()) + } + + fn clear_temp_files(&self) { + for file in &self.files { + let tmp = temp_path(&file.abs_path); + if let Ok(tmp) = tmp { + if tmp.exists() { + let _ = fs::remove_file(tmp); + } + } + } + } +} + +impl Drop for FsBundle { + fn drop(&mut self) { + self.clear_temp_files(); + } +} + +impl BundleSink for FsBundle { + fn put_file(&mut self, path: &str, reader: &mut R) -> Result<(), String> { + let rel = sanitize_rel_path(path).map_err(|e| e.to_string())?; + let rel = rel.to_string_lossy().to_string(); + + if !self.seen.insert(rel.clone()) { + return Err(format!("Duplicate bundle path: {rel}")); + } + + let dest = safe_join(&self.root, &rel).map_err(|e| e.to_string())?; + if let Some(parent) = dest.parent() { + fs::create_dir_all(parent).map_err(|e| e.to_string())?; + } + + let tmp = temp_path(&dest).map_err(|e| e.to_string())?; + let mut file = match File::create(&tmp) { + Ok(file) => file, + Err(e) => { + self.seen.remove(&rel); + return Err(e.to_string()); + } + }; + + let mut hasher = sha2::Sha256::new(); + let mut buf = [0u8; 1024 * 64]; + let mut total = 0u64; + + loop { + let read = match reader.read(&mut buf) { + Ok(read) => read, + Err(e) => { + let _ = fs::remove_file(&tmp); + self.seen.remove(&rel); + return Err(e.to_string()); + } + }; + if read == 0 { + break; + } + if let Err(e) = file.write_all(&buf[..read]) { + let _ = fs::remove_file(&tmp); + self.seen.remove(&rel); + return Err(e.to_string()); + } + hasher.update(&buf[..read]); + total += read as u64; + } + + let checksum = format!("{:x}", hasher.finalize()); + + if let Err(err) = finalize_temp_file(&tmp, &dest) { + self.seen.remove(&rel); + let _ = fs::remove_file(&tmp); + return Err(err.to_string()); + } + + self.files.push(FsBundleFile { + rel_path: rel, + abs_path: dest, + size_bytes: Some(total), + checksum: Some(checksum), + }); + + Ok(()) + } +} + +/// File descriptor emitted by a file-backed bundle. +#[derive(Debug, Clone)] +pub struct FsBundleFile { + /// Relative path within the bundle. + pub rel_path: String, + /// Absolute file system path for the cached file. + pub abs_path: PathBuf, + /// Size in bytes, when known. + pub size_bytes: Option, + /// SHA-256 checksum (hex), when known. + pub checksum: Option, +} + +impl BundleSource for FsBundle { + fn open(&self, path: &str) -> Result, String> { + let rel = sanitize_rel_path(path).map_err(|e| e.to_string())?; + let rel = rel.to_string_lossy().to_string(); + + if !self.seen.contains(&rel) { + return Err(format!("Bundle path not found: {rel}")); + } + + let file_path = safe_join(&self.root, &rel).map_err(|e| e.to_string())?; + let file = File::open(&file_path).map_err(|e| e.to_string())?; + Ok(Box::new(file)) + } + + fn list(&self) -> Result, String> { + Ok(self.file_paths()) + } +} + +fn temp_path(dest: &Path) -> Result { + let file_name = dest + .file_name() + .ok_or_else(|| std::io::Error::other("Missing file name"))? + .to_string_lossy(); + Ok(dest.with_file_name(format!(".{file_name}.partial"))) +} + +fn finalize_temp_file(tmp: &Path, dest: &Path) -> Result<(), std::io::Error> { + if dest.exists() { + fs::remove_file(dest)?; + } + + fs::rename(tmp, dest) +} + +impl MultipartUploadSource for FsBundle { + fn file_len(&self, rel_path: &str) -> Result { + let source = safe_join(self.root(), rel_path).map_err(UploadError::InvalidMultipart)?; + let metadata = std::fs::metadata(&source).map_err(|e| { + UploadError::InvalidMultipart(format!( + "Missing file for multipart upload {}: {}", + rel_path, e + )) + })?; + if !metadata.is_file() { + return Err(UploadError::InvalidMultipart(format!( + "Multipart upload source is not a file: {}", + rel_path + ))); + } + + Ok(metadata.len()) + } + + fn open_part( + &self, + rel_path: &str, + offset: u64, + size: u64, + ) -> Result, UploadError> { + let source = safe_join(self.root(), rel_path).map_err(UploadError::InvalidMultipart)?; + let mut file = File::open(&source).map_err(|e| UploadError::MultipartReader { + rel_path: rel_path.to_string(), + source: Box::new(e), + })?; + file.seek(SeekFrom::Start(offset)) + .map_err(|e| UploadError::MultipartReader { + rel_path: rel_path.to_string(), + source: Box::new(e), + })?; + Ok(Box::new(file.take(size))) + } +} diff --git a/crates/burn-central-core/src/bundle/memory/mod.rs b/crates/burn-central-artifact/src/bundle/memory/mod.rs similarity index 100% rename from crates/burn-central-core/src/bundle/memory/mod.rs rename to crates/burn-central-artifact/src/bundle/memory/mod.rs diff --git a/crates/burn-central-core/src/bundle/memory/reader.rs b/crates/burn-central-artifact/src/bundle/memory/reader.rs similarity index 94% rename from crates/burn-central-core/src/bundle/memory/reader.rs rename to crates/burn-central-artifact/src/bundle/memory/reader.rs index c61c7a3a..d3da0def 100644 --- a/crates/burn-central-core/src/bundle/memory/reader.rs +++ b/crates/burn-central-artifact/src/bundle/memory/reader.rs @@ -1,7 +1,8 @@ use std::collections::BTreeMap; use std::io::Read; -use crate::bundle::{BundleSource, normalize_bundle_path}; +use crate::bundle::BundleSource; +use crate::tools::path::normalize_bundle_path; /// In-memory reader for synthetic or cached bundles. pub struct InMemoryBundleReader { diff --git a/crates/burn-central-core/src/bundle/memory/sources.rs b/crates/burn-central-artifact/src/bundle/memory/sources.rs similarity index 91% rename from crates/burn-central-core/src/bundle/memory/sources.rs rename to crates/burn-central-artifact/src/bundle/memory/sources.rs index 97e4c33b..c53d79a7 100644 --- a/crates/burn-central-core/src/bundle/memory/sources.rs +++ b/crates/burn-central-artifact/src/bundle/memory/sources.rs @@ -1,6 +1,6 @@ use std::io::Read; -use crate::bundle::{BundleSink, normalize_bundle_path}; +use crate::{bundle::BundleSink, tools::path::normalize_bundle_path}; /// A builder for creating bundles with multiple files #[derive(Default, Clone)] @@ -61,7 +61,10 @@ impl BundleSink for InMemoryBundleSources { reader .read_to_end(&mut buf) .map_err(|e| format!("Failed to read from source: {}", e))?; - *self = self.clone().add_bytes(buf, path); + self.files.push(PendingFile { + dest_path: normalize_bundle_path(path), + source: buf, + }); Ok(()) } } diff --git a/crates/burn-central-core/src/bundle/memory/tests.rs b/crates/burn-central-artifact/src/bundle/memory/tests.rs similarity index 100% rename from crates/burn-central-core/src/bundle/memory/tests.rs rename to crates/burn-central-artifact/src/bundle/memory/tests.rs diff --git a/crates/burn-central-artifact/src/bundle/mod.rs b/crates/burn-central-artifact/src/bundle/mod.rs new file mode 100644 index 00000000..76502f62 --- /dev/null +++ b/crates/burn-central-artifact/src/bundle/mod.rs @@ -0,0 +1,98 @@ +//! This module defines the core traits and structures for working with bundles of artifact files in Burn Central. +//! +//! Bundles are a way to group multiple related files together as a single artifact, with support for encoding/decoding complex data structures into bundles of files, and abstracting over different storage backends for reading/writing those bundles. +//! +//! As a user of Burn Central, you will typically interact with bundles indirectly through higher-level APIs for logging experiment artifacts, registering models, etc. +//! However, if you need to implement custom artifact handling logic (e.g. for a new model format), you may need to implement the BundleEncode/BundleDecode traits for your data structures, and use the BundleSink/BundleSource traits to read/write files from/to bundles in a storage-agnostic way. +//! +//! # Examples +//! +//! ``` +//! use burn_central_artifact::bundle::{BundleEncode, BundleDecode, BundleSink, BundleSource}; +//! use serde::{Serialize, Deserialize}; +//! +//! #[derive(Serialize, Deserialize)] +//! struct MyModel { +//! name: String, +//! parameters: Vec, +//! } +//! impl BundleEncode for MyModel { +//! type Settings = (); +//! type Error = String; +//! fn encode(self, sink: &mut O, _settings: &Self::Settings) -> Result<(), Self::Error> { +//! let json = serde_json::to_string(&self).map_err(|e| e.to_string())?; +//! sink.put_bytes("model.json", json.as_bytes()).map_err(|e| e.to_string())?; +//! Ok(()) +//! } +//! } +//! impl BundleDecode for MyModel { +//! type Settings = (); +//! type Error = String; +//! fn decode(source: &I, _settings: &Self::Settings) -> Result { +//! let mut reader = source.open("model.json").map_err(|e| e.to_string())?; +//! let mut json = String::new(); +//! reader.read_to_string(&mut json).map_err(|e| e.to_string())?; +//! serde_json::from_str(&json).map_err(|e| e.to_string()) +//! } +//! } +//! ``` + +mod fs; +mod memory; + +pub use fs::*; +pub use memory::*; + +use serde::{Serialize, de::DeserializeOwned}; +use std::io::Read; + +/// Trait for encoding data into a bundle of files +/// +/// Implementors should write their data to the provided BundleSink, which abstracts over the underlying storage mechanism. The Settings associated type can be used to pass any necessary configuration for encoding (e.g. compression level, file naming conventions, etc). +pub trait BundleEncode { + /// Settings type for encoding, which can include any necessary configuration (e.g. compression level, file naming conventions, etc). + type Settings: Default + Serialize + DeserializeOwned; + /// Error type for encoding failures. Should be convertible to a generic error type for ease of use in higher-level APIs. + type Error: Into>; + + /// Encode the data into the provided BundleSink. The sink should be used to write all files that are part of the bundle, and the implementation should return an error if encoding fails for any reason (e.g. serialization errors, I/O errors, etc). + fn encode( + self, + sink: &mut O, + settings: &Self::Settings, + ) -> Result<(), Self::Error>; +} + +/// Trait for decoding data from a bundle of files +/// +/// Implementors should read their data from the provided BundleSource, which abstracts over the underlying storage mechanism. The Settings associated type can be used to pass any necessary configuration for decoding (e.g. expected file names, compression settings, etc). +pub trait BundleDecode: Sized { + /// Settings type for decoding, which can include any necessary configuration (e.g. expected file names, compression settings, etc). + type Settings: Default + Serialize + DeserializeOwned; + /// Error type for decoding failures. Should be convertible to a generic error type for ease of use in higher-level APIs. + type Error: Into>; + + /// Decode the data from the provided BundleSource. The implementation should return an error if decoding fails for any reason (e.g. deserialization errors, I/O errors, etc). + fn decode(source: &I, settings: &Self::Settings) -> Result; +} + +/// Trait for writing files to a bundle +pub trait BundleSink { + /// Add a file by streaming its bytes. Returns computed checksum + size. + fn put_file(&mut self, path: &str, reader: &mut R) -> Result<(), String>; + + /// Convenience: write all bytes. + fn put_bytes(&mut self, path: &str, bytes: &[u8]) -> Result<(), String> { + let mut r = std::io::Cursor::new(bytes); + self.put_file(path, &mut r) + } +} + +/// Trait for reading files from a bundle +pub trait BundleSource { + /// Open the given path for streaming read. Must validate existence. + fn open(&self, path: &str) -> Result, String>; + + /// Optionally list available files (used by generic decoders; can be best-effort). + fn list(&self) -> Result, String>; +} diff --git a/crates/burn-central-artifact/src/download.rs b/crates/burn-central-artifact/src/download.rs index 28c88a77..6cf3b307 100644 --- a/crates/burn-central-artifact/src/download.rs +++ b/crates/burn-central-artifact/src/download.rs @@ -1,144 +1,328 @@ -use std::io::{Read, Write}; +//! This module provides utilities for downloading artifact files from any source to any target bundle sink. +//! +//! Downloaded files are validated against expected sizes and checksums when provided, and the download process can be customized with any implementation of the FileTransferClient trait (e.g. for custom HTTP clients, authentication, retries, etc). + +use std::collections::HashSet; +use std::io::Read; -use crossbeam::channel; -use reqwest::blocking::Client as HttpClient; use sha2::Digest; +use crate::bundle::BundleSink; +use crate::tools::path::normalize_bundle_path; use crate::tools::validation::normalize_checksum; +use crate::{FileTransferClient, ReqwestTransferClient}; +/// Errors that can occur during artifact file downloads. #[derive(Debug, thiserror::Error)] pub enum DownloadError { - #[error("failed to download {path}: {details}")] - DownloadFailed { path: String, details: String }, + /// Errors from the transfer client (e.g. network errors, HTTP errors). + #[error("transfer error for {rel_path}: {source}")] + Transfer { + rel_path: String, + #[source] + source: crate::transfer::TransferError, + }, + /// Errors related to file size mismatches after download. #[error("size mismatch for {path}: expected {expected} bytes, got {actual} bytes")] SizeMismatch { path: String, expected: u64, actual: u64, }, + /// Errors related to checksum mismatches after download. #[error("checksum mismatch for {path}: expected {expected}, got {actual}")] ChecksumMismatch { path: String, expected: String, actual: String, }, + /// Errors related to invalid checksums (e.g. non-hex, wrong length). #[error("invalid checksum: {0}")] InvalidChecksum(String), - #[error("writer error: {0}")] - WriterError(#[from] std::io::Error), + /// Errors related to invalid relative paths (e.g. empty, duplicates, unsafe). #[error("invalid path: {0}")] InvalidPath(String), + /// Errors from the target bundle sink (e.g. file system errors). + #[error("target error: {0}")] + TargetError(String), } -/// A single file download task. -#[derive(Clone)] -pub struct DownloadTask { +/// Generic download descriptor for any model artifact file. +#[derive(Debug, Clone)] +pub struct ArtifactDownloadFile { pub rel_path: String, pub url: String, - pub writer: W, - pub expected_size: u64, - pub expected_checksum: String, + /// Optional expected file size in bytes. + pub size_bytes: Option, + /// Optional expected SHA-256 checksum. + pub checksum: Option, } -/// Download multiple files in parallel. -pub fn download_tasks( - http: &HttpClient, - tasks: Vec>, - max_parallel: usize, +/// Download artifact files into any bundle sink implementation. +pub fn download_artifacts_to_sink( + sink: &mut S, + files: &[ArtifactDownloadFile], ) -> Result<(), DownloadError> { - if tasks.is_empty() { - return Ok(()); + let client = ReqwestTransferClient::new(); + download_artifacts_to_sink_with_client(&client, sink, files) +} + +/// Download artifact files into any bundle sink implementation using a custom transfer client. +pub fn download_artifacts_to_sink_with_client( + client: &FTC, + sink: &mut S, + files: &[ArtifactDownloadFile], +) -> Result<(), DownloadError> { + let files = validated_download_files(files)?; + for (rel_path, file) in files { + let reader = client + .get_reader(&file.url) + .map_err(|e| DownloadError::Transfer { + rel_path: rel_path.clone(), + source: e, + })?; + let mut verifying_reader = VerifyingReader::new(reader); + + sink.put_file(&rel_path, &mut verifying_reader) + .map_err(DownloadError::TargetError)?; + + let (total, digest) = verifying_reader.finish(); + validate_download( + &rel_path, + total, + digest, + file.size_bytes, + file.checksum.as_deref(), + )?; + } + + Ok(()) +} + +fn validated_download_files( + files: &[ArtifactDownloadFile], +) -> Result, DownloadError> { + let mut seen = HashSet::with_capacity(files.len()); + let mut out = Vec::with_capacity(files.len()); + for file in files { + let rel_path = normalize_bundle_path(&file.rel_path); + if rel_path.is_empty() { + return Err(DownloadError::InvalidPath( + "empty relative artifact path".to_string(), + )); + } + if !seen.insert(rel_path.clone()) { + return Err(DownloadError::InvalidPath(format!( + "duplicate relative artifact path: {rel_path}" + ))); + } + + out.push((rel_path, file)); } - if max_parallel <= 1 || tasks.len() == 1 { - for mut task in tasks { - download_one(http, &mut task)?; + Ok(out) +} + +struct VerifyingReader { + inner: R, + hasher: sha2::Sha256, + total: u64, +} + +impl VerifyingReader { + fn new(inner: R) -> Self { + Self { + inner, + hasher: sha2::Sha256::new(), + total: 0, } - return Ok(()); } - let (tx, rx) = channel::unbounded::>(); - for task in tasks { - tx.send(task).expect("channel open"); + fn finish(self) -> (u64, String) { + (self.total, format!("{:x}", self.hasher.finalize())) } - drop(tx); - - crossbeam::scope(|scope| { - let mut handles = Vec::new(); - let worker_count = max_parallel.min(rx.len().max(1)); - for _ in 0..worker_count { - let rx = rx.clone(); - let http = http.clone(); - handles.push(scope.spawn(move |_| { - for mut task in rx.iter() { - download_one(&http, &mut task)?; - } - Ok::<(), DownloadError>(()) - })); +} + +impl Read for VerifyingReader { + fn read(&mut self, buf: &mut [u8]) -> std::io::Result { + let read = self.inner.read(buf)?; + self.hasher.update(&buf[..read]); + self.total += read as u64; + Ok(read) + } +} + +fn validate_download( + rel_path: &str, + total: u64, + digest: String, + expected_size: Option, + expected_checksum: Option<&str>, +) -> Result<(), DownloadError> { + if let Some(expected_size) = expected_size { + if total != expected_size { + return Err(DownloadError::SizeMismatch { + path: rel_path.to_string(), + expected: expected_size, + actual: total, + }); } + } - for handle in handles { - handle.join().expect("thread panicked")?; + if let Some(expected_checksum) = expected_checksum { + let expected_checksum = + normalize_checksum(expected_checksum).map_err(DownloadError::InvalidChecksum)?; + if digest != expected_checksum { + return Err(DownloadError::ChecksumMismatch { + path: rel_path.to_string(), + expected: expected_checksum, + actual: digest, + }); } + } - Ok(()) - }) - .expect("scope failed") + Ok(()) } -/// Download a single file with checksum verification. -fn download_one( - http: &HttpClient, - task: &mut DownloadTask, -) -> Result<(), DownloadError> { - let mut resp = http - .get(&task.url) - .send() - .map_err(|e| DownloadError::DownloadFailed { - path: task.rel_path.clone(), - details: e.to_string(), - })?; - - if !resp.status().is_success() { - return Err(DownloadError::DownloadFailed { - path: task.rel_path.clone(), - details: format!("HTTP {}", resp.status()), - }); +#[cfg(test)] +mod tests { + use super::*; + use crate::bundle::InMemoryBundleSources; + use crate::transfer::TransferError; + use std::collections::HashMap; + use std::io::{Cursor, Read}; + use std::sync::Arc; + + #[derive(Clone)] + struct MockClient { + files: Arc>>, } - let sink = &mut task.writer; - let mut hasher = sha2::Sha256::new(); - let mut buf = [0u8; 1024 * 64]; - let mut total = 0u64; + impl MockClient { + fn new(files: HashMap>) -> Self { + Self { + files: Arc::new(files), + } + } + } - loop { - let read = resp.read(&mut buf)?; - if read == 0 { - break; + impl FileTransferClient for MockClient { + fn put_reader( + &self, + _url: &str, + mut reader: R, + _size_bytes: u64, + ) -> Result<(), TransferError> { + let mut buf = Vec::new(); + reader + .read_to_end(&mut buf) + .map_err(|e| TransferError::Transport(e.to_string()))?; + Ok(()) } - sink.write_all(&buf[..read])?; - hasher.update(&buf[..read]); - total += read as u64; + + fn get_reader(&self, url: &str) -> Result, TransferError> { + let bytes = self + .files + .get(url) + .ok_or_else(|| TransferError::Transport(format!("missing url in mock: {url}")))?; + Ok(Box::new(Cursor::new(bytes.clone()))) + } + } + + fn sha256_hex(bytes: &[u8]) -> String { + let mut hasher = sha2::Sha256::new(); + hasher.update(bytes); + format!("{:x}", hasher.finalize()) } - let digest = format!("{:x}", hasher.finalize()); - let expected_checksum = - normalize_checksum(&task.expected_checksum).map_err(DownloadError::InvalidChecksum)?; + #[test] + fn downloads_to_sink_and_validates_checksum_and_size() { + let data = b"hello world".to_vec(); + let checksum = sha256_hex(&data); + let mut sink = InMemoryBundleSources::new(); + let client = MockClient::new(HashMap::from([("mock://f1".to_string(), data.clone())])); + let files = vec![ArtifactDownloadFile { + rel_path: "weights.bin".to_string(), + url: "mock://f1".to_string(), + size_bytes: Some(data.len() as u64), + checksum: Some(checksum), + }]; - if total != task.expected_size { - return Err(DownloadError::SizeMismatch { - path: task.rel_path.clone(), - expected: task.expected_size, - actual: total, - }); + download_artifacts_to_sink_with_client(&client, &mut sink, &files) + .expect("download should succeed"); + + assert_eq!(sink.len(), 1); + assert_eq!(sink.files()[0].dest_path(), "weights.bin"); + assert_eq!(sink.files()[0].source(), data); } - if digest != expected_checksum { - return Err(DownloadError::ChecksumMismatch { - path: task.rel_path.clone(), - expected: expected_checksum, - actual: digest, - }); + + #[test] + fn rejects_duplicate_relative_paths() { + let client = MockClient::new(HashMap::new()); + let mut sink = InMemoryBundleSources::new(); + let files = vec![ + ArtifactDownloadFile { + rel_path: "a.bin".to_string(), + url: "mock://a".to_string(), + size_bytes: None, + checksum: None, + }, + ArtifactDownloadFile { + rel_path: "a.bin".to_string(), + url: "mock://b".to_string(), + size_bytes: None, + checksum: None, + }, + ]; + + let err = download_artifacts_to_sink_with_client(&client, &mut sink, &files) + .expect_err("duplicate paths should fail"); + + match err { + DownloadError::InvalidPath(msg) => assert!(msg.contains("duplicate")), + other => panic!("unexpected error: {other:?}"), + } } - Ok(()) + #[test] + fn fails_on_checksum_mismatch() { + let data = b"payload".to_vec(); + let mut sink = InMemoryBundleSources::new(); + let client = MockClient::new(HashMap::from([("mock://f2".to_string(), data.clone())])); + let files = vec![ArtifactDownloadFile { + rel_path: "params.bin".to_string(), + url: "mock://f2".to_string(), + size_bytes: Some(data.len() as u64), + checksum: Some("00".repeat(32)), + }]; + + let err = download_artifacts_to_sink_with_client(&client, &mut sink, &files) + .expect_err("checksum mismatch should fail"); + + match err { + DownloadError::ChecksumMismatch { path, .. } => assert_eq!(path, "params.bin"), + other => panic!("unexpected error: {other:?}"), + } + } + + #[test] + fn fails_on_size_mismatch() { + let data = b"payload".to_vec(); + let mut sink = InMemoryBundleSources::new(); + let client = MockClient::new(HashMap::from([("mock://f3".to_string(), data.clone())])); + let files = vec![ArtifactDownloadFile { + rel_path: "params.bin".to_string(), + url: "mock://f3".to_string(), + size_bytes: Some((data.len() as u64) + 1), + checksum: None, + }]; + + let err = download_artifacts_to_sink_with_client(&client, &mut sink, &files) + .expect_err("size mismatch should fail"); + + match err { + DownloadError::SizeMismatch { path, .. } => assert_eq!(path, "params.bin"), + other => panic!("unexpected error: {other:?}"), + } + } } diff --git a/crates/burn-central-artifact/src/lib.rs b/crates/burn-central-artifact/src/lib.rs index fac1aa3b..8a6a4cc6 100644 --- a/crates/burn-central-artifact/src/lib.rs +++ b/crates/burn-central-artifact/src/lib.rs @@ -1,8 +1,11 @@ -//! This crate centralizes traits, structures and utilities for handling artifacts and models in Burn Central. +//! This crate centralizes traits, structures and utilities for handling artifacts in Burn Central. -mod artifact_download; -mod download; mod tools; +mod transfer; -pub use artifact_download::{ArtifactDownloadFile, download_artifacts_to_dir}; -pub use download::DownloadError; +pub mod bundle; +pub mod download; +pub mod upload; + +pub use tools::validation::normalize_checksum; +pub use transfer::{FileTransferClient, ReqwestTransferClient}; diff --git a/crates/burn-central-artifact/src/tools/path.rs b/crates/burn-central-artifact/src/tools/path.rs index 0a2661b6..fe876ffc 100644 --- a/crates/burn-central-artifact/src/tools/path.rs +++ b/crates/burn-central-artifact/src/tools/path.rs @@ -1,10 +1,20 @@ use std::path::{Path, PathBuf}; -use burn_central_core::bundle::normalize_bundle_path; +/// Normalize a path within a bundle (use forward slashes, remove leading slash) +pub fn normalize_bundle_path>(s: S) -> String { + s.as_ref() + .replace('\\', "/") + .trim_start_matches('/') + .to_string() +} /// Sanitize a relative path to prevent directory traversal attacks. pub fn sanitize_rel_path(path: &str) -> Result { let normalized = normalize_bundle_path(path); + if normalized.is_empty() { + return Err("invalid path component: empty path".to_string()); + } + let rel = Path::new(&normalized); for component in rel.components() { use std::path::Component; diff --git a/crates/burn-central-artifact/src/transfer.rs b/crates/burn-central-artifact/src/transfer.rs new file mode 100644 index 00000000..8ecf3637 --- /dev/null +++ b/crates/burn-central-artifact/src/transfer.rs @@ -0,0 +1,86 @@ +use std::io::Read; + +#[derive(Debug, thiserror::Error)] +pub enum TransferError { + #[error("Transport error: {0}")] + Transport(String), +} + +/// Generic client interface used for uploading and downloading files, abstracting over the underlying HTTP client or other transport mechanism. +pub trait FileTransferClient: Clone + Send + Sync + 'static { + /// Upload data from a reader to the given URL with known size. + fn put_reader( + &self, + url: &str, + reader: R, + size_bytes: u64, + ) -> Result<(), TransferError>; + + /// Download data from the given URL as a reader. + fn get_reader(&self, url: &str) -> Result, TransferError>; +} + +/// Reqwest-based transfer client. +#[derive(Clone)] +pub struct ReqwestTransferClient { + http: reqwest::blocking::Client, +} + +impl ReqwestTransferClient { + pub fn new() -> Self { + Self { + http: reqwest::blocking::Client::new(), + } + } + + pub fn with_client(http: reqwest::blocking::Client) -> Self { + Self { http } + } +} + +impl Default for ReqwestTransferClient { + fn default() -> Self { + Self::new() + } +} + +impl FileTransferClient for ReqwestTransferClient { + fn put_reader( + &self, + url: &str, + reader: R, + size_bytes: u64, + ) -> Result<(), TransferError> { + let body = reqwest::blocking::Body::sized(reader, size_bytes); + let response = self + .http + .put(url) + .body(body) + .send() + .map_err(|e| TransferError::Transport(e.to_string()))?; + + if !response.status().is_success() { + return Err(TransferError::Transport( + response.error_for_status().err().unwrap().to_string(), + )); + } + + Ok(()) + } + + fn get_reader(&self, url: &str) -> Result, TransferError> { + let response = self + .http + .get(url) + .send() + .map_err(|e| TransferError::Transport(e.to_string()))?; + + if !response.status().is_success() { + return Err(TransferError::Transport( + response.error_for_status().err().unwrap().to_string(), + )); + } + + Ok(Box::new(response)) + } +} diff --git a/crates/burn-central-artifact/src/upload.rs b/crates/burn-central-artifact/src/upload.rs new file mode 100644 index 00000000..bdee717c --- /dev/null +++ b/crates/burn-central-artifact/src/upload.rs @@ -0,0 +1,336 @@ +//! This module provides utilities for uploading artifact files from any source to any target bundle sink using multipart uploads with presigned URLs. +//! +//! The upload process can be customized with any implementation of the FileTransferClient trait (e.g. for custom HTTP clients, authentication, retries, etc), and multipart file sources can be abstracted behind the MultipartUploadSource trait for maximum flexibility (e.g. to support streaming from large files without loading them fully into memory). + +use crate::transfer::TransferError; +use crate::{FileTransferClient, ReqwestTransferClient}; +use std::collections::HashSet; +use std::io::Read; + +/// Errors that can occur during artifact file uploads. +#[derive(Debug, thiserror::Error)] +pub enum UploadError { + /// Errors from the transfer client (e.g. network errors, HTTP errors). + #[error("transfer error for part {part_index} of {total_parts} for {rel_path}: {source}")] + Transfer { + part_index: usize, + total_parts: usize, + rel_path: String, + #[source] + source: TransferError, + }, + /// Errors related to invalid multipart upload plans (e.g. duplicate paths, invalid part numbering). + #[error("invalid multipart upload plan: {0}")] + InvalidMultipart(String), + /// Errors related to multipart reader issues (e.g. file access errors). + #[error("multipart reader error for {rel_path}: {source}")] + MultipartReader { + rel_path: String, + #[source] + source: Box, + }, +} + +/// One multipart upload part descriptor. +#[derive(Debug, Clone)] +pub struct MultipartUploadPart { + pub part: u32, + pub url: String, + pub size_bytes: u64, +} + +/// One file multipart upload descriptor. +#[derive(Debug, Clone)] +pub struct MultipartUploadFile { + pub rel_path: String, + pub parts: Vec, +} + +/// Source abstraction for multipart uploads. +pub trait MultipartUploadSource { + /// Return the file length in bytes for a relative path. + fn file_len(&self, rel_path: &str) -> Result; + + /// Open a reader for one file chunk. + fn open_part( + &self, + rel_path: &str, + offset: u64, + size: u64, + ) -> Result, UploadError>; +} + +/// Upload multiple files from a multipart source using presigned URLs. +pub fn upload_bundle_multipart( + source: &S, + files: &[MultipartUploadFile], +) -> Result<(), UploadError> { + let client = ReqwestTransferClient::new(); + upload_bundle_multipart_with_client(&client, source, files) +} + +/// Upload multiple files from a multipart source using presigned URLs and a custom client. +pub fn upload_bundle_multipart_with_client( + client: &FTC, + source: &S, + files: &[MultipartUploadFile], +) -> Result<(), UploadError> { + let mut seen = HashSet::new(); + + for file in files { + if !seen.insert(file.rel_path.clone()) { + return Err(UploadError::InvalidMultipart(format!( + "Duplicate multipart upload descriptor for {}", + file.rel_path + ))); + } + + upload_source_file_multipart_streaming(client, source, &file.rel_path, &file.parts)?; + } + + Ok(()) +} + +fn upload_source_file_multipart_streaming( + client: &FTC, + source: &S, + rel_path: &str, + parts: &[MultipartUploadPart], +) -> Result<(), UploadError> { + let file_len = source.file_len(rel_path)?; + + let mut part_indices: Vec = (0..parts.len()).collect(); + part_indices.sort_by_key(|&i| parts[i].part); + + for (i, &part_idx) in part_indices.iter().enumerate() { + let part = &parts[part_idx]; + if part.part != (i as u32 + 1) { + return Err(UploadError::InvalidMultipart(format!( + "Invalid part numbering for {}: expected {}, got {}", + rel_path, + i + 1, + part.part + ))); + } + } + + let mut offset = 0u64; + + for (part_index, &part_idx) in part_indices.iter().enumerate() { + let part = &parts[part_idx]; + let size = part.size_bytes; + + if offset + size > file_len { + return Err(UploadError::InvalidMultipart(format!( + "Part {} exceeds file length for {}", + part_index + 1, + rel_path + ))); + } + + let reader = source.open_part(rel_path, offset, size)?; + client + .put_reader(&part.url, reader, size) + .map_err(|e| UploadError::Transfer { + part_index: part_index + 1, + total_parts: parts.len(), + rel_path: rel_path.to_string(), + source: e, + })?; + + offset += size; + } + + if offset != file_len { + return Err(UploadError::InvalidMultipart(format!( + "Multipart size mismatch for {} (uploaded {}, expected {})", + rel_path, offset, file_len + ))); + } + + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::transfer::TransferError; + use std::collections::HashMap; + use std::io::{Cursor, Read}; + use std::sync::{Arc, Mutex}; + + #[derive(Clone, Default)] + struct MockClient { + puts: Arc)>>>, + } + + impl FileTransferClient for MockClient { + fn put_reader( + &self, + url: &str, + mut reader: R, + size_bytes: u64, + ) -> Result<(), TransferError> { + let mut bytes = Vec::new(); + reader + .read_to_end(&mut bytes) + .map_err(|e| TransferError::Transport(e.to_string()))?; + self.puts + .lock() + .expect("lock puts") + .push((url.to_string(), size_bytes, bytes)); + Ok(()) + } + + fn get_reader(&self, _url: &str) -> Result, TransferError> { + Err(TransferError::Transport( + "get_reader should not be used in upload tests".to_string(), + )) + } + } + + struct MockSource { + files: HashMap>, + } + + impl MockSource { + fn new(files: HashMap>) -> Self { + Self { files } + } + } + + impl MultipartUploadSource for MockSource { + fn file_len(&self, rel_path: &str) -> Result { + let bytes = self.files.get(rel_path).ok_or_else(|| { + UploadError::InvalidMultipart(format!("missing file: {rel_path}")) + })?; + Ok(bytes.len() as u64) + } + + fn open_part( + &self, + rel_path: &str, + offset: u64, + size: u64, + ) -> Result, UploadError> { + let bytes = self.files.get(rel_path).ok_or_else(|| { + UploadError::InvalidMultipart(format!("missing file: {rel_path}")) + })?; + let start = offset as usize; + let end = (offset + size) as usize; + let slice = bytes + .get(start..end) + .ok_or_else(|| UploadError::MultipartReader { + rel_path: rel_path.to_string(), + source: format!( + "invalid part range [{start}..{end}) for file of len {}", + bytes.len() + ) + .into(), + })?; + Ok(Box::new(Cursor::new(slice.to_vec()))) + } + } + + #[test] + fn rejects_non_contiguous_part_numbering() { + let client = MockClient::default(); + let source = MockSource::new(HashMap::from([( + "weights.bin".to_string(), + b"abcd".to_vec(), + )])); + let files = vec![MultipartUploadFile { + rel_path: "weights.bin".to_string(), + parts: vec![ + MultipartUploadPart { + part: 1, + url: "u1".to_string(), + size_bytes: 2, + }, + MultipartUploadPart { + part: 3, + url: "u3".to_string(), + size_bytes: 2, + }, + ], + }]; + + let err = upload_bundle_multipart_with_client(&client, &source, &files) + .expect_err("part numbering must be contiguous"); + + match err { + UploadError::InvalidMultipart(msg) => assert!(msg.contains("expected 2, got 3")), + other => panic!("unexpected error: {other:?}"), + } + } + + #[test] + fn rejects_part_plan_exceeding_file_len() { + let client = MockClient::default(); + let source = MockSource::new(HashMap::from([( + "weights.bin".to_string(), + b"abc".to_vec(), + )])); + let files = vec![MultipartUploadFile { + rel_path: "weights.bin".to_string(), + parts: vec![ + MultipartUploadPart { + part: 1, + url: "u1".to_string(), + size_bytes: 2, + }, + MultipartUploadPart { + part: 2, + url: "u2".to_string(), + size_bytes: 2, + }, + ], + }]; + + let err = upload_bundle_multipart_with_client(&client, &source, &files) + .expect_err("total part sizes cannot exceed file len"); + + match err { + UploadError::InvalidMultipart(msg) => assert!(msg.contains("exceeds file length")), + other => panic!("unexpected error: {other:?}"), + } + } + + #[test] + fn uploads_all_parts_with_expected_content() { + let client = MockClient::default(); + let source = MockSource::new(HashMap::from([( + "weights.bin".to_string(), + b"abcdef".to_vec(), + )])); + let files = vec![MultipartUploadFile { + rel_path: "weights.bin".to_string(), + parts: vec![ + MultipartUploadPart { + part: 1, + url: "u1".to_string(), + size_bytes: 2, + }, + MultipartUploadPart { + part: 2, + url: "u2".to_string(), + size_bytes: 2, + }, + MultipartUploadPart { + part: 3, + url: "u3".to_string(), + size_bytes: 2, + }, + ], + }]; + + upload_bundle_multipart_with_client(&client, &source, &files) + .expect("valid multipart plan should upload"); + + let puts = client.puts.lock().expect("lock puts"); + assert_eq!(puts.len(), 3); + assert_eq!(puts[0], ("u1".to_string(), 2, b"ab".to_vec())); + assert_eq!(puts[1], ("u2".to_string(), 2, b"cd".to_vec())); + assert_eq!(puts[2], ("u3".to_string(), 2, b"ef".to_vec())); + } +} diff --git a/crates/burn-central-core/Cargo.toml b/crates/burn-central-core/Cargo.toml index 35863bb3..03bc74b0 100644 --- a/crates/burn-central-core/Cargo.toml +++ b/crates/burn-central-core/Cargo.toml @@ -17,8 +17,6 @@ url.workspace = true serde.workspace = true serde_json.workspace = true sha2.workspace = true -tempfile = "3.24.0" -reqwest = { version = "0.13.2", features = ["blocking"] } strum.workspace = true thiserror.workspace = true tracing.workspace = true @@ -26,3 +24,4 @@ tracing-core.workspace = true tracing-subscriber.workspace = true crossbeam.workspace = true burn-central-client.workspace = true +burn-central-artifact.workspace = true diff --git a/crates/burn-central-core/src/artifacts/mod.rs b/crates/burn-central-core/src/artifacts/mod.rs index ea8edb47..e87b52e9 100644 --- a/crates/burn-central-core/src/artifacts/mod.rs +++ b/crates/burn-central-core/src/artifacts/mod.rs @@ -1,13 +1,16 @@ -use burn_central_client::response::{ArtifactResponse, MultipartUploadResponse}; +use burn_central_artifact::bundle::{BundleDecode, BundleEncode, FsBundle}; +use burn_central_artifact::download::{ + ArtifactDownloadFile, DownloadError, download_artifacts_to_sink, +}; +use burn_central_artifact::upload::{ + MultipartUploadFile, MultipartUploadPart, UploadError, upload_bundle_multipart, +}; +use burn_central_client::request::{ArtifactFileSpecRequest, CreateArtifactRequest}; +use burn_central_client::response::ArtifactResponse; use burn_central_client::{Client, ClientError}; use std::collections::BTreeMap; -use std::fs::{self, File}; -use std::io::{Read, Seek, SeekFrom}; -use std::path::Path; -use crate::bundle::{BundleDecode, BundleEncode, FsBundleReader, FsBundleSink}; use crate::schemas::ExperimentPath; -use burn_central_client::request::{ArtifactFileSpecRequest, CreateArtifactRequest}; #[derive(Debug, Clone, strum::Display, strum::EnumString)] #[strum(serialize_all = "snake_case")] @@ -17,7 +20,7 @@ pub enum ArtifactKind { Other, } -/// A scope for artifact operations within a specific experiment +/// A scope for artifact operations within a specific experiment. #[derive(Clone)] pub struct ExperimentArtifactClient { client: Client, @@ -38,19 +41,25 @@ impl ExperimentArtifactClient { settings: &E::Settings, ) -> Result { let name = name.into(); - let mut sink = FsBundleSink::temp() + let mut bundle = FsBundle::temp() .map_err(|e| ArtifactError::Internal(format!("Failed to create temp bundle: {e}")))?; - artifact.encode(&mut sink, settings).map_err(|e| { + artifact.encode(&mut bundle, settings).map_err(|e| { ArtifactError::Encoding(format!("Failed to encode artifact: {}", e.into())) })?; - let mut specs = Vec::with_capacity(sink.files().len()); - for f in sink.files() { + let mut specs = Vec::with_capacity(bundle.files().len()); + for f in bundle.files() { + let size_bytes = f.size_bytes.ok_or_else(|| { + ArtifactError::Internal(format!("Missing file size for {}", f.rel_path)) + })?; + let checksum = f.checksum.clone().ok_or_else(|| { + ArtifactError::Internal(format!("Missing checksum for {}", f.rel_path)) + })?; specs.push(ArtifactFileSpecRequest { rel_path: f.rel_path.clone(), - size_bytes: f.size_bytes, - checksum: f.checksum.clone(), + size_bytes, + checksum, }); } @@ -65,14 +74,14 @@ impl ExperimentArtifactClient { }, )?; - let mut multipart_map: BTreeMap = BTreeMap::new(); + let mut multipart_map = BTreeMap::new(); for f in &res.files { multipart_map.insert(f.rel_path.clone(), &f.urls); } - let files = sink.files().to_vec(); + let mut uploads = Vec::with_capacity(bundle.files().len()); - for f in files { + for f in bundle.files() { let multipart_info = multipart_map.get(&f.rel_path).ok_or_else(|| { ArtifactError::Internal(format!( "Missing multipart upload info for file {}", @@ -80,8 +89,22 @@ impl ExperimentArtifactClient { )) })?; - self.upload_file_multipart_streaming(&f.abs_path, &f.rel_path, multipart_info)?; + let parts = multipart_info + .parts + .iter() + .map(|part| MultipartUploadPart { + part: part.part, + url: part.url.clone(), + size_bytes: part.size_bytes, + }) + .collect::>(); + + uploads.push(MultipartUploadFile { + rel_path: f.rel_path.clone(), + parts, + }); } + upload_bundle_multipart(&bundle, &uploads)?; self.client.complete_artifact_upload( self.exp_path.owner_name(), @@ -94,7 +117,7 @@ impl ExperimentArtifactClient { Ok(res.id) } - /// Download an artifact and decode it using the BundleDecode trait (filesystem-backed) + /// Download an artifact and decode it using the BundleDecode trait (filesystem-backed). pub fn download( &self, name: impl AsRef, @@ -110,8 +133,8 @@ impl ExperimentArtifactClient { }) } - /// Download an artifact as a filesystem-backed bundle reader - pub fn download_raw(&self, name: impl AsRef) -> Result { + /// Download an artifact as a filesystem-backed bundle. + pub fn download_raw(&self, name: impl AsRef) -> Result { let name = name.as_ref(); let artifact = self.fetch(name)?; let resp = self.client.presign_artifact_download( @@ -121,62 +144,25 @@ impl ExperimentArtifactClient { &artifact.id.to_string(), )?; - let mut file_list = Vec::new(); - for file_info in &resp.files { - file_list.push(file_info.rel_path.clone()); + let mut files = Vec::with_capacity(resp.files.len()); + for file in resp.files { + files.push(ArtifactDownloadFile { + rel_path: file.rel_path, + url: file.url, + size_bytes: None, + checksum: None, + }); } - // Create a temporary bundle reader that owns its temp directory - let reader = FsBundleReader::temp(file_list) + let mut bundle = FsBundle::temp() .map_err(|e| ArtifactError::Internal(format!("Failed to create temp bundle: {e}")))?; - for file_info in resp.files { - let rel_path = file_info.rel_path; - let dest_path = reader.root().join(&rel_path); - - // Create parent directories - if let Some(parent) = dest_path.parent() { - fs::create_dir_all(parent).map_err(|e| { - ArtifactError::Internal(format!( - "Failed to create directory for {}: {}", - rel_path, e - )) - })?; - } - - // Download file directly to disk - self.download_file_to_path(&file_info.url, &dest_path)?; - } - - Ok(reader) - } - - fn download_file_to_path(&self, url: &str, dest: &Path) -> Result<(), ArtifactError> { - let http = reqwest::blocking::Client::new(); - let mut response = http - .get(url) - .send() - .map_err(|e| ArtifactError::Internal(format!("Failed to download from URL: {}", e)))?; - - if !response.status().is_success() { - return Err(ArtifactError::Internal(format!( - "Failed to download file: HTTP {}", - response.status() - ))); - } - - let mut file = File::create(dest).map_err(|e| { - ArtifactError::Internal(format!("Failed to create file {}: {}", dest.display(), e)) - })?; + download_artifacts_to_sink(&mut bundle, &files)?; - std::io::copy(&mut response, &mut file).map_err(|e| { - ArtifactError::Internal(format!("Failed to write file {}: {}", dest.display(), e)) - })?; - - Ok(()) + Ok(bundle) } - /// Fetch information about an artifact by name + /// Fetch information about an artifact by name. pub fn fetch(&self, name: impl AsRef) -> Result { let name = name.as_ref(); self.client @@ -191,88 +177,6 @@ impl ExperimentArtifactClient { .next() .ok_or_else(|| ArtifactError::NotFound(name.to_owned())) } - - fn upload_file_multipart_streaming( - &self, - file_path: &Path, - rel_path: &str, - multipart_info: &MultipartUploadResponse, - ) -> Result<(), ArtifactError> { - let metadata = fs::metadata(file_path) - .map_err(|e| ArtifactError::Internal(format!("Failed to stat file {rel_path}: {e}")))?; - let file_len = metadata.len(); - - let mut part_indices: Vec = (0..multipart_info.parts.len()).collect(); - part_indices.sort_by_key(|&i| multipart_info.parts[i].part); - - for (i, &part_idx) in part_indices.iter().enumerate() { - let part = &multipart_info.parts[part_idx]; - if part.part != (i as u32 + 1) { - return Err(ArtifactError::Internal(format!( - "Invalid part numbering for {}: expected part {}, got part {}", - rel_path, - i + 1, - part.part - ))); - } - } - - let http = reqwest::blocking::Client::new(); - let mut offset = 0u64; - - for (part_index, &part_idx) in part_indices.iter().enumerate() { - let part_info = &multipart_info.parts[part_idx]; - let size = part_info.size_bytes; - - if offset + size > file_len { - return Err(ArtifactError::Internal(format!( - "Part {} exceeds file length for {}", - part_index + 1, - rel_path - ))); - } - - let mut file = File::open(file_path).map_err(|e| { - ArtifactError::Internal(format!("Failed to open file {rel_path}: {e}")) - })?; - file.seek(SeekFrom::Start(offset)).map_err(|e| { - ArtifactError::Internal(format!("Failed to seek file {rel_path}: {e}")) - })?; - - let reader = file.take(size); - let body = reqwest::blocking::Body::sized(reader, size); - let response = http.put(&part_info.url).body(body).send().map_err(|e| { - ArtifactError::Internal(format!( - "Failed to upload part {} of {} for {}: {}", - part_index + 1, - multipart_info.parts.len(), - rel_path, - e - )) - })?; - - if !response.status().is_success() { - return Err(ArtifactError::Internal(format!( - "Failed to upload part {} of {} for {}: HTTP {}", - part_index + 1, - multipart_info.parts.len(), - rel_path, - response.status() - ))); - } - - offset += size; - } - - if offset != file_len { - return Err(ArtifactError::Internal(format!( - "Multipart upload size mismatch for {} (uploaded {}, expected {})", - rel_path, offset, file_len - ))); - } - - Ok(()) - } } #[derive(Debug, thiserror::Error)] @@ -285,6 +189,10 @@ pub enum ArtifactError { Encoding(String), #[error("Error while decoding artifact: {0}")] Decoding(String), + #[error(transparent)] + Download(#[from] DownloadError), + #[error(transparent)] + Upload(#[from] UploadError), #[error("Internal error: {0}")] Internal(String), } diff --git a/crates/burn-central-core/src/bundle/core.rs b/crates/burn-central-core/src/bundle/core.rs deleted file mode 100644 index b590ad3e..00000000 --- a/crates/burn-central-core/src/bundle/core.rs +++ /dev/null @@ -1,51 +0,0 @@ -use serde::{Serialize, de::DeserializeOwned}; -use std::io::Read; - -/// Trait for encoding data into a bundle of files -pub trait BundleEncode { - type Settings: Default + Serialize + DeserializeOwned; - type Error: Into>; - - fn encode( - self, - sink: &mut O, - settings: &Self::Settings, - ) -> Result<(), Self::Error>; -} - -/// Trait for decoding data from a bundle of files -pub trait BundleDecode: Sized { - type Settings: Default + Serialize + DeserializeOwned; - type Error: Into>; - - fn decode(source: &I, settings: &Self::Settings) -> Result; -} - -/// Trait for writing files to a bundle -pub trait BundleSink { - /// Add a file by streaming its bytes. Returns computed checksum + size. - fn put_file(&mut self, path: &str, reader: &mut R) -> Result<(), String>; - - /// Convenience: write all bytes. - fn put_bytes(&mut self, path: &str, bytes: &[u8]) -> Result<(), String> { - let mut r = std::io::Cursor::new(bytes); - self.put_file(path, &mut r) - } -} - -/// Trait for reading files from a bundle -pub trait BundleSource { - /// Open the given path for streaming read. Must validate existence. - fn open(&self, path: &str) -> Result, String>; - - /// Optionally list available files (used by generic decoders; can be best-effort). - fn list(&self) -> Result, String>; -} - -/// Normalize a path within a bundle (use forward slashes, remove leading slash) -pub fn normalize_bundle_path>(s: S) -> String { - s.as_ref() - .replace('\\', "/") - .trim_start_matches('/') - .to_string() -} diff --git a/crates/burn-central-core/src/bundle/fs.rs b/crates/burn-central-core/src/bundle/fs.rs deleted file mode 100644 index 5f705221..00000000 --- a/crates/burn-central-core/src/bundle/fs.rs +++ /dev/null @@ -1,201 +0,0 @@ -use std::collections::HashSet; -use std::fs::{self, File}; -use std::io::{Read, Write}; -use std::path::{Path, PathBuf}; - -use sha2::Digest; -use tempfile::TempDir; - -use crate::bundle::{BundleSink, BundleSource, normalize_bundle_path}; - -/// File-backed bundle sink that streams files to disk and computes checksums. -#[derive(Debug)] -pub struct FsBundleSink { - root: PathBuf, - files: Vec, - seen: HashSet, - #[allow(unused)] - _temp: Option, -} - -impl FsBundleSink { - /// Create a bundle sink rooted at the provided directory. - pub fn new(root: impl Into) -> Result { - let root = root.into(); - fs::create_dir_all(&root)?; - Ok(Self { - root, - files: Vec::new(), - seen: HashSet::new(), - _temp: None, - }) - } - - /// Create a temporary bundle sink that cleans up on drop. - pub fn temp() -> Result { - let temp = TempDir::new()?; - let root = temp.path().to_path_buf(); - Ok(Self { - root, - files: Vec::new(), - seen: HashSet::new(), - _temp: Some(temp), - }) - } - - /// Root directory for the bundle files. - pub fn root(&self) -> &Path { - &self.root - } - - /// Files written into the bundle. - pub fn files(&self) -> &[FsBundleFile] { - &self.files - } - - /// Consume the sink and return the file list. - pub fn into_files(self) -> Vec { - self.files - } -} - -impl BundleSink for FsBundleSink { - fn put_file(&mut self, path: &str, reader: &mut R) -> Result<(), String> { - let rel = sanitize_rel_path(path)?; - - if !self.seen.insert(rel.to_string()) { - return Err(format!("Duplicate bundle path: {rel}")); - } - - let dest = self.root.join(&rel); - if let Some(parent) = dest.parent() { - fs::create_dir_all(parent).map_err(|e| e.to_string())?; - } - - let tmp = temp_path(&dest).map_err(|e| e.to_string())?; - let mut file = File::create(&tmp).map_err(|e| e.to_string())?; - - let mut hasher = sha2::Sha256::new(); - let mut buf = [0u8; 1024 * 64]; - let mut total = 0u64; - - loop { - let read = reader.read(&mut buf).map_err(|e| e.to_string())?; - if read == 0 { - break; - } - file.write_all(&buf[..read]).map_err(|e| e.to_string())?; - hasher.update(&buf[..read]); - total += read as u64; - } - - let checksum = format!("{:x}", hasher.finalize()); - - if dest.exists() { - fs::remove_file(&dest).map_err(|e| e.to_string())?; - } - - fs::rename(&tmp, &dest).map_err(|e| e.to_string())?; - - self.files.push(FsBundleFile { - rel_path: rel.to_string(), - abs_path: dest, - size_bytes: total, - checksum, - }); - - Ok(()) - } -} - -/// File descriptor emitted by a file-backed bundle sink. -#[derive(Debug, Clone)] -pub struct FsBundleFile { - /// Relative path within the bundle. - pub rel_path: String, - /// Absolute file system path for the cached file. - pub abs_path: PathBuf, - /// Size in bytes. - pub size_bytes: u64, - /// SHA-256 checksum (hex). - pub checksum: String, -} - -/// File-backed bundle reader for streaming decode. -pub struct FsBundleReader { - root: PathBuf, - files: Vec, - #[allow(unused)] - _temp: Option, -} - -impl FsBundleReader { - /// Create a file-backed bundle reader. - pub fn new(root: PathBuf, files: Vec) -> Self { - Self { - root, - files, - _temp: None, - } - } - - /// Create a temporary bundle reader that cleans up on drop. - pub fn temp(files: Vec) -> Result { - let temp = TempDir::new()?; - let root = temp.path().to_path_buf(); - Ok(Self { - root, - files, - _temp: Some(temp), - }) - } - - /// Root directory for the bundle files. - pub fn root(&self) -> &Path { - &self.root - } -} - -impl BundleSource for FsBundleReader { - fn open(&self, path: &str) -> Result, String> { - let rel = sanitize_rel_path(path)?; - let full = self.root.join(rel); - let file = File::open(full).map_err(|e| e.to_string())?; - Ok(Box::new(file)) - } - - fn list(&self) -> Result, String> { - Ok(self.files.clone()) - } -} - -fn sanitize_rel_path(path: &str) -> Result { - let normalized = normalize_bundle_path(path); - if normalized.is_empty() { - return Err("Empty bundle path".to_string()); - } - - let rel = Path::new(&normalized); - for component in rel.components() { - use std::path::Component; - match component { - Component::ParentDir - | Component::RootDir - | Component::Prefix(_) - | Component::CurDir => { - return Err(format!("Invalid bundle path: {path}")); - } - Component::Normal(_) => {} - } - } - - Ok(normalized) -} - -fn temp_path(dest: &Path) -> Result { - let file_name = dest - .file_name() - .ok_or_else(|| std::io::Error::other("Missing file name"))? - .to_string_lossy(); - Ok(dest.with_file_name(format!(".{file_name}.partial"))) -} diff --git a/crates/burn-central-core/src/bundle/mod.rs b/crates/burn-central-core/src/bundle/mod.rs deleted file mode 100644 index a74b50c6..00000000 --- a/crates/burn-central-core/src/bundle/mod.rs +++ /dev/null @@ -1,7 +0,0 @@ -mod core; -mod fs; -mod memory; - -pub use core::*; -pub use fs::*; -pub use memory::*; diff --git a/crates/burn-central-core/src/experiment/base.rs b/crates/burn-central-core/src/experiment/base.rs index 3a6e670e..8d100abe 100644 --- a/crates/burn-central-core/src/experiment/base.rs +++ b/crates/burn-central-core/src/experiment/base.rs @@ -1,11 +1,11 @@ use super::socket::ExperimentSocket; use crate::artifacts::{ArtifactKind, ExperimentArtifactClient}; -use crate::bundle::{BundleDecode, BundleEncode, FsBundleReader}; use crate::experiment::CancelToken; use crate::experiment::error::ExperimentTrackerError; use crate::experiment::log_store::TempLogStore; use crate::experiment::socket::ThreadError; use crate::schemas::ExperimentPath; +use burn_central_artifact::bundle::{BundleDecode, BundleEncode, FsBundle}; use burn_central_client::Client; pub use burn_central_client::websocket::MetricLog; use burn_central_client::websocket::{ExperimentCompletion, ExperimentMessage, InputUsed}; @@ -62,7 +62,7 @@ impl ExperimentRunHandle { pub fn load_artifact_raw( &self, name: impl AsRef, - ) -> Result { + ) -> Result { self.try_upgrade()?.load_artifact_raw(name) } @@ -206,7 +206,7 @@ impl ExperimentRunInner { pub fn load_artifact_raw( &self, name: impl AsRef, - ) -> Result { + ) -> Result { let scope = ExperimentArtifactClient::new(self.http_client.clone(), self.id.clone()); let artifact = scope.fetch(&name)?; self.send(ExperimentMessage::InputUsed(InputUsed::Artifact { diff --git a/crates/burn-central-core/src/integration/checkpoint/mod.rs b/crates/burn-central-core/src/integration/checkpoint/mod.rs index e06ca64b..04c68b67 100644 --- a/crates/burn-central-core/src/integration/checkpoint/mod.rs +++ b/crates/burn-central-core/src/integration/checkpoint/mod.rs @@ -1,12 +1,12 @@ use std::path::PathBuf; use crate::artifacts::ArtifactKind; -use crate::bundle::{BundleDecode, BundleEncode, BundleSink}; use crate::experiment::{ExperimentRun, ExperimentRunHandle}; use burn::record::{ FileRecorder, FullPrecisionSettings, NamedMpkBytesRecorder, Record, Recorder, RecorderError, }; use burn::tensor::backend::Backend; +use burn_central_artifact::bundle::{BundleDecode, BundleEncode, BundleSink, BundleSource}; use serde::Deserialize; use serde::{Serialize, de::DeserializeOwned}; @@ -67,10 +67,7 @@ where type Settings = CheckpointRecordArtifactSettings; type Error = String; - fn decode( - source: &I, - settings: &Self::Settings, - ) -> Result { + fn decode(source: &I, settings: &Self::Settings) -> Result { let mut reader = source.open(&settings.name).map_err(|e| { format!( "Failed to get reader for checkpoint artifact {}: {}", diff --git a/crates/burn-central-core/src/lib.rs b/crates/burn-central-core/src/lib.rs index fb772e37..107aa849 100644 --- a/crates/burn-central-core/src/lib.rs +++ b/crates/burn-central-core/src/lib.rs @@ -1,15 +1,20 @@ +//! This crate provides core functionalities for tracking experiments on the Burn Central platform. + mod client; +pub mod artifacts; pub mod experiment; pub mod integration; +pub mod models; mod schemas; pub use crate::client::*; pub type BurnCentralCredentials = burn_central_client::BurnCentralCredentials; pub type Env = burn_central_client::Env; - -pub mod artifacts; -pub mod bundle; -pub mod models; pub use schemas::*; + +/// This is a temporary re-export of the bundle traits for users to implement them for their artifacts. Later, these traits will be available in a separate crate `burn-central-artifact`. +pub mod bundle { + pub use burn_central_artifact::bundle::*; +} diff --git a/crates/burn-central-core/src/models/mod.rs b/crates/burn-central-core/src/models/mod.rs index a99ec787..91cd7779 100644 --- a/crates/burn-central-core/src/models/mod.rs +++ b/crates/burn-central-core/src/models/mod.rs @@ -1,6 +1,6 @@ use std::collections::BTreeMap; -use crate::bundle::{BundleDecode, InMemoryBundleReader}; +use burn_central_artifact::bundle::{BundleDecode, InMemoryBundleReader}; use burn_central_client::response::{ModelResponse, ModelVersionResponse}; use burn_central_client::{Client, ClientError}; diff --git a/crates/burn-central-fleet/src/inference.rs b/crates/burn-central-fleet/src/inference.rs index c225e053..933c301e 100644 --- a/crates/burn-central-fleet/src/inference.rs +++ b/crates/burn-central-fleet/src/inference.rs @@ -3,10 +3,10 @@ use std::time::{Duration, Instant}; use arc_swap::ArcSwapOption; use burn::prelude::Backend; +use burn_central_artifact::bundle::FsBundle; use burn_central_inference::{Inference, InferenceWriter}; use crate::FleetDeviceSession; -use crate::model::ModelSource; use crate::telemetry::{InferenceMetadata, InferenceWriterTelemetryObserver}; #[derive(Debug, thiserror::Error)] @@ -18,7 +18,7 @@ pub enum FleetManagedInferenceError { pub trait FleetManagedFactory: Send + Sync { fn build( &self, - model_source: ModelSource, + model_source: FsBundle, runtime_config: serde_json::Value, device: B::Device, ) -> Result; @@ -26,12 +26,12 @@ pub trait FleetManagedFactory: Send + Sync { impl FleetManagedFactory for F where - F: Fn(ModelSource, serde_json::Value, B::Device) -> Result + Send + Sync, + F: Fn(FsBundle, serde_json::Value, B::Device) -> Result + Send + Sync, B: Backend, { fn build( &self, - model_source: ModelSource, + model_source: FsBundle, runtime_config: serde_json::Value, device: B::Device, ) -> Result { diff --git a/crates/burn-central-fleet/src/lib.rs b/crates/burn-central-fleet/src/lib.rs index c3e72616..ed1c1a55 100644 --- a/crates/burn-central-fleet/src/lib.rs +++ b/crates/burn-central-fleet/src/lib.rs @@ -7,7 +7,6 @@ mod telemetry; pub use error::FleetError; pub use inference::{FleetManagedFactory, FleetManagedInference, FleetManagedInferenceError}; -pub use model::ModelSource; pub use session::FleetDeviceSession; pub use telemetry::{metrics_recorder, tracing_log_layer, tracing_metrics_layer}; diff --git a/crates/burn-central-fleet/src/model.rs b/crates/burn-central-fleet/src/model.rs index a159fdb6..41793d95 100644 --- a/crates/burn-central-fleet/src/model.rs +++ b/crates/burn-central-fleet/src/model.rs @@ -1,9 +1,12 @@ use std::fs; use std::io; use std::path::Path; -use std::path::PathBuf; -use burn_central_artifact::{ArtifactDownloadFile, DownloadError, download_artifacts_to_dir}; +use burn_central_artifact::bundle::FsBundle; +use burn_central_artifact::download::{ + ArtifactDownloadFile, DownloadError, download_artifacts_to_sink, +}; +use burn_central_artifact::normalize_checksum; use burn_central_client::fleet::response::FleetModelDownloadResponse; use serde::{Deserialize, Serialize}; @@ -20,20 +23,11 @@ pub enum ModelCacheError { #[error("cached model file missing: {0}")] MissingCachedFile(String), #[error(transparent)] - Registry(#[from] DownloadError), -} - -/// Source information for loading an assigned model. -#[derive(Debug, Clone)] -pub struct ModelSource { - pub root: PathBuf, - pub files: Vec, -} - -impl ModelSource { - pub fn new(root: PathBuf, files: Vec) -> Self { - Self { root, files } - } + Download(#[from] DownloadError), + #[error("invalid file path in model download manifest: {0}")] + InvalidRelPath(String), + #[error("invalid checksum in model metadata: {0}")] + InvalidChecksum(String), } #[derive(Serialize, Deserialize)] @@ -78,24 +72,34 @@ pub fn ensure_cached_model( .files .iter() .map(|f| { - ( + Ok(( f.rel_path.clone(), f.size_bytes, - normalize_checksum(&f.checksum), - ) + normalize_checksum(&f.checksum).map_err(|e| { + ModelCacheError::InvalidChecksum(format!( + "manifest checksum for {}: {}", + f.rel_path, e + )) + })?, + )) }) - .collect::>(); + .collect::, ModelCacheError>>()?; let mut download_files = download .files .iter() .map(|f| { - ( + Ok(( f.rel_path.clone(), f.size_bytes, - normalize_checksum(&f.checksum), - ) + normalize_checksum(&f.checksum).map_err(|e| { + ModelCacheError::InvalidChecksum(format!( + "download checksum for {}: {}", + f.rel_path, e + )) + })?, + )) }) - .collect::>(); + .collect::, ModelCacheError>>()?; manifest_files.sort_unstable(); download_files.sort_unstable(); @@ -118,8 +122,8 @@ pub fn ensure_cached_model( files.push(ArtifactDownloadFile { rel_path: entry.rel_path.clone(), url: entry.url.clone(), - size_bytes: entry.size_bytes, - checksum: entry.checksum.clone(), + size_bytes: Some(entry.size_bytes), + checksum: Some(entry.checksum.clone()), }); } @@ -129,7 +133,8 @@ pub fn ensure_cached_model( "new model version detected, downloading model files to local filesystem" ); - download_artifacts_to_dir(&model_root, &files)?; + let mut sink = FsBundle::create(model_root.clone()).map_err(ModelCacheError::Io)?; + download_artifacts_to_sink(&mut sink, &files)?; let manifest = ModelDownloadManifest { model_version_id: download.model_version_id.clone(), @@ -169,18 +174,10 @@ fn cached_files_present_and_sized( Ok(true) } -fn normalize_checksum(value: &str) -> String { - let trimmed = value.trim().to_ascii_lowercase(); - match trimmed.strip_prefix("sha256:") { - Some(rest) => rest.to_string(), - None => trimmed, - } -} - pub fn load_cached_model_source( models_root: &Path, model_version_id: &str, -) -> Result { +) -> Result { if model_version_id.is_empty() { tracing::error!("model version id is empty in fleet state"); return Err(ModelCacheError::MissingActiveModelVersion); @@ -223,7 +220,10 @@ pub fn load_cached_model_source( } } - Ok(ModelSource::new(model_root, files)) + let source = + FsBundle::with_files(model_root, files).map_err(ModelCacheError::InvalidRelPath)?; + + Ok(source) } fn write_manifest_if_changed( @@ -253,6 +253,7 @@ fn write_manifest_if_changed( #[cfg(test)] mod tests { use super::*; + use burn_central_client::fleet::response::FleetPresignedModelFileUrlResponse; use std::{ path::PathBuf, time::{SystemTime, UNIX_EPOCH}, @@ -310,4 +311,45 @@ mod tests { let _ = fs::remove_dir_all(root); } + + #[test] + fn manifest_with_invalid_checksum_is_rejected() { + let root = temp_path("invalid-checksum"); + let model_root = root.join("mv-1"); + fs::create_dir_all(&model_root).expect("model root should exist"); + + let manifest = ModelDownloadManifest { + model_version_id: "mv-1".to_string(), + files: vec![ModelDownloadManifestFile { + rel_path: "weights.bin".to_string(), + size_bytes: 10, + checksum: "md5:abc".to_string(), + }], + }; + let manifest_path = model_root.join("manifest.json"); + fs::write( + &manifest_path, + serde_json::to_vec_pretty(&manifest).expect("serialize manifest"), + ) + .expect("write manifest"); + + let download = FleetModelDownloadResponse { + model_version_id: "mv-1".to_string(), + files: vec![FleetPresignedModelFileUrlResponse { + rel_path: "weights.bin".to_string(), + url: "mock://weights".to_string(), + size_bytes: 10, + checksum: "sha256:00".to_string(), + }], + }; + + let err = ensure_cached_model(&root, "mv-1", &download) + .expect_err("invalid checksum metadata should fail early"); + match err { + ModelCacheError::InvalidChecksum(msg) => assert!(msg.contains("manifest checksum")), + other => panic!("unexpected error: {other:?}"), + } + + let _ = fs::remove_dir_all(root); + } } diff --git a/crates/burn-central-fleet/src/session.rs b/crates/burn-central-fleet/src/session.rs index dc9fa13f..ddfbe7fe 100644 --- a/crates/burn-central-fleet/src/session.rs +++ b/crates/burn-central-fleet/src/session.rs @@ -7,10 +7,7 @@ use burn_central_client::{Env, FleetClient}; use directories::{BaseDirs, ProjectDirs}; use crate::{ - DeviceMetadata, FleetRegistrationToken, - error::FleetError, - model::{self, ModelSource}, - state, + DeviceMetadata, FleetRegistrationToken, error::FleetError, model, state, telemetry::TelemetryPipeline, }; @@ -80,7 +77,7 @@ impl FleetDeviceSession { &self.fleet_key } - pub fn model_source(&self) -> Result { + pub fn model_source(&self) -> Result { model::load_cached_model_source( &self.store.models_dir(&self.fleet_key), self.state.active_model_version_id(), diff --git a/crates/burn-central-runtime/Cargo.toml b/crates/burn-central-runtime/Cargo.toml index 042e463b..dfb4a86e 100644 --- a/crates/burn-central-runtime/Cargo.toml +++ b/crates/burn-central-runtime/Cargo.toml @@ -16,6 +16,7 @@ anyhow.workspace = true burn-central-core.workspace = true burn-central-fleet.workspace = true burn-central-inference.workspace = true +burn-central-artifact.workspace = true burn.workspace = true thiserror.workspace = true variadics_please.workspace = true diff --git a/crates/burn-central-runtime/src/executor.rs b/crates/burn-central-runtime/src/executor.rs index 5d698bdd..97e9dd82 100644 --- a/crates/burn-central-runtime/src/executor.rs +++ b/crates/burn-central-runtime/src/executor.rs @@ -295,7 +295,7 @@ mod test { use burn::backend::{Autodiff, NdArray}; use burn::nn::{Linear, LinearConfig}; use burn::prelude::*; - use burn_central_core::bundle::{BundleEncode, BundleSink}; + use burn_central_artifact::bundle::{BundleEncode, BundleSink}; use serde::{Deserialize, Serialize}; impl ExecutorBuilder { diff --git a/crates/burn-central-runtime/src/inference/fleet.rs b/crates/burn-central-runtime/src/inference/fleet.rs index 8414e754..0c6323ad 100644 --- a/crates/burn-central-runtime/src/inference/fleet.rs +++ b/crates/burn-central-runtime/src/inference/fleet.rs @@ -38,11 +38,9 @@ where let error_name = inference_name.clone(); let inference_factory: Box> = Box::new( - move |model_source: burn_central_fleet::ModelSource, - runtime_config: serde_json::Value, - device: B::Device| { + move |model_source, runtime_config: serde_json::Value, device: B::Device| { let init = InferenceInit { - model: ModelSource::from(model_source), + model: Some(ModelSource::from(model_source)).into(), device, }; let mut ctx = InferenceContext::new(init, InferenceArgs::new(Some(runtime_config))); diff --git a/crates/burn-central-runtime/src/inference/registry.rs b/crates/burn-central-runtime/src/inference/registry.rs index 81a93e1c..b330e897 100644 --- a/crates/burn-central-runtime/src/inference/registry.rs +++ b/crates/burn-central-runtime/src/inference/registry.rs @@ -3,13 +3,13 @@ use crate::params::args::{LaunchArgs, deserialize_and_merge_with_default}; use crate::routine::{BoxedRoutine, IntoRoutine}; use crate::{Args, MultiDevice}; use burn::prelude::Backend; -use burn_central_core::bundle::BundleDecode; +use burn_central_artifact::bundle::{BundleDecode, FsBundle}; use burn_central_inference::{ErasedInference, Inference, JsonInference}; use derive_more::{Deref, From}; use serde::{Serialize, de::DeserializeOwned}; +use std::cell::RefCell; use std::collections::HashMap; use std::marker::PhantomData; -use std::path::PathBuf; #[derive(Debug, thiserror::Error)] pub enum InferenceError { @@ -20,23 +20,21 @@ pub enum InferenceError { } /// Runtime wrapper around fleet model sources to support routine param injection. -#[derive(Debug, Clone, Deref, From)] -pub struct ModelSource(burn_central_fleet::ModelSource); +#[derive(Debug, Deref, From)] +pub struct ModelSource(FsBundle); impl ModelSource { - pub fn new(root: PathBuf, files: Vec) -> Self { - Self(burn_central_fleet::ModelSource::new(root, files)) + pub fn new(source: FsBundle) -> Self { + Self(source) } pub fn load(&self, settings: &D::Settings) -> Result { - let reader = - burn_central_core::bundle::FsBundleReader::new(self.root.clone(), self.files.clone()); - D::decode(&reader, settings) + D::decode(&self.0, settings) } } pub struct InferenceInit { - pub model: ModelSource, + pub model: RefCell>, pub device: B::Device, } @@ -100,8 +98,11 @@ impl InferenceContext { self.args.merged_args_or_default() } - pub fn model(&self) -> &ModelSource { - &self.init.model + pub fn model(&self) -> ModelSource { + self.init + .model + .take() + .expect("model source should be set in inference context") } pub fn device(&self) -> &B::Device { @@ -113,7 +114,7 @@ impl RoutineParam> for ModelSource { type Item<'new> = ModelSource; fn try_retrieve(ctx: &InferenceContext) -> anyhow::Result> { - Ok(ctx.model().clone()) + Ok(ctx.model()) } } diff --git a/crates/burn-central-runtime/src/output.rs b/crates/burn-central-runtime/src/output.rs index 245a3fdc..bc67fa8d 100644 --- a/crates/burn-central-runtime/src/output.rs +++ b/crates/burn-central-runtime/src/output.rs @@ -1,8 +1,8 @@ use crate::executor::ExecutionContext; use crate::params::default::Model; use burn::prelude::Backend; +use burn_central_artifact::bundle::BundleEncode; use burn_central_core::artifacts::ArtifactKind; -use burn_central_core::bundle::BundleEncode; use std::fmt::Display; /// This trait defines how a specific return type (Output) from a handler apply its effects to the execution context. diff --git a/crates/burn-central-runtime/src/params/artifact_loader.rs b/crates/burn-central-runtime/src/params/artifact_loader.rs index 1670a63f..a80e614a 100644 --- a/crates/burn-central-runtime/src/params/artifact_loader.rs +++ b/crates/burn-central-runtime/src/params/artifact_loader.rs @@ -1,9 +1,9 @@ use crate::executor::ExecutionContext; use crate::params::RoutineParam; use burn::prelude::Backend; +use burn_central_artifact::bundle::BundleDecode; use burn_central_core::BurnCentral; use burn_central_core::artifacts::ArtifactError; -use burn_central_core::bundle::BundleDecode; /// Artifact loader for loading artifacts from Burn Central. It allow to fecth for instance other /// experiment endpoint to be able to restart from a certain point your experiment. @@ -14,7 +14,7 @@ use burn_central_core::bundle::BundleDecode; /// /// ```ignore /// # use burn_central_runtime::ArtifactLoader; -/// # use burn_central_core::bundle::BundleDecode; +/// # use burn_central_artifact::bundle::BundleDecode; /// # use burn_central::register; /// # use burn_central_runtime::Model; /// # use burn_central_runtime::MultiDevice;