Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ concurrent-queue = "2.0.0"
fastrand = "2.0.0"
futures-lite = { version = "2.0.0", default-features = false }
slab = "0.4.4"
thread_local = { git = "https://github.com/james7132/thread_local-rs", branch = "fix-iter-ub" }

[target.'cfg(target_family = "wasm")'.dependencies]
futures-lite = { version = "2.0.0", default-features = false, features = ["std"] }
Expand Down
71 changes: 43 additions & 28 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
use std::fmt;
use std::future::Future;
use std::marker::PhantomData;
use std::ops::Deref;
use std::panic::{RefUnwindSafe, UnwindSafe};
use std::rc::Rc;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
Expand All @@ -47,6 +48,7 @@ use async_task::{Builder, Runnable};
use concurrent_queue::ConcurrentQueue;
use futures_lite::{future, prelude::*};
use slab::Slab;
use thread_local::ThreadLocal;

#[doc(no_inline)]
pub use async_task::Task;
Expand Down Expand Up @@ -265,8 +267,17 @@ impl<'a> Executor<'a> {
fn schedule(&self) -> impl Fn(Runnable) + Send + Sync + 'static {
let state = self.state().clone();

// TODO: If possible, push into the current local queue and notify the ticker.
move |runnable| {
move |mut runnable| {
// If possible, push into the current local queue and notify the ticker.
let local_queue = state.local_queues.get();
if let Some(queue) = local_queue {
runnable = if let Err(err) = queue.push(runnable) {
err.into_inner()
} else {
state.notify();
return;
}
}
state.queue.push(runnable).unwrap();
state.notify();
}
Expand Down Expand Up @@ -508,7 +519,7 @@ struct State {
queue: ConcurrentQueue<Runnable>,

/// Local queues created by runners.
local_queues: RwLock<Vec<Arc<ConcurrentQueue<Runnable>>>>,
local_queues: ThreadLocal<LocalQueue>,

/// Set to `true` when a sleeping ticker is notified or no tickers are sleeping.
notified: AtomicBool,
Expand All @@ -525,7 +536,7 @@ impl State {
fn new() -> State {
State {
queue: ConcurrentQueue::unbounded(),
local_queues: RwLock::new(Vec::new()),
local_queues: ThreadLocal::new(),
notified: AtomicBool::new(true),
sleepers: Mutex::new(Sleepers {
count: 0,
Expand Down Expand Up @@ -756,9 +767,6 @@ struct Runner<'a> {
/// Inner ticker.
ticker: Ticker<'a>,

/// The local queue.
local: Arc<ConcurrentQueue<Runnable>>,

/// Bumped every time a runnable task is found.
ticks: AtomicUsize,
}
Expand All @@ -769,38 +777,34 @@ impl Runner<'_> {
let runner = Runner {
state,
ticker: Ticker::new(state),
local: Arc::new(ConcurrentQueue::bounded(512)),
ticks: AtomicUsize::new(0),
};
state
.local_queues
.write()
.unwrap()
.push(runner.local.clone());
runner
}

/// Waits for the next runnable task to run.
async fn runnable(&self, rng: &mut fastrand::Rng) -> Runnable {
let local_queue = self.state.local_queues.get_or_default();

let runnable = self
.ticker
.runnable_with(|| {
// Try the local queue.
if let Ok(r) = self.local.pop() {
if let Ok(r) = local_queue.pop() {
return Some(r);
}

// Try stealing from the global queue.
if let Ok(r) = self.state.queue.pop() {
steal(&self.state.queue, &self.local);
steal(&self.state.queue, local_queue);
return Some(r);
}

// Try stealing from other runners.
let local_queues = self.state.local_queues.read().unwrap();
let local_queues = &self.state.local_queues;

// Pick a random starting point in the iterator list and rotate the list.
let n = local_queues.len();
let n = local_queues.iter().count();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a cold operation? It seems like this would take a while.

Copy link
Contributor Author

@james7132 james7132 Feb 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is one part I'm not so sure about. Generally this shouldn't be under contention, since the cost to spin up new threads is going to be higher than it is to scan over the entire container, unless you have literally thousands of threads. It otherwise is just a scan through fairly small buckets.

We could use an atomic counter to track how many there are, but since you can't remove items from the ThreadLocal, there will be residual thread locals from currently unused threads (as thread IDs are reused), that may get out of sync.

let start = rng.usize(..n);
let iter = local_queues
.iter()
Expand All @@ -809,12 +813,12 @@ impl Runner<'_> {
.take(n);

// Remove this runner's local queue.
let iter = iter.filter(|local| !Arc::ptr_eq(local, &self.local));
let iter = iter.filter(|local| !core::ptr::eq(*local, local_queue));

// Try stealing from each local queue in the list.
for local in iter {
steal(local, &self.local);
if let Ok(r) = self.local.pop() {
steal(local, local_queue);
if let Ok(r) = local_queue.pop() {
return Some(r);
}
}
Expand All @@ -828,7 +832,7 @@ impl Runner<'_> {

if ticks % 64 == 0 {
// Steal tasks from the global queue to ensure fair task scheduling.
steal(&self.state.queue, &self.local);
steal(&self.state.queue, local_queue);
}

runnable
Expand All @@ -838,14 +842,10 @@ impl Runner<'_> {
impl Drop for Runner<'_> {
fn drop(&mut self) {
// Remove the local queue.
self.state
.local_queues
.write()
.unwrap()
.retain(|local| !Arc::ptr_eq(local, &self.local));
let local_queue = self.state.local_queues.get_or_default();

// Re-schedule remaining tasks in the local queue.
while let Ok(r) = self.local.pop() {
while let Ok(r) = local_queue.pop() {
r.schedule();
}
}
Expand Down Expand Up @@ -937,11 +937,26 @@ fn debug_executor(executor: &Executor<'_>, name: &str, f: &mut fmt::Formatter<'_
f.debug_struct(name)
.field("active", &ActiveTasks(&state.active))
.field("global_tasks", &state.queue.len())
.field("local_runners", &LocalRunners(&state.local_queues))
.field("sleepers", &SleepCount(&state.sleepers))
.finish()
}

struct LocalQueue(ConcurrentQueue<Runnable>);

impl Default for LocalQueue {
fn default() -> Self {
Self(ConcurrentQueue::bounded(512))
}
}

impl Deref for LocalQueue {
type Target = ConcurrentQueue<Runnable>;

fn deref(&self) -> &Self::Target {
&self.0
}
}

/// Runs a closure when dropped.
struct CallOnDrop<F: FnMut()>(F);

Expand Down