diff --git a/Cargo.lock b/Cargo.lock index 9a3a5e8b..452d2a1f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -179,6 +179,12 @@ version = "1.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" +[[package]] +name = "atomic_refcell" +version = "0.1.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41e67cd8309bbd06cd603a9e693a784ac2e5d1e955f11286e355089fcab3047c" + [[package]] name = "attohttpc" version = "0.24.1" @@ -1601,6 +1607,17 @@ dependencies = [ "web-sys", ] +[[package]] +name = "io-uring" +version = "0.7.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b86e202f00093dcba4275d4636b93ef9dd75d025ae560d2521b45ea28ab49013" +dependencies = [ + "bitflags 2.9.1", + "cfg-if", + "libc", +] + [[package]] name = "ipconfig" version = "0.3.2" @@ -1713,6 +1730,7 @@ version = "0.90.0" dependencies = [ "anyhow", "arrayvec", + "atomic_refcell", "bao-tree", "bytes", "chrono", @@ -3936,17 +3954,19 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.45.1" +version = "1.46.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "75ef51a33ef1da925cea3e4eb122833cb377c61439ca401b770f54902b806779" +checksum = "0cc3a2344dafbe23a245241fe8b09735b521110d30fcefbbd5feb1797ca35d17" dependencies = [ "backtrace", "bytes", + "io-uring", "libc", "mio", "parking_lot", "pin-project-lite", "signal-hook-registry", + "slab", "socket2", "tokio-macros", "windows-sys 0.52.0", diff --git a/Cargo.toml b/Cargo.toml index 6fb6efce..c73016fa 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -58,6 +58,7 @@ testresult = "0.4.1" tracing-subscriber = { version = "0.3.19", features = ["fmt"] } tracing-test = "0.2.5" walkdir = "2.5.0" +atomic_refcell = "0.1.13" [features] hide-proto-docs = [] diff --git a/src/api/blobs.rs b/src/api/blobs.rs index 74c6e1f2..0f79838f 100644 --- a/src/api/blobs.rs +++ b/src/api/blobs.rs @@ -127,7 +127,7 @@ impl Blobs { .await } - pub fn add_slice(&self, data: impl AsRef<[u8]>) -> AddProgress { + pub fn add_slice(&self, data: impl AsRef<[u8]>) -> AddProgress<'_> { let options = ImportBytesRequest { data: Bytes::copy_from_slice(data.as_ref()), format: crate::BlobFormat::Raw, @@ -136,7 +136,7 @@ impl Blobs { self.add_bytes_impl(options) } - pub fn add_bytes(&self, data: impl Into) -> AddProgress { + pub fn add_bytes(&self, data: impl Into) -> AddProgress<'_> { let options = ImportBytesRequest { data: data.into(), format: crate::BlobFormat::Raw, @@ -145,7 +145,7 @@ impl Blobs { self.add_bytes_impl(options) } - pub fn add_bytes_with_opts(&self, options: impl Into) -> AddProgress { + pub fn add_bytes_with_opts(&self, options: impl Into) -> AddProgress<'_> { let options = options.into(); let request = ImportBytesRequest { data: options.data, @@ -155,7 +155,7 @@ impl Blobs { self.add_bytes_impl(request) } - fn add_bytes_impl(&self, options: ImportBytesRequest) -> AddProgress { + fn add_bytes_impl(&self, options: ImportBytesRequest) -> AddProgress<'_> { trace!("{options:?}"); let this = self.clone(); let stream = Gen::new(|co| async move { @@ -180,7 +180,7 @@ impl Blobs { AddProgress::new(self, stream) } - pub fn add_path_with_opts(&self, options: impl Into) -> AddProgress { + pub fn add_path_with_opts(&self, options: impl Into) -> AddProgress<'_> { let options = options.into(); self.add_path_with_opts_impl(ImportPathRequest { path: options.path, @@ -190,7 +190,7 @@ impl Blobs { }) } - fn add_path_with_opts_impl(&self, options: ImportPathRequest) -> AddProgress { + fn add_path_with_opts_impl(&self, options: ImportPathRequest) -> AddProgress<'_> { trace!("{:?}", options); let client = self.client.clone(); let stream = Gen::new(|co| async move { @@ -215,7 +215,7 @@ impl Blobs { AddProgress::new(self, stream) } - pub fn add_path(&self, path: impl AsRef) -> AddProgress { + pub fn add_path(&self, path: impl AsRef) -> AddProgress<'_> { self.add_path_with_opts(AddPathOptions { path: path.as_ref().to_owned(), mode: ImportMode::Copy, @@ -226,7 +226,7 @@ impl Blobs { pub async fn add_stream( &self, data: impl Stream> + Send + Sync + 'static, - ) -> AddProgress { + ) -> AddProgress<'_> { let inner = ImportByteStreamRequest { format: crate::BlobFormat::Raw, scope: Scope::default(), @@ -521,7 +521,7 @@ pub struct Batch<'a> { } impl<'a> Batch<'a> { - pub fn add_bytes(&self, data: impl Into) -> BatchAddProgress { + pub fn add_bytes(&self, data: impl Into) -> BatchAddProgress<'_> { let options = ImportBytesRequest { data: data.into(), format: crate::BlobFormat::Raw, @@ -530,7 +530,7 @@ impl<'a> Batch<'a> { BatchAddProgress(self.blobs.add_bytes_impl(options)) } - pub fn add_bytes_with_opts(&self, options: impl Into) -> BatchAddProgress { + pub fn add_bytes_with_opts(&self, options: impl Into) -> BatchAddProgress<'_> { let options = options.into(); BatchAddProgress(self.blobs.add_bytes_impl(ImportBytesRequest { data: options.data, @@ -539,7 +539,7 @@ impl<'a> Batch<'a> { })) } - pub fn add_slice(&self, data: impl AsRef<[u8]>) -> BatchAddProgress { + pub fn add_slice(&self, data: impl AsRef<[u8]>) -> BatchAddProgress<'_> { let options = ImportBytesRequest { data: Bytes::copy_from_slice(data.as_ref()), format: crate::BlobFormat::Raw, @@ -548,7 +548,7 @@ impl<'a> Batch<'a> { BatchAddProgress(self.blobs.add_bytes_impl(options)) } - pub fn add_path_with_opts(&self, options: impl Into) -> BatchAddProgress { + pub fn add_path_with_opts(&self, options: impl Into) -> BatchAddProgress<'_> { let options = options.into(); BatchAddProgress(self.blobs.add_path_with_opts_impl(ImportPathRequest { path: options.path, diff --git a/src/hash.rs b/src/hash.rs index 8190009a..006f4a9d 100644 --- a/src/hash.rs +++ b/src/hash.rs @@ -111,7 +111,7 @@ impl From<&[u8; 32]> for Hash { impl PartialOrd for Hash { fn partial_cmp(&self, other: &Self) -> Option { - Some(self.0.as_bytes().cmp(other.0.as_bytes())) + Some(self.cmp(other)) } } diff --git a/src/metrics.rs b/src/metrics.rs index c47fb6ea..0ff5cd2a 100644 --- a/src/metrics.rs +++ b/src/metrics.rs @@ -4,6 +4,7 @@ use iroh_metrics::{Counter, MetricsGroup}; /// Enum of metrics for the module #[allow(missing_docs)] +#[allow(dead_code)] #[derive(Debug, Default, MetricsGroup)] #[metrics(name = "iroh-blobs")] pub struct Metrics { diff --git a/src/store/fs.rs b/src/store/fs.rs index 024d9786..53ed163a 100644 --- a/src/store/fs.rs +++ b/src/store/fs.rs @@ -64,7 +64,6 @@ //! safely shut down as well. Any store refs you are holding will be inoperable //! after this. use std::{ - collections::{HashMap, HashSet}, fmt, fs, future::Future, io::Write, @@ -84,15 +83,16 @@ use bao_tree::{ }; use bytes::Bytes; use delete_set::{BaoFilePart, ProtectHandle}; +use entity_manager::{EntityManagerState, SpawnArg}; use entry_state::{DataLocation, OutboardLocation}; use gc::run_gc; use import::{ImportEntry, ImportSource}; use irpc::channel::mpsc; -use meta::{list_blobs, Snapshot}; +use meta::list_blobs; use n0_future::{future::yield_now, io}; use nested_enum_utils::enum_conversions; use range_collections::range_set::RangeSetRange; -use tokio::task::{Id, JoinError, JoinSet}; +use tokio::task::{JoinError, JoinSet}; use tracing::{error, instrument, trace}; use crate::{ @@ -106,6 +106,7 @@ use crate::{ ApiClient, }, store::{ + fs::util::entity_manager::{self, ActiveEntityState}, util::{BaoTreeSender, FixedSize, MemOrFile, ValueOrPoisioned}, Hash, }, @@ -116,7 +117,7 @@ use crate::{ }, }; mod bao_file; -use bao_file::{BaoFileHandle, BaoFileHandleWeak}; +use bao_file::BaoFileHandle; mod delete_set; mod entry_state; mod import; @@ -200,6 +201,29 @@ impl TaskContext { } } +impl entity_manager::Reset for Slot { + fn reset(&mut self) { + self.0 = Arc::new(tokio::sync::Mutex::new(None)); + } +} + +#[derive(Debug)] +struct EmParams; + +impl entity_manager::Params for EmParams { + type EntityId = Hash; + + type GlobalState = Arc; + + type EntityState = Slot; + + async fn on_shutdown( + _state: entity_manager::ActiveEntityState, + _cause: entity_manager::ShutdownCause, + ) { + } +} + #[derive(Debug)] struct Actor { // Context that can be cheaply shared with tasks. @@ -210,56 +234,36 @@ struct Actor { fs_cmd_rx: tokio::sync::mpsc::Receiver, // Tasks for import and export operations. tasks: JoinSet<()>, - // Running tasks - running: HashSet, - // handles - handles: HashMap, + // Entity manager that handles concurrency for entities. + handles: EntityManagerState, // temp tags temp_tags: TempTags, // our private tokio runtime. It has to live somewhere. _rt: RtWrapper, } -/// Wraps a slot and the task context. -/// -/// This contains everything a hash-specific task should need. -struct HashContext { - slot: Slot, - ctx: Arc, -} +type HashContext = ActiveEntityState; impl HashContext { pub fn db(&self) -> &meta::Db { - &self.ctx.db + &self.global.db } pub fn options(&self) -> &Arc { - &self.ctx.options + &self.global.options } - pub async fn lock(&self) -> tokio::sync::MutexGuard<'_, Option> { - self.slot.0.lock().await + pub async fn lock(&self) -> tokio::sync::MutexGuard<'_, Option> { + self.state.0.lock().await } pub fn protect(&self, hash: Hash, parts: impl IntoIterator) { - self.ctx.protect.protect(hash, parts); + self.global.protect.protect(hash, parts); } /// Update the entry state in the database, and wait for completion. - pub async fn update(&self, hash: Hash, state: EntryState) -> io::Result<()> { - let (tx, rx) = oneshot::channel(); - self.db() - .send( - meta::Update { - hash, - state, - tx: Some(tx), - span: tracing::Span::current(), - } - .into(), - ) - .await?; - rx.await.map_err(|_e| io::Error::other(""))??; + pub async fn update_await(&self, hash: Hash, state: EntryState) -> io::Result<()> { + self.db().update_await(hash, state).await?; Ok(()) } @@ -269,60 +273,25 @@ impl HashContext { data_location: DataLocation::Inline(Bytes::new()), outboard_location: OutboardLocation::NotNeeded, })); - } - let (tx, rx) = oneshot::channel(); - self.db() - .send( - meta::Get { - hash, - tx, - span: tracing::Span::current(), - } - .into(), - ) - .await - .ok(); - let res = rx.await.map_err(io::Error::other)?; - Ok(res.state?) + }; + self.db().get(hash).await } /// Update the entry state in the database, and wait for completion. pub async fn set(&self, hash: Hash, state: EntryState) -> io::Result<()> { - let (tx, rx) = oneshot::channel(); - self.db() - .send( - meta::Set { - hash, - state, - tx, - span: tracing::Span::current(), - } - .into(), - ) - .await - .map_err(io::Error::other)?; - rx.await.map_err(|_e| io::Error::other(""))??; - Ok(()) - } - - pub async fn get_maybe_create(&self, hash: Hash, create: bool) -> api::Result { - if create { - self.get_or_create(hash).await - } else { - self.get(hash).await - } + self.db().set(hash, state).await } pub async fn get(&self, hash: Hash) -> api::Result { if hash == Hash::EMPTY { - return Ok(self.ctx.empty.clone()); + return Ok(self.global.empty.clone()); } let res = self - .slot + .state .get_or_create(|| async { let res = self.db().get(hash).await.map_err(io::Error::other)?; let res = match res { - Some(state) => open_bao_file(&hash, state, &self.ctx).await, + Some(state) => open_bao_file(&hash, state, &self.global).await, None => Err(io::Error::new(io::ErrorKind::NotFound, "hash not found")), }; Ok((res?, ())) @@ -335,17 +304,17 @@ impl HashContext { pub async fn get_or_create(&self, hash: Hash) -> api::Result { if hash == Hash::EMPTY { - return Ok(self.ctx.empty.clone()); + return Ok(self.global.empty.clone()); } let res = self - .slot + .state .get_or_create(|| async { let res = self.db().get(hash).await.map_err(io::Error::other)?; let res = match res { - Some(state) => open_bao_file(&hash, state, &self.ctx).await, + Some(state) => open_bao_file(&hash, state, &self.global).await, None => Ok(BaoFileHandle::new_partial_mem( hash, - self.ctx.options.clone(), + self.global.options.clone(), )), }; Ok((res?, ())) @@ -402,14 +371,9 @@ async fn open_bao_file( /// An entry for each hash, containing a weak reference to a BaoFileHandle /// wrapped in a tokio mutex so handle creation is sequential. #[derive(Debug, Clone, Default)] -pub(crate) struct Slot(Arc>>); +pub(crate) struct Slot(Arc>>); impl Slot { - pub async fn is_live(&self) -> bool { - let slot = self.0.lock().await; - slot.as_ref().map(|weak| !weak.is_dead()).unwrap_or(false) - } - /// Get the handle if it exists and is still alive, otherwise load it from the database. /// If there is nothing in the database, create a new in-memory handle. /// @@ -421,14 +385,12 @@ impl Slot { T: Default, { let mut slot = self.0.lock().await; - if let Some(weak) = &*slot { - if let Some(handle) = weak.upgrade() { - return Ok((handle, Default::default())); - } + if let Some(handle) = &*slot { + return Ok((handle.clone(), Default::default())); } let handle = make().await; if let Ok((handle, _)) = &handle { - *slot = Some(handle.downgrade()); + *slot = Some(handle.clone()); } handle } @@ -445,17 +407,12 @@ impl Actor { fn spawn(&mut self, fut: impl Future + Send + 'static) { let span = tracing::Span::current(); - let id = self.tasks.spawn(fut.instrument(span)).id(); - self.running.insert(id); + self.tasks.spawn(fut.instrument(span)); } - fn log_task_result(&mut self, res: Result<(Id, ()), JoinError>) { + fn log_task_result(res: Result<(), JoinError>) { match res { - Ok((id, _)) => { - // println!("task {id} finished"); - self.running.remove(&id); - // println!("{:?}", self.running); - } + Ok(_) => {} Err(e) => { error!("task failed: {e}"); } @@ -471,26 +428,6 @@ impl Actor { tx.send(tt).await.ok(); } - async fn clear_dead_handles(&mut self) { - let mut to_remove = Vec::new(); - for (hash, slot) in &self.handles { - if !slot.is_live().await { - to_remove.push(*hash); - } - } - for hash in to_remove { - if let Some(slot) = self.handles.remove(&hash) { - // do a quick check if the handle has become alive in the meantime, and reinsert it - let guard = slot.0.lock().await; - let is_live = guard.as_ref().map(|x| !x.is_dead()).unwrap_or_default(); - if is_live { - drop(guard); - self.handles.insert(hash, slot); - } - } - } - } - async fn handle_command(&mut self, cmd: Command) { let span = cmd.parent_span(); let _entered = span.enter(); @@ -525,34 +462,22 @@ impl Actor { } Command::ClearProtected(cmd) => { trace!("{cmd:?}"); - self.clear_dead_handles().await; self.db().send(cmd.into()).await.ok(); } Command::BlobStatus(cmd) => { trace!("{cmd:?}"); self.db().send(cmd.into()).await.ok(); } + Command::DeleteBlobs(cmd) => { + trace!("{cmd:?}"); + self.db().send(cmd.into()).await.ok(); + } Command::ListBlobs(cmd) => { trace!("{cmd:?}"); - let (tx, rx) = tokio::sync::oneshot::channel(); - self.db() - .send( - Snapshot { - tx, - span: cmd.span.clone(), - } - .into(), - ) - .await - .ok(); - if let Ok(snapshot) = rx.await { + if let Ok(snapshot) = self.db().snapshot(cmd.span.clone()).await { self.spawn(list_blobs(snapshot, cmd)); } } - Command::DeleteBlobs(cmd) => { - trace!("{cmd:?}"); - self.db().send(cmd.into()).await.ok(); - } Command::Batch(cmd) => { trace!("{cmd:?}"); let (id, scope) = self.temp_tags.create_scope(); @@ -581,40 +506,27 @@ impl Actor { } Command::ExportPath(cmd) => { trace!("{cmd:?}"); - let ctx = self.hash_context(cmd.hash); - self.spawn(export_path(cmd, ctx)); + cmd.spawn(&mut self.handles, &mut self.tasks).await; } Command::ExportBao(cmd) => { trace!("{cmd:?}"); - let ctx = self.hash_context(cmd.hash); - self.spawn(export_bao(cmd, ctx)); + cmd.spawn(&mut self.handles, &mut self.tasks).await; } Command::ExportRanges(cmd) => { trace!("{cmd:?}"); - let ctx = self.hash_context(cmd.hash); - self.spawn(export_ranges(cmd, ctx)); + cmd.spawn(&mut self.handles, &mut self.tasks).await; } Command::ImportBao(cmd) => { trace!("{cmd:?}"); - let ctx = self.hash_context(cmd.hash); - self.spawn(import_bao(cmd, ctx)); + cmd.spawn(&mut self.handles, &mut self.tasks).await; } Command::Observe(cmd) => { trace!("{cmd:?}"); - let ctx = self.hash_context(cmd.hash); - self.spawn(observe(cmd, ctx)); + cmd.spawn(&mut self.handles, &mut self.tasks).await; } } } - /// Create a hash context for a given hash. - fn hash_context(&mut self, hash: Hash) -> HashContext { - HashContext { - slot: self.handles.entry(hash).or_default().clone(), - ctx: self.context.clone(), - } - } - async fn handle_fs_command(&mut self, cmd: InternalCommand) { let span = cmd.parent_span(); let _entered = span.enter(); @@ -642,8 +554,7 @@ impl Actor { format: cmd.format, }, ); - let ctx = self.hash_context(cmd.hash); - self.spawn(finish_import(cmd, tt, ctx)); + (tt, cmd).spawn(&mut self.handles, &mut self.tasks).await; } } } @@ -652,6 +563,11 @@ impl Actor { async fn run(mut self) { loop { tokio::select! { + task = self.handles.tick() => { + if let Some(task) = task { + self.spawn(task); + } + } cmd = self.cmd_rx.recv() => { let Some(cmd) = cmd else { break; @@ -661,11 +577,15 @@ impl Actor { Some(cmd) = self.fs_cmd_rx.recv() => { self.handle_fs_command(cmd).await; } - Some(res) = self.tasks.join_next_with_id(), if !self.tasks.is_empty() => { - self.log_task_result(res); + Some(res) = self.tasks.join_next(), if !self.tasks.is_empty() => { + Self::log_task_result(res); } } } + self.handles.shutdown().await; + while let Some(res) = self.tasks.join_next().await { + Self::log_task_result(res); + } } async fn new( @@ -708,18 +628,98 @@ impl Actor { }); rt.spawn(db_actor.run()); Ok(Self { - context: slot_context, + context: slot_context.clone(), cmd_rx, fs_cmd_rx: fs_commands_rx, tasks: JoinSet::new(), - running: HashSet::new(), - handles: Default::default(), + handles: EntityManagerState::new(slot_context, 1024, 32, 32, 2), temp_tags: Default::default(), _rt: rt, }) } } +trait HashSpecificCommand: HashSpecific + Send + 'static { + fn handle(self, ctx: HashContext) -> impl Future + Send + 'static; + + fn on_error(self) -> impl Future + Send + 'static; + + async fn spawn( + self, + manager: &mut entity_manager::EntityManagerState, + tasks: &mut JoinSet<()>, + ) where + Self: Sized, + { + let task = manager + .spawn_boxed( + self.hash(), + Box::new(|x| { + Box::pin(async move { + match x { + SpawnArg::Active(state) => { + self.handle(state).await; + } + SpawnArg::Busy => { + self.on_error().await; + } + SpawnArg::Dead => { + self.on_error().await; + } + } + }) + }), + ) + .await; + if let Some(task) = task { + tasks.spawn(task); + } + } +} + +impl HashSpecificCommand for ObserveMsg { + async fn handle(self, ctx: HashContext) { + observe(self, ctx).await + } + async fn on_error(self) {} +} +impl HashSpecificCommand for ExportPathMsg { + async fn handle(self, ctx: HashContext) { + export_path(self, ctx).await + } + async fn on_error(self) {} +} +impl HashSpecificCommand for ExportBaoMsg { + async fn handle(self, ctx: HashContext) { + export_bao(self, ctx).await + } + async fn on_error(self) {} +} +impl HashSpecificCommand for ExportRangesMsg { + async fn handle(self, ctx: HashContext) { + export_ranges(self, ctx).await + } + async fn on_error(self) {} +} +impl HashSpecificCommand for ImportBaoMsg { + async fn handle(self, ctx: HashContext) { + import_bao(self, ctx).await + } + async fn on_error(self) {} +} +impl HashSpecific for (TempTag, ImportEntryMsg) { + fn hash(&self) -> Hash { + self.1.hash() + } +} +impl HashSpecificCommand for (TempTag, ImportEntryMsg) { + async fn handle(self, ctx: HashContext) { + let (tt, cmd) = self; + finish_import(cmd, tt, ctx).await + } + async fn on_error(self) {} +} + struct RtWrapper(Option); impl From for RtWrapper { @@ -811,7 +811,7 @@ async fn finish_import_impl(import_data: ImportEntry, ctx: HashContext) -> io::R } } let guard = ctx.lock().await; - let handle = guard.as_ref().and_then(|x| x.upgrade()); + let handle = guard.as_ref().map(|x| x.clone()); // if I do have an existing handle, I have to possibly deal with observers. // if I don't have an existing handle, there are 2 cases: // the entry exists in the db, but we don't have a handle @@ -892,7 +892,7 @@ async fn finish_import_impl(import_data: ImportEntry, ctx: HashContext) -> io::R data_location, outboard_location, }; - ctx.update(hash, state).await?; + ctx.update_await(hash, state).await?; Ok(()) } @@ -936,7 +936,7 @@ async fn import_bao_impl( // if the batch is not empty, the last item is a leaf and the current item is a parent, write the batch if !batch.is_empty() && batch[batch.len() - 1].is_leaf() && item.is_parent() { let bitfield = Bitfield::new_unchecked(ranges, size.into()); - handle.write_batch(&batch, &bitfield, &ctx.ctx).await?; + handle.write_batch(&batch, &bitfield, &ctx.global).await?; batch.clear(); ranges = ChunkRanges::empty(); } @@ -952,7 +952,7 @@ async fn import_bao_impl( } if !batch.is_empty() { let bitfield = Bitfield::new_unchecked(ranges, size.into()); - handle.write_batch(&batch, &bitfield, &ctx.ctx).await?; + handle.write_batch(&batch, &bitfield, &ctx.global).await?; } Ok(()) } @@ -1028,7 +1028,7 @@ async fn export_ranges_impl( #[instrument(skip_all, fields(hash = %cmd.hash_short()))] async fn export_bao(mut cmd: ExportBaoMsg, ctx: HashContext) { - match ctx.get_maybe_create(cmd.hash, false).await { + match ctx.get(cmd.hash).await { Ok(handle) => { if let Err(cause) = export_bao_impl(cmd.inner, &mut cmd.tx, handle).await { cmd.tx diff --git a/src/store/fs/bao_file.rs b/src/store/fs/bao_file.rs index 410317c2..bf150ae8 100644 --- a/src/store/fs/bao_file.rs +++ b/src/store/fs/bao_file.rs @@ -4,7 +4,7 @@ use std::{ io, ops::Deref, path::Path, - sync::{Arc, Weak}, + sync::Arc, }; use bao_tree::{ @@ -21,24 +21,20 @@ use bytes::{Bytes, BytesMut}; use derive_more::Debug; use irpc::channel::mpsc; use tokio::sync::watch; -use tracing::{debug, error, info, trace, Span}; +use tracing::{debug, error, info, trace}; use super::{ entry_state::{DataLocation, EntryState, OutboardLocation}, - meta::Update, options::{Options, PathOptions}, BaoFilePart, }; use crate::{ api::blobs::Bitfield, store::{ - fs::{ - meta::{raw_outboard_size, Set}, - TaskContext, - }, + fs::{meta::raw_outboard_size, TaskContext}, util::{ read_checksummed_and_truncate, write_checksummed, FixedSize, MemOrFile, - PartialMemStorage, SizeInfo, SparseMemFile, DD, + PartialMemStorage, DD, }, Hash, IROH_BLOCK_SIZE, }, @@ -507,27 +503,6 @@ impl BaoFileStorage { } } -/// A weak reference to a bao file handle. -#[derive(Debug, Clone)] -pub struct BaoFileHandleWeak(Weak); - -impl BaoFileHandleWeak { - /// Upgrade to a strong reference if possible. - pub fn upgrade(&self) -> Option { - let inner = self.0.upgrade()?; - if let &BaoFileStorage::Poisoned = inner.storage.borrow().deref() { - trace!("poisoned storage, cannot upgrade"); - return None; - }; - Some(BaoFileHandle(inner)) - } - - /// True if the handle is definitely dead. - pub fn is_dead(&self) -> bool { - self.0.strong_count() == 0 - } -} - /// The inner part of a bao file handle. pub struct BaoFileHandleInner { pub(crate) storage: watch::Sender, @@ -550,19 +525,12 @@ impl fmt::Debug for BaoFileHandleInner { #[derive(Debug, Clone, derive_more::Deref)] pub struct BaoFileHandle(Arc); -impl Drop for BaoFileHandle { - fn drop(&mut self) { +impl BaoFileHandle { + pub fn persist(&mut self) { self.0.storage.send_if_modified(|guard| { if Arc::strong_count(&self.0) > 1 { return false; } - // there is the possibility that somebody else will increase the strong count - // here. there is nothing we can do about it, but they won't be able to - // access the internals of the handle because we have the lock. - // - // We poison the storage. A poisoned storage is considered dead and will - // have to be recreated, but only *after* we are done with persisting - // the bitfield. let BaoFileStorage::Partial(fs) = guard.take() else { return false; }; @@ -586,6 +554,12 @@ impl Drop for BaoFileHandle { } } +impl Drop for BaoFileHandle { + fn drop(&mut self) { + self.persist(); + } +} + /// A reader for a bao file, reading just the data. #[derive(Debug)] pub struct DataReader(BaoFileHandle); @@ -644,21 +618,7 @@ impl BaoFileHandle { let size = storage.bitfield.size; let (storage, entry_state) = storage.into_complete(size, &options)?; debug!("File was reconstructed as complete"); - let (tx, rx) = crate::util::channel::oneshot::channel(); - ctx.db - .sender - .send( - Set { - hash, - state: entry_state, - tx, - span: Span::current(), - } - .into(), - ) - .await - .map_err(|_| io::Error::other("send update"))?; - rx.await.map_err(|_| io::Error::other("receive update"))??; + ctx.db.set(hash, entry_state).await?; storage.into() } else { storage.into() @@ -771,11 +731,6 @@ impl BaoFileHandle { self.hash } - /// Downgrade to a weak reference. - pub fn downgrade(&self) -> BaoFileHandleWeak { - BaoFileHandleWeak(Arc::downgrade(&self.0)) - } - /// Write a batch and notify the db pub(super) async fn write_batch( &self, @@ -796,26 +751,14 @@ impl BaoFileHandle { true }); if let Some(update) = res? { - ctx.db - .sender - .send( - Update { - hash: self.hash, - state: update, - tx: None, - span: Span::current(), - } - .into(), - ) - .await - .map_err(|_| io::Error::other("send update"))?; + ctx.db.update(self.hash, update).await?; } Ok(()) } } impl PartialMemStorage { - /// Persist the batch to disk, creating a FileBatch. + /// Persist the batch to disk. fn persist(self, ctx: &TaskContext, hash: &Hash) -> io::Result { let options = &ctx.options.path; ctx.protect.protect( @@ -843,12 +786,6 @@ impl PartialMemStorage { bitfield: self.bitfield, }) } - - /// Get the parts data, outboard and sizes - #[allow(dead_code)] - pub fn into_parts(self) -> (SparseMemFile, SparseMemFile, SizeInfo) { - (self.data, self.outboard, self.size) - } } pub struct BaoFileStorageSubscriber { diff --git a/src/store/fs/gc.rs b/src/store/fs/gc.rs index a394dc19..70333f3e 100644 --- a/src/store/fs/gc.rs +++ b/src/store/fs/gc.rs @@ -192,6 +192,7 @@ mod tests { use std::{ io::{self}, path::Path, + time::Duration, }; use bao_tree::{io::EncodeError, ChunkNum}; @@ -299,6 +300,7 @@ mod tests { let outboard_path = options.outboard_path(&bh); let sizes_path = options.sizes_path(&bh); let bitfield_path = options.bitfield_path(&bh); + tokio::time::sleep(Duration::from_millis(100)).await; // allow for some time for the file to be written assert!(data_path.exists()); assert!(outboard_path.exists()); assert!(sizes_path.exists()); diff --git a/src/store/fs/meta.rs b/src/store/fs/meta.rs index 617db98c..21fbd9ed 100644 --- a/src/store/fs/meta.rs +++ b/src/store/fs/meta.rs @@ -34,7 +34,7 @@ mod proto; pub use proto::*; pub(crate) mod tables; use tables::{ReadOnlyTables, ReadableTables, Tables}; -use tracing::{debug, error, info_span, trace}; +use tracing::{debug, error, info_span, trace, Span}; use super::{ delete_set::DeleteHandle, @@ -88,7 +88,7 @@ pub type ActorResult = Result; #[derive(Debug, Clone)] pub struct Db { - pub sender: tokio::sync::mpsc::Sender, + sender: tokio::sync::mpsc::Sender, } impl Db { @@ -96,8 +96,71 @@ impl Db { Self { sender } } + pub async fn snapshot(&self, span: tracing::Span) -> io::Result { + let (tx, rx) = tokio::sync::oneshot::channel(); + self.sender + .send(Snapshot { tx, span }.into()) + .await + .map_err(|_| io::Error::other("send snapshot"))?; + rx.await.map_err(|_| io::Error::other("receive snapshot")) + } + + pub async fn update_await(&self, hash: Hash, state: EntryState) -> io::Result<()> { + let (tx, rx) = oneshot::channel(); + self.sender + .send( + Update { + hash, + state, + tx: Some(tx), + span: tracing::Span::current(), + } + .into(), + ) + .await + .map_err(|_| io::Error::other("send update"))?; + rx.await + .map_err(|_e| io::Error::other("receive update"))??; + Ok(()) + } + + /// Update the entry state for a hash, without awaiting completion. + pub async fn update(&self, hash: Hash, state: EntryState) -> io::Result<()> { + self.sender + .send( + Update { + hash, + state, + tx: None, + span: Span::current(), + } + .into(), + ) + .await + .map_err(|_| io::Error::other("send update")) + } + + /// Set the entry state and await completion. + pub async fn set(&self, hash: Hash, entry_state: EntryState) -> io::Result<()> { + let (tx, rx) = oneshot::channel(); + self.sender + .send( + Set { + hash, + state: entry_state, + tx, + span: Span::current(), + } + .into(), + ) + .await + .map_err(|_| io::Error::other("send update"))?; + rx.await.map_err(|_| io::Error::other("receive update"))??; + Ok(()) + } + /// Get the entry state for a hash, if any. - pub async fn get(&self, hash: Hash) -> anyhow::Result>> { + pub async fn get(&self, hash: Hash) -> io::Result>> { let (tx, rx) = oneshot::channel(); self.sender .send( @@ -108,8 +171,9 @@ impl Db { } .into(), ) - .await?; - let res = rx.await?; + .await + .map_err(|_| io::Error::other("send get"))?; + let res = rx.await.map_err(|_| io::Error::other("receive get"))?; Ok(res.state?) } diff --git a/src/store/fs/util.rs b/src/store/fs/util.rs index f2949a7c..1cbd01bc 100644 --- a/src/store/fs/util.rs +++ b/src/store/fs/util.rs @@ -1,6 +1,7 @@ use std::future::Future; use tokio::{select, sync::mpsc}; +pub(crate) mod entity_manager; /// A wrapper for a tokio mpsc receiver that allows peeking at the next message. #[derive(Debug)] diff --git a/src/store/fs/util/entity_manager.rs b/src/store/fs/util/entity_manager.rs new file mode 100644 index 00000000..f9628434 --- /dev/null +++ b/src/store/fs/util/entity_manager.rs @@ -0,0 +1,1322 @@ +#![allow(dead_code)] +use std::{fmt::Debug, future::Future, hash::Hash}; + +use n0_future::{future, FuturesUnordered}; +use tokio::sync::{mpsc, oneshot}; + +/// Trait to reset an entity state in place. +/// +/// In many cases this is just assigning the default value, but e.g. for an +/// `Arc>` resetting to the default value means an allocation, whereas +/// reset can be done without. +pub trait Reset: Default { + /// Reset the state to its default value. + fn reset(&mut self); +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ShutdownCause { + /// The entity is shutting down gracefully because the entity is idle. + Idle, + /// The entity is shutting down because the entity manager is shutting down. + Soft, + /// The entity is shutting down because the sender was dropped. + Drop, +} + +/// Parameters for the entity manager system. +pub trait Params: Send + Sync + 'static { + /// Entity id type. + /// + /// This does not require Copy to allow for more complex types, such as `String`, + /// but you have to make sure that ids are small and cheap to clone, since they are + /// used as keys in maps. + type EntityId: Debug + Hash + Eq + Clone + Send + Sync + 'static; + /// Global state type. + /// + /// This is passed into all entity actors. It also needs to be cheap handle. + /// If you don't need it, just set it to `()`. + type GlobalState: Debug + Clone + Send + Sync + 'static; + /// Entity state type. + /// + /// This is the actual distinct per-entity state. This needs to implement + /// `Default` and a matching `Reset`. It also needs to implement `Clone` + /// since we unfortunately need to pass an owned copy of the state to the + /// callback - otherwise we run into some rust lifetime limitations + /// . + /// + /// Frequently this is an `Arc>` or similar. Note that per entity + /// access is concurrent but not parallel, so you can use a more efficient + /// synchronization primitive like [`AtomicRefCell`](https://crates.io/crates/atomic_refcell) if you want to. + type EntityState: Default + Debug + Reset + Clone + Send + Sync + 'static; + /// Function being called when an entity actor is shutting down. + fn on_shutdown( + state: entity_actor::State, + cause: ShutdownCause, + ) -> impl Future + Send + 'static + where + Self: Sized; +} + +/// Sent to the main actor and then delegated to the entity actor to spawn a new task. +pub(crate) struct Spawn { + id: P::EntityId, + f: Box) -> future::Boxed<()> + Send>, +} + +pub(crate) struct EntityShutdown; + +/// Argument for the `EntityManager::spawn` function. +pub enum SpawnArg { + /// The entity is active, and we were able to spawn a task. + Active(ActiveEntityState

), + /// The entity is busy and cannot spawn a new task. + Busy, + /// The entity is dead. + Dead, +} + +/// Sent from the entity actor to the main actor to notify that it is shutting down. +/// +/// With this message the entity actor gives back the receiver for its command channel, +/// so it can be reusd either immediately if commands come in during shutdown, or later +/// if the entity actor is reused for a different entity. +struct Shutdown { + id: P::EntityId, + receiver: mpsc::Receiver>, +} + +struct ShutdownAll { + tx: oneshot::Sender<()>, +} + +/// Sent from the main actor to the entity actor to notify that it has completed shutdown. +/// +/// With this message the entity actor sends back the remaining state. The tasks set +/// at this point must be empty, as the entity actor has already completed all tasks. +struct ShutdownComplete { + state: ActiveEntityState

, + tasks: FuturesUnordered>, +} + +mod entity_actor { + #![allow(dead_code)] + use n0_future::{future, FuturesUnordered, StreamExt}; + use tokio::sync::mpsc; + + use super::{ + EntityShutdown, Params, Reset, Shutdown, ShutdownCause, ShutdownComplete, Spawn, SpawnArg, + }; + + /// State of an active entity. + #[derive(Debug)] + pub struct State { + /// The entity id. + pub id: P::EntityId, + /// A copy of the global state. + pub global: P::GlobalState, + /// The per-entity state which might have internal mutability. + pub state: P::EntityState, + } + + impl Clone for State

{ + fn clone(&self) -> Self { + Self { + id: self.id.clone(), + global: self.global.clone(), + state: self.state.clone(), + } + } + } + + pub enum Command { + Spawn(Spawn

), + EntityShutdown(EntityShutdown), + } + + impl From for Command

{ + fn from(_: EntityShutdown) -> Self { + Self::EntityShutdown(EntityShutdown) + } + } + + #[derive(Debug)] + pub struct Actor { + pub recv: mpsc::Receiver>, + pub main: mpsc::Sender>, + pub state: State

, + pub tasks: FuturesUnordered>, + } + + impl Actor

{ + pub async fn run(mut self) { + loop { + tokio::select! { + command = self.recv.recv() => { + let Some(command) = command else { + // Channel closed, this means that the main actor is shutting down. + self.drop_shutdown_state().await; + break; + }; + match command { + Command::Spawn(spawn) => { + let task = (spawn.f)(SpawnArg::Active(self.state.clone())); + self.tasks.push(task); + } + Command::EntityShutdown(_) => { + self.soft_shutdown_state().await; + break; + } + } + } + Some(_) = self.tasks.next(), if !self.tasks.is_empty() => {} + } + if self.tasks.is_empty() && self.recv.is_empty() { + // No more tasks and no more commands, we can recycle the actor. + self.recycle_state().await; + break; // Exit the loop, actor is done. + } + } + } + + /// drop shutdown state. + /// + /// All senders for our receive channel were dropped, so we shut down without waiting for any tasks to complete. + async fn drop_shutdown_state(self) { + let Self { state, .. } = self; + P::on_shutdown(state, ShutdownCause::Drop).await; + } + + /// Soft shutdown state. + /// + /// We have received an explicit shutdown command, so we wait for all tasks to complete and then call the shutdown function. + async fn soft_shutdown_state(mut self) { + while (self.tasks.next().await).is_some() {} + P::on_shutdown(self.state.clone(), ShutdownCause::Soft).await; + } + + async fn recycle_state(self) { + // we can't check if recv is empty here, since new messages might come in while we are in recycle_state. + assert!( + self.tasks.is_empty(), + "Tasks must be empty before recycling" + ); + // notify main actor that we are starting to shut down. + // if the main actor is shutting down, this could fail, but we don't care. + self.main + .send( + Shutdown { + id: self.state.id.clone(), + receiver: self.recv, + } + .into(), + ) + .await + .ok(); + P::on_shutdown(self.state.clone(), ShutdownCause::Idle).await; + // Notify the main actor that we have completed shutdown. + // here we also give back the rest of ourselves so the main actor can recycle us. + self.main + .send( + ShutdownComplete { + state: self.state, + tasks: self.tasks, + } + .into(), + ) + .await + .ok(); + } + + /// Recycle the actor for reuse by setting its state to default. + /// + /// This also checks several invariants: + /// - There must be no pending messages in the receive channel. + /// - The sender must have a strong count of 1, meaning no other references exist + /// - The tasks set must be empty, meaning no tasks are running. + /// - The global state must match the scope provided. + /// - The state must be unique to the actor, meaning no other references exist. + pub fn recycle(&mut self) { + assert!( + self.recv.is_empty(), + "Cannot recycle actor with pending messages" + ); + assert!( + self.recv.sender_strong_count() == 1, + "There must be only one sender left" + ); + assert!( + self.tasks.is_empty(), + "Tasks must be empty before recycling" + ); + self.state.state.reset(); + } + } +} +pub use entity_actor::State as ActiveEntityState; +pub use main_actor::ActorState as EntityManagerState; + +mod main_actor { + #![allow(dead_code)] + use std::{collections::HashMap, future::Future}; + + use n0_future::{future, FuturesUnordered}; + use tokio::{sync::mpsc, task::JoinSet}; + use tracing::{error, warn}; + + use super::{ + entity_actor, EntityShutdown, Params, Reset, Shutdown, ShutdownAll, ShutdownComplete, + Spawn, SpawnArg, + }; + + pub(super) enum Command { + Spawn(Spawn

), + ShutdownAll(ShutdownAll), + } + + impl From for Command

{ + fn from(shutdown_all: ShutdownAll) -> Self { + Self::ShutdownAll(shutdown_all) + } + } + + pub(super) enum InternalCommand { + ShutdownComplete(ShutdownComplete

), + Shutdown(Shutdown

), + } + + impl From> for InternalCommand

{ + fn from(shutdown: Shutdown

) -> Self { + Self::Shutdown(shutdown) + } + } + + impl From> for InternalCommand

{ + fn from(shutdown_complete: ShutdownComplete

) -> Self { + Self::ShutdownComplete(shutdown_complete) + } + } + + #[derive(Debug)] + pub enum EntityHandle { + /// A running entity actor. + Live { + send: mpsc::Sender>, + }, + ShuttingDown { + send: mpsc::Sender>, + recv: mpsc::Receiver>, + }, + } + + impl EntityHandle

{ + pub fn send(&self) -> &mpsc::Sender> { + match self { + EntityHandle::Live { send } => send, + EntityHandle::ShuttingDown { send, .. } => send, + } + } + } + + /// State machine for an entity actor manager. + /// + /// This is if you don't want a separate manager actor, but want to inline the entity + /// actor management into your main actor. + #[derive(Debug)] + pub struct ActorState { + /// Channel to receive internal commands from the entity actors. + /// This channel will never be closed since we also hold a sender to it. + internal_recv: mpsc::Receiver>, + /// Channel to send internal commands to ourselves, to hand out to entity actors. + internal_send: mpsc::Sender>, + /// Map of live entity actors. + live: HashMap>, + /// Global state shared across all entity actors. + state: P::GlobalState, + /// Pool of inactive entity actors to reuse. + pool: Vec<( + mpsc::Sender>, + entity_actor::Actor

, + )>, + /// Maximum size of the inbox of an entity actor. + entity_inbox_size: usize, + /// Initial capacity of the futures set for entity actors. + entity_futures_initial_capacity: usize, + } + + impl ActorState

{ + pub fn new( + state: P::GlobalState, + pool_capacity: usize, + entity_inbox_size: usize, + entity_response_inbox_size: usize, + entity_futures_initial_capacity: usize, + ) -> Self { + let (internal_send, internal_recv) = mpsc::channel(entity_response_inbox_size); + Self { + internal_recv, + internal_send, + live: HashMap::new(), + state, + pool: Vec::with_capacity(pool_capacity), + entity_inbox_size, + entity_futures_initial_capacity, + } + } + + #[must_use = "this function may return a future that must be spawned by the caller"] + /// Friendly version of `spawn_boxed` that does the boxing + pub async fn spawn( + &mut self, + id: P::EntityId, + f: F, + ) -> Option + Send + 'static> + where + F: FnOnce(SpawnArg

) -> Fut + Send + 'static, + Fut: Future + Send + 'static, + { + self.spawn_boxed( + id, + Box::new(|x| { + Box::pin(async move { + f(x).await; + }) + }), + ) + .await + } + + #[must_use = "this function may return a future that must be spawned by the caller"] + pub async fn spawn_boxed( + &mut self, + id: P::EntityId, + f: Box) -> future::Boxed<()> + Send>, + ) -> Option + Send + 'static> { + let (entity_handle, task) = self.get_or_create(id.clone()); + let sender = entity_handle.send(); + if let Err(e) = + sender.try_send(entity_actor::Command::Spawn(Spawn { id: id.clone(), f })) + { + match e { + mpsc::error::TrySendError::Full(cmd) => { + let entity_actor::Command::Spawn(spawn) = cmd else { + panic!() + }; + warn!( + "Entity actor inbox is full, cannot send command to entity actor {:?}.", + id + ); + // we await in the select here, but I think this is fine, since the actor is busy. + // maybe slowing things down a bit is helpful. + (spawn.f)(SpawnArg::Busy).await; + } + mpsc::error::TrySendError::Closed(cmd) => { + let entity_actor::Command::Spawn(spawn) = cmd else { + panic!() + }; + error!( + "Entity actor inbox is closed, cannot send command to entity actor {:?}.", + id + ); + // give the caller a chance to react to this bad news. + // at this point we are in trouble anyway, so awaiting is going to be the least of our problems. + (spawn.f)(SpawnArg::Dead).await; + } + } + }; + task + } + + /// This function needs to be polled by the owner of the actor state to advance the + /// entity manager state machine. If it returns a future, that future must be spawned + /// by the caller. + #[must_use = "this function may return a future that must be spawned by the caller"] + pub async fn tick(&mut self) -> Option + Send + 'static> { + if let Some(cmd) = self.internal_recv.recv().await { + match cmd { + InternalCommand::Shutdown(Shutdown { id, receiver }) => { + let Some(entity_handle) = self.live.remove(&id) else { + error!("Received shutdown command for unknown entity actor {id:?}"); + return None; + }; + let EntityHandle::Live { send } = entity_handle else { + error!( + "Received shutdown command for entity actor {id:?} that is already shutting down" + ); + return None; + }; + self.live.insert( + id.clone(), + EntityHandle::ShuttingDown { + send, + recv: receiver, + }, + ); + } + InternalCommand::ShutdownComplete(ShutdownComplete { state, tasks }) => { + let id = state.id.clone(); + let Some(entity_handle) = self.live.remove(&id) else { + error!( + "Received shutdown complete command for unknown entity actor {id:?}" + ); + return None; + }; + let EntityHandle::ShuttingDown { send, recv } = entity_handle else { + error!( + "Received shutdown complete command for entity actor {id:?} that is not shutting down" + ); + return None; + }; + // re-assemble the actor from the parts + let mut actor = entity_actor::Actor { + main: self.internal_send.clone(), + recv, + state, + tasks, + }; + if actor.recv.is_empty() { + // No commands during shutdown, we can recycle the actor. + self.recycle(send, actor); + } else { + actor.state.state.reset(); + self.live.insert(id.clone(), EntityHandle::Live { send }); + return Some(actor.run()); + } + } + } + } + None + } + + /// Send a shutdown command to all live entity actors. + pub async fn shutdown(self) { + for handle in self.live.values() { + handle.send().send(EntityShutdown {}.into()).await.ok(); + } + } + + /// Get or create an entity actor for the given id. + /// + /// If this function returns a future, it must be spawned by the caller. + fn get_or_create( + &mut self, + id: P::EntityId, + ) -> ( + &mut EntityHandle

, + Option + Send + 'static>, + ) { + let mut task = None; + let handle = self.live.entry(id.clone()).or_insert_with(|| { + if let Some((send, mut actor)) = self.pool.pop() { + // Get an actor from the pool of inactive actors and initialize it. + actor.state.id = id.clone(); + actor.state.global = self.state.clone(); + // strictly speaking this is not needed, since we reset the state when adding the actor to the pool. + actor.state.state.reset(); + task = Some(actor.run()); + EntityHandle::Live { send } + } else { + // Create a new entity actor and inbox. + let (send, recv) = mpsc::channel(self.entity_inbox_size); + let state: entity_actor::State

= entity_actor::State { + id: id.clone(), + global: self.state.clone(), + state: Default::default(), + }; + let actor = entity_actor::Actor { + main: self.internal_send.clone(), + recv, + state, + tasks: FuturesUnordered::with_capacity( + self.entity_futures_initial_capacity, + ), + }; + task = Some(actor.run()); + EntityHandle::Live { send } + } + }); + (handle, task) + } + + fn recycle( + &mut self, + sender: mpsc::Sender>, + mut actor: entity_actor::Actor

, + ) { + assert!(sender.strong_count() == 1); + // todo: check that sender and receiver are the same channel. tokio does not have an api for this, unfortunately. + // reset the actor in any case, just to check the invariants. + actor.recycle(); + // Recycle the actor for later use. + if self.pool.len() < self.pool.capacity() { + self.pool.push((sender, actor)); + } + } + } + + pub struct Actor { + /// Channel to receive commands from the outside world. + /// If this channel is closed, it means we need to shut down in a hurry. + recv: mpsc::Receiver>, + /// Tasks that are currently running. + tasks: JoinSet<()>, + /// Internal state of the actor + state: ActorState

, + } + + impl Actor

{ + pub fn new( + state: P::GlobalState, + recv: tokio::sync::mpsc::Receiver>, + pool_capacity: usize, + entity_inbox_size: usize, + entity_response_inbox_size: usize, + entity_futures_initial_capacity: usize, + ) -> Self { + Self { + recv, + tasks: JoinSet::new(), + state: ActorState::new( + state, + pool_capacity, + entity_inbox_size, + entity_response_inbox_size, + entity_futures_initial_capacity, + ), + } + } + + pub async fn run(mut self) { + enum SelectOutcome { + Command(A), + Tick(B), + TaskDone(C), + } + loop { + let res = tokio::select! { + x = self.recv.recv() => SelectOutcome::Command(x), + x = self.state.tick() => SelectOutcome::Tick(x), + Some(task) = self.tasks.join_next(), if !self.tasks.is_empty() => SelectOutcome::TaskDone(task), + }; + match res { + SelectOutcome::Command(cmd) => { + let Some(cmd) = cmd else { + // Channel closed, this means that the main actor is shutting down. + self.hard_shutdown().await; + break; + }; + match cmd { + Command::Spawn(spawn) => { + if let Some(task) = self.state.spawn_boxed(spawn.id, spawn.f).await + { + self.tasks.spawn(task); + } + } + Command::ShutdownAll(arg) => { + self.soft_shutdown().await; + arg.tx.send(()).ok(); + break; + } + } + // Handle incoming command + } + SelectOutcome::Tick(future) => { + if let Some(task) = future { + self.tasks.spawn(task); + } + } + SelectOutcome::TaskDone(result) => { + // Handle completed task + if let Err(e) = result { + error!("Task failed: {e:?}"); + } + } + } + } + } + + async fn soft_shutdown(self) { + let Self { + mut tasks, state, .. + } = self; + state.shutdown().await; + while let Some(res) = tasks.join_next().await { + if let Err(e) = res { + eprintln!("Task failed during shutdown: {e:?}"); + } + } + } + + async fn hard_shutdown(self) { + let Self { + mut tasks, state, .. + } = self; + // this is needed so calls to internal_send in idle shutdown fail fast. + // otherwise we would have to drain the channel, but we don't care about the messages at + // this point. + drop(state); + while let Some(res) = tasks.join_next().await { + if let Err(e) = res { + eprintln!("Task failed during shutdown: {e:?}"); + } + } + } + } +} + +/// A manager for entities identified by an entity id. +/// +/// The manager provides parallelism between entities, but just concurrency within a single entity. +/// This is useful if the entity wraps an external resource such as a file that does not benefit +/// from parallelism. +/// +/// The entity manager internally uses a main actor and per-entity actors. Per entity actors +/// and their inbox queues are recycled when they become idle, to save allocations. +/// +/// You can mostly ignore these implementation details, except when you want to customize the +/// queue sizes in the [`Options`] struct. +/// +/// The main entry point is the [`EntityManager::spawn`] function. +/// +/// Dropping the `EntityManager` will shut down the entity actors without waiting for their +/// tasks to complete. For a more gentle shutdown, use the [`EntityManager::shutdown`] function +/// that does wait for tasks to complete. +#[derive(Debug, Clone)] +pub struct EntityManager(mpsc::Sender>); + +#[derive(Debug, Clone, Copy)] +pub struct Options { + /// Maximum number of inactive entity actors that are being pooled for reuse. + pub pool_capacity: usize, + /// Size of the inbox for the manager actor. + pub inbox_size: usize, + /// Size of the inbox for entity actors. + pub entity_inbox_size: usize, + /// Size of the inbox for entity actor responses to the manager actor. + pub entity_response_inbox_size: usize, + /// Initial capacity of the futures set for entity actors. + /// + /// Set this to the expected average concurrency level of your entities. + pub entity_futures_initial_capacity: usize, +} + +impl Default for Options { + fn default() -> Self { + Self { + pool_capacity: 10, + inbox_size: 10, + entity_inbox_size: 10, + entity_response_inbox_size: 100, + entity_futures_initial_capacity: 16, + } + } +} + +impl EntityManager

{ + pub fn new(state: P::GlobalState, options: Options) -> Self { + let (send, recv) = mpsc::channel(options.inbox_size); + let actor = main_actor::Actor::new( + state, + recv, + options.pool_capacity, + options.entity_inbox_size, + options.entity_response_inbox_size, + options.entity_futures_initial_capacity, + ); + tokio::spawn(actor.run()); + Self(send) + } + + /// Spawn a new task on the entity actor with the given id. + /// + /// Unless the world is ending - e.g. tokio runtime is shutting down - the passed function + /// is guaranteed to be called. However, there is no guarantee that the entity actor is + /// alive and responsive. See [`SpawnArg`] for details. + /// + /// Multiple callbacks for the same entity will be executed sequentially. There is no + /// parallelism within a single entity. So you can use synchronization primitives that + /// assume unique access in P::EntityState. And even if you do use multithreaded synchronization + /// primitives, they will never be contended. + /// + /// The future returned by `f` will be executed concurrently with other tasks, but again + /// there will be no real parallelism within a single entity actor. + pub async fn spawn(&self, id: P::EntityId, f: F) -> Result<(), &'static str> + where + F: FnOnce(SpawnArg

) -> Fut + Send + 'static, + Fut: future::Future + Send + 'static, + { + let spawn = Spawn { + id, + f: Box::new(|arg| { + Box::pin(async move { + f(arg).await; + }) + }), + }; + self.0 + .send(main_actor::Command::Spawn(spawn)) + .await + .map_err(|_| "Failed to send spawn command") + } + + pub async fn shutdown(&self) -> std::result::Result<(), &'static str> { + let (tx, rx) = oneshot::channel(); + self.0 + .send(ShutdownAll { tx }.into()) + .await + .map_err(|_| "Failed to send shutdown command")?; + rx.await + .map_err(|_| "Failed to receive shutdown confirmation") + } +} + +#[cfg(test)] +mod tests { + //! Tests for the entity manager. + //! + //! We implement a simple database for u128 counters, identified by u64 ids, + //! with both an in-memory and a file-based implementation. + //! + //! The database does internal consistency checks, to ensure that each + //! entity is only ever accessed by a single tokio task at a time, and to + //! ensure that wakeup and shutdown events are interleaved. + //! + //! We also check that the database behaves correctly by comparing with an + //! in-memory implementation. + //! + //! Database operations are done in parallel, so the fact that we are using + //! AtomicRefCell provides another test - if there was parallel write access + //! to a single entity due to a bug, it would panic. + use std::collections::HashMap; + + use n0_future::{BufferedStreamExt, StreamExt}; + use testresult::TestResult; + + use super::*; + + // a simple database for u128 counters, identified by u64 ids. + trait CounterDb { + async fn add(&self, id: u64, value: u128) -> Result<(), &'static str>; + async fn get(&self, id: u64) -> Result; + async fn shutdown(&self) -> Result<(), &'static str>; + async fn check_consistency(&self, values: HashMap); + } + + #[derive(Debug, PartialEq, Eq)] + enum Event { + Wakeup, + Shutdown, + } + + mod mem { + //! The in-memory database uses a HashMap in the global state to store + //! the values of the counters. Loading means reading from the global + //! state into the entity state, and persisting means writing to the + //! global state from the entity state. + use std::{ + collections::{HashMap, HashSet}, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, Mutex, + }, + time::Instant, + }; + + use atomic_refcell::AtomicRefCell; + + use super::*; + + #[derive(Debug, Default)] + struct Inner { + value: Option, + tasks: HashSet, + } + + #[derive(Debug, Clone, Default)] + struct State(Arc>); + + impl Reset for State { + fn reset(&mut self) { + *self.0.borrow_mut() = Default::default(); + } + } + + #[derive(Debug, Default)] + struct Global { + // the "database" of entity values + data: HashMap, + // log of awake and shutdown events + log: HashMap>, + } + + struct Counters; + impl Params for Counters { + type EntityId = u64; + type GlobalState = Arc>; + type EntityState = State; + async fn on_shutdown(entity: entity_actor::State, _cause: ShutdownCause) { + let state = entity.state.0.borrow(); + let mut global = entity.global.lock().unwrap(); + assert_eq!(state.tasks.len(), 1); + // persist the state + if let Some(value) = state.value { + global.data.insert(entity.id, value); + } + // log the shutdown event + global + .log + .entry(entity.id) + .or_default() + .push((Event::Shutdown, Instant::now())); + } + } + + pub struct MemDb { + m: EntityManager, + global: Arc>, + } + + impl entity_actor::State { + async fn with_value(&self, f: impl FnOnce(&mut u128)) -> Result<(), &'static str> { + let mut state = self.state.0.borrow_mut(); + // lazily load the data from the database + if state.value.is_none() { + let mut global = self.global.lock().unwrap(); + state.value = Some(global.data.get(&self.id).copied().unwrap_or_default()); + // log the wakeup event + global + .log + .entry(self.id) + .or_default() + .push((Event::Wakeup, Instant::now())); + } + // insert the task id into the tasks set to check that access is always + // from the same tokio task (not necessarily the same thread). + state.tasks.insert(tokio::task::id()); + // do the actual work + let r = state.value.as_mut().unwrap(); + f(r); + Ok(()) + } + } + + impl MemDb { + pub fn new() -> Self { + let global = Arc::new(Mutex::new(Global::default())); + Self { + global: global.clone(), + m: EntityManager::::new(global, Options::default()), + } + } + } + + impl super::CounterDb for MemDb { + async fn add(&self, id: u64, value: u128) -> Result<(), &'static str> { + self.m + .spawn(id, move |arg| async move { + match arg { + SpawnArg::Active(state) => { + state + .with_value(|v| *v = v.wrapping_add(value)) + .await + .unwrap(); + } + SpawnArg::Busy => println!("Entity actor is busy"), + SpawnArg::Dead => println!("Entity actor is dead"), + } + }) + .await + } + + async fn get(&self, id: u64) -> Result { + let (tx, rx) = oneshot::channel(); + self.m + .spawn(id, move |arg| async move { + match arg { + SpawnArg::Active(state) => { + state + .with_value(|v| { + tx.send(*v) + .unwrap_or_else(|_| println!("Failed to send value")) + }) + .await + .unwrap(); + } + SpawnArg::Busy => println!("Entity actor is busy"), + SpawnArg::Dead => println!("Entity actor is dead"), + } + }) + .await?; + rx.await.map_err(|_| "Failed to receive value") + } + + async fn shutdown(&self) -> Result<(), &'static str> { + self.m.shutdown().await + } + + async fn check_consistency(&self, values: HashMap) { + let global = self.global.lock().unwrap(); + assert_eq!(global.data, values, "Data mismatch"); + for id in values.keys() { + let log = global.log.get(id).unwrap(); + assert!( + log.len() % 2 == 0, + "Log must contain alternating wakeup and shutdown events" + ); + for (i, (event, _)) in log.iter().enumerate() { + assert_eq!( + *event, + if i % 2 == 0 { + Event::Wakeup + } else { + Event::Shutdown + }, + "Unexpected event type" + ); + } + } + } + } + + /// If a task is so busy that it can't drain it's inbox in time, we will + /// get a SpawnArg::Busy instead of access to the actual state. + /// + /// This will only happen if the system is seriously overloaded, since + /// the entity actor just spawns tasks for each message. So here we + /// simulate it by just not spawning the task as we are supposed to. + #[tokio::test] + async fn test_busy() -> TestResult<()> { + let mut state = EntityManagerState::::new( + Arc::new(Mutex::new(Global::default())), + 1024, + 8, + 8, + 2, + ); + let active = Arc::new(AtomicUsize::new(0)); + let busy = Arc::new(AtomicUsize::new(0)); + let inc = || { + let active = active.clone(); + let busy = busy.clone(); + |arg: SpawnArg| async move { + match arg { + SpawnArg::Active(_) => { + active.fetch_add(1, Ordering::SeqCst); + } + SpawnArg::Busy => { + busy.fetch_add(1, Ordering::SeqCst); + } + SpawnArg::Dead => { + println!("Entity actor is dead"); + } + } + } + }; + let fut1 = state.spawn(1, inc()).await; + assert!(fut1.is_some(), "First spawn should give us a task to spawn"); + for _ in 0..9 { + let fut = state.spawn(1, inc()).await; + assert!( + fut.is_none(), + "Subsequent spawns should assume first task has been spawned" + ); + } + assert_eq!( + active.load(Ordering::SeqCst), + 0, + "Active should have never been called, since we did not spawn the task!" + ); + assert_eq!(busy.load(Ordering::SeqCst), 2, "Busy should have been called two times, since we sent 10 msgs to a queue with capacity 8, and nobody is draining it"); + Ok(()) + } + + /// If there is a panic in any of the fns that run on an entity actor, + /// the entire entity becomes dead. This can not be recovered from, and + /// trying to spawn a new task on the dead entity actor will result in + /// a SpawnArg::Dead. + #[tokio::test] + async fn test_dead() -> TestResult<()> { + let manager = EntityManager::::new( + Arc::new(Mutex::new(Global::default())), + Options::default(), + ); + let (tx, rx) = oneshot::channel(); + let killer = |arg: SpawnArg| async move { + if let SpawnArg::Active(_) = arg { + tx.send(()).ok(); + panic!("Panic to kill the task"); + } + }; + // spawn a task that kills the entity actor + manager.spawn(1, killer).await?; + rx.await.expect("Failed to receive kill confirmation"); + let (tx, rx) = oneshot::channel(); + let counter = |arg: SpawnArg| async move { + if let SpawnArg::Dead = arg { + tx.send(()).ok(); + } + }; + // // spawn another task on the - now dead - entity actor + manager.spawn(1, counter).await?; + rx.await.expect("Failed to receive dead confirmation"); + Ok(()) + } + } + + mod fs { + //! The fs db uses one file per counter, stored as a 16-byte big-endian u128. + use std::{ + collections::HashSet, + path::{Path, PathBuf}, + sync::{Arc, Mutex}, + time::Instant, + }; + + use atomic_refcell::AtomicRefCell; + + use super::*; + + #[derive(Debug, Clone, Default)] + struct State { + value: Option, + tasks: HashSet, + } + + #[derive(Debug)] + struct Global { + path: PathBuf, + log: HashMap>, + } + + #[derive(Debug, Clone, Default)] + struct EntityState(Arc>); + + impl Reset for EntityState { + fn reset(&mut self) { + *self.0.borrow_mut() = Default::default(); + } + } + + fn get_path(root: impl AsRef, id: u64) -> PathBuf { + root.as_ref().join(hex::encode(id.to_be_bytes())) + } + + impl entity_actor::State { + async fn with_value(&self, f: impl FnOnce(&mut u128)) -> Result<(), &'static str> { + let Ok(mut r) = self.state.0.try_borrow_mut() else { + panic!("failed to borrow state mutably"); + }; + if r.value.is_none() { + let mut global = self.global.lock().unwrap(); + global + .log + .entry(self.id) + .or_default() + .push((Event::Wakeup, Instant::now())); + let path = get_path(&global.path, self.id); + // note: if we were to use async IO, we would need to make sure not to hold the + // lock guard over an await point. The entity manager makes sure that all fns + // are run on the same tokio task, but there is still concurrency, which + // a mutable borrow of the state does not allow. + let value = match std::fs::read(path) { + Ok(value) => value, + Err(e) if e.kind() == std::io::ErrorKind::NotFound => { + // If the file does not exist, we initialize it to 0. + vec![0; 16] + } + Err(_) => return Err("Failed to read disk state"), + }; + let value = u128::from_be_bytes( + value.try_into().map_err(|_| "Invalid disk state format")?, + ); + r.value = Some(value); + } + let Some(value) = r.value.as_mut() else { + panic!("State must be Memory at this point"); + }; + f(value); + Ok(()) + } + } + + struct Counters; + impl Params for Counters { + type EntityId = u64; + type GlobalState = Arc>; + type EntityState = EntityState; + async fn on_shutdown(state: entity_actor::State, _cause: ShutdownCause) { + let r = state.state.0.borrow(); + let mut global = state.global.lock().unwrap(); + if let Some(value) = r.value { + let path = get_path(&global.path, state.id); + let value_bytes = value.to_be_bytes(); + std::fs::write(&path, value_bytes).expect("Failed to write disk state"); + } + global + .log + .entry(state.id) + .or_default() + .push((Event::Shutdown, Instant::now())); + } + } + + pub struct FsDb { + global: Arc>, + m: EntityManager, + } + + impl FsDb { + pub fn new(path: impl AsRef) -> Self { + let global = Global { + path: path.as_ref().to_owned(), + log: HashMap::new(), + }; + let global = Arc::new(Mutex::new(global)); + Self { + global: global.clone(), + m: EntityManager::::new(global, Options::default()), + } + } + } + + impl super::CounterDb for FsDb { + async fn add(&self, id: u64, value: u128) -> Result<(), &'static str> { + self.m + .spawn(id, move |arg| async move { + match arg { + SpawnArg::Active(state) => { + println!( + "Adding value {} to entity actor with id {:?}", + value, state.id + ); + state + .with_value(|v| *v = v.wrapping_add(value)) + .await + .unwrap(); + } + SpawnArg::Busy => println!("Entity actor is busy"), + SpawnArg::Dead => println!("Entity actor is dead"), + } + }) + .await + } + + async fn get(&self, id: u64) -> Result { + let (tx, rx) = oneshot::channel(); + self.m + .spawn(id, move |arg| async move { + match arg { + SpawnArg::Active(state) => { + state + .with_value(|v| { + tx.send(*v) + .unwrap_or_else(|_| println!("Failed to send value")) + }) + .await + .unwrap(); + } + SpawnArg::Busy => println!("Entity actor is busy"), + SpawnArg::Dead => println!("Entity actor is dead"), + } + }) + .await?; + rx.await.map_err(|_| "Failed to receive value in get") + } + + async fn shutdown(&self) -> Result<(), &'static str> { + self.m.shutdown().await + } + + async fn check_consistency(&self, values: HashMap) { + let global = self.global.lock().unwrap(); + for (id, value) in &values { + let path = get_path(&global.path, *id); + let disk_value = match std::fs::read(path) { + Ok(data) => u128::from_be_bytes(data.try_into().unwrap()), + Err(e) if e.kind() == std::io::ErrorKind::NotFound => 0, + Err(_) => panic!("Failed to read disk state for id {id}"), + }; + assert_eq!(disk_value, *value, "Disk value mismatch for id {id}"); + } + for id in values.keys() { + let log = global.log.get(id).unwrap(); + assert!( + log.len() % 2 == 0, + "Log must contain alternating wakeup and shutdown events" + ); + for (i, (event, _)) in log.iter().enumerate() { + assert_eq!( + *event, + if i % 2 == 0 { + Event::Wakeup + } else { + Event::Shutdown + }, + "Unexpected event type" + ); + } + } + } + } + } + + async fn test_random( + db: impl CounterDb, + entries: &[(u64, u128)], + ) -> testresult::TestResult<()> { + // compute the expected values + let mut reference = HashMap::new(); + for (id, value) in entries { + let v: &mut u128 = reference.entry(*id).or_default(); + *v = v.wrapping_add(*value); + } + // do the same computation using the database, and some concurrency + // and parallelism (we will get parallelism if we are using a multi-threaded runtime). + let mut errors = Vec::new(); + n0_future::stream::iter(entries) + .map(|(id, value)| db.add(*id, *value)) + .buffered_unordered(16) + .for_each(|result| { + if let Err(e) = result { + errors.push(e); + } + }) + .await; + assert!(errors.is_empty(), "Failed to add some entries: {errors:?}"); + // check that the db contains the expected values + let ids = reference.keys().copied().collect::>(); + for id in &ids { + let res = db.get(*id).await?; + assert_eq!(res, reference.get(id).copied().unwrap_or_default()); + } + db.shutdown().await?; + // check that the db is consistent with the reference + db.check_consistency(reference).await; + Ok(()) + } + + #[test_strategy::proptest] + fn test_counters_manager_proptest_mem(entries: Vec<(u64, u128)>) { + let rt = tokio::runtime::Builder::new_multi_thread() + .build() + .expect("Failed to create tokio runtime"); + rt.block_on(async move { + let db = mem::MemDb::new(); + test_random(db, &entries).await + }) + .expect("Test failed"); + } + + #[test_strategy::proptest] + fn test_counters_manager_proptest_fs(entries: Vec<(u64, u128)>) { + let dir = tempfile::tempdir().unwrap(); + let rt = tokio::runtime::Builder::new_multi_thread() + .build() + .expect("Failed to create tokio runtime"); + rt.block_on(async move { + let db = fs::FsDb::new(dir.path()); + test_random(db, &entries).await + }) + .expect("Test failed"); + } +}