Skip to content

Commit 8f4045a

Browse files
committed
feat: add burn-central-artifact crate with download functionality and validation tools
1 parent 00af202 commit 8f4045a

13 files changed

Lines changed: 344 additions & 12 deletions

File tree

Cargo.lock

Lines changed: 14 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,14 +60,15 @@ opentelemetry = { version = "0.31.0", features = ["metrics"] }
6060
burn-central-client = { version = "0.5.0", path = "../burn-central-client/burn-central-client" }
6161
# burn-central-client = "0.5.0"
6262

63-
burn-central-registry = { path = "crates/burn-central-registry", version = "0.5.0" }
6463

6564
## Crate
6665
burn-central-core = { path = "crates/burn-central-core", version = "0.5.0" }
6766
burn-central-runtime = { path = "crates/burn-central-runtime", version = "0.5.0" }
6867
burn-central-macros = { path = "crates/burn-central-macros", version = "0.5.0" }
6968
burn-central-inference = { path = "crates/burn-central-inference", version = "0.5.0" }
7069
burn-central-fleet = { path = "crates/burn-central-fleet", version = "0.5.0" }
70+
burn-central-artifact = { path = "crates/burn-central-artifact", version = "0.5.0" }
71+
burn-central-registry = { path = "crates/burn-central-registry", version = "0.5.0" }
7172

7273
### For xtask crate ###
7374
tracel-xtask = "4.5.0"
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
[package]
2+
name = "burn-central-artifact"
3+
edition.workspace = true
4+
version.workspace = true
5+
readme.workspace = true
6+
license.workspace = true
7+
rust-version.workspace = true
8+
authors.workspace = true
9+
repository.workspace = true
10+
keywords.workspace = true
11+
categories.workspace = true
12+
13+
[dependencies]
14+
crossbeam = { workspace = true }
15+
reqwest = { version = "0.13.2", features = ["blocking"] }
16+
serde = { workspace = true, features = ["derive"] }
17+
sha2 = { workspace = true }
18+
serde_json = { workspace = true }
19+
burn-central-core = { workspace = true }
20+
thiserror = { workspace = true }
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
use std::fs::{self, File};
2+
use std::io::BufWriter;
3+
use std::path::{Path, PathBuf};
4+
5+
use burn_central_core::bundle::normalize_bundle_path;
6+
7+
use crate::download::{DownloadError, DownloadTask, download_tasks};
8+
use crate::tools::path::safe_join;
9+
10+
/// Generic download descriptor for any model artifact file.
11+
#[derive(Debug, Clone)]
12+
pub struct ArtifactDownloadFile {
13+
pub rel_path: String,
14+
pub url: String,
15+
pub size_bytes: u64,
16+
pub checksum: String,
17+
}
18+
19+
/// Download artifact files into a destination directory, validating size and checksum.
20+
pub fn download_artifacts_to_dir(
21+
dest_root: &Path,
22+
files: &[ArtifactDownloadFile],
23+
) -> Result<(), DownloadError> {
24+
fs::create_dir_all(dest_root)?;
25+
26+
if files.is_empty() {
27+
return Ok(());
28+
}
29+
30+
let mut tmps = Vec::with_capacity(files.len());
31+
let mut tasks = Vec::with_capacity(files.len());
32+
for file in files {
33+
let rel_path = normalize_bundle_path(&file.rel_path);
34+
let dest = safe_join(dest_root, &rel_path)
35+
.map_err(|e| DownloadError::InvalidPath(e.to_string()))?;
36+
37+
if let Some(parent) = dest.parent() {
38+
fs::create_dir_all(parent)?;
39+
}
40+
let tmp = temp_path(&dest)?;
41+
tmps.push((dest.clone(), tmp.clone()));
42+
43+
let dest_file = File::create(dest)?;
44+
let writer = BufWriter::new(dest_file);
45+
46+
tasks.push(DownloadTask {
47+
rel_path: rel_path.clone(),
48+
url: file.url.clone(),
49+
writer,
50+
expected_size: file.size_bytes,
51+
expected_checksum: file.checksum.clone(),
52+
});
53+
}
54+
55+
let parallelism = std::thread::available_parallelism()
56+
.map(|n| n.get())
57+
.unwrap_or(4);
58+
let http = reqwest::blocking::Client::new();
59+
let res = download_tasks(&http, tasks, parallelism);
60+
61+
for (tmp_dest, tmp) in tmps {
62+
if tmp_dest.exists() {
63+
fs::remove_file(&tmp_dest)?;
64+
}
65+
if tmp.exists() {
66+
fs::rename(tmp, tmp_dest)?;
67+
}
68+
}
69+
70+
res
71+
}
72+
73+
/// Generate a temporary file path for downloads.
74+
fn temp_path(dest: &Path) -> Result<PathBuf, DownloadError> {
75+
let file_name = dest
76+
.file_name()
77+
.ok_or_else(|| DownloadError::InvalidPath("missing file name".to_string()))?
78+
.to_string_lossy();
79+
Ok(dest.with_file_name(format!(".{file_name}.partial")))
80+
}
Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
use std::io::{Read, Write};
2+
3+
use crossbeam::channel;
4+
use reqwest::blocking::Client as HttpClient;
5+
use sha2::Digest;
6+
7+
use crate::tools::validation::normalize_checksum;
8+
9+
#[derive(Debug, thiserror::Error)]
10+
pub enum DownloadError {
11+
#[error("failed to download {path}: {details}")]
12+
DownloadFailed { path: String, details: String },
13+
#[error("size mismatch for {path}: expected {expected} bytes, got {actual} bytes")]
14+
SizeMismatch {
15+
path: String,
16+
expected: u64,
17+
actual: u64,
18+
},
19+
#[error("checksum mismatch for {path}: expected {expected}, got {actual}")]
20+
ChecksumMismatch {
21+
path: String,
22+
expected: String,
23+
actual: String,
24+
},
25+
#[error("invalid checksum: {0}")]
26+
InvalidChecksum(String),
27+
#[error("writer error: {0}")]
28+
WriterError(#[from] std::io::Error),
29+
#[error("invalid path: {0}")]
30+
InvalidPath(String),
31+
}
32+
33+
/// A single file download task.
34+
#[derive(Clone)]
35+
pub struct DownloadTask<W> {
36+
pub rel_path: String,
37+
pub url: String,
38+
pub writer: W,
39+
pub expected_size: u64,
40+
pub expected_checksum: String,
41+
}
42+
43+
/// Download multiple files in parallel.
44+
pub fn download_tasks<W: Write + Send>(
45+
http: &HttpClient,
46+
tasks: Vec<DownloadTask<W>>,
47+
max_parallel: usize,
48+
) -> Result<(), DownloadError> {
49+
if tasks.is_empty() {
50+
return Ok(());
51+
}
52+
53+
if max_parallel <= 1 || tasks.len() == 1 {
54+
for mut task in tasks {
55+
download_one(http, &mut task)?;
56+
}
57+
return Ok(());
58+
}
59+
60+
let (tx, rx) = channel::unbounded::<DownloadTask<W>>();
61+
for task in tasks {
62+
tx.send(task).expect("channel open");
63+
}
64+
drop(tx);
65+
66+
crossbeam::scope(|scope| {
67+
let mut handles = Vec::new();
68+
let worker_count = max_parallel.min(rx.len().max(1));
69+
for _ in 0..worker_count {
70+
let rx = rx.clone();
71+
let http = http.clone();
72+
handles.push(scope.spawn(move |_| {
73+
for mut task in rx.iter() {
74+
download_one(&http, &mut task)?;
75+
}
76+
Ok::<(), DownloadError>(())
77+
}));
78+
}
79+
80+
for handle in handles {
81+
handle.join().expect("thread panicked")?;
82+
}
83+
84+
Ok(())
85+
})
86+
.expect("scope failed")
87+
}
88+
89+
/// Download a single file with checksum verification.
90+
fn download_one<W: Write>(
91+
http: &HttpClient,
92+
task: &mut DownloadTask<W>,
93+
) -> Result<(), DownloadError> {
94+
// if let Some(parent) = task.dest.parent() {
95+
// fs::create_dir_all(parent)?;
96+
// }
97+
98+
// let tmp = temp_path(&task.dest)?;
99+
100+
let mut resp = http
101+
.get(&task.url)
102+
.send()
103+
.map_err(|e| DownloadError::DownloadFailed {
104+
path: task.rel_path.clone(),
105+
details: e.to_string(),
106+
})?;
107+
108+
if !resp.status().is_success() {
109+
return Err(DownloadError::DownloadFailed {
110+
path: task.rel_path.clone(),
111+
details: format!("HTTP {}", resp.status()),
112+
});
113+
}
114+
115+
let sink = &mut task.writer;
116+
let mut hasher = sha2::Sha256::new();
117+
let mut buf = [0u8; 1024 * 64];
118+
let mut total = 0u64;
119+
120+
loop {
121+
let read = resp.read(&mut buf)?;
122+
if read == 0 {
123+
break;
124+
}
125+
sink.write_all(&buf[..read])?;
126+
hasher.update(&buf[..read]);
127+
total += read as u64;
128+
}
129+
130+
let digest = format!("{:x}", hasher.finalize());
131+
let expected_checksum =
132+
normalize_checksum(&task.expected_checksum).map_err(DownloadError::InvalidChecksum)?;
133+
134+
if total != task.expected_size {
135+
return Err(DownloadError::SizeMismatch {
136+
path: task.rel_path.clone(),
137+
expected: task.expected_size,
138+
actual: total,
139+
});
140+
}
141+
if digest != expected_checksum {
142+
return Err(DownloadError::ChecksumMismatch {
143+
path: task.rel_path.clone(),
144+
expected: expected_checksum,
145+
actual: digest,
146+
});
147+
}
148+
149+
// if task.dest.exists() {
150+
// fs::remove_file(&task.dest)?;
151+
// }
152+
153+
// fs::rename(tmp, &task.dest)?;
154+
155+
Ok(())
156+
}
157+
158+
// /// Generate a temporary file path for downloads.
159+
// fn temp_path(dest: &Path) -> Result<PathBuf, RegistryError> {
160+
// let file_name = dest
161+
// .file_name()
162+
// .ok_or_else(|| RegistryError::InvalidPath("missing file name".to_string()))?
163+
// .to_string_lossy();
164+
// Ok(dest.with_file_name(format!(".{file_name}.partial")))
165+
// }
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
//! This crate centralizes traits, structures and utilities for handling artifacts and models in Burn Central.
2+
3+
mod artifact_download;
4+
mod download;
5+
mod tools;
6+
7+
pub use artifact_download::{ArtifactDownloadFile, download_artifacts_to_dir};
8+
pub use download::DownloadError;
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
pub mod path;
2+
pub mod validation;
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
use std::path::{Path, PathBuf};
2+
3+
use burn_central_core::bundle::normalize_bundle_path;
4+
5+
/// Sanitize a relative path to prevent directory traversal attacks.
6+
pub fn sanitize_rel_path(path: &str) -> Result<PathBuf, String> {
7+
let normalized = normalize_bundle_path(path);
8+
let rel = Path::new(&normalized);
9+
for component in rel.components() {
10+
use std::path::Component;
11+
match component {
12+
Component::ParentDir | Component::RootDir | Component::Prefix(_) => {
13+
return Err(format!("invalid path component: {path}"));
14+
}
15+
Component::CurDir => {
16+
return Err(format!("invalid path component: {path}"));
17+
}
18+
Component::Normal(_) => {}
19+
}
20+
}
21+
Ok(PathBuf::from(normalized))
22+
}
23+
24+
/// Safely join a root path with a relative path.
25+
pub fn safe_join(root: &Path, rel: &str) -> Result<PathBuf, String> {
26+
let rel = sanitize_rel_path(rel)?;
27+
Ok(root.join(rel))
28+
}
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
/// Normalize a checksum string (strip prefixes, lowercase).
2+
pub fn normalize_checksum(value: &str) -> Result<String, String> {
3+
let trimmed = value.trim();
4+
if trimmed.is_empty() {
5+
return Err("checksum is empty".to_string());
6+
}
7+
let lower = trimmed.to_ascii_lowercase();
8+
if let Some(rest) = lower.strip_prefix("sha256:") {
9+
return Ok(rest.to_string());
10+
}
11+
if lower.contains(':') {
12+
return Err(format!("unsupported checksum format: {trimmed}"));
13+
}
14+
Ok(lower)
15+
}

crates/burn-central-fleet/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ rust-version.workspace = true
1515
burn.workspace = true
1616
burn-central-client.workspace = true
1717
burn-central-inference.workspace = true
18-
burn-central-registry.workspace = true
18+
burn-central-artifact.workspace = true
1919
thiserror.workspace = true
2020
serde.workspace = true
2121
serde_json.workspace = true

0 commit comments

Comments
 (0)