diff --git a/rayon-core/src/lib.rs b/rayon-core/src/lib.rs index 03456b3ee..811bf4ad8 100644 --- a/rayon-core/src/lib.rs +++ b/rayon-core/src/lib.rs @@ -200,6 +200,9 @@ pub struct ThreadPoolBuilder { /// Closure invoked on worker thread exit. exit_handler: Option>, + /// Affects the blocking/work-stealing behavior when using nested thread pools. + full_blocking: bool, + /// Closure invoked to spawn threads. spawn_handler: S, @@ -245,6 +248,7 @@ impl Default for ThreadPoolBuilder { exit_handler: None, spawn_handler: DefaultSpawn, breadth_first: false, + full_blocking: false, } } } @@ -455,6 +459,7 @@ impl ThreadPoolBuilder { start_handler: self.start_handler, exit_handler: self.exit_handler, breadth_first: self.breadth_first, + full_blocking: self.full_blocking, } } @@ -672,6 +677,25 @@ impl ThreadPoolBuilder { self.exit_handler = Some(Box::new(exit_handler)); self } + + /// Changes the behavior of nested thread pools. + /// + /// If false, when a job is created on this thread pool by a job running in a separate thread + /// pool, the parent thread is allowed to start executing a new job in the parent thread pool. + /// + /// If true, when a job is created on this thread pool by a job running in a separate thread + /// pool, the parent thread will block until the jobs in this thread pool are completed. This + /// is useful for avoiding deadlock when using mutexes. + /// + /// Default is false. + pub fn full_blocking(mut self) -> Self { + self.full_blocking = true; + self + } + + fn get_full_blocking(&self) -> bool { + self.full_blocking + } } #[allow(deprecated)] @@ -811,6 +835,7 @@ impl fmt::Debug for ThreadPoolBuilder { ref exit_handler, spawn_handler: _, ref breadth_first, + ref full_blocking, } = *self; // Just print `Some()` or `None` to the debug @@ -835,6 +860,7 @@ impl fmt::Debug for ThreadPoolBuilder { .field("start_handler", &start_handler) .field("exit_handler", &exit_handler) .field("breadth_first", &breadth_first) + .field("full_blocking", &full_blocking) .finish() } } diff --git a/rayon-core/src/registry.rs b/rayon-core/src/registry.rs index d30f815bd..7db367224 100644 --- a/rayon-core/src/registry.rs +++ b/rayon-core/src/registry.rs @@ -135,6 +135,7 @@ pub(super) struct Registry { panic_handler: Option>, start_handler: Option>, exit_handler: Option>, + full_blocking: bool, // When this latch reaches 0, it means that all work on this // registry must be complete. This is ensured in the following ways: @@ -267,6 +268,7 @@ impl Registry { panic_handler: builder.take_panic_handler(), start_handler: builder.take_start_handler(), exit_handler: builder.take_exit_handler(), + full_blocking: builder.get_full_blocking(), }); // If we return early or panic, make sure to terminate existing threads. @@ -493,7 +495,11 @@ impl Registry { if worker_thread.is_null() { self.in_worker_cold(op) } else if (*worker_thread).registry().id() != self.id() { - self.in_worker_cross(&*worker_thread, op) + if self.full_blocking { + self.in_worker_cross_blocking(op) + } else { + self.in_worker_cross(&*worker_thread, op) + } } else { // Perfectly valid to give them a `&T`: this is the // current thread, so we know the data structure won't be @@ -552,6 +558,30 @@ impl Registry { job.into_result() } + #[cold] + unsafe fn in_worker_cross_blocking(&self, op: OP) -> R + where + OP: FnOnce(&WorkerThread, bool) -> R + Send, + R: Send, + { + thread_local!(static LOCK_LATCH: LockLatch = LockLatch::new()); + + LOCK_LATCH.with(|l| { + let job = StackJob::new( + |injected| { + let worker_thread = WorkerThread::current(); + assert!(injected && !worker_thread.is_null()); + op(&*worker_thread, true) + }, + LatchRef::new(l), + ); + self.inject(job.as_job_ref()); + job.latch.wait_and_reset(); // Make sure we can use the same latch again next time. + + job.into_result() + }) + } + /// Increments the terminate counter. This increment should be /// balanced by a call to `terminate`, which will decrement. This /// is used when spawning asynchronous work, which needs to diff --git a/rayon-core/src/thread_pool/test.rs b/rayon-core/src/thread_pool/test.rs index 88b36282d..811125aaf 100644 --- a/rayon-core/src/thread_pool/test.rs +++ b/rayon-core/src/thread_pool/test.rs @@ -3,6 +3,8 @@ use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::mpsc::channel; use std::sync::{Arc, Mutex}; +use std::thread; +use std::time::{Duration, Instant}; use crate::{join, Scope, ScopeFifo, ThreadPool, ThreadPoolBuilder}; @@ -416,3 +418,46 @@ fn yield_local_to_spawn() { // for it to finish if a different thread stole it first. assert_eq!(22, rx.recv().unwrap()); } + +#[test] +fn nested_thread_pools_deadlock() { + let global_pool = ThreadPoolBuilder::new().num_threads(1).build().unwrap(); + // The lock thread pool must be full_blocking for this test to pass. + let lock_pool = Arc::new( + ThreadPoolBuilder::new() + .full_blocking() + .num_threads(1) + .build() + .unwrap(), + ); + let mutex = Arc::new(Mutex::new(())); + let start_time = Instant::now(); + + global_pool.scope(|s| { + for i in 0..5 { + let mutex = mutex.clone(); + let lock_pool = lock_pool.clone(); + // Create 5 jobs that try to acquire the lock. + // If all 5 jobs are unable the acquire the lock in 2 seconds, deadlock occurred. + s.spawn(move |_| { + let mut acquired = false; + while start_time.elapsed() < Duration::from_secs(2) { + if let Ok(_guard) = mutex.try_lock() { + println!("Thread {i} acquired the mutex"); + lock_pool.scope(|lock_s| { + lock_s.spawn(|_| { + thread::sleep(Duration::from_millis(100)); + }); + }); + acquired = true; + break; + } + thread::sleep(Duration::from_millis(10)); + } + if !acquired { + panic!("Thread {i} failed to acquire the mutex within 2 seconds."); + } + }); + } + }); +} diff --git a/tests/issue592.rs b/tests/issue592.rs new file mode 100644 index 000000000..8b5409938 --- /dev/null +++ b/tests/issue592.rs @@ -0,0 +1,25 @@ +use std::sync::{Arc, Mutex}; +use rayon::ThreadPoolBuilder; +use rayon::iter::IntoParallelRefIterator; +use rayon::iter::ParallelIterator; + +fn mutex_and_par(mutex: Arc>>, blocking_pool: &rayon::ThreadPool) { + // Lock the mutex and collect items using the full blocking thread pool + let vec = mutex.lock().unwrap(); + let result: Vec = blocking_pool.install(|| vec.par_iter().cloned().collect()); + println!("{:?}", result); +} + +#[test] +fn test_issue592() { + let collection = vec![1, 2, 3, 4, 5]; + let mutex = Arc::new(Mutex::new(collection)); + + let blocking_pool = ThreadPoolBuilder::new().full_blocking().num_threads(4).build().unwrap(); + + let dummy_collection: Vec = (1..=100).collect(); + dummy_collection.par_iter().for_each(|_| { + mutex_and_par(mutex.clone(), &blocking_pool); + }); +} +