Skip to content

Commit ff35704

Browse files
committed
Add callback lock
1 parent a685f29 commit ff35704

File tree

2 files changed

+68
-0
lines changed

2 files changed

+68
-0
lines changed

src/util.rs

+1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ pub mod progress;
2222
pub use mem_or_file::{FileAndSize, MemOrFile};
2323
mod sparse_mem_file;
2424
pub use sparse_mem_file::SparseMemFile;
25+
pub mod callback_lock;
2526
pub mod local_pool;
2627

2728
#[cfg(test)]

src/util/callback_lock.rs

+67
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
//! This module defines a wrapper around a [`tokio::sync::RwLock`] that runs a callback
2+
//! After any write operation occurs
3+
4+
use std::future::Future;
5+
6+
#[derive(derive_more::Debug)]
7+
pub struct CallbackLock<T, F> {
8+
inner: tokio::sync::RwLock<T>,
9+
#[debug(skip)]
10+
callback: F,
11+
}
12+
13+
pub struct CallbackLockWriteGuard<'a, T, F: Fn(&T)> {
14+
inner: tokio::sync::RwLockWriteGuard<'a, T>,
15+
callback: &'a F,
16+
}
17+
18+
impl<'a, T, F: Fn(&T)> std::ops::Deref for CallbackLockWriteGuard<'a, T, F> {
19+
type Target = T;
20+
21+
fn deref(&self) -> &Self::Target {
22+
&*self.inner
23+
}
24+
}
25+
26+
impl<'a, T, F: Fn(&T)> std::ops::DerefMut for CallbackLockWriteGuard<'a, T, F> {
27+
fn deref_mut(&mut self) -> &mut Self::Target {
28+
&mut *self.inner
29+
}
30+
}
31+
32+
impl<'a, T, F: Fn(&T)> Drop for CallbackLockWriteGuard<'a, T, F> {
33+
fn drop(&mut self) {
34+
(self.callback)(&*self.inner);
35+
}
36+
}
37+
38+
impl<T, F> CallbackLock<T, F>
39+
where
40+
F: Fn(&T),
41+
{
42+
pub fn new(val: T, callback: F) -> Self {
43+
CallbackLock {
44+
inner: tokio::sync::RwLock::new(val),
45+
callback,
46+
}
47+
}
48+
49+
pub async fn write<'a>(&'a self) -> CallbackLockWriteGuard<'a, T, F> {
50+
let guard = self.inner.write().await;
51+
52+
CallbackLockWriteGuard {
53+
inner: guard,
54+
callback: &self.callback,
55+
}
56+
}
57+
58+
pub fn read(&self) -> impl Future<Output = tokio::sync::RwLockReadGuard<'_, T>> {
59+
self.inner.read()
60+
}
61+
62+
pub fn try_read(
63+
&self,
64+
) -> Result<tokio::sync::RwLockReadGuard<'_, T>, tokio::sync::TryLockError> {
65+
self.inner.try_read()
66+
}
67+
}

0 commit comments

Comments
 (0)