diff --git a/Cargo.lock b/Cargo.lock index 52bc191..5443637 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -142,6 +142,20 @@ dependencies = [ "tokio", ] +[[package]] +name = "async-compression" +version = "0.4.25" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "40f6024f3f856663b45fd0c9b6f2024034a702f453549449e0d84a305900dad4" +dependencies = [ + "futures-core", + "memchr", + "pin-project-lite", + "tokio", + "zstd", + "zstd-safe", +] + [[package]] name = "async-trait" version = "0.1.86" @@ -344,6 +358,8 @@ version = "1.2.14" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0c3d1b2e905a3a7b00a6141adb0e4c0bb941d11caf55349d863942a1cc44e3c9" dependencies = [ + "jobserver", + "libc", "shlex", ] @@ -2057,6 +2073,16 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8eaf4bc02d17cbdd7ff4c7438cafcdf7fb9a4613313ad11b4f8fefe7d3fa0130" +[[package]] +name = "jobserver" +version = "0.1.33" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38f262f097c174adebe41eb73d66ae9c06b2844fb0da69969647bbddd9b0538a" +dependencies = [ + "getrandom 0.3.3", + "libc", +] + [[package]] name = "js-sys" version = "0.3.77" @@ -2744,6 +2770,12 @@ dependencies = [ "spki", ] +[[package]] +name = "pkg-config" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c" + [[package]] name = "pnet_base" version = "0.34.0" @@ -3535,6 +3567,7 @@ name = "sendme" version = "0.26.0" dependencies = [ "anyhow", + "async-compression", "base64", "clap", "console", @@ -3554,6 +3587,7 @@ dependencies = [ "serde_json", "tempfile", "tokio", + "tokio-util", "tracing", "tracing-subscriber", "walkdir", @@ -5287,3 +5321,31 @@ dependencies = [ "quote", "syn 2.0.98", ] + +[[package]] +name = "zstd" +version = "0.13.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e91ee311a569c327171651566e07972200e76fcfe2242a4fa446149a3881c08a" +dependencies = [ + "zstd-safe", +] + +[[package]] +name = "zstd-safe" +version = "7.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f49c4d5f0abb602a93fb8736af2a4f4dd9512e36f7f570d66e65ff867ed3b9d" +dependencies = [ + "zstd-sys", +] + +[[package]] +name = "zstd-sys" +version = "2.0.15+zstd.1.5.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eb81183ddd97d0c74cedf1d50d85c8d08c1b8b68ee863bdee9e706eedba1a237" +dependencies = [ + "cc", + "pkg-config", +] diff --git a/Cargo.toml b/Cargo.toml index e5ea2c0..f9170d8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -37,6 +37,8 @@ data-encoding = "2.6.0" n0-future = "0.1.2" base64 = { version = "0.22.1", optional = true } hex = "0.4.3" +async-compression = { version = "0.4.25", features = ["tokio", "zstd"], optional = true } +tokio-util = { version = "0.7.15",optional = true } [dev-dependencies] duct = "0.13.6" @@ -47,8 +49,9 @@ tempfile = "3.8.1" [features] clipboard = ["dep:base64"] -default = ["clipboard"] +zstd = ["async-compression","tokio-util"] +default = ["clipboard","zstd"] [patch.crates-io] iroh = { git = "https://github.com/n0-computer/iroh.git", branch = "main" } -iroh-blobs = { git = "https://github.com/n0-computer/iroh-blobs.git", branch = "main" } +iroh-blobs = { git = "https://github.com/n0-computer/iroh-blobs.git", branch = "main" } \ No newline at end of file diff --git a/src/main.rs b/src/main.rs index b562069..f8df3b0 100644 --- a/src/main.rs +++ b/src/main.rs @@ -10,6 +10,10 @@ use std::{ }; use anyhow::Context; +#[cfg(feature = "zstd")] +use async_compression::tokio::bufread::{ZstdDecoder, ZstdEncoder}; +#[cfg(feature = "zstd")] +use async_compression::Level; use clap::{ error::{ContextKind, ErrorKind}, CommandFactory, Parser, Subcommand, @@ -24,11 +28,15 @@ use iroh::{ discovery::{dns::DnsDiscovery, pkarr::PkarrPublisher}, Endpoint, NodeAddr, RelayMode, RelayUrl, SecretKey, Watcher, }; +#[cfg(feature = "zstd")] +use iroh_blobs::api::blobs::EncodedItem; +#[cfg(feature = "zstd")] +use iroh_blobs::protocol::ChunkRanges; use iroh_blobs::{ api::{ blobs::{ - AddPathOptions, AddProgressItem, ExportMode, ExportOptions, ExportProgressItem, - ImportMode, + AddPathOptions, AddProgress, AddProgressItem, ExportMode, ExportOptions, + ExportProgressItem, ImportMode, }, remote::GetProgressItem, Store, TempTag, @@ -44,7 +52,14 @@ use iroh_blobs::{ use n0_future::{task::AbortOnDropHandle, StreamExt}; use rand::Rng; use serde::{Deserialize, Serialize}; + +#[cfg(feature = "zstd")] +use tokio::fs::{create_dir_all, File}; +#[cfg(feature = "zstd")] +use tokio::io::{BufReader, BufWriter}; use tokio::{select, sync::mpsc}; +#[cfg(feature = "zstd")] +use tokio_util::io::{ReaderStream, StreamReader}; use tracing::{error, trace}; use walkdir::WalkDir; @@ -142,6 +157,20 @@ pub struct CommonArgs { #[clap(long)] pub show_secret: bool, + + /// Use zstd to compress outgoing and decompress incoming data + #[cfg(feature = "zstd")] + #[clap(short = 'z', long)] + pub zstd: bool, + + /// Compression level for zstd + #[cfg(feature = "zstd")] + #[clap(short = 'q', long, default_value_t = 3, requires("zstd"))] + pub compression_quality: u8, + + #[cfg(not(feature = "zstd"))] + #[clap(short = 'z', long, hide = true)] + pub zstd: bool, } /// Available command line options for configuring relays. @@ -355,6 +384,8 @@ async fn import( path: PathBuf, db: &Store, mp: &mut MultiProgress, + _do_compress: bool, + _compression_level: u8, ) -> anyhow::Result<(TempTag, u64, Collection)> { let parallelism = num_cpus::get(); let path = path.canonicalize()?; @@ -391,11 +422,36 @@ async fn import( op.inc(1); let pb = mp.add(make_import_item_progress()); pb.set_message(format!("copying {name}")); - let import = db.add_path_with_opts(AddPathOptions { - path, - mode: ImportMode::TryReference, - format: BlobFormat::Raw, - }); + let import: AddProgress; + + #[cfg(feature = "zstd")] + if _do_compress { + let file_stream = File::open(&path).await?; + pb.set_message(format!("Compressing {name}")); + pb.set_length(file_stream.metadata().await?.len()); + let reader = BufReader::new(file_stream); + let encoder = + ZstdEncoder::with_quality(reader, Level::Precise(_compression_level as _)); + + let compressed_stream = ReaderStream::new(encoder); + import = db.add_stream(compressed_stream).await; + } else { + import = db.add_path_with_opts(AddPathOptions { + path, + mode: ImportMode::TryReference, + format: BlobFormat::Raw, + }); + } + + #[cfg(not(feature = "zstd"))] + { + import = db.add_path_with_opts(AddPathOptions { + path, + mode: ImportMode::TryReference, + format: BlobFormat::Raw, + }); + } + let mut stream = import.stream().await; let mut item_size = 0; let temp_tag = loop { @@ -464,13 +520,75 @@ fn get_export_path(root: &Path, name: &str) -> anyhow::Result { Ok(path) } -async fn export(db: &Store, collection: Collection, mp: &mut MultiProgress) -> anyhow::Result<()> { +async fn export_single_file( + db: &Store, + mp: &MultiProgress, + hash: &Hash, + target: PathBuf, + name: &String, +) -> anyhow::Result<()> { + let mut stream = db + .export_with_opts(ExportOptions { + hash: *hash, + target: target.clone(), + mode: ExportMode::TryReference, + }) + .stream() + .await; + + let pb = mp.add(make_export_item_progress()); + pb.set_message(format!("exporting {name}")); + + while let Some(item) = stream.next().await { + match item { + ExportProgressItem::Size(size) => { + pb.set_length(size); + } + ExportProgressItem::CopyProgress(offset) => { + pb.set_position(offset); + } + ExportProgressItem::Done => { + pb.finish_and_clear(); + } + ExportProgressItem::Error(cause) => { + pb.finish_and_clear(); + anyhow::bail!("error exporting {}: {}", name, cause); + } + } + } + + Ok(()) +} + +async fn export( + db: &Store, + collection: Collection, + mp: &mut MultiProgress, + _decompress: bool, + _postfix_target: bool, +) -> anyhow::Result<()> { let root = std::env::current_dir()?; let op = mp.add(make_export_overall_progress()); op.set_length(collection.len() as u64); for (i, (name, hash)) in collection.iter().enumerate() { op.set_position(i as u64); let target = get_export_path(&root, name)?; + + #[cfg(not(feature = "zstd"))] + let target = if _postfix_target { + let file_name = target + .file_name() + .and_then(|n| n.to_str()) + .map(|n| format!("{}.zst", n)) + .ok_or_else(|| { + std::io::Error::new(std::io::ErrorKind::Other, "Invalid file name") + })?; + + target.with_file_name(file_name) + } else { + target + }; + if target.exists() { eprintln!( "target {} already exists. Export stopped.", @@ -479,32 +597,47 @@ async fn export(db: &Store, collection: Collection, mp: &mut MultiProgress) -> a eprintln!("You can remove the file or directory and try again. The download will not be repeated."); anyhow::bail!("target {} already exists", target.display()); } - let mut stream = db - .export_with_opts(ExportOptions { - hash: *hash, - target, - mode: ExportMode::TryReference, - }) - .stream() - .await; - let pb = mp.add(make_export_item_progress()); - pb.set_message(format!("exporting {name}")); - while let Some(item) = stream.next().await { - match item { - ExportProgressItem::Size(size) => { - pb.set_length(size); - } - ExportProgressItem::CopyProgress(offset) => { - pb.set_position(offset); - } - ExportProgressItem::Done => { - pb.finish_and_clear(); - } - ExportProgressItem::Error(cause) => { - pb.finish_and_clear(); - anyhow::bail!("error exporting {}: {}", name, cause); - } + + #[cfg(feature = "zstd")] + if _decompress { + let pb = mp.add(make_export_item_progress()); + pb.set_message(format!("Decompressing {name}")); + let byte_stream = db + .export_bao(*hash, ChunkRanges::all()) + .stream() + .inspect(|res| match res { + EncodedItem::Size(size) => { + pb.set_length(*size); + } + EncodedItem::Leaf(leaf) => { + pb.set_position(leaf.offset); + } + EncodedItem::Done => { + pb.finish_and_clear(); + } + _ => {} + }) + .filter_map(|res| match res { + EncodedItem::Leaf(leaf) => Some(Ok(leaf.data)), + EncodedItem::Error(err) => Some(Err(tokio::io::Error::other(err.to_string()))), + _ => None, + }); + + let reader = StreamReader::new(byte_stream); + let mut decoder = ZstdDecoder::new(reader); + if let Some(parent) = target.parent() { + create_dir_all(parent).await?; } + let target_file = File::create(&target).await?; + let mut output_writer = BufWriter::new(target_file); + tokio::io::copy(&mut decoder, &mut output_writer).await?; + } else { + export_single_file(db, mp, hash, target, name).await?; + } + + #[cfg(not(feature = "zstd"))] + { + export_single_file(db, mp, hash, target, name).await?; } } op.finish_and_clear(); @@ -567,7 +700,7 @@ async fn show_provide_progress( ProgressStyle::with_template( "{msg}{spinner:.green} [{elapsed_precise}] [{wide_bar:.cyan/blue}] {bytes}/{total_bytes}", )? - .progress_chars("#>-"), + .progress_chars("#>-"), ); pb.set_message(format!("{request_id} {hash}")); let Some(connection) = connections.get_mut(&connection_id) else { @@ -640,6 +773,25 @@ async fn show_provide_progress( Ok(()) } +fn zstd_enabled(zstd_requested: bool, _is_sending: bool) -> bool { + #[cfg(feature = "zstd")] + return zstd_requested; + + #[cfg(not(feature = "zstd"))] + { + if zstd_requested { + if _is_sending { + eprintln!( + "Warning: --zstd ignored (no support in this build). Sending uncompressed." + ); + } else { + eprintln!("Warning: This build does not support zstd decompression. Files will be saved with a `.zst` extension. You can manually decompress them using `unzstd `."); + } + } + return false; + } +} + async fn send(args: SendArgs) -> anyhow::Result<()> { let secret_key = get_or_create_secret(args.common.verbose > 0)?; if args.common.show_secret { @@ -672,6 +824,7 @@ async fn send(args: SendArgs) -> anyhow::Result<()> { ); std::process::exit(1); } + let do_compress = zstd_enabled(args.common.zstd, true); let mut mp = MultiProgress::new(); let mp2 = mp.clone(); @@ -697,7 +850,20 @@ async fn send(args: SendArgs) -> anyhow::Result<()> { let store = FsStore::load(&blobs_data_dir2).await?; let blobs = Blobs::new(&store, endpoint.clone(), Some(progress_tx)); - let import_result = import(path2, blobs.store(), &mut mp).await?; + #[cfg(feature = "zstd")] + let compression_quality = args.common.compression_quality.clamp(1, 22); + + #[cfg(not(feature = "zstd"))] + let compression_quality = 0; + + let import_result = import( + path2, + blobs.store(), + &mut mp, + do_compress, + compression_quality, + ) + .await?; let dt = t0.elapsed(); let router = iroh::protocol::Router::builder(endpoint) @@ -739,7 +905,11 @@ async fn send(args: SendArgs) -> anyhow::Result<()> { } println!("to get this data, use"); - println!("sendme receive {ticket}"); + println!( + "sendme receive{} {}", + if do_compress { " -z" } else { "" }, + ticket + ); #[cfg(feature = "clipboard")] { @@ -747,7 +917,7 @@ async fn send(args: SendArgs) -> anyhow::Result<()> { // Add command to the clipboard if args.clipboard { - add_to_clipboard(&ticket); + add_to_clipboard(&ticket, do_compress); } let _keyboard = tokio::task::spawn(async move { @@ -755,7 +925,7 @@ async fn send(args: SendArgs) -> anyhow::Result<()> { println!("press c to copy command to clipboard, or use the --clipboard argument"); loop { if let Ok(Key::Char('c')) = term.read_key() { - add_to_clipboard(&ticket); + add_to_clipboard(&ticket, do_compress); } } }); @@ -777,7 +947,7 @@ async fn send(args: SendArgs) -> anyhow::Result<()> { } #[cfg(feature = "clipboard")] -fn add_to_clipboard(ticket: &BlobTicket) { +fn add_to_clipboard(ticket: &BlobTicket, add_decompress_tag: bool) { use std::io::{stdout, Write}; use base64::prelude::{Engine, BASE64_STANDARD}; @@ -785,7 +955,10 @@ fn add_to_clipboard(ticket: &BlobTicket) { // Use OSC 52 to copy content to clipboard. print!( "\x1B]52;c;{}\x07", - BASE64_STANDARD.encode(format!("sendme receive {ticket}")) + BASE64_STANDARD.encode(format!( + "sendme receive{} {ticket}", + if add_decompress_tag { " -z" } else { "" } + )) ); stdout() @@ -875,8 +1048,8 @@ fn make_export_item_progress() -> ProgressBar { ProgressStyle::with_template( "{msg}{spinner:.green} [{elapsed_precise}] [{wide_bar:.cyan/blue}] {bytes}/{total_bytes}", ) - .unwrap() - .progress_chars("#>-"), + .unwrap() + .progress_chars("#>-"), ); pb } @@ -940,6 +1113,9 @@ async fn receive(args: ReceiveArgs) -> anyhow::Result<()> { let iroh_data_dir = std::env::current_dir()?.join(dir_name); let db = iroh_blobs::store::fs::FsStore::load(&iroh_data_dir).await?; let db2 = db.clone(); + + let do_decompress = zstd_enabled(args.common.zstd, false); + trace!("load done!"); let fut = async move { trace!("running"); @@ -1026,10 +1202,25 @@ async fn receive(args: ReceiveArgs) -> anyhow::Result<()> { } if let Some((name, _)) = collection.iter().next() { if let Some(first) = name.split('/').next() { - println!("exporting to {first}"); + println!( + "exporting to {first}{}", + if do_decompress != args.common.zstd && collection.len() == 1 { + ".zst" + } else { + "" + } + ); } } - export(&db, collection, &mut mp).await?; + + export( + &db, + collection, + &mut mp, + do_decompress, + do_decompress != args.common.zstd, + ) + .await?; anyhow::Ok((total_files, payload_size, stats)) }; let (total_files, payload_size, stats) = select! {