diff --git a/Cargo.toml b/Cargo.toml index 3b4eb3455..c404c11ca 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,6 +15,7 @@ salsa-macros = { version = "0.23.0", path = "components/salsa-macros", optional boxcar = "0.2.13" crossbeam-queue = "0.3.12" crossbeam-utils = "0.8.21" +crossbeam-channel = "0.5.15" hashbrown = "0.15" hashlink = "0.10" indexmap = "2" @@ -55,7 +56,6 @@ salsa-macros = { version = "=0.23.0", path = "components/salsa-macros" } [dev-dependencies] # examples -crossbeam-channel = "0.5.15" dashmap = { version = "6", features = ["raw-api"] } eyre = "0.6.12" notify-debouncer-mini = "0.4.1" diff --git a/src/database.rs b/src/database.rs index 30178b2da..0d44f220c 100644 --- a/src/database.rs +++ b/src/database.rs @@ -42,7 +42,7 @@ pub trait Database: Send + ZalsaDatabase + AsDynDatabase { /// is owned by the current thread, this could trigger deadlock. fn trigger_lru_eviction(&mut self) { let zalsa_mut = self.zalsa_mut(); - zalsa_mut.evict_lru(); + zalsa_mut.reset_for_new_revision(); } /// A "synthetic write" causes the system to act *as though* some diff --git a/src/function.rs b/src/function.rs index 891e3dbad..b711d08d1 100644 --- a/src/function.rs +++ b/src/function.rs @@ -1,9 +1,10 @@ pub(crate) use maybe_changed_after::VerifyResult; use std::any::Any; -use std::fmt; +use std::marker::PhantomData; use std::ptr::NonNull; use std::sync::atomic::Ordering; use std::sync::OnceLock; +use std::{fmt, mem}; pub(crate) use sync::SyncGuard; use crate::cycle::{ @@ -11,14 +12,13 @@ use crate::cycle::{ ProvisionalStatus, }; use crate::database::RawDatabase; -use crate::function::delete::DeletedEntries; use crate::function::sync::{ClaimResult, SyncTable}; use crate::ingredient::{Ingredient, WaitForResult}; use crate::key::DatabaseKeyIndex; use crate::plumbing::MemoIngredientMap; use crate::salsa_struct::SalsaStructInDb; use crate::sync::Arc; -use crate::table::memo::MemoTableTypes; +use crate::table::memo::{DeletedEntries, MemoTableTypes}; use crate::table::Table; use crate::views::DatabaseDownCaster; use crate::zalsa::{IngredientIndex, MemoIngredientIndex, Zalsa}; @@ -28,7 +28,6 @@ use crate::{Id, Revision}; #[cfg(feature = "accumulator")] mod accumulated; mod backdate; -mod delete; mod diff_outputs; mod execute; mod fetch; @@ -147,7 +146,8 @@ pub struct IngredientImpl { /// current revision: you would be right, but we are being defensive, because /// we don't know that we can trust the database to give us the same runtime /// everytime and so forth. - deleted_entries: DeletedEntries, + delete: DeletedEntries, + config: PhantomData C>, } impl IngredientImpl @@ -162,10 +162,11 @@ where Self { index, memo_ingredient_indices, - lru: lru::Lru::new(lru), - deleted_entries: Default::default(), view_caster: OnceLock::new(), + lru: lru::Lru::new(lru), + delete: DeletedEntries::default(), sync_table: SyncTable::new(index), + config: PhantomData, } } @@ -222,16 +223,7 @@ where // FIXME: Use `Box::into_non_null` once stable let memo = NonNull::from(Box::leak(Box::new(memo))); - if let Some(old_value) = - self.insert_memo_into_table_for(zalsa, id, memo, memo_ingredient_index) - { - // In case there is a reference to the old memo out there, we have to store it - // in the deleted entries. This will get cleared when a new revision starts. - // - // SAFETY: Once the revision starts, there will be no outstanding borrows to the - // memo contents, and so it will be safe to free. - unsafe { self.deleted_entries.push(old_value) }; - } + self.insert_memo_into_table_for(zalsa, id, memo, memo_ingredient_index); // SAFETY: memo has been inserted into the table unsafe { self.extend_memo_lifetime(memo.as_ref()) } } @@ -344,7 +336,11 @@ where true } - fn reset_for_new_revision(&mut self, table: &mut Table) { + fn reset_for_new_revision( + &mut self, + table: &mut Table, + new_buffer: DeletedEntries, + ) -> DeletedEntries { self.lru.for_each_evicted(|evict| { let ingredient_index = table.ingredient_index(evict); Self::evict_value_from_memo_for( @@ -352,8 +348,7 @@ where self.memo_ingredient_indices.get(ingredient_index), ) }); - - self.deleted_entries.clear(); + mem::replace(&mut self.delete, new_buffer) } fn debug_name(&self) -> &'static str { diff --git a/src/function/delete.rs b/src/function/delete.rs deleted file mode 100644 index d061917b0..000000000 --- a/src/function/delete.rs +++ /dev/null @@ -1,52 +0,0 @@ -use std::ptr::NonNull; - -use crate::function::memo::Memo; -use crate::function::Configuration; - -/// Stores the list of memos that have been deleted so they can be freed -/// once the next revision starts. See the comment on the field -/// `deleted_entries` of [`FunctionIngredient`][] for more details. -pub(super) struct DeletedEntries { - memos: boxcar::Vec>>, -} - -#[allow(clippy::undocumented_unsafe_blocks)] // TODO(#697) document safety -unsafe impl Send for SharedBox {} -#[allow(clippy::undocumented_unsafe_blocks)] // TODO(#697) document safety -unsafe impl Sync for SharedBox {} - -impl Default for DeletedEntries { - fn default() -> Self { - Self { - memos: Default::default(), - } - } -} - -impl DeletedEntries { - /// # Safety - /// - /// The memo must be valid and safe to free when the `DeletedEntries` list is cleared or dropped. - pub(super) unsafe fn push(&self, memo: NonNull>) { - // Safety: The memo must be valid and safe to free when the `DeletedEntries` list is cleared or dropped. - let memo = - unsafe { std::mem::transmute::>, NonNull>>(memo) }; - - self.memos.push(SharedBox(memo)); - } - - /// Free all deleted memos, keeping the list available for reuse. - pub(super) fn clear(&mut self) { - self.memos.clear(); - } -} - -/// A wrapper around `NonNull` that frees the allocation when it is dropped. -struct SharedBox(NonNull); - -impl Drop for SharedBox { - fn drop(&mut self) { - // SAFETY: Guaranteed by the caller of `DeletedEntries::push`. - unsafe { drop(Box::from_raw(self.0.as_ptr())) }; - } -} diff --git a/src/function/memo.rs b/src/function/memo.rs index 810e5b268..78a1345ec 100644 --- a/src/function/memo.rs +++ b/src/function/memo.rs @@ -19,24 +19,23 @@ use crate::{Event, EventKind, Id, Revision}; impl IngredientImpl { /// Inserts the memo for the given key; (atomically) overwrites and returns any previously existing memo pub(super) fn insert_memo_into_table_for<'db>( - &self, + &'db self, zalsa: &'db Zalsa, id: Id, memo: NonNull>, memo_ingredient_index: MemoIngredientIndex, - ) -> Option>> { + ) { // SAFETY: The table stores 'static memos (to support `Any`), the memos are in fact valid // for `'db` though as we delay their dropping to the end of a revision. let static_memo = unsafe { transmute::>, NonNull>>(memo) }; let old_static_memo = zalsa .memo_table_for::>(id) - .insert(memo_ingredient_index, static_memo)?; - // SAFETY: The table stores 'static memos (to support `Any`), the memos are in fact valid - // for `'db` though as we delay their dropping to the end of a revision. - Some(unsafe { - transmute::>, NonNull>>(old_static_memo) - }) + .insert(memo_ingredient_index, static_memo); + if let Some(old_memo) = old_static_memo { + // SAFETY: We delay clearing properly + unsafe { self.delete.push(old_memo) }; + } } /// Loads the current memo for `key_index`. This does not hold any sort of @@ -62,9 +61,10 @@ impl IngredientImpl { pub(super) fn evict_value_from_memo_for( table: MemoTableWithTypesMut<'_>, memo_ingredient_index: MemoIngredientIndex, + //FIXME should provide a page to move the value into so we can delay the drop ) { - let map = |memo: &mut Memo<'static, C>| { - match memo.revisions.origin.as_ref() { + if let Some(memo) = table.fetch::>(memo_ingredient_index) { + match &memo.revisions.origin.as_ref() { QueryOriginRef::Assigned(_) | QueryOriginRef::DerivedUntracked(_) | QueryOriginRef::FixpointInitial => { @@ -73,14 +73,9 @@ impl IngredientImpl { // or those with untracked inputs // as their values cannot be reconstructed. } - QueryOriginRef::Derived(_) => { - // Set the memo value to `None`. - memo.value = None; - } + QueryOriginRef::Derived(_) => _ = memo.value.take(), } - }; - - table.map_memo(memo_ingredient_index, map) + } } } @@ -333,6 +328,10 @@ where }, } } + + fn clear_value(&mut self) { + self.value = None; + } } pub(super) enum TryClaimHeadsResult<'me> { diff --git a/src/ingredient.rs b/src/ingredient.rs index 2ad7bb8ee..986f5dc1e 100644 --- a/src/ingredient.rs +++ b/src/ingredient.rs @@ -9,7 +9,7 @@ use crate::database::RawDatabase; use crate::function::VerifyResult; use crate::runtime::Running; use crate::sync::Arc; -use crate::table::memo::MemoTableTypes; +use crate::table::memo::{DeletedEntries, MemoTableTypes}; use crate::table::Table; use crate::zalsa::{transmute_data_mut_ptr, transmute_data_ptr, IngredientIndex, Zalsa}; use crate::zalsa_local::QueryOriginRef; @@ -128,8 +128,12 @@ pub trait Ingredient: Any + std::fmt::Debug + Send + Sync { /// /// **Important:** to actually receive resets, the ingredient must set /// [`IngredientRequiresReset::RESET_ON_NEW_REVISION`] to true. - fn reset_for_new_revision(&mut self, table: &mut Table) { - _ = table; + fn reset_for_new_revision( + &mut self, + table: &mut Table, + new_buffer: DeletedEntries, + ) -> DeletedEntries { + _ = (table, new_buffer); panic!( "Ingredient `{}` set `Ingredient::requires_reset_for_new_revision` to true but does \ not overwrite `Ingredient::reset_for_new_revision`", diff --git a/src/lib.rs b/src/lib.rs index 66c346b20..4832c5cd4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -66,7 +66,7 @@ pub use self::revision::Revision; pub use self::runtime::Runtime; pub use self::storage::{Storage, StorageHandle}; pub use self::update::Update; -pub use self::zalsa::IngredientIndex; +pub use self::zalsa::{DeletedEntriesDropper, DropChannelReceiver, IngredientIndex}; pub use crate::attach::{attach, with_attached_database}; pub mod prelude { diff --git a/src/runtime.rs b/src/runtime.rs index bc2859a7e..72b48af51 100644 --- a/src/runtime.rs +++ b/src/runtime.rs @@ -2,9 +2,11 @@ use self::dependency_graph::DependencyGraph; use crate::durability::Durability; use crate::function::SyncGuard; use crate::key::DatabaseKeyIndex; +use crate::plumbing::Ingredient; use crate::sync::atomic::{AtomicBool, Ordering}; use crate::sync::thread::{self, ThreadId}; use crate::sync::Mutex; +use crate::table::memo::DeletedEntries; use crate::table::Table; use crate::zalsa::Zalsa; use crate::{Cancelled, Event, EventKind, Revision}; @@ -194,10 +196,6 @@ impl Runtime { &self.table } - pub(crate) fn table_mut(&mut self) -> &mut Table { - &mut self.table - } - /// Increments the "current revision" counter and clears /// the cancellation flag. /// @@ -263,4 +261,12 @@ impl Runtime { .lock() .unblock_runtimes_blocked_on(database_key, wait_result); } + + pub(crate) fn reset_ingredient_for_new_revision( + &mut self, + ingredient: &mut (dyn Ingredient + 'static), + new_buffer: DeletedEntries, + ) -> DeletedEntries { + ingredient.reset_for_new_revision(&mut self.table, new_buffer) + } } diff --git a/src/storage.rs b/src/storage.rs index f63981e4f..62337293c 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -4,7 +4,9 @@ use std::panic::RefUnwindSafe; use crate::database::RawDatabase; use crate::sync::{Arc, Condvar, Mutex}; -use crate::zalsa::{ErasedJar, HasJar, Zalsa, ZalsaDatabase}; +use crate::zalsa::{ + drop_channel, DropChannelReceiver, DropChannelSender, ErasedJar, HasJar, Zalsa, ZalsaDatabase, +}; use crate::zalsa_local::{self, ZalsaLocal}; use crate::{Database, Event, EventKind}; @@ -43,15 +45,27 @@ impl Default for StorageHandle { impl StorageHandle { pub fn new(event_callback: Option>) -> Self { - Self::with_jars(event_callback, Vec::new()) + Self::with_jars(None, event_callback, Vec::new()) + } + + pub fn new_with_drop_channel( + event_callback: Option>, + capacity: Option, + ) -> (Self, DropChannelReceiver) { + let (sender, receiver) = drop_channel(capacity); + ( + Self::with_jars(Some(sender), event_callback, Vec::new()), + receiver, + ) } fn with_jars( + drop_channel_sender: Option, event_callback: Option>, jars: Vec, ) -> Self { Self { - zalsa_impl: Arc::new(Zalsa::new::(event_callback, jars)), + zalsa_impl: Arc::new(Zalsa::new::(drop_channel_sender, event_callback, jars)), coordinate: CoordinateDrop(Arc::new(Coordinate { clones: Mutex::new(1), cvar: Default::default(), @@ -113,7 +127,7 @@ impl Default for Storage { } impl Storage { - /// Create a new database storage. + /// Create a new database storage that drops stale memoized results synchronously on a revision change. /// /// The `event_callback` function is invoked by the salsa runtime at various points during execution. pub fn new(event_callback: Option>) -> Self { @@ -123,6 +137,24 @@ impl Storage { } } + /// Create a new database storage with a drop channel that receives stale memoized results for + /// flexible dropping. + /// + /// The `event_callback` function is invoked by the salsa runtime at various points during execution. + pub fn new_with_drop_channel( + event_callback: Option>, + capacity: Option, + ) -> (Self, DropChannelReceiver) { + let (handle, receiver) = StorageHandle::new_with_drop_channel(event_callback, capacity); + ( + Self { + handle, + zalsa_local: ZalsaLocal::new(), + }, + receiver, + ) + } + /// Returns a builder for database storage. pub fn builder() -> StorageBuilder { StorageBuilder::default() @@ -220,10 +252,25 @@ impl StorageBuilder { self } + /// Construct the [`Storage`] using the provided builder options with a drop channel. + pub fn build_with_drop_channel( + self, + capacity: Option, + ) -> (Storage, DropChannelReceiver) { + let (sender, receiver) = drop_channel(capacity); + ( + Storage { + handle: StorageHandle::with_jars(Some(sender), self.event_callback, self.jars), + zalsa_local: ZalsaLocal::new(), + }, + receiver, + ) + } + /// Construct the [`Storage`] using the provided builder options. pub fn build(self) -> Storage { Storage { - handle: StorageHandle::with_jars(self.event_callback, self.jars), + handle: StorageHandle::with_jars(None, self.event_callback, self.jars), zalsa_local: ZalsaLocal::new(), } } diff --git a/src/table/memo.rs b/src/table/memo.rs index b7bc5fb7d..ac47ae209 100644 --- a/src/table/memo.rs +++ b/src/table/memo.rs @@ -14,7 +14,7 @@ pub struct MemoTable { } impl MemoTable { - /// Create a `MemoTable` with slots for memos from the provided `MemoTableTypes`. + /// Create a `MemoTable` with slots for memos from the provided `MemoTableTypes`. /// /// # Safety /// @@ -45,6 +45,58 @@ pub trait Memo: Any + Send + Sync { /// Returns memory usage information about the memoized value. #[cfg(feature = "salsa_unstable")] fn memory_usage(&self) -> crate::database::MemoInfo; + + fn clear_value(&mut self); +} + +/// An untyped memo that can only be dropped. +#[derive(Debug)] +pub struct MemoDrop(NonNull, unsafe fn(NonNull)); + +impl MemoDrop { + pub fn new(memo: NonNull) -> Self { + Self( + MemoEntryType::to_dummy(memo), + // SAFETY: `M` is the same as used in `to_dummy` + |memo| unsafe { drop(Box::from_raw(MemoEntryType::from_dummy::(memo).as_ptr())) }, + ) + } +} + +impl Drop for MemoDrop { + fn drop(&mut self) { + // SAFETY: We only construct this type with a valid drop function pointer + unsafe { self.1(self.0) }; + } +} + +/// SAFETY: `MemoDrop` is `Send` because only contains `Memo` types which are `Send` +unsafe impl Send for MemoDrop where DummyMemo: Send {} +/// SAFETY: `MemoDrop` is `Sync` because only contains `Memo` types which are `Sync` +unsafe impl Sync for MemoDrop where DummyMemo: Sync {} + +#[derive(Default)] +pub struct DeletedEntries { + memos: Box>, +} + +impl DeletedEntries { + /// # Safety + /// + /// The memo must be valid and safe to free when the `DeletedEntries` list is cleared or dropped. + pub(crate) unsafe fn push(&self, memo: NonNull) { + self.memos.push(MemoDrop::new(memo)); + } + + // FIXME: This implies that dropping `DeletedEntries` should be unsafe. + /// Free all deleted memos, keeping the list available for reuse. + pub(crate) unsafe fn clear(&mut self) { + self.memos.clear(); + } + + pub(crate) fn is_empty(&self) -> bool { + self.memos.is_empty() + } } /// Data for a memoized entry. @@ -131,6 +183,10 @@ impl Memo for DummyMemo { }, } } + + fn clear_value(&mut self) { + unreachable!("should not get here") + } } #[derive(Default)] @@ -254,26 +310,19 @@ impl MemoTableWithTypes<'_> { } } -pub(crate) struct MemoTableWithTypesMut<'a> { - types: &'a MemoTableTypes, - memos: &'a mut MemoTable, +pub(crate) struct MemoTableWithTypesMut<'db> { + types: &'db MemoTableTypes, + memos: &'db mut MemoTable, } -impl MemoTableWithTypesMut<'_> { +impl<'db> MemoTableWithTypesMut<'db> { /// Calls `f` on the memo at `memo_ingredient_index`. - /// - /// If the memo is not present, `f` is not called. - pub(crate) fn map_memo( + pub(crate) fn fetch( self, memo_ingredient_index: MemoIngredientIndex, - f: impl FnOnce(&mut M), - ) { - let Some(MemoEntry { atomic_memo }) = - self.memos.memos.get_mut(memo_ingredient_index.as_usize()) - else { - return; - }; - + ) -> Option<&'db mut M> { + let MemoEntry { atomic_memo } = + self.memos.memos.get_mut(memo_ingredient_index.as_usize())?; // SAFETY: Any indices that are in-bounds for the `MemoTable` are also in-bounds for its // corresponding `MemoTableTypes`, by construction. let type_ = unsafe { @@ -287,12 +336,10 @@ impl MemoTableWithTypesMut<'_> { type_assert_failed(memo_ingredient_index); } - let Some(memo) = NonNull::new(*atomic_memo.get_mut()) else { - return; - }; + let memo = NonNull::new(*atomic_memo.get_mut())?; // SAFETY: We asserted that the type is correct above. - f(unsafe { MemoEntryType::from_dummy(memo).as_mut() }); + Some(unsafe { MemoEntryType::from_dummy(memo).as_mut() }) } /// To drop an entry, we need its type, so we don't implement `Drop`, and instead have this method. diff --git a/src/tracing.rs b/src/tracing.rs index 47f95d00e..03626289a 100644 --- a/src/tracing.rs +++ b/src/tracing.rs @@ -27,11 +27,11 @@ macro_rules! debug_span { macro_rules! event { ($level:ident, $($x:tt)*) => {{ - let event = { - #[cold] #[inline(never)] || { ::tracing::event!(::tracing::Level::$level, $($x)*) } - }; - if ::tracing::enabled!(::tracing::Level::$level) { + let event = { + #[cold] #[inline(never)] || { ::tracing::event!(::tracing::Level::$level, $($x)*) } + }; + event(); } }}; @@ -39,11 +39,11 @@ macro_rules! event { macro_rules! span { ($level:ident, $($x:tt)*) => {{ - let span = { - #[cold] #[inline(never)] || { ::tracing::span!(::tracing::Level::$level, $($x)*) } - }; - if ::tracing::enabled!(::tracing::Level::$level) { + let span = { + #[cold] #[inline(never)] || { ::tracing::span!(::tracing::Level::$level, $($x)*) } + }; + span() } else { ::tracing::Span::none() diff --git a/src/zalsa.rs b/src/zalsa.rs index 1cc6ba5f5..3f0de4e6d 100644 --- a/src/zalsa.rs +++ b/src/zalsa.rs @@ -1,7 +1,9 @@ use std::any::{Any, TypeId}; use std::hash::BuildHasherDefault; +use std::mem::ManuallyDrop; use std::panic::RefUnwindSafe; +use crossbeam_channel::{Receiver, Sender}; use hashbrown::HashMap; use rustc_hash::FxHashMap; @@ -10,7 +12,7 @@ use crate::hash::TypeIdHasher; use crate::ingredient::{Ingredient, Jar}; use crate::plumbing::SalsaStructInDb; use crate::runtime::Runtime; -use crate::table::memo::MemoTableWithTypes; +use crate::table::memo::{DeletedEntries, MemoTableWithTypes}; use crate::table::Table; use crate::views::Views; use crate::zalsa_local::ZalsaLocal; @@ -168,6 +170,10 @@ pub struct Zalsa { /// Each handle gets its own runtime, but the runtimes have shared state between them. runtime: Runtime, + /// Contains either the channel primitives for user controlled dropping or double buffered + /// deleted entries for synchronous dropping. + deleted_entries_channel: MemoDropMode, + event_callback: Option>, } @@ -180,6 +186,7 @@ impl RefUnwindSafe for Zalsa {} impl Zalsa { pub(crate) fn new( + drop_channel_receiver: Option, event_callback: Option>, jars: Vec, ) -> Self { @@ -191,6 +198,13 @@ impl Zalsa { ingredients_requiring_reset: boxcar::Vec::new(), runtime: Runtime::default(), memo_ingredient_indices: Default::default(), + deleted_entries_channel: match drop_channel_receiver { + Some(drop_chan_sender) => { + let (sender, receiver) = crossbeam_channel::bounded(0); + MemoDropMode::Channel(sender, receiver, drop_chan_sender) + } + None => MemoDropMode::Synchronous(None), + }, event_callback, #[cfg(not(feature = "inventory"))] nonce: NONCE.nonce(), @@ -416,30 +430,59 @@ impl Zalsa { pub fn new_revision(&mut self) -> Revision { let new_revision = self.runtime.new_revision(); let _span = crate::tracing::debug_span!("new_revision", ?new_revision).entered(); - - for (_, index) in self.ingredients_requiring_reset.iter() { - let index = index.as_u32() as usize; - let ingredient = self - .ingredients_vec - .get_mut(index) - .unwrap_or_else(|| panic!("index `{index}` is uninitialized")); - - ingredient.reset_for_new_revision(self.runtime.table_mut()); - } - + self.reset_for_new_revision(); new_revision } /// **NOT SEMVER STABLE** #[doc(hidden)] - pub fn evict_lru(&mut self) { - let _span = crate::tracing::debug_span!("evict_lru").entered(); - for (_, index) in self.ingredients_requiring_reset.iter() { - let index = index.as_u32() as usize; - self.ingredients_vec - .get_mut(index) - .unwrap_or_else(|| panic!("index `{index}` is uninitialized")) - .reset_for_new_revision(self.runtime.table_mut()); + pub fn reset_for_new_revision(&mut self) { + let _span = crate::tracing::debug_span!("reset_for_new_revision").entered(); + + match &mut self.deleted_entries_channel { + MemoDropMode::Channel(pool_sender, pool_receiver, drop_chan_sender) => { + let len = self.ingredients_requiring_reset.count(); + if pool_receiver.capacity() != Some(len) { + (*pool_sender, *pool_receiver) = crossbeam_channel::bounded(len); + } + for (_, index) in self.ingredients_requiring_reset.iter() { + let index = index.as_u32() as usize; + let ingredient = &mut **self + .ingredients_vec + .get_mut(index) + .unwrap_or_else(|| panic!("index `{index}` is uninitialized")); + + let new_buffer = pool_receiver.try_recv().unwrap_or_default(); + let deleted_entries = self + .runtime + .reset_ingredient_for_new_revision(ingredient, new_buffer); + if deleted_entries.is_empty() { + _ = pool_sender.try_send(deleted_entries); + } else { + // if the user dropped the receiver or if its bounded and full, + // fall back to releasing the entries synchronously + _ = drop_chan_sender.0.try_send(DeletedEntriesDropper { + entries: ManuallyDrop::new(deleted_entries), + sender: pool_sender.clone(), + }); + } + } + } + MemoDropMode::Synchronous(buffer) => { + let mut new_buffer = buffer.take().unwrap_or_default(); + for (_, index) in self.ingredients_requiring_reset.iter() { + let index = index.as_u32() as usize; + let ingredient = &mut **self + .ingredients_vec + .get_mut(index) + .unwrap_or_else(|| panic!("index `{index}` is uninitialized")); + + new_buffer = self + .runtime + .reset_ingredient_for_new_revision(ingredient, new_buffer); + } + *buffer = Some(new_buffer); + } } } @@ -464,6 +507,86 @@ impl Zalsa { } } +#[derive(Debug)] +pub(crate) struct DropChannelSender(Sender); + +/// A channel receiver that receives [`DeletedEntriesDropper`] messages. +#[derive(Debug)] +pub struct DropChannelReceiver(Receiver); + +impl DropChannelReceiver { + pub fn capacity(&self) -> Option { + self.0.capacity() + } + + pub fn is_empty(&self) -> bool { + self.0.is_empty() + } + + pub fn is_full(&self) -> bool { + self.0.is_full() + } + + pub fn recv(&self) -> Option { + self.0.recv().ok() + } + + pub fn len(&self) -> usize { + self.0.len() + } + + pub fn iter(&self) -> impl Iterator + use<'_> { + self.0.iter() + } + + #[allow(clippy::should_implement_trait)] + pub fn into_iter(self) -> impl Iterator { + self.0.into_iter() + } + + pub fn try_iter(&self) -> impl Iterator + use<'_> { + self.0.try_iter() + } + + pub fn try_recv(&self) -> Option { + self.0.try_recv().ok() + } +} + +pub(crate) fn drop_channel(capacity: Option) -> (DropChannelSender, DropChannelReceiver) { + let (sender, receiver) = crossbeam_channel::bounded(capacity.unwrap_or(0)); + (DropChannelSender(sender), DropChannelReceiver(receiver)) +} + +/// A drop struct that runs destructors for deleted memoized values. +/// +/// This for example allows to flexibly drop memoized values in different threads. +pub struct DeletedEntriesDropper { + entries: ManuallyDrop, + sender: Sender, +} + +impl Drop for DeletedEntriesDropper { + fn drop(&mut self) { + // SAFETY: `DeletedEntriesDropper` only gets constructed once its safe to clear the entries. + unsafe { self.entries.clear() }; + // SAFETY: We no longer use `self.entries` after this call. + let msg = unsafe { ManuallyDrop::take(&mut self.entries) }; + // Either the receiver was dropped due to resizing or the channel is already + // full, discard the buffer + _ = self.sender.try_send(msg); + } +} + +enum MemoDropMode { + Channel( + Sender, + Receiver, + DropChannelSender, + ), + Synchronous(Option), +} + /// A type-erased `Jar`, used for ingredient registration. #[derive(Clone, Copy)] pub struct ErasedJar { diff --git a/src/zalsa_local.rs b/src/zalsa_local.rs index 8d58d7171..dde3c2cea 100644 --- a/src/zalsa_local.rs +++ b/src/zalsa_local.rs @@ -413,7 +413,7 @@ impl std::panic::RefUnwindSafe for ZalsaLocal {} #[derive(Debug)] // #[derive(Clone)] cloning this is expensive, so we don't derive pub(crate) struct QueryRevisions { - /// The most revision in which some input changed. + /// The most recent revision in which some input changed. pub(crate) changed_at: Revision, /// Minimum durability of the inputs to this query. diff --git a/tests/common/mod.rs b/tests/common/mod.rs index f7aa79b31..47d04b95d 100644 --- a/tests/common/mod.rs +++ b/tests/common/mod.rs @@ -62,6 +62,19 @@ pub struct LoggerDatabase { logger: Logger, } +impl LoggerDatabase { + pub fn new_with_drop_channel() -> (LoggerDatabase, salsa::DropChannelReceiver) { + let (storage, drop_chan) = Storage::new_with_drop_channel(None, None); + ( + Self { + storage, + logger: Logger::default(), + }, + drop_chan, + ) + } +} + impl HasLogger for LoggerDatabase { fn logger(&self) -> &Logger { &self.logger diff --git a/tests/lru.rs b/tests/lru.rs index 1d417267a..eee99359a 100644 --- a/tests/lru.rs +++ b/tests/lru.rs @@ -9,26 +9,35 @@ use std::sync::Arc; mod common; use common::LogDatabase; -use salsa::Database as _; +use salsa::{Database as _, DropChannelReceiver}; use test_log::test; -#[derive(Debug, PartialEq, Eq)] -struct HotPotato(u32); +#[derive(Debug)] +struct HotPotato(u32, Arc); + +impl Eq for HotPotato {} +impl PartialEq for HotPotato { + fn eq(&self, other: &Self) -> bool { + self.0 == other.0 + } +} thread_local! { - static N_POTATOES: AtomicUsize = const { AtomicUsize::new(0) } + static N_POTATOES: Arc = Arc::new(AtomicUsize::new(0)) } impl HotPotato { fn new(id: u32) -> HotPotato { - N_POTATOES.with(|n| n.fetch_add(1, Ordering::SeqCst)); - HotPotato(id) + N_POTATOES.with(|n| { + n.fetch_add(1, Ordering::SeqCst); + HotPotato(id, n.clone()) + }) } } impl Drop for HotPotato { fn drop(&mut self) { - N_POTATOES.with(|n| n.fetch_sub(1, Ordering::SeqCst)); + self.1.fetch_sub(1, Ordering::SeqCst); } } @@ -53,6 +62,20 @@ fn load_n_potatoes() -> usize { N_POTATOES.with(|n| n.load(Ordering::SeqCst)) } +fn wait_until_n_potatoes(n: usize) { + let now = std::time::Instant::now(); + while load_n_potatoes() != n { + std::thread::yield_now(); + if now.elapsed().as_secs() > 10 { + panic!( + "timed out waiting for {} potatoes, we've got {} instead", + n, + load_n_potatoes() + ); + } + } +} + #[test] fn lru_works() { let mut db = common::LoggerDatabase::default(); @@ -67,11 +90,13 @@ fn lru_works() { assert_eq!(load_n_potatoes(), 32); // trigger the GC db.synthetic_write(salsa::Durability::HIGH); - assert_eq!(load_n_potatoes(), 8); + wait_until_n_potatoes(8); + drop(db); + assert_eq!(load_n_potatoes(), 0); } #[test] -fn lru_can_be_changed_at_runtime() { +fn lru_can_be_changed_at_runtime_sync() { let mut db = common::LoggerDatabase::default(); assert_eq!(load_n_potatoes(), 0); @@ -85,10 +110,11 @@ fn lru_can_be_changed_at_runtime() { assert_eq!(load_n_potatoes(), 32); // trigger the GC db.synthetic_write(salsa::Durability::HIGH); - assert_eq!(load_n_potatoes(), 8); + std::thread::sleep(std::time::Duration::from_millis(100)); + + wait_until_n_potatoes(8); get_hot_potato::set_lru_capacity(&mut db, 16); - assert_eq!(load_n_potatoes(), 8); for &(i, input) in inputs.iter() { let p = get_hot_potato(&db, input); assert_eq!(p.0, i); @@ -97,11 +123,10 @@ fn lru_can_be_changed_at_runtime() { assert_eq!(load_n_potatoes(), 32); // trigger the GC db.synthetic_write(salsa::Durability::HIGH); - assert_eq!(load_n_potatoes(), 16); + wait_until_n_potatoes(16); // Special case: setting capacity to zero disables LRU get_hot_potato::set_lru_capacity(&mut db, 0); - assert_eq!(load_n_potatoes(), 16); for &(i, input) in inputs.iter() { let p = get_hot_potato(&db, input); assert_eq!(p.0, i); @@ -110,15 +135,114 @@ fn lru_can_be_changed_at_runtime() { assert_eq!(load_n_potatoes(), 32); // trigger the GC db.synthetic_write(salsa::Durability::HIGH); + wait_until_n_potatoes(32); + + drop(db); + assert_eq!(load_n_potatoes(), 0); +} + +#[test] +fn lru_keeps_dependency_info_sync() { + let mut db = common::LoggerDatabase::default(); + let capacity = 8; + + // Invoke `get_hot_potato2` 33 times. This will (in turn) invoke + // `get_hot_potato`, which will trigger LRU after 8 executions. + let inputs: Vec = (0..(capacity + 1)) + .map(|i| MyInput::new(&db, i as u32)) + .collect(); + + for (i, input) in inputs.iter().enumerate() { + let x = get_hot_potato2(&db, *input); + assert_eq!(x as usize, i); + } + + db.synthetic_write(salsa::Durability::HIGH); + + // We want to test that calls to `get_hot_potato2` are still considered + // clean. Check that no new executions occur as we go here. + db.assert_logs_len((capacity + 1) * 2); + + // calling `get_hot_potato2(0)` has to check that `get_hot_potato(0)` is still valid; + // even though we've evicted it (LRU), we find that it is still good + let p = get_hot_potato2(&db, *inputs.first().unwrap()); + assert_eq!(p, 0); + db.assert_logs_len(0); +} + +#[test] +fn lru_works_async() { + let (mut db, drop_chan) = common::LoggerDatabase::new_with_drop_channel(); + let drop_thread = drop_thread(drop_chan); + assert_eq!(load_n_potatoes(), 0); + + for i in 0..32u32 { + let input = MyInput::new(&db, i); + let p = get_hot_potato(&db, input); + assert_eq!(p.0, i); + } + assert_eq!(load_n_potatoes(), 32); + // trigger the GC + db.synthetic_write(salsa::Durability::HIGH); + wait_until_n_potatoes(8); + drop(db); + wait_until_n_potatoes(0); + drop_thread.join().unwrap(); +} + +#[test] +fn lru_can_be_changed_at_runtime() { + let (mut db, drop_chan) = common::LoggerDatabase::new_with_drop_channel(); + let drop_thread = drop_thread(drop_chan); + assert_eq!(load_n_potatoes(), 0); + + let inputs: Vec<(u32, MyInput)> = (0..32).map(|i| (i, MyInput::new(&db, i))).collect(); + + for &(i, input) in inputs.iter() { + let p = get_hot_potato(&db, input); + assert_eq!(p.0, i); + } + + assert_eq!(load_n_potatoes(), 32); + // trigger the GC + db.synthetic_write(salsa::Durability::HIGH); + std::thread::sleep(std::time::Duration::from_millis(100)); + + wait_until_n_potatoes(8); + + get_hot_potato::set_lru_capacity(&mut db, 16); + for &(i, input) in inputs.iter() { + let p = get_hot_potato(&db, input); + assert_eq!(p.0, i); + } + + assert_eq!(load_n_potatoes(), 32); + // trigger the GC + db.synthetic_write(salsa::Durability::HIGH); + wait_until_n_potatoes(16); + + // Special case: setting capacity to zero disables LRU + get_hot_potato::set_lru_capacity(&mut db, 0); + for &(i, input) in inputs.iter() { + let p = get_hot_potato(&db, input); + assert_eq!(p.0, i); + } + + assert_eq!(load_n_potatoes(), 32); + // trigger the GC + db.synthetic_write(salsa::Durability::HIGH); + wait_until_n_potatoes(32); drop(db); assert_eq!(load_n_potatoes(), 0); + drop_thread.join().unwrap(); } #[test] fn lru_keeps_dependency_info() { - let mut db = common::LoggerDatabase::default(); + let (mut db, drop_chan) = common::LoggerDatabase::new_with_drop_channel(); + let drop_thread = drop_thread(drop_chan); let capacity = 8; // Invoke `get_hot_potato2` 33 times. This will (in turn) invoke @@ -143,4 +267,16 @@ fn lru_keeps_dependency_info() { let p = get_hot_potato2(&db, *inputs.first().unwrap()); assert_eq!(p, 0); db.assert_logs_len(0); + drop(db); + drop_thread.join().unwrap(); +} + +#[cfg(feature = "shuttle")] +fn drop_thread(receiver: DropChannelReceiver) -> shuttle::thread::JoinHandle<()> { + shuttle::thread::spawn(|| receiver.into_iter().for_each(|_| ())) +} + +#[cfg(not(feature = "shuttle"))] +fn drop_thread(receiver: DropChannelReceiver) -> std::thread::JoinHandle<()> { + std::thread::spawn(|| receiver.into_iter().for_each(|_| ())) }