Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add full_blocking feature to thread pool #1175

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
26 changes: 26 additions & 0 deletions rayon-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,9 @@ pub struct ThreadPoolBuilder<S = DefaultSpawn> {
/// Closure invoked on worker thread exit.
exit_handler: Option<Box<ExitHandler>>,

/// Affects the blocking/work-stealing behavior when using nested thread pools.
full_blocking: bool,

/// Closure invoked to spawn threads.
spawn_handler: S,

Expand Down Expand Up @@ -245,6 +248,7 @@ impl Default for ThreadPoolBuilder {
exit_handler: None,
spawn_handler: DefaultSpawn,
breadth_first: false,
full_blocking: false,
}
}
}
Expand Down Expand Up @@ -455,6 +459,7 @@ impl<S> ThreadPoolBuilder<S> {
start_handler: self.start_handler,
exit_handler: self.exit_handler,
breadth_first: self.breadth_first,
full_blocking: self.full_blocking,
}
}

Expand Down Expand Up @@ -672,6 +677,25 @@ impl<S> ThreadPoolBuilder<S> {
self.exit_handler = Some(Box::new(exit_handler));
self
}

/// Changes the behavior of nested thread pools.
///
/// If false, when a job is created on this thread pool by a job running in a separate thread
/// pool, the parent thread is allowed to start executing a new job in the parent thread pool.
///
/// If true, when a job is created on this thread pool by a job running in a separate thread
/// pool, the parent thread will block until the jobs in this thread pool are completed. This
/// is useful for avoiding deadlock when using mutexes.
///
/// Default is false.
pub fn full_blocking(mut self) -> Self {
self.full_blocking = true;
self
}
nhukc marked this conversation as resolved.
Show resolved Hide resolved

fn get_full_blocking(&self) -> bool {
self.full_blocking
}
}

#[allow(deprecated)]
Expand Down Expand Up @@ -811,6 +835,7 @@ impl<S> fmt::Debug for ThreadPoolBuilder<S> {
ref exit_handler,
spawn_handler: _,
ref breadth_first,
ref full_blocking,
} = *self;

// Just print `Some(<closure>)` or `None` to the debug
Expand All @@ -835,6 +860,7 @@ impl<S> fmt::Debug for ThreadPoolBuilder<S> {
.field("start_handler", &start_handler)
.field("exit_handler", &exit_handler)
.field("breadth_first", &breadth_first)
.field("full_blocking", &full_blocking)
.finish()
}
}
Expand Down
32 changes: 31 additions & 1 deletion rayon-core/src/registry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ pub(super) struct Registry {
panic_handler: Option<Box<PanicHandler>>,
start_handler: Option<Box<StartHandler>>,
exit_handler: Option<Box<ExitHandler>>,
full_blocking: bool,

// When this latch reaches 0, it means that all work on this
// registry must be complete. This is ensured in the following ways:
Expand Down Expand Up @@ -267,6 +268,7 @@ impl Registry {
panic_handler: builder.take_panic_handler(),
start_handler: builder.take_start_handler(),
exit_handler: builder.take_exit_handler(),
full_blocking: builder.get_full_blocking(),
});

// If we return early or panic, make sure to terminate existing threads.
Expand Down Expand Up @@ -493,7 +495,11 @@ impl Registry {
if worker_thread.is_null() {
self.in_worker_cold(op)
} else if (*worker_thread).registry().id() != self.id() {
self.in_worker_cross(&*worker_thread, op)
if self.full_blocking {
self.in_worker_cross_blocking(op)
} else {
self.in_worker_cross(&*worker_thread, op)
}
} else {
// Perfectly valid to give them a `&T`: this is the
// current thread, so we know the data structure won't be
Expand Down Expand Up @@ -552,6 +558,30 @@ impl Registry {
job.into_result()
}

#[cold]
unsafe fn in_worker_cross_blocking<OP, R>(&self, op: OP) -> R
where
OP: FnOnce(&WorkerThread, bool) -> R + Send,
R: Send,
{
thread_local!(static LOCK_LATCH: LockLatch = LockLatch::new());

LOCK_LATCH.with(|l| {
let job = StackJob::new(
|injected| {
let worker_thread = WorkerThread::current();
assert!(injected && !worker_thread.is_null());
op(&*worker_thread, true)
},
LatchRef::new(l),
);
self.inject(job.as_job_ref());
job.latch.wait_and_reset(); // Make sure we can use the same latch again next time.

job.into_result()
})
}

/// Increments the terminate counter. This increment should be
/// balanced by a call to `terminate`, which will decrement. This
/// is used when spawning asynchronous work, which needs to
Expand Down
45 changes: 45 additions & 0 deletions rayon-core/src/thread_pool/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::mpsc::channel;
use std::sync::{Arc, Mutex};
use std::thread;
use std::time::{Duration, Instant};

use crate::{join, Scope, ScopeFifo, ThreadPool, ThreadPoolBuilder};

Expand Down Expand Up @@ -416,3 +418,46 @@ fn yield_local_to_spawn() {
// for it to finish if a different thread stole it first.
assert_eq!(22, rx.recv().unwrap());
}

#[test]
fn nested_thread_pools_deadlock() {
let global_pool = ThreadPoolBuilder::new().num_threads(1).build().unwrap();
// The lock thread pool must be full_blocking for this test to pass.
let lock_pool = Arc::new(
ThreadPoolBuilder::new()
.full_blocking()
nhukc marked this conversation as resolved.
Show resolved Hide resolved
.num_threads(1)
.build()
.unwrap(),
);
let mutex = Arc::new(Mutex::new(()));
let start_time = Instant::now();

global_pool.scope(|s| {
for i in 0..5 {
let mutex = mutex.clone();
let lock_pool = lock_pool.clone();
// Create 5 jobs that try to acquire the lock.
// If all 5 jobs are unable the acquire the lock in 2 seconds, deadlock occurred.
s.spawn(move |_| {
let mut acquired = false;
while start_time.elapsed() < Duration::from_secs(2) {
if let Ok(_guard) = mutex.try_lock() {
println!("Thread {i} acquired the mutex");
lock_pool.scope(|lock_s| {
lock_s.spawn(|_| {
thread::sleep(Duration::from_millis(100));
});
});
acquired = true;
break;
}
thread::sleep(Duration::from_millis(10));
}
if !acquired {
panic!("Thread {i} failed to acquire the mutex within 2 seconds.");
}
});
}
});
}
25 changes: 25 additions & 0 deletions tests/issue592.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
use std::sync::{Arc, Mutex};
use rayon::ThreadPoolBuilder;
use rayon::iter::IntoParallelRefIterator;
use rayon::iter::ParallelIterator;

fn mutex_and_par(mutex: Arc<Mutex<Vec<i32>>>, blocking_pool: &rayon::ThreadPool) {
// Lock the mutex and collect items using the full blocking thread pool
let vec = mutex.lock().unwrap();
let result: Vec<i32> = blocking_pool.install(|| vec.par_iter().cloned().collect());
println!("{:?}", result);
}

#[test]
fn test_issue592() {
let collection = vec![1, 2, 3, 4, 5];
let mutex = Arc::new(Mutex::new(collection));

let blocking_pool = ThreadPoolBuilder::new().full_blocking().num_threads(4).build().unwrap();
nhukc marked this conversation as resolved.
Show resolved Hide resolved

let dummy_collection: Vec<i32> = (1..=100).collect();
dummy_collection.par_iter().for_each(|_| {
mutex_and_par(mutex.clone(), &blocking_pool);
});
}