diff --git a/src/net_protocol.rs b/src/net_protocol.rs index 770e3e65e..16d9fcb94 100644 --- a/src/net_protocol.rs +++ b/src/net_protocol.rs @@ -25,7 +25,8 @@ use crate::{ Stats, }, provider::EventSender, - store::GcConfig, + rpc::{client::blobs::MemClient, MemRpcHandler, RpcHandler}, + store::{GcConfig, Store}, util::{ local_pool::{self, LocalPoolHandle}, progress::{AsyncChannelProgressSender, ProgressSender}, @@ -53,8 +54,8 @@ impl Default for GcState { } #[derive(Debug)] -pub struct Blobs { - rt: LocalPoolHandle, +pub(crate) struct BlobsInner { + pub(crate) rt: LocalPoolHandle, pub(crate) store: S, events: EventSender, downloader: Downloader, @@ -62,7 +63,7 @@ pub struct Blobs { endpoint: Endpoint, gc_state: Arc>, #[cfg(feature = "rpc")] - pub(crate) rpc_handler: Arc>, + pub(crate) rpc_handler: Arc>, } /// Name used for logging when new node addresses are added from gossip. @@ -135,9 +136,21 @@ impl Builder { /// Build the Blobs protocol handler. /// You need to provide a local pool handle and an endpoint. - pub fn build(self, rt: &LocalPoolHandle, endpoint: &Endpoint) -> Arc> { + pub fn build(self, rt: &LocalPoolHandle, endpoint: &Endpoint) -> Arc { + let inner = self.build_inner(rt, endpoint); + Arc::new(Blobs { inner }) + } + + pub fn build_rpc_handler(self, rt: &LocalPoolHandle, endpoint: &Endpoint) -> RpcHandler { + let inner = self.build_inner(rt, endpoint); + RpcHandler::from_blobs(inner) + } + + /// Build the Blobs protocol handler. + /// You need to provide a local pool handle and an endpoint. + fn build_inner(self, rt: &LocalPoolHandle, endpoint: &Endpoint) -> Arc> { let downloader = Downloader::new(self.store.clone(), endpoint.clone(), rt.clone()); - Arc::new(Blobs::new( + Arc::new(BlobsInner::new( self.store, rt.clone(), self.events.unwrap_or_default(), @@ -147,24 +160,20 @@ impl Builder { } } -impl Blobs { +impl Blobs { /// Create a new Blobs protocol handler builder, given a store. - pub fn builder(store: S) -> Builder { + pub fn builder(store: S) -> Builder { Builder { store, events: None, } } -} -impl Blobs { /// Create a new memory-backed Blobs protocol handler. pub fn memory() -> Builder { Self::builder(crate::store::mem::Store::new()) } -} -impl Blobs { /// Load a persistent Blobs protocol handler from a path. pub async fn persistent( path: impl AsRef, @@ -173,8 +182,8 @@ impl Blobs { } } -impl Blobs { - pub fn new( +impl BlobsInner { + fn new( store: S, rt: LocalPoolHandle, events: EventSender, @@ -194,18 +203,6 @@ impl Blobs { } } - pub fn store(&self) -> &S { - &self.store - } - - pub fn rt(&self) -> &LocalPoolHandle { - &self.rt - } - - pub fn downloader(&self) -> &Downloader { - &self.downloader - } - pub fn endpoint(&self) -> &Endpoint { &self.endpoint } @@ -390,66 +387,71 @@ impl Blobs { } } -// trait BlobsInner: Debug + Send + Sync + 'static { -// fn shutdown(self: Arc) -> BoxedFuture<()>; -// fn accept(self: Arc, conn: Connecting) -> BoxedFuture>; -// fn client(self: Arc) -> MemClient; -// fn local_pool_handle(&self) -> &LocalPoolHandle; -// fn downloader(&self) -> &Downloader; -// } - -// #[derive(Debug)] -// struct Blobs2 { -// inner: Arc, -// } - -// impl Blobs2 { -// fn client(&self) -> MemClient { -// self.inner.clone().client() -// } - -// fn local_pool_handle(&self) -> &LocalPoolHandle { -// self.inner.local_pool_handle() -// } - -// fn downloader(&self) -> &Downloader { -// self.inner.downloader() -// } -// } - -// impl BlobsInner for Blobs { -// fn shutdown(self: Arc) -> BoxedFuture<()> { -// ProtocolHandler::shutdown(self) -// } - -// fn accept(self: Arc, conn: Connecting) -> BoxedFuture> { -// ProtocolHandler::accept(self, conn) -// } - -// fn client(self: Arc) -> MemClient { -// Blobs::client(self) -// } - -// fn local_pool_handle(&self) -> &LocalPoolHandle { -// self.rt() -// } - -// fn downloader(&self) -> &Downloader { -// self.downloader() -// } -// } - -// impl ProtocolHandler for Blobs2 { -// fn accept(self: Arc, conn: Connecting) -> BoxedFuture> { -// self.inner.clone().accept(conn) -// } - -// fn shutdown(self: Arc) -> BoxedFuture<()> { -// self.inner.clone().shutdown() -// } -// } - -impl ProtocolHandler for Blobs { +trait DynBlobs: Debug + Send + Sync + 'static { + fn shutdown(self: Arc) -> BoxedFuture<()>; + fn accept(self: Arc, conn: Connecting) -> BoxedFuture>; + fn client(self: Arc) -> MemClient; + fn local_pool_handle(&self) -> &LocalPoolHandle; + fn downloader(&self) -> &Downloader; + fn endpoint(&self) -> &Endpoint; + fn start_gc(&self, config: GcConfig) -> Result<()>; + fn add_protected(&self, cb: ProtectCb) -> Result<()>; + fn stop_rpc_task(&self); +} + +#[derive(Debug)] +pub struct Blobs { + inner: Arc, +} + +impl Blobs { + pub(crate) fn from_inner(inner: Arc>) -> Self { + Self { inner } + } + + pub fn client(&self) -> MemClient { + self.inner.clone().client() + } + + pub fn local_pool_handle(&self) -> &LocalPoolHandle { + self.inner.local_pool_handle() + } + + pub fn downloader(&self) -> &Downloader { + self.inner.downloader() + } + + pub fn endpoint(&self) -> &Endpoint { + self.inner.endpoint() + } + + pub fn add_protected(&self, cb: ProtectCb) -> Result<()> { + self.inner.add_protected(cb) + } + + pub fn start_gc(&self, config: GcConfig) -> Result<()> { + self.inner.start_gc(config) + } + + pub fn new( + store: S, + rt: LocalPoolHandle, + events: EventSender, + downloader: Downloader, + endpoint: Endpoint, + ) -> Self { + let inner = Arc::new(BlobsInner::new(store, rt, events, downloader, endpoint)); + Self { inner } + } +} + +impl Drop for Blobs { + fn drop(&mut self) { + self.inner.stop_rpc_task(); + } +} + +impl DynBlobs for BlobsInner { fn accept(self: Arc, conn: Connecting) -> BoxedFuture> { Box::pin(async move { crate::provider::handle_connection( @@ -465,9 +467,55 @@ impl ProtocolHandler for Blobs { fn shutdown(self: Arc) -> BoxedFuture<()> { Box::pin(async move { + self.stop_rpc_task(); self.store.shutdown().await; }) } + + fn stop_rpc_task(&self) { + if let Some(rpc_handler) = self.rpc_handler.get() { + rpc_handler.shutdown(); + } + } + + fn client(self: Arc) -> MemClient { + let client = self + .rpc_handler + .get_or_init(|| MemRpcHandler::new(&self)) + .client + .clone(); + MemClient::new(client) + } + + fn local_pool_handle(&self) -> &LocalPoolHandle { + &self.rt + } + + fn downloader(&self) -> &Downloader { + &self.downloader + } + + fn start_gc(&self, config: GcConfig) -> Result<()> { + self.start_gc(config) + } + + fn add_protected(&self, cb: ProtectCb) -> Result<()> { + self.add_protected(cb) + } + + fn endpoint(&self) -> &Endpoint { + &self.endpoint + } +} + +impl ProtocolHandler for Blobs { + fn accept(self: Arc, conn: Connecting) -> BoxedFuture> { + self.inner.clone().accept(conn) + } + + fn shutdown(self: Arc) -> BoxedFuture<()> { + self.inner.clone().shutdown() + } } /// A request to the node to download and share the data specified by the hash. diff --git a/src/rpc.rs b/src/rpc.rs index ff8d13fa7..3e77e1d20 100644 --- a/src/rpc.rs +++ b/src/rpc.rs @@ -2,13 +2,12 @@ use std::{ io, - ops::Deref, sync::{Arc, Mutex}, }; use anyhow::anyhow; use client::{ - blobs::{self, BlobInfo, BlobStatus, IncompleteBlobInfo, WrapOption}, + blobs::{BlobInfo, BlobStatus, IncompleteBlobInfo, WrapOption}, tags::TagInfo, MemConnector, }; @@ -44,10 +43,11 @@ use crate::{ export::ExportProgress, format::collection::Collection, get::db::DownloadProgress, - net_protocol::{BlobDownloadRequest, Blobs}, + net_protocol::{BlobDownloadRequest, Blobs, BlobsInner}, provider::{AddProgress, BatchAddPathProgress}, store::{ConsistencyCheckProgress, ImportProgress, MapEntry, ValidateProgress}, util::{ + local_pool::LocalPoolHandle, progress::{AsyncChannelProgressSender, ProgressSender}, SetTagOption, }, @@ -61,17 +61,7 @@ const RPC_BLOB_GET_CHUNK_SIZE: usize = 1024 * 64; /// Channel cap for getting blobs over RPC const RPC_BLOB_GET_CHANNEL_CAP: usize = 2; -impl Blobs { - /// Get a client for the blobs protocol - pub fn client(self: Arc) -> blobs::MemClient { - let client = self - .rpc_handler - .get_or_init(|| RpcHandler::new(&self)) - .client - .clone(); - blobs::Client::new(client) - } - +impl BlobsInner { /// Handle an RPC request pub async fn handle_rpc_request( self: Arc, @@ -81,33 +71,56 @@ impl Blobs { where C: ChannelTypes, { - use Request::*; - let handler = Handler(self); - match msg { - Blobs(msg) => handler.handle_blobs_request(msg, chan).await, - Tags(msg) => handler.handle_tags_request(msg, chan).await, + RpcHandler { + blobs: self.clone(), } + .handle_rpc_request(msg, chan) + .await } } -#[derive(Clone)] -struct Handler(Arc>); +/// RPC handler for the blobs protocol +#[derive(Debug, Clone)] +pub struct RpcHandler { + blobs: Arc>, +} -impl Deref for Handler { - type Target = Blobs; +impl RpcHandler { + fn store(&self) -> &D { + &self.blobs.store + } - fn deref(&self) -> &Self::Target { - &self.0 + fn rt(&self) -> &LocalPoolHandle { + &self.blobs.rt } -} -impl Handler { - fn store(&self) -> &D { - &self.0.store + pub(crate) fn from_blobs(blobs: Arc>) -> Self { + Self { blobs } + } + + /// Get the blobs ProtocolHandler + pub fn blobs(&self) -> Arc { + Arc::new(Blobs::from_inner(self.blobs.clone())) + } + + /// Handle an RPC request + pub async fn handle_rpc_request( + self, + msg: Request, + chan: RpcChannel, + ) -> std::result::Result<(), RpcServerError> + where + C: ChannelTypes, + { + use Request::*; + match msg { + Blobs(msg) => self.handle_blobs_request(msg, chan).await, + Tags(msg) => self.handle_tags_request(msg, chan).await, + } } /// Handle a tags request - pub async fn handle_tags_request( + async fn handle_tags_request( self, msg: proto::tags::Request, chan: RpcChannel, @@ -125,7 +138,7 @@ impl Handler { } /// Handle a blobs request - pub async fn handle_blobs_request( + async fn handle_blobs_request( self, msg: proto::blobs::Request, chan: RpcChannel, @@ -170,9 +183,9 @@ impl Handler { } async fn blob_status(self, msg: BlobStatusRequest) -> RpcResult { - let blobs = self; + let blobs = self.blobs; let entry = blobs - .store() + .store .get(&msg.hash) .await .map_err(|e| RpcError::new(&e))?; @@ -193,8 +206,8 @@ impl Handler { async fn blob_list_impl(self, co: &Co>) -> io::Result<()> { use bao_tree::io::fsm::Outboard; - let blobs = self; - let db = blobs.store(); + let blobs = self.blobs; + let db = &blobs.store; for blob in db.blobs().await? { let blob = blob?; let Some(entry) = db.get(&blob).await? else { @@ -346,14 +359,14 @@ impl Handler { } async fn tags_set(self, msg: TagsSetRequest) -> RpcResult<()> { - let blobs = self; + let blobs = self.blobs; blobs - .store() + .store .set_tag(msg.name, msg.value) .await .map_err(|e| RpcError::new(&e))?; if let SyncMode::Full = msg.sync { - blobs.store().sync().await.map_err(|e| RpcError::new(&e))?; + blobs.store.sync().await.map_err(|e| RpcError::new(&e))?; } if let Some(batch) = msg.batch { if let Some(content) = msg.value.as_ref() { @@ -368,14 +381,14 @@ impl Handler { } async fn tags_create(self, msg: TagsCreateRequest) -> RpcResult { - let blobs = self; + let blobs = self.blobs; let tag = blobs - .store() + .store .create_tag(msg.value) .await .map_err(|e| RpcError::new(&e))?; if let SyncMode::Full = msg.sync { - blobs.store().sync().await.map_err(|e| RpcError::new(&e))?; + blobs.store.sync().await.map_err(|e| RpcError::new(&e))?; } if let Some(batch) = msg.batch { blobs @@ -389,13 +402,14 @@ impl Handler { fn blob_download(self, msg: BlobDownloadRequest) -> impl Stream { let (sender, receiver) = async_channel::bounded(1024); - let endpoint = self.endpoint().clone(); + let endpoint = self.blobs.endpoint().clone(); let progress = AsyncChannelProgressSender::new(sender); let blobs_protocol = self.clone(); self.rt().spawn_detached(move || async move { if let Err(err) = blobs_protocol + .blobs .download(endpoint, msg, progress.clone()) .await { @@ -554,8 +568,8 @@ impl Handler { } async fn batch_create_temp_tag(self, msg: BatchCreateTempTagRequest) -> RpcResult<()> { - let blobs = self; - let tag = blobs.store().temp_tag(msg.content); + let blobs = self.blobs; + let tag = blobs.store.temp_tag(msg.content); blobs.batches().await.store(msg.batch, tag); Ok(()) } @@ -602,7 +616,7 @@ impl Handler { stream: impl Stream + Send + Unpin + 'static, progress: async_channel::Sender, ) -> anyhow::Result<()> { - let blobs = self; + let blobs = self.blobs; let progress = AsyncChannelProgressSender::new(progress); let stream = stream.map(|item| match item { @@ -619,7 +633,7 @@ impl Handler { _ => None, }); let (temp_tag, _len) = blobs - .store() + .store .import_stream(stream, msg.format, import_progress) .await?; let hash = temp_tag.inner().hash; @@ -658,9 +672,9 @@ impl Handler { "trying to add missing path: {}", root.display() ); - let blobs = self; + let blobs = self.blobs; let (tag, _) = blobs - .store() + .store .import_file(root, import_mode, format, import_progress) .await?; let hash = *tag.hash(); @@ -827,7 +841,7 @@ impl Handler { _: BatchCreateRequest, mut updates: impl Stream + Send + Unpin + 'static, ) -> impl Stream { - let blobs = self; + let blobs = self.blobs; async move { let batch = blobs.batches().await.create(); tokio::spawn(async move { @@ -895,21 +909,28 @@ impl Handler { } #[derive(Debug)] -pub(crate) struct RpcHandler { +pub(crate) struct MemRpcHandler { /// Client to hand out - client: RpcClient, + pub(crate) client: RpcClient, /// Handler task - _handler: AbortOnDropHandle<()>, + handler: AbortOnDropHandle<()>, } -impl RpcHandler { - fn new(blobs: &Arc>) -> Self { +impl MemRpcHandler { + pub fn new(blobs: &Arc>) -> Self { let blobs = blobs.clone(); let (listener, connector) = quic_rpc::transport::flume::channel(1); let listener = RpcServer::new(listener); let client = RpcClient::new(connector); let _handler = listener .spawn_accept_loop(move |req, chan| blobs.clone().handle_rpc_request(req, chan)); - Self { client, _handler } + Self { + client, + handler: _handler, + } + } + + pub fn shutdown(&self) { + self.handler.abort(); } } diff --git a/src/rpc/client/blobs.rs b/src/rpc/client/blobs.rs index 3d8e1e182..f1e75cb4e 100644 --- a/src/rpc/client/blobs.rs +++ b/src/rpc/client/blobs.rs @@ -1003,29 +1003,24 @@ mod tests { mod node { //! An iroh node that just has the blobs transport - use std::{path::Path, sync::Arc}; + use std::path::Path; use iroh::{protocol::Router, Endpoint, NodeAddr, NodeId}; - use tokio_util::task::AbortOnDropHandle; - use super::RpcService; + use super::MemClient; use crate::{ - downloader::Downloader, net_protocol::Blobs, provider::{CustomEventSender, EventSender}, rpc::client::{blobs, tags}, util::local_pool::LocalPool, }; - type RpcClient = quic_rpc::RpcClient; - /// An iroh node that just has the blobs transport #[derive(Debug)] pub struct Node { router: iroh::protocol::Router, - client: RpcClient, + client: MemClient, _local_pool: LocalPool, - _rpc_task: AbortOnDropHandle<()>, } /// An iroh node builder @@ -1066,31 +1061,19 @@ mod tests { let mut router = Router::builder(endpoint.clone()); // Setup blobs - let downloader = - Downloader::new(store.clone(), endpoint.clone(), local_pool.handle().clone()); - let blobs = Arc::new(Blobs::new( - store.clone(), - local_pool.handle().clone(), - events, - downloader, - endpoint.clone(), - )); + let blobs = Blobs::builder(store.clone()) + .events(events) + .build(&local_pool, &endpoint); router = router.accept(crate::ALPN, blobs.clone()); // Build the router let router = router.spawn().await?; // Setup RPC - let (internal_rpc, controller) = quic_rpc::transport::flume::channel(32); - let internal_rpc = quic_rpc::RpcServer::new(internal_rpc).boxed(); - let _rpc_task = internal_rpc.spawn_accept_loop(move |msg, chan| { - blobs.clone().handle_rpc_request(msg, chan) - }); - let client = quic_rpc::RpcClient::new(controller).boxed(); + let client = blobs.client(); Ok(Node { router, client, - _rpc_task, _local_pool: local_pool, }) } @@ -1129,17 +1112,18 @@ mod tests { /// Shuts down the node pub async fn shutdown(self) -> anyhow::Result<()> { - self.router.shutdown().await + self.router.shutdown().await?; + Ok(()) } /// Returns an in-memory blobs client - pub fn blobs(&self) -> blobs::Client { - blobs::Client::new(self.client.clone()) + pub fn blobs(&self) -> blobs::MemClient { + self.client.clone() } /// Returns an in-memory tags client - pub fn tags(&self) -> tags::Client { - tags::Client::new(self.client.clone()) + pub fn tags(&self) -> tags::MemClient { + self.blobs().tags() } } } @@ -1509,7 +1493,7 @@ mod tests { #[tokio::test] async fn test_blob_provide_events() -> Result<()> { - let _guard = iroh_test::logging::setup(); + // let _guard = iroh_test::logging::setup(); let (node1_events, mut node1_events_r) = BlobEvents::new(16); let node1 = node::Node::memory() diff --git a/tests/gc.rs b/tests/gc.rs index 10cc74e6e..eeb7954cf 100644 --- a/tests/gc.rs +++ b/tests/gc.rs @@ -41,7 +41,7 @@ use tokio::io::AsyncReadExt; #[derive(Debug)] pub struct Node { pub router: iroh::protocol::Router, - pub blobs: Arc>, + pub blobs: Arc, pub store: S, pub _local_pool: LocalPool, }