diff --git a/Cargo.toml b/Cargo.toml index c73016fa..df74fd43 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -44,6 +44,7 @@ iroh-base = "0.90" reflink-copy = "0.1.24" irpc = { version = "0.5.0", features = ["rpc", "quinn_endpoint_setup", "message_spans", "stream", "derive"], default-features = false } iroh-metrics = { version = "0.35" } +atomic_refcell = "0.1.13" [dev-dependencies] clap = { version = "4.5.31", features = ["derive"] } @@ -58,7 +59,6 @@ testresult = "0.4.1" tracing-subscriber = { version = "0.3.19", features = ["fmt"] } tracing-test = "0.2.5" walkdir = "2.5.0" -atomic_refcell = "0.1.13" [features] hide-proto-docs = [] diff --git a/src/store/fs.rs b/src/store/fs.rs index 53ed163a..9a1ca375 100644 --- a/src/store/fs.rs +++ b/src/store/fs.rs @@ -106,7 +106,7 @@ use crate::{ ApiClient, }, store::{ - fs::util::entity_manager::{self, ActiveEntityState}, + fs::util::entity_manager::{self, ActiveEntityState, ShutdownCause}, util::{BaoTreeSender, FixedSize, MemOrFile, ValueOrPoisioned}, Hash, }, @@ -217,10 +217,17 @@ impl entity_manager::Params for EmParams { type EntityState = Slot; - async fn on_shutdown( - _state: entity_manager::ActiveEntityState, - _cause: entity_manager::ShutdownCause, - ) { + async fn on_shutdown(state: HashContext, cause: ShutdownCause) { + // this isn't strictly necessary. Drop will run anyway as soon as the + // state is reset to it's default value. Doing it here means that we + // have exact control over where it happens. + if let Some(handle) = state.state.0.lock().await.take() { + trace!( + "shutting down entity manager for hash: {}, cause: {cause:?}", + state.id + ); + drop(handle); + } } } diff --git a/src/store/fs/bao_file.rs b/src/store/fs/bao_file.rs index bf150ae8..65b86723 100644 --- a/src/store/fs/bao_file.rs +++ b/src/store/fs/bao_file.rs @@ -20,12 +20,12 @@ use bao_tree::{ use bytes::{Bytes, BytesMut}; use derive_more::Debug; use irpc::channel::mpsc; -use tokio::sync::watch; use tracing::{debug, error, info, trace}; use super::{ entry_state::{DataLocation, EntryState, OutboardLocation}, options::{Options, PathOptions}, + util::watch, BaoFilePart, }; use crate::{ diff --git a/src/store/fs/util.rs b/src/store/fs/util.rs index 1cbd01bc..b739394a 100644 --- a/src/store/fs/util.rs +++ b/src/store/fs/util.rs @@ -2,6 +2,7 @@ use std::future::Future; use tokio::{select, sync::mpsc}; pub(crate) mod entity_manager; +pub(crate) mod watch; /// A wrapper for a tokio mpsc receiver that allows peeking at the next message. #[derive(Debug)] diff --git a/src/store/fs/util/watch.rs b/src/store/fs/util/watch.rs new file mode 100644 index 00000000..58a56a0a --- /dev/null +++ b/src/store/fs/util/watch.rs @@ -0,0 +1,87 @@ +use std::{ops::Deref, sync::Arc}; + +use atomic_refcell::{AtomicRef, AtomicRefCell}; + +struct State { + value: T, + dropped: bool, +} + +struct Shared { + value: AtomicRefCell>, + notify: tokio::sync::Notify, +} + +pub struct Sender(Arc>); + +pub struct Receiver(Arc>); + +impl Sender { + pub fn new(value: T) -> Self { + Self(Arc::new(Shared { + value: AtomicRefCell::new(State { + value, + dropped: false, + }), + notify: tokio::sync::Notify::new(), + })) + } + + pub fn send_if_modified(&self, modify: F) -> bool + where + F: FnOnce(&mut T) -> bool, + { + let mut state = self.0.value.borrow_mut(); + let modified = modify(&mut state.value); + if modified { + self.0.notify.notify_waiters(); + } + modified + } + + pub fn borrow(&self) -> impl Deref + '_ { + AtomicRef::map(self.0.value.borrow(), |state| &state.value) + } + + pub fn subscribe(&self) -> Receiver { + Receiver(self.0.clone()) + } +} + +impl Drop for Sender { + fn drop(&mut self) { + self.0.value.borrow_mut().dropped = true; + self.0.notify.notify_waiters(); + } +} + +impl Receiver { + pub async fn changed(&self) -> Result<(), error::RecvError> { + self.0.notify.notified().await; + if self.0.value.borrow().dropped { + Err(error::RecvError(())) + } else { + Ok(()) + } + } + + pub fn borrow(&self) -> impl Deref + '_ { + AtomicRef::map(self.0.value.borrow(), |state| &state.value) + } +} + +pub mod error { + use std::{error::Error, fmt}; + + /// Error produced when receiving a change notification. + #[derive(Debug, Clone)] + pub struct RecvError(pub(super) ()); + + impl fmt::Display for RecvError { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(fmt, "channel closed") + } + } + + impl Error for RecvError {} +}