Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ use fork_union as fu;
fn heavy_math(_: usize) {}

fn main() -> Result<(), Box<dyn Error>> {
let mut pool = fu::ThreadPool::try_spawn(4)?;
// Create a named thread pool for easier debugging and profiling
let mut pool = fu::ThreadPool::try_named_spawn("heavy-math", 4)?;
pool.for_n_dynamic(400, |prong| {
heavy_math(prong.task_index);
Expand Down
9 changes: 6 additions & 3 deletions c/lib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,9 @@ size_t fu_volume_any_pages(void) { return fu::get_ram_total_volume(); }

size_t fu_volume_huge_pages_in(FU_MAYBE_UNUSED_ size_t numa_node_index) {
#if FU_ENABLE_NUMA
if (!globals_initialize()) return 0;
if (numa_node_index >= global_numa_topology.nodes_count()) return 0;

size_t total_volume = 0;
auto const &node = global_numa_topology.node(numa_node_index);
for (auto const &page_size : node.page_sizes) total_volume += page_size.bytes_per_page * page_size.free_pages;
Expand Down Expand Up @@ -590,17 +593,17 @@ void fu_pool_for_slices(fu_pool_t *pool, size_t n, fu_for_slices_t callback, fu_
#pragma region - Flexible API

void fu_pool_unsafe_for_threads(fu_pool_t *pool, fu_for_threads_t callback, fu_lambda_context_t context) {
assert(pool != nullptr && callback != nullptr);
if (pool == nullptr || callback == nullptr) return;
opaque_pool_t *opaque = upcast_pool(pool);
opaque->current_context = context;
opaque->current_callback = callback;
visit([&](auto &variant) { variant.unsafe_for_threads(*opaque); }, opaque->variants);
}

void fu_pool_unsafe_join(fu_pool_t *pool) {
assert(pool != nullptr);
if (pool == nullptr) return;
opaque_pool_t *opaque = upcast_pool(pool);
assert(opaque->current_context != nullptr);
if (opaque->current_context == nullptr) return; // No broadcast was issued
visit([](auto &variant) { variant.unsafe_join(); }, opaque->variants);
opaque->current_context = nullptr;
opaque->current_callback = nullptr;
Expand Down
43 changes: 27 additions & 16 deletions rust/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -571,6 +571,7 @@ unsafe impl Send for ThreadPool {}
unsafe impl Sync for ThreadPool {}

impl ThreadPool {
#[must_use]
pub fn try_spawn_with_exclusivity(
threads: usize,
exclusivity: CallerExclusivity,
Expand All @@ -587,18 +588,20 @@ impl ThreadPool {
return Err(Error::InvalidParameter);
}

unsafe {
let name_ptr = if let Some(name_str) = name {
let mut name_buffer = [0u8; 16];
let name_bytes = name_str.as_bytes();
let copy_len = core::cmp::min(name_bytes.len(), 15); // Leave space for null terminator
name_buffer[..copy_len].copy_from_slice(&name_bytes[..copy_len]);
// name_buffer[copy_len] is already 0 from initialization
name_buffer.as_ptr() as *const c_char
} else {
core::ptr::null()
};
// SAFETY: Buffer must outlive the fu_pool_new call, so we declare it here
// before the unsafe block to ensure it lives long enough.
let mut name_buffer = [0u8; 16];
let name_ptr = if let Some(name_str) = name {
let name_bytes = name_str.as_bytes();
let copy_len = core::cmp::min(name_bytes.len(), 15); // Leave space for null terminator
name_buffer[..copy_len].copy_from_slice(&name_bytes[..copy_len]);
// name_buffer[copy_len] is already 0 from initialization
name_buffer.as_ptr() as *const c_char
} else {
core::ptr::null()
};

unsafe {
let inner = fu_pool_new(name_ptr);
if inner.is_null() {
return Err(Error::CreationFailed);
Expand Down Expand Up @@ -632,6 +635,7 @@ impl ThreadPool {
/// let pool = ThreadPool::try_spawn(4).expect("Failed to create thread pool");
/// assert_eq!(pool.threads(), 4);
/// ```
#[must_use]
pub fn try_spawn(threads: usize) -> Result<Self, Error> {
Self::try_spawn_with_exclusivity(threads, CallerExclusivity::Inclusive)
}
Expand All @@ -655,6 +659,7 @@ impl ThreadPool {
/// let pool = ThreadPool::try_named_spawn("worker_pool", 4).expect("Failed to create thread pool");
/// assert_eq!(pool.threads(), 4);
/// ```
#[must_use]
pub fn try_named_spawn(name: &str, threads: usize) -> Result<Self, Error> {
Self::try_named_spawn_with_exclusivity(Some(name), threads, CallerExclusivity::Inclusive)
}
Expand Down Expand Up @@ -2333,23 +2338,24 @@ impl<T> RoundRobinVec<T> {
/// }
///
/// // Fill all vectors with random values
/// rr_vec.fill_with(|| rand::random::<i32>(), &mut pool);
/// rr_vec.fill_with(|| 42, &mut pool);
/// ```
pub fn fill_with<F>(&mut self, mut f: F, pool: &mut ThreadPool)
pub fn fill_with<F>(&mut self, f: F, pool: &mut ThreadPool)
where
F: FnMut() -> T + Send + Sync,
F: Fn() -> T + Send + Sync,
T: Send + Sync,
{
let colocations_count = self.colocations_count();
let safe_ptr = SafePtr(self.colocations.as_mut_ptr());
let f_ptr = SafePtr(&mut f as *mut F);
let f_ptr = SyncConstPtr::new(&f as *const F);
let pool_ptr = SafePtr(pool as *const ThreadPool as *mut ThreadPool);

pool.for_threads(move |thread_index, colocation_index| {
if colocation_index < colocations_count {
// Get the specific pinned vector for this NUMA node
let node_vec = safe_ptr.get_mut_at(colocation_index);
let f_ref = f_ptr.get_mut();
// SAFETY: f is Fn (not FnMut), so concurrent calls are safe
let f_ref = unsafe { &*f_ptr.as_ptr() };
let pool = pool_ptr.get_mut();

let threads_in_colocation = pool.count_threads_in(colocation_index);
Expand Down Expand Up @@ -2614,11 +2620,13 @@ impl<T> SyncConstPtr<T> {
/// # Returns
///
/// A reference to the element at the given index.
#[inline]
pub unsafe fn get(&self, index: usize) -> &T {
&*self.ptr.add(index)
}

/// Returns the raw pointer.
#[inline]
pub fn as_ptr(&self) -> *const T {
self.ptr
}
Expand Down Expand Up @@ -2650,10 +2658,12 @@ impl<T> SyncMutPtr<T> {
/// - No overlapping mutable access occurs from multiple threads
/// - Each thread accesses disjoint indices when used concurrently
/// - The pointer remains valid for the duration of access
#[inline]
pub unsafe fn get(&self, index: usize) -> *mut T {
self.ptr.add(index)
}

#[inline]
pub fn as_ptr(&self) -> *mut T {
self.ptr
}
Expand Down Expand Up @@ -3708,6 +3718,7 @@ impl IndexedSplit {
}

/// Returns the range for a specific thread index.
#[inline]
pub fn get(&self, thread_index: usize) -> core::ops::Range<usize> {
let begin = self.quotient * thread_index + thread_index.min(self.remainder);
let count = self.quotient + if thread_index < self.remainder { 1 } else { 0 };
Expand Down