Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Thread throttling #1167

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions rayon-core/src/latch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::marker::PhantomData;
use std::ops::Deref;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::usize;

use crate::registry::{Registry, WorkerThread};
use crate::sync::{Condvar, Mutex};
Expand Down Expand Up @@ -269,6 +270,49 @@ impl Latch for LockLatch {
}
}

/// A Latch starts as false and can be toggled multipe times. One can block
/// until it becomes true and get the value
#[derive(Debug)]
pub(super) struct ToggleLatch {
m: Mutex<bool>,
v: Condvar,
}

impl ToggleLatch {
#[inline]
pub(super) fn new() -> ToggleLatch {
ToggleLatch {
m: Mutex::new(false),
v: Condvar::new(),
}
}

pub(super) fn get(&self) -> bool {
let guard = self.m.lock().unwrap();
return *guard;
}

/// Block until latch is set.
pub(super) fn wait(&self) {
let mut guard = self.m.lock().unwrap();
while !*guard {
guard = self.v.wait(guard).unwrap();
}
}
}

impl Latch for ToggleLatch {
#[inline]
unsafe fn set(this: *const Self) {
let mut guard = (*this).m.lock().unwrap();
*guard = !*guard;
if *guard {
(*this).v.notify_all();
}
}
}


/// Once latches are used to implement one-time blocking, primarily
/// for the termination flag of the threads in the pool.
///
Expand Down
4 changes: 3 additions & 1 deletion rayon-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@
//! conflicting requirements will need to be resolved before the build will
//! succeed.

#![deny(missing_debug_implementations)]
// TODO
//#![deny(missing_debug_implementations)]
#![deny(missing_docs)]
#![deny(unreachable_pub)]
#![warn(rust_2018_idioms)]
Expand Down Expand Up @@ -102,6 +103,7 @@ pub use self::thread_pool::current_thread_has_pending_tasks;
pub use self::thread_pool::current_thread_index;
pub use self::thread_pool::ThreadPool;
pub use self::thread_pool::{yield_local, yield_now, Yield};
pub use self::registry::Registry;

#[cfg(not(feature = "web_spin_lock"))]
use std::sync;
Expand Down
95 changes: 89 additions & 6 deletions rayon-core/src/registry.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::job::{JobFifo, JobRef, StackJob};
use crate::latch::{AsCoreLatch, CoreLatch, Latch, LatchRef, LockLatch, OnceLatch, SpinLatch};
use crate::latch::{AsCoreLatch, CoreLatch, Latch, LatchRef, LockLatch, OnceLatch, SpinLatch, ToggleLatch};
use crate::sleep::Sleep;
use crate::sync::Mutex;
use crate::unwind;
Expand All @@ -18,6 +18,7 @@ use std::ptr;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, Once};
use std::thread;
use std::usize;

/// Thread builder used for customization via
/// [`ThreadPoolBuilder::spawn_handler`](struct.ThreadPoolBuilder.html#method.spawn_handler).
Expand Down Expand Up @@ -127,7 +128,8 @@ where
}
}

pub(super) struct Registry {
/// The Registry
pub struct Registry {
thread_infos: Vec<ThreadInfo>,
sleep: Sleep,
injected_jobs: Injector<JobRef>,
Expand Down Expand Up @@ -311,7 +313,57 @@ impl Registry {
Ok(registry)
}

pub(super) fn current() -> Arc<Registry> {
/// Block `num` threads
pub(crate) fn block_threads(&self, num: usize) {
// reverse so we reach thread 0 last (wich we should *never!* block)
let unblocked_threads = self.thread_infos.iter().rev().filter(|&p| {
!p.blocked.get()
});

for (i, thread) in unblocked_threads.enumerate() {
if (i + 1) >= num { // do not block thread with id 0 or the programm will be stalled
return;
}
unsafe { Latch::set(&thread.blocked); }; // toggles the blocked latch to block
}
}

/// Unblock `num` threads
pub fn unblock_threads(&self, num: usize) {
let blocked_threads = self.thread_infos.iter().filter(|&p| {
p.blocked.get()
});

for (i, thread) in blocked_threads.enumerate() {
if i >= num {
return;
}
unsafe { Latch::set(&thread.blocked); }; // toggles the blocked latch to unblock
}
}

/// Adjust so `num` threads are unblocked
pub fn adjust_blocked_threads(&self, num: usize) {
let unblocked_threads = self.thread_infos.iter().filter(|&p| {
!p.blocked.get()
});

let unblocked_threads = unblocked_threads.count();

match unblocked_threads.cmp(&num) {
std::cmp::Ordering::Less => {
self.unblock_threads(num - unblocked_threads)
},
std::cmp::Ordering::Greater => {
self.block_threads(unblocked_threads - num) },
std::cmp::Ordering::Equal => {
return;
},
};
}

/// get the global registry
pub fn current() -> Arc<Registry> {
unsafe {
let worker_thread = WorkerThread::current();
let registry = if worker_thread.is_null() {
Expand Down Expand Up @@ -359,7 +411,9 @@ impl Registry {
}

pub(super) fn num_threads(&self) -> usize {
self.thread_infos.len()
self.thread_infos.iter().filter(|&p | {
!p.blocked.get()
}).count()
}

pub(super) fn catch_unwind(&self, f: impl FnOnce()) {
Expand All @@ -373,6 +427,15 @@ impl Registry {
}
}

/// Waits for the worker threads to be unblocked
pub(super) fn wait_until_unblocked(&self, index: usize) {
self.thread_infos[index].blocked.wait();
for info in &self.thread_infos {
info.blocked.wait();
}
}


/// Waits for the worker threads to get up and running. This is
/// meant to be used for benchmarking purposes, primarily, so that
/// you can get more consistent numbers by having everything
Expand Down Expand Up @@ -405,6 +468,8 @@ impl Registry {
let worker_thread = WorkerThread::current();
unsafe {
if !worker_thread.is_null() && (*worker_thread).registry().id() == self.id() {
// wait if we are blocked
(*worker_thread).registry().wait_until_unblocked((*worker_thread).index());
(*worker_thread).push(job_ref);
} else {
self.inject(job_ref);
Expand Down Expand Up @@ -456,6 +521,12 @@ impl Registry {
assert_eq!(self.num_threads(), injected_jobs.len());
{
let broadcasts = self.broadcasts.lock().unwrap();
let filtered_broadcasts = broadcasts.iter().zip(&self.thread_infos).filter(|(_, info)| {
!&info.blocked.get()
})
.map(|(worker, _)| {
worker
});

// It should not be possible for `state.terminate` to be true
// here. It is only set to true when the user creates (and
Expand All @@ -468,8 +539,17 @@ impl Registry {
"inject_broadcast() sees state.terminate as true"
);

assert_eq!(broadcasts.len(), injected_jobs.len());
for (worker, job_ref) in broadcasts.iter().zip(injected_jobs) {
// TODO, can't use count without move, so we reconstruct it...
// should better be coded better...
assert_eq!(broadcasts.iter().zip(&self.thread_infos).filter(|(_, info)| {
!&info.blocked.get()
})
.map(|(worker, _)| {
worker
}).count(), injected_jobs.len());

for (worker, job_ref) in filtered_broadcasts
.zip(injected_jobs) {
worker.push(job_ref);
}
}
Expand Down Expand Up @@ -618,6 +698,8 @@ struct ThreadInfo {

/// the "stealer" half of the worker's deque
stealer: Stealer<JobRef>,

blocked: ToggleLatch,
}

impl ThreadInfo {
Expand All @@ -626,6 +708,7 @@ impl ThreadInfo {
primed: LockLatch::new(),
stopped: LockLatch::new(),
terminate: OnceLatch::new(),
blocked: ToggleLatch::new(),
stealer,
}
}
Expand Down
18 changes: 17 additions & 1 deletion rayon-core/src/thread_pool/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use crate::{scope, Scope};
use crate::{scope_fifo, ScopeFifo};
use crate::{ThreadPoolBuildError, ThreadPoolBuilder};
use std::error::Error;
use std::fmt;
use std::{fmt, usize};
use std::sync::Arc;

mod test;
Expand Down Expand Up @@ -349,6 +349,22 @@ impl ThreadPool {
unsafe { spawn::spawn_in(op, &self.registry) }
}

/// Block `num` threads
pub fn block_threads(&self, num: usize) {
self.registry.block_threads(num);
}

/// Unblock `num` threads
pub fn unblock_threads(&self, num: usize) {
self.registry.unblock_threads(num);
}

/// Adjust so `num` threads are unblocked
pub fn adjust_blocked_threads_threads(&self, num: usize) {
self.registry.adjust_blocked_threads(num);
}


/// Spawns an asynchronous task in this thread-pool. This task will
/// run in the implicit, global scope, which means that it may outlast
/// the current stack frame -- therefore, it cannot capture any references
Expand Down