diff --git a/src/downloader.rs b/src/downloader.rs index 512bb3ab..e6a04508 100644 --- a/src/downloader.rs +++ b/src/downloader.rs @@ -140,7 +140,7 @@ pub enum GetOutput { } /// Concurrency limits for the [`Downloader`]. -#[derive(Debug)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct ConcurrencyLimits { /// Maximum number of requests the service performs concurrently. pub max_concurrent_requests: usize, @@ -192,7 +192,7 @@ impl ConcurrencyLimits { } /// Configuration for retry behavior of the [`Downloader`]. -#[derive(Debug)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct RetryConfig { /// Maximum number of retry attempts for a node that failed to dial or failed with IO errors. pub max_retries_per_node: u32, @@ -324,13 +324,29 @@ impl Future for DownloadHandle { } } +/// All numerical config options for the downloader. +#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)] +pub struct Config { + /// Concurrency limits for the downloader. + pub concurrency: ConcurrencyLimits, + /// Retry configuration for the downloader. + pub retry: RetryConfig, +} + /// Handle for the download services. -#[derive(Clone, Debug)] +#[derive(Debug, Clone)] pub struct Downloader { + inner: Arc, +} + +#[derive(Debug)] +struct Inner { /// Next id to use for a download intent. - next_id: Arc, + next_id: AtomicU64, /// Channel to communicate with the service. msg_tx: mpsc::Sender, + /// Configuration for the downloader. + config: Arc, metrics: Arc, } @@ -340,54 +356,48 @@ impl Downloader { where S: Store, { - Self::with_config(store, endpoint, rt, Default::default(), Default::default()) + Self::with_config(store, endpoint, rt, Default::default()) } /// Create a new Downloader with custom [`ConcurrencyLimits`] and [`RetryConfig`]. - pub fn with_config( - store: S, - endpoint: Endpoint, - rt: LocalPoolHandle, - concurrency_limits: ConcurrencyLimits, - retry_config: RetryConfig, - ) -> Self + pub fn with_config(store: S, endpoint: Endpoint, rt: LocalPoolHandle, config: Config) -> Self where S: Store, { let metrics = Arc::new(Metrics::default()); + let metrics2 = metrics.clone(); let me = endpoint.node_id().fmt_short(); let (msg_tx, msg_rx) = mpsc::channel(SERVICE_CHANNEL_CAPACITY); let dialer = Dialer::new(endpoint); - - let metrics_clone = metrics.clone(); + let config = Arc::new(config); + let config2 = config.clone(); let create_future = move || { let getter = get::IoGetter { store: store.clone(), }; - - let service = Service::new( - getter, - dialer, - concurrency_limits, - retry_config, - msg_rx, - metrics_clone, - ); - + let service = Service::new(getter, dialer, config2, msg_rx, metrics2); service.run().instrument(error_span!("downloader", %me)) }; rt.spawn_detached(create_future); Self { - next_id: Arc::new(AtomicU64::new(0)), - msg_tx, - metrics, + inner: Arc::new(Inner { + next_id: AtomicU64::new(0), + msg_tx, + config, + metrics, + }), } } + /// Get the current configuration. + pub fn config(&self) -> &Config { + &self.inner.config + } + /// Queue a download. pub async fn queue(&self, request: DownloadRequest) -> DownloadHandle { let kind = request.kind; - let intent_id = IntentId(self.next_id.fetch_add(1, Ordering::SeqCst)); + let intent_id = IntentId(self.inner.next_id.fetch_add(1, Ordering::SeqCst)); let (sender, receiver) = oneshot::channel(); let handle = DownloadHandle { id: intent_id, @@ -401,7 +411,7 @@ impl Downloader { }; // if this fails polling the handle will fail as well since the sender side of the oneshot // will be dropped - if let Err(send_err) = self.msg_tx.send(msg).await { + if let Err(send_err) = self.inner.msg_tx.send(msg).await { let msg = send_err.0; debug!(?msg, "download not sent"); } @@ -417,7 +427,7 @@ impl Downloader { receiver: _, } = handle; let msg = Message::CancelIntent { id, kind }; - if let Err(send_err) = self.msg_tx.send(msg).await { + if let Err(send_err) = self.inner.msg_tx.send(msg).await { let msg = send_err.0; debug!(?msg, "cancel not sent"); } @@ -429,7 +439,7 @@ impl Downloader { /// downloads. Use [`Self::queue`] to queue a download. pub async fn nodes_have(&mut self, hash: Hash, nodes: Vec) { let msg = Message::NodesHave { hash, nodes }; - if let Err(send_err) = self.msg_tx.send(msg).await { + if let Err(send_err) = self.inner.msg_tx.send(msg).await { let msg = send_err.0; debug!(?msg, "nodes have not been sent") } @@ -437,7 +447,7 @@ impl Downloader { /// Returns the metrics collected for this downloader. pub fn metrics(&self) -> &Arc { - &self.metrics + &self.inner.metrics } } @@ -586,8 +596,7 @@ impl, D: DialerT> Service { fn new( getter: G, dialer: D, - concurrency_limits: ConcurrencyLimits, - retry_config: RetryConfig, + config: Arc, msg_rx: mpsc::Receiver, metrics: Arc, ) -> Self { @@ -595,8 +604,8 @@ impl, D: DialerT> Service { getter, dialer, msg_rx, - concurrency_limits, - retry_config, + concurrency_limits: config.concurrency, + retry_config: config.retry, connected_nodes: Default::default(), retry_node_state: Default::default(), providers: Default::default(), diff --git a/src/downloader/test.rs b/src/downloader/test.rs index 75202db0..87ef11f2 100644 --- a/src/downloader/test.rs +++ b/src/downloader/test.rs @@ -49,24 +49,25 @@ impl Downloader { let lp = LocalPool::default(); let metrics_clone = metrics.clone(); + let config = Arc::new(Config { + concurrency: concurrency_limits, + retry: retry_config, + }); + let config2 = config.clone(); lp.spawn_detached(move || async move { // we want to see the logs of the service - let service = Service::new( - getter, - dialer, - concurrency_limits, - retry_config, - msg_rx, - metrics_clone, - ); + let service = Service::new(getter, dialer, config2, msg_rx, metrics_clone); service.run().await }); ( Downloader { - next_id: Arc::new(AtomicU64::new(0)), - msg_tx, - metrics, + inner: Arc::new(Inner { + next_id: AtomicU64::new(0), + msg_tx, + config, + metrics, + }), }, lp, ) diff --git a/src/net_protocol.rs b/src/net_protocol.rs index 198a2e32..8017248d 100644 --- a/src/net_protocol.rs +++ b/src/net_protocol.rs @@ -18,7 +18,7 @@ use serde::{Deserialize, Serialize}; use tracing::debug; use crate::{ - downloader::{ConcurrencyLimits, Downloader, RetryConfig}, + downloader::{self, ConcurrencyLimits, Downloader, RetryConfig}, metrics::Metrics, provider::EventSender, store::GcConfig, @@ -148,9 +148,8 @@ impl BlobBatches { pub struct Builder { store: S, events: Option, + downloader_config: Option, rt: Option, - concurrency_limits: Option, - retry_config: Option, } impl Builder { @@ -166,15 +165,23 @@ impl Builder { self } + /// Set custom downloader config + pub fn downloader_config(mut self, downloader_config: downloader::Config) -> Self { + self.downloader_config = Some(downloader_config); + self + } + /// Set custom [`ConcurrencyLimits`] to use. pub fn concurrency_limits(mut self, concurrency_limits: ConcurrencyLimits) -> Self { - self.concurrency_limits = Some(concurrency_limits); + let downloader_config = self.downloader_config.get_or_insert_with(Default::default); + downloader_config.concurrency = concurrency_limits; self } /// Set a custom [`RetryConfig`] to use. pub fn retry_config(mut self, retry_config: RetryConfig) -> Self { - self.retry_config = Some(retry_config); + let downloader_config = self.downloader_config.get_or_insert_with(Default::default); + downloader_config.retry = retry_config; self } @@ -185,12 +192,12 @@ impl Builder { .rt .map(Rt::Handle) .unwrap_or_else(|| Rt::Owned(LocalPool::default())); + let downloader_config = self.downloader_config.unwrap_or_default(); let downloader = Downloader::with_config( self.store.clone(), endpoint.clone(), rt.clone(), - self.concurrency_limits.unwrap_or_default(), - self.retry_config.unwrap_or_default(), + downloader_config, ); Blobs::new( self.store, @@ -208,9 +215,8 @@ impl Blobs { Builder { store, events: None, + downloader_config: None, rt: None, - concurrency_limits: None, - retry_config: None, } } } diff --git a/tests/rpc.rs b/tests/rpc.rs index 7dc12e7b..ab96c8f6 100644 --- a/tests/rpc.rs +++ b/tests/rpc.rs @@ -1,7 +1,7 @@ #![cfg(feature = "test")] use std::{net::SocketAddr, path::PathBuf, vec}; -use iroh_blobs::net_protocol::Blobs; +use iroh_blobs::{downloader, net_protocol::Blobs}; use quic_rpc::client::QuinnConnector; use tempfile::TempDir; use testresult::TestResult; @@ -85,3 +85,28 @@ async fn quinn_rpc_large() -> TestResult<()> { assert_eq!(data, &data2[..]); Ok(()) } + +#[tokio::test] +async fn downloader_config() -> TestResult<()> { + let _ = tracing_subscriber::fmt::try_init(); + let endpoint = iroh::Endpoint::builder().bind().await?; + let store = iroh_blobs::store::mem::Store::default(); + let expected = downloader::Config { + concurrency: downloader::ConcurrencyLimits { + max_concurrent_requests: usize::MAX, + max_concurrent_requests_per_node: usize::MAX, + max_open_connections: usize::MAX, + max_concurrent_dials_per_hash: usize::MAX, + }, + retry: downloader::RetryConfig { + max_retries_per_node: u32::MAX, + initial_retry_delay: std::time::Duration::from_secs(1), + }, + }; + let blobs = Blobs::builder(store) + .downloader_config(expected) + .build(&endpoint); + let actual = blobs.downloader().config(); + assert_eq!(&expected, actual); + Ok(()) +}