From fbb8d01498816ae2bc924ca8c3bd53d540566d18 Mon Sep 17 00:00:00 2001 From: Al Liu Date: Sat, 17 Feb 2024 02:22:46 +0800 Subject: [PATCH] Add `tokio` specific `AsyncWaitGroup` --- .github/workflows/ci.yml | 19 +++ Cargo.toml | 19 ++- README.md | 67 ++++++-- examples/future.rs | 30 ++++ examples/{axync.rs => tokio.rs} | 4 +- src/future.rs | 294 ++++++++------------------------ src/lib.rs | 13 +- src/tokio.rs | 267 +++++++++++++++++++++++++++++ tests/future.rs | 162 ++++++++++++++++++ tests/tokio.rs | 162 ++++++++++++++++++ 10 files changed, 795 insertions(+), 242 deletions(-) create mode 100644 examples/future.rs rename examples/{axync.rs => tokio.rs} (88%) create mode 100644 src/tokio.rs create mode 100644 tests/future.rs create mode 100644 tests/tokio.rs diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 0bc8501..bfcd80b 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -88,6 +88,24 @@ jobs: run: rustup update stable --no-self-update && rustup default stable - name: Test run: cargo test --lib --no-default-features --features future + + tokio: + name: tokio + strategy: + matrix: + os: + - ubuntu-latest + - macos-latest + - windows-latest + runs-on: ${{ matrix.os }} + steps: + - uses: actions/checkout@v3 + - name: Install Rust + # --no-self-update is necessary because the windows environment cannot self-update rustup.exe. + run: rustup update stable --no-self-update && rustup default stable + - name: Test + run: cargo test --lib --no-default-features --features tokio + sync: name: sync strategy: @@ -109,6 +127,7 @@ jobs: name: cargo tarpaulin runs-on: ubuntu-latest needs: + - tokio - future - sync - build diff --git a/Cargo.toml b/Cargo.toml index 68ed178..ad6bdf0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,7 +6,7 @@ homepage = "https://github.com/al8n/wg" repository = "https://github.com/al8n/wg.git" documentation = "https://docs.rs/wg/" readme = "README.md" -version = "0.6.2" +version = "0.7.0" license = "MIT OR Apache-2.0" keywords = ["waitgroup", "async", "sync", "notify", "wake"] categories = ["asynchronous", "concurrency", "data-structures"] @@ -20,6 +20,8 @@ parking_lot = ["dep:parking_lot"] future = ["event-listener", "event-listener-strategy", "pin-project-lite"] +tokio = ["dep:tokio", "futures-core"] + [dependencies] parking_lot = { version = "0.12", optional = true } triomphe = { version = "0.1", optional = true } @@ -27,9 +29,24 @@ event-listener = { version = "5", optional = true } event-listener-strategy = { version = "0.5", optional = true } pin-project-lite = { version = "0.2", optional = true } +tokio = { version = "1", default-features = false, optional = true, features = ["sync", "rt"] } +futures-core = { version = "0.3", optional = true } + [dev-dependencies] tokio = { version = "1", features = ["full"] } +async-std = { version = "1", features = ["attributes"] } [package.metadata.docs.rs] all-features = true rustdoc-args = ["--cfg", "docsrs"] + +[[test]] +name = "tokio" +path = "tests/tokio.rs" +required-features = ["tokio"] + +[[test]] +name = "future" +path = "tests/future.rs" +required-features = ["future"] + diff --git a/README.md b/README.md index 30bfff1..fedd873 100644 --- a/README.md +++ b/README.md @@ -16,27 +16,41 @@ Golang like WaitGroup implementation for sync/async Rust. license -## Installation +## Introduction -By default, blocking version `WaitGroup` is enabled, if you want to use non-blocking `AsyncWaitGroup`, you need to -enbale `future` feature in your `Cargo.toml`. +By default, blocking version `WaitGroup` is enabled. + +If you are using `tokio`, you need to enable `tokio` feature in your `Cargo.toml` and use `wg::tokio::AsyncWaitGroup`. + +If you are using other async runtime, you need to +enbale `future` feature in your `Cargo.toml` and use `wg::future::AsyncWaitGroup`. ### Sync ```toml [dependencies] -wg = "0.6" +wg = "0.7" ``` -### Async +### `tokio` + +An async implementation for `tokio` runtime. ```toml [dependencies] -wg = { version: "0.6", features = ["future"] } +wg = { version: "0.7", features = ["tokio"] } ``` +### `future` -## Example +A more generic async implementation. + +```toml +[dependencies] +wg = { version: "0.7", features = ["future"] } +``` + +## Instruction ### Sync @@ -69,10 +83,10 @@ fn main() { } ``` -### Async +### `tokio` ```rust -use wg::AsyncWaitGroup; +use wg::tokio::AsyncWaitGroup; use std::sync::Arc; use std::sync::atomic::{AtomicUsize, Ordering}; use tokio::{spawn, time::{sleep, Duration}}; @@ -100,9 +114,42 @@ async fn main() { } ``` +### `async-io` + +```rust +use wg::future::AsyncWaitGroup; +use std::sync::Arc; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::time::Duration; +use async_std::task::{spawn, block_on, sleep}; + +fn main() { + block_on(async { + let wg = AsyncWaitGroup::new(); + let ctr = Arc::new(AtomicUsize::new(0)); + + for _ in 0..5 { + let ctrx = ctr.clone(); + let t_wg = wg.add(1); + spawn(async move { + // mock some time consuming task + sleep(Duration::from_millis(50)).await; + ctrx.fetch_add(1, Ordering::Relaxed); + + // mock task is finished + t_wg.done(); + }); + } + + wg.wait().await; + assert_eq!(ctr.load(Ordering::Relaxed), 5); + }); +} +``` + ## Acknowledgements -- Inspired by Golang sync.WaitGroup, [ibraheemdev's `AwaitGroup`] and [`crossbeam_utils::WaitGroup`]. +- Inspired by Golang sync.WaitGroup and [`crossbeam_utils::WaitGroup`]. ## License diff --git a/examples/future.rs b/examples/future.rs new file mode 100644 index 0000000..60ed1d5 --- /dev/null +++ b/examples/future.rs @@ -0,0 +1,30 @@ +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Arc; +use tokio::{ + spawn, + time::{sleep, Duration}, +}; +use wg::future::AsyncWaitGroup; + +fn main() { + async_std::task::block_on(async { + let wg = AsyncWaitGroup::new(); + let ctr = Arc::new(AtomicUsize::new(0)); + + for _ in 0..5 { + let ctrx = ctr.clone(); + let t_wg = wg.add(1); + spawn(async move { + // mock some time consuming task + sleep(Duration::from_millis(50)).await; + ctrx.fetch_add(1, Ordering::Relaxed); + + // mock task is finished + t_wg.done(); + }); + } + + wg.wait().await; + assert_eq!(ctr.load(Ordering::Relaxed), 5); + }); +} diff --git a/examples/axync.rs b/examples/tokio.rs similarity index 88% rename from examples/axync.rs rename to examples/tokio.rs index 017d337..f10bc6b 100644 --- a/examples/axync.rs +++ b/examples/tokio.rs @@ -4,9 +4,9 @@ use tokio::{ spawn, time::{sleep, Duration}, }; -use wg::AsyncWaitGroup; +use wg::tokio::AsyncWaitGroup; -#[tokio::main(flavor = "multi_thread", worker_threads = 10)] +#[tokio::main] async fn main() { let wg = AsyncWaitGroup::new(); let ctr = Arc::new(AtomicUsize::new(0)); diff --git a/src/future.rs b/src/future.rs index 4e4a71e..698f868 100644 --- a/src/future.rs +++ b/src/future.rs @@ -25,32 +25,32 @@ struct AsyncInner { /// # Example /// /// ```rust -/// use wg::AsyncWaitGroup; +/// use wg::future::AsyncWaitGroup; /// use std::sync::Arc; /// use std::sync::atomic::{AtomicUsize, Ordering}; -/// use tokio::{spawn, time::{sleep, Duration}}; +/// use std::time::Duration; +/// use async_std::task::{spawn, sleep}; /// -/// #[tokio::main] -/// async fn main() { -/// let wg = AsyncWaitGroup::new(); -/// let ctr = Arc::new(AtomicUsize::new(0)); +/// # async_std::task::block_on(async { +/// let wg = AsyncWaitGroup::new(); +/// let ctr = Arc::new(AtomicUsize::new(0)); /// -/// for _ in 0..5 { -/// let ctrx = ctr.clone(); -/// let t_wg = wg.add(1); -/// spawn(async move { -/// // mock some time consuming task -/// sleep(Duration::from_millis(50)).await; -/// ctrx.fetch_add(1, Ordering::Relaxed); +/// for _ in 0..5 { +/// let ctrx = ctr.clone(); +/// let t_wg = wg.add(1); +/// spawn(async move { +/// // mock some time consuming task +/// sleep(Duration::from_millis(50)).await; +/// ctrx.fetch_add(1, Ordering::Relaxed); /// -/// // mock task is finished -/// t_wg.done(); -/// }); -/// } -/// -/// wg.wait().await; -/// assert_eq!(ctr.load(Ordering::Relaxed), 5); +/// // mock task is finished +/// t_wg.done(); +/// }); /// } +/// +/// wg.wait().await; +/// assert_eq!(ctr.load(Ordering::Relaxed), 5); +/// # }) /// ``` /// /// [`wait`]: struct.AsyncWaitGroup.html#method.wait @@ -115,24 +115,25 @@ impl AsyncWaitGroup { /// new `add` calls must happen after all previous [`wait`] calls have returned. /// /// # Example + /// /// ```rust - /// use wg::AsyncWaitGroup; - /// - /// #[tokio::main] - /// async fn main() { - /// let wg = AsyncWaitGroup::new(); + /// use wg::future::AsyncWaitGroup; + /// use async_std::task::spawn; + /// + /// # async_std::task::block_on(async { + /// let wg = AsyncWaitGroup::new(); /// - /// wg.add(3); - /// (0..3).for_each(|_| { - /// let t_wg = wg.clone(); - /// tokio::spawn(async move { - /// // do some time consuming work - /// t_wg.done(); - /// }); + /// wg.add(3); + /// (0..3).for_each(|_| { + /// let t_wg = wg.clone(); + /// spawn(async move { + /// // do some time consuming work + /// t_wg.done(); /// }); + /// }); /// - /// wg.wait().await; - /// } + /// wg.wait().await; + /// # }) /// ``` /// /// [`wait`]: struct.AsyncWaitGroup.html#method.wait @@ -149,18 +150,18 @@ impl AsyncWaitGroup { /// # Example /// /// ```rust - /// use wg::AsyncWaitGroup; + /// use wg::future::AsyncWaitGroup; + /// use async_std::task::spawn; /// - /// #[tokio::main] - /// async fn main() { + /// # async_std::task::block_on(async { /// let wg = AsyncWaitGroup::new(); /// wg.add(1); /// let t_wg = wg.clone(); - /// tokio::spawn(async move { + /// spawn(async move { /// // do some time consuming task /// t_wg.done(); /// }); - /// } + /// # }) /// ``` pub fn done(&self) { if self.inner.counter.fetch_sub(1, Ordering::SeqCst) == 1 { @@ -178,22 +179,22 @@ impl AsyncWaitGroup { /// # Example /// /// ```rust - /// use wg::AsyncWaitGroup; - /// - /// #[tokio::main] - /// async fn main() { - /// let wg = AsyncWaitGroup::new(); - /// wg.add(1); - /// let t_wg = wg.clone(); + /// use wg::future::AsyncWaitGroup; + /// use async_std::task::spawn; + /// + /// # async_std::task::block_on(async { + /// let wg = AsyncWaitGroup::new(); + /// wg.add(1); + /// let t_wg = wg.clone(); /// - /// tokio::spawn( async move { - /// // do some time consuming task - /// t_wg.done() - /// }); + /// spawn(async move { + /// // do some time consuming task + /// t_wg.done() + /// }); /// - /// // wait other thread completes - /// wg.wait().await; - /// } + /// // wait other thread completes + /// wg.wait().await; + /// # }) /// ``` pub fn wait(&self) -> WaitGroupFuture<'_> { WaitGroupFuture::_new(WaitGroupFutureInner::new(&self.inner)) @@ -208,22 +209,22 @@ impl AsyncWaitGroup { /// # Example /// /// ```rust - /// use wg::AsyncWaitGroup; - /// - /// #[tokio::main] - /// async fn main() { - /// let wg = AsyncWaitGroup::new(); - /// wg.add(1); - /// let t_wg = wg.clone(); + /// use wg::future::AsyncWaitGroup; + /// use async_std::task::spawn; + /// + /// # async_std::task::block_on(async { + /// let wg = AsyncWaitGroup::new(); + /// wg.add(1); + /// let t_wg = wg.clone(); /// - /// tokio::spawn( async move { - /// // do some time consuming task - /// t_wg.done() - /// }); + /// spawn(async move { + /// // do some time consuming task + /// t_wg.done() + /// }); /// - /// // wait other thread completes - /// wg.block_wait(); - /// } + /// // wait other thread completes + /// wg.block_wait(); + /// # }) /// ``` pub fn block_wait(&self) { WaitGroupFutureInner::new(&self.inner).wait(); @@ -298,162 +299,3 @@ impl EventListenerFuture for WaitGroupFutureInner<'_> { } } } - -#[cfg(test)] -mod tests { - use super::*; - use std::time::Duration; - - #[tokio::test] - async fn test_async_wait_group() { - let wg = AsyncWaitGroup::new(); - let ctr = Arc::new(AtomicUsize::new(0)); - - for _ in 0..5 { - let ctrx = ctr.clone(); - let wg = wg.add(1); - - tokio::spawn(async move { - tokio::time::sleep(Duration::from_millis(50)).await; - ctrx.fetch_add(1, Ordering::Relaxed); - wg.done(); - }); - } - wg.wait().await; - assert_eq!(ctr.load(Ordering::Relaxed), 5); - } - - #[tokio::test] - async fn test_async_wait_group_reuse() { - let wg = AsyncWaitGroup::new(); - let ctr = Arc::new(AtomicUsize::new(0)); - for _ in 0..6 { - let wg = wg.add(1); - let ctrx = ctr.clone(); - tokio::spawn(async move { - tokio::time::sleep(Duration::from_millis(5)).await; - ctrx.fetch_add(1, Ordering::Relaxed); - wg.done(); - }); - } - - wg.wait().await; - assert_eq!(ctr.load(Ordering::Relaxed), 6); - - let worker = wg.add(1); - - let ctrx = ctr.clone(); - tokio::spawn(async move { - tokio::time::sleep(Duration::from_millis(5)).await; - ctrx.fetch_add(1, Ordering::Relaxed); - worker.done(); - }); - - wg.wait().await; - assert_eq!(ctr.load(Ordering::Relaxed), 7); - } - - #[tokio::test] - async fn test_async_wait_group_nested() { - let wg = AsyncWaitGroup::new(); - let ctr = Arc::new(AtomicUsize::new(0)); - for _ in 0..5 { - let worker = wg.add(1); - let ctrx = ctr.clone(); - tokio::spawn(async move { - let nested_worker = worker.add(1); - let ctrxx = ctrx.clone(); - tokio::spawn(async move { - ctrxx.fetch_add(1, Ordering::Relaxed); - nested_worker.done(); - }); - ctrx.fetch_add(1, Ordering::Relaxed); - worker.done(); - }); - } - - wg.wait().await; - assert_eq!(ctr.load(Ordering::Relaxed), 10); - } - - #[tokio::test] - async fn test_async_wait_group_from() { - let wg = AsyncWaitGroup::from(5); - for _ in 0..5 { - let t = wg.clone(); - tokio::spawn(async move { - t.done(); - }); - } - wg.wait().await; - } - - #[tokio::test] - async fn test_sync_wait_group() { - let wg = AsyncWaitGroup::new(); - let ctr = Arc::new(AtomicUsize::new(0)); - - for _ in 0..5 { - let ctrx = ctr.clone(); - let wg = wg.add(1); - std::thread::spawn(move || { - std::thread::sleep(Duration::from_millis(50)); - ctrx.fetch_add(1, Ordering::Relaxed); - - wg.done(); - }); - } - wg.wait().await; - assert_eq!(ctr.load(Ordering::Relaxed), 5); - } - - #[tokio::test] - async fn test_async_waitings() { - let wg = AsyncWaitGroup::new(); - wg.add(1); - wg.add(1); - assert_eq!(wg.waitings(), 2); - } - - #[test] - fn test_async_block_wait() { - let wg = AsyncWaitGroup::new(); - let t_wg = wg.add(1); - std::thread::spawn(move || { - // do some time consuming task - t_wg.done(); - }); - - // wait other thread completes - wg.block_wait(); - - assert_eq!(wg.waitings(), 0); - } - - #[tokio::test] - async fn test_wake_after_updating() { - let wg = AsyncWaitGroup::new(); - for _ in 0..100000 { - let worker = wg.add(1); - tokio::spawn(async move { - tokio::time::sleep(std::time::Duration::from_millis(10)).await; - let mut a = 0; - for _ in 0..1000 { - a += 1; - } - println!("{a}"); - tokio::time::sleep(std::time::Duration::from_millis(10)).await; - worker.done(); - }); - } - wg.wait().await; - } - - #[test] - fn test_clone_and_fmt() { - let awg = AsyncWaitGroup::new(); - let awg1 = awg.clone(); - awg1.add(3); - assert_eq!(format!("{:?}", awg), format!("{:?}", awg1)); - } -} diff --git a/src/lib.rs b/src/lib.rs index 06487b2..b7c5e74 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -22,10 +22,17 @@ #![cfg_attr(docsrs, feature(doc_cfg))] #![cfg_attr(docsrs, allow(unused_attributes))] +/// [`AsyncWaitGroup`] for `futures`. #[cfg(feature = "future")] -mod future; -#[cfg(feature = "future")] -pub use future::*; +#[cfg_attr(docsrs, doc(cfg(feature = "future")))] +pub mod future; +// #[cfg(feature = "future")] +// pub use future::*; + +/// [`AsyncWaitGroup`] for `tokio` runtime. +#[cfg(feature = "tokio")] +#[cfg_attr(docsrs, doc(cfg(feature = "tokio")))] +pub mod tokio; trait Mu { type Guard<'a> diff --git a/src/tokio.rs b/src/tokio.rs new file mode 100644 index 0000000..b48c81a --- /dev/null +++ b/src/tokio.rs @@ -0,0 +1,267 @@ +use super::*; +use ::tokio::sync::{futures::Notified, Notify}; + +use std::{ + future::Future, + pin::Pin, + sync::atomic::{AtomicUsize, Ordering}, + task::{Context, Poll}, +}; + +#[derive(Debug)] +struct AsyncInner { + counter: AtomicUsize, + notify: Notify, +} + +/// An AsyncWaitGroup waits for a collection of threads to finish. +/// The main thread calls [`add`] to set the number of +/// thread to wait for. Then each of the tasks +/// runs and calls Done when finished. At the same time, +/// Wait can be used to block until all tasks have finished. +/// +/// A WaitGroup must not be copied after first use. +/// +/// # Example +/// +/// ```rust +/// use wg::tokio::AsyncWaitGroup; +/// use std::sync::Arc; +/// use std::sync::atomic::{AtomicUsize, Ordering}; +/// use tokio::{spawn, time::{sleep, Duration}}; +/// +/// #[tokio::main] +/// async fn main() { +/// let wg = AsyncWaitGroup::new(); +/// let ctr = Arc::new(AtomicUsize::new(0)); +/// +/// for _ in 0..5 { +/// let ctrx = ctr.clone(); +/// let t_wg = wg.add(1); +/// spawn(async move { +/// // mock some time consuming task +/// sleep(Duration::from_millis(50)).await; +/// ctrx.fetch_add(1, Ordering::Relaxed); +/// +/// // mock task is finished +/// t_wg.done(); +/// }); +/// } +/// +/// wg.wait().await; +/// assert_eq!(ctr.load(Ordering::Relaxed), 5); +/// } +/// ``` +/// +/// [`wait`]: struct.AsyncWaitGroup.html#method.wait +/// [`add`]: struct.AsyncWaitGroup.html#method.add +#[cfg_attr(docsrs, doc(cfg(feature = "future")))] +pub struct AsyncWaitGroup { + inner: Arc, +} + +impl Default for AsyncWaitGroup { + fn default() -> Self { + Self { + inner: Arc::new(AsyncInner { + counter: AtomicUsize::new(0), + notify: Notify::new(), + }), + } + } +} + +impl From for AsyncWaitGroup { + fn from(count: usize) -> Self { + Self { + inner: Arc::new(AsyncInner { + counter: AtomicUsize::new(count), + notify: Notify::new(), + }), + } + } +} + +impl Clone for AsyncWaitGroup { + fn clone(&self) -> Self { + Self { + inner: self.inner.clone(), + } + } +} + +impl std::fmt::Debug for AsyncWaitGroup { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("AsyncWaitGroup") + .field("counter", &self.inner.counter) + .finish() + } +} + +impl AsyncWaitGroup { + /// Creates a new `AsyncWaitGroup` + pub fn new() -> Self { + Self::default() + } + + /// Adds delta to the WaitGroup counter. + /// If the counter becomes zero, all threads blocked on [`wait`] are released. + /// + /// Note that calls with a delta that occur when the counter is zero + /// must happen before a Wait. + /// Typically this means the calls to add should execute before the statement + /// creating the thread or other event to be waited for. + /// If a `AsyncWaitGroup` is reused to [`wait`] for several independent sets of events, + /// new `add` calls must happen after all previous [`wait`] calls have returned. + /// + /// # Example + /// ```rust + /// use wg::tokio::AsyncWaitGroup; + /// + /// #[tokio::main] + /// async fn main() { + /// let wg = AsyncWaitGroup::new(); + /// + /// wg.add(3); + /// (0..3).for_each(|_| { + /// let t_wg = wg.clone(); + /// tokio::spawn(async move { + /// // do some time consuming work + /// t_wg.done(); + /// }); + /// }); + /// + /// wg.wait().await; + /// } + /// ``` + /// + /// [`wait`]: struct.AsyncWaitGroup.html#method.wait + pub fn add(&self, num: usize) -> Self { + self.inner.counter.fetch_add(num, Ordering::AcqRel); + + Self { + inner: self.inner.clone(), + } + } + + /// done decrements the WaitGroup counter by one. + /// + /// # Example + /// + /// ```rust + /// use wg::tokio::AsyncWaitGroup; + /// + /// #[tokio::main] + /// async fn main() { + /// let wg = AsyncWaitGroup::new(); + /// wg.add(1); + /// let t_wg = wg.clone(); + /// tokio::spawn(async move { + /// // do some time consuming task + /// t_wg.done(); + /// }); + /// } + /// ``` + pub fn done(&self) { + if self.inner.counter.fetch_sub(1, Ordering::SeqCst) == 1 { + self.inner.notify.notify_waiters(); + } + } + + /// waitings return how many jobs are waiting. + pub fn waitings(&self) -> usize { + self.inner.counter.load(Ordering::Acquire) + } + + /// wait blocks until the [`AsyncWaitGroup`] counter is zero. + /// + /// # Example + /// + /// ```rust + /// use wg::tokio::AsyncWaitGroup; + /// + /// #[tokio::main] + /// async fn main() { + /// let wg = AsyncWaitGroup::new(); + /// wg.add(1); + /// let t_wg = wg.clone(); + /// + /// tokio::spawn( async move { + /// // do some time consuming task + /// t_wg.done() + /// }); + /// + /// // wait other thread completes + /// wg.wait().await; + /// } + /// ``` + pub fn wait(&self) -> WaitGroupFuture<'_> { + WaitGroupFuture { + inner: self, + notified: self.inner.notify.notified(), + _pin: std::marker::PhantomPinned, + } + } + + /// Wait blocks until the [`AsyncWaitGroup`] counter is zero. This method is + /// intended to be used in a non-async context, + /// e.g. when implementing the [`Drop`] trait. + /// + /// The implementation is like a spin lock, which is not efficient, so use it with caution. + /// + /// # Example + /// + /// ```rust + /// use wg::tokio::AsyncWaitGroup; + /// + /// #[tokio::main(flavor = "multi_thread")] + /// async fn main() { + /// let wg = AsyncWaitGroup::new(); + /// wg.add(1); + /// let t_wg = wg.clone(); + /// + /// tokio::spawn( async move { + /// // do some time consuming task + /// t_wg.done() + /// }); + /// + /// // wait other thread completes + /// wg.block_wait(); + /// } + /// ``` + pub fn block_wait(&self) { + let this = self.clone(); + let (tx, rx) = std::sync::mpsc::channel(); + ::tokio::task::spawn(async move { + this.wait().await; + let _ = tx.send(()); + }); + let _ = rx.recv(); + } +} + +pin_project_lite::pin_project! { + /// A future returned by [`AsyncWaitGroup::wait()`]. + #[derive(Debug)] + #[must_use = "futures do nothing unless you `.await` or poll them"] + #[cfg_attr(docsrs, doc(cfg(feature = "tokio")))] + pub struct WaitGroupFuture<'a> { + inner: &'a AsyncWaitGroup, + #[pin] + notified: Notified<'a>, + #[pin] + _pin: std::marker::PhantomPinned, + } +} + +impl<'a> Future for WaitGroupFuture<'a> { + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + if self.inner.inner.counter.load(Ordering::SeqCst) == 0 { + return Poll::Ready(()); + } + + self.project().notified.poll(cx) + } +} diff --git a/tests/future.rs b/tests/future.rs new file mode 100644 index 0000000..93af4b3 --- /dev/null +++ b/tests/future.rs @@ -0,0 +1,162 @@ +use wg::future::AsyncWaitGroup; + +use std::{ + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, + }, + time::Duration, +}; + +#[async_std::test] +async fn test_async_wait_group() { + let wg = AsyncWaitGroup::new(); + let ctr = Arc::new(AtomicUsize::new(0)); + + for _ in 0..5 { + let ctrx = ctr.clone(); + let wg = wg.add(1); + + async_std::task::spawn(async move { + async_std::task::sleep(Duration::from_millis(50)).await; + ctrx.fetch_add(1, Ordering::Relaxed); + wg.done(); + }); + } + wg.wait().await; + assert_eq!(ctr.load(Ordering::Relaxed), 5); +} + +#[async_std::test] +async fn test_async_wait_group_reuse() { + let wg = AsyncWaitGroup::new(); + let ctr = Arc::new(AtomicUsize::new(0)); + for _ in 0..6 { + let wg = wg.add(1); + let ctrx = ctr.clone(); + async_std::task::spawn(async move { + async_std::task::sleep(Duration::from_millis(5)).await; + ctrx.fetch_add(1, Ordering::Relaxed); + wg.done(); + }); + } + + wg.wait().await; + assert_eq!(ctr.load(Ordering::Relaxed), 6); + + let worker = wg.add(1); + + let ctrx = ctr.clone(); + async_std::task::spawn(async move { + async_std::task::sleep(Duration::from_millis(5)).await; + ctrx.fetch_add(1, Ordering::Relaxed); + worker.done(); + }); + + wg.wait().await; + assert_eq!(ctr.load(Ordering::Relaxed), 7); +} + +#[async_std::test] +async fn test_async_wait_group_nested() { + let wg = AsyncWaitGroup::new(); + let ctr = Arc::new(AtomicUsize::new(0)); + for _ in 0..5 { + let worker = wg.add(1); + let ctrx = ctr.clone(); + async_std::task::spawn(async move { + let nested_worker = worker.add(1); + let ctrxx = ctrx.clone(); + async_std::task::spawn(async move { + ctrxx.fetch_add(1, Ordering::Relaxed); + nested_worker.done(); + }); + ctrx.fetch_add(1, Ordering::Relaxed); + worker.done(); + }); + } + + wg.wait().await; + assert_eq!(ctr.load(Ordering::Relaxed), 10); +} + +#[async_std::test] +async fn test_async_wait_group_from() { + let wg = AsyncWaitGroup::from(5); + for _ in 0..5 { + let t = wg.clone(); + async_std::task::spawn(async move { + t.done(); + }); + } + wg.wait().await; +} + +#[async_std::test] +async fn test_sync_wait_group() { + let wg = AsyncWaitGroup::new(); + let ctr = Arc::new(AtomicUsize::new(0)); + + for _ in 0..5 { + let ctrx = ctr.clone(); + let wg = wg.add(1); + std::thread::spawn(move || { + std::thread::sleep(Duration::from_millis(50)); + ctrx.fetch_add(1, Ordering::Relaxed); + + wg.done(); + }); + } + wg.wait().await; + assert_eq!(ctr.load(Ordering::Relaxed), 5); +} + +#[async_std::test] +async fn test_async_waitings() { + let wg = AsyncWaitGroup::new(); + wg.add(1); + wg.add(1); + assert_eq!(wg.waitings(), 2); +} + +#[test] +fn test_async_block_wait() { + let wg = AsyncWaitGroup::new(); + let t_wg = wg.add(1); + std::thread::spawn(move || { + // do some time consuming task + t_wg.done(); + }); + + // wait other thread completes + wg.block_wait(); + + assert_eq!(wg.waitings(), 0); +} + +#[async_std::test] +async fn test_wake_after_updating() { + let wg = AsyncWaitGroup::new(); + for _ in 0..100000 { + let worker = wg.add(1); + async_std::task::spawn(async move { + async_std::task::sleep(std::time::Duration::from_millis(10)).await; + let mut a = 0; + for _ in 0..1000 { + a += 1; + } + println!("{a}"); + async_std::task::sleep(std::time::Duration::from_millis(10)).await; + worker.done(); + }); + } + wg.wait().await; +} + +#[test] +fn test_clone_and_fmt() { + let awg = AsyncWaitGroup::new(); + let awg1 = awg.clone(); + awg1.add(3); + assert_eq!(format!("{:?}", awg), format!("{:?}", awg1)); +} diff --git a/tests/tokio.rs b/tests/tokio.rs new file mode 100644 index 0000000..79fa576 --- /dev/null +++ b/tests/tokio.rs @@ -0,0 +1,162 @@ +use std::{ + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, + }, + time::Duration, +}; +use wg::tokio::*; + +#[::tokio::test] +async fn test_async_wait_group() { + let wg = AsyncWaitGroup::new(); + let ctr = Arc::new(AtomicUsize::new(0)); + + for _ in 0..5 { + let ctrx = ctr.clone(); + let wg = wg.add(1); + + tokio::spawn(async move { + tokio::time::sleep(Duration::from_millis(50)).await; + ctrx.fetch_add(1, Ordering::Relaxed); + wg.done(); + }); + } + wg.wait().await; + assert_eq!(ctr.load(Ordering::Relaxed), 5); +} + +#[::tokio::test] +async fn test_async_wait_group_reuse() { + let wg = AsyncWaitGroup::new(); + let ctr = Arc::new(AtomicUsize::new(0)); + for _ in 0..6 { + let wg = wg.add(1); + let ctrx = ctr.clone(); + tokio::spawn(async move { + tokio::time::sleep(Duration::from_millis(5)).await; + ctrx.fetch_add(1, Ordering::Relaxed); + wg.done(); + }); + } + + wg.wait().await; + assert_eq!(ctr.load(Ordering::Relaxed), 6); + + let worker = wg.add(1); + + let ctrx = ctr.clone(); + tokio::spawn(async move { + tokio::time::sleep(Duration::from_millis(5)).await; + ctrx.fetch_add(1, Ordering::Relaxed); + worker.done(); + }); + + wg.wait().await; + assert_eq!(ctr.load(Ordering::Relaxed), 7); +} + +#[::tokio::test] +async fn test_async_wait_group_nested() { + let wg = AsyncWaitGroup::new(); + let ctr = Arc::new(AtomicUsize::new(0)); + for _ in 0..5 { + let worker = wg.add(1); + let ctrx = ctr.clone(); + tokio::spawn(async move { + let nested_worker = worker.add(1); + let ctrxx = ctrx.clone(); + tokio::spawn(async move { + ctrxx.fetch_add(1, Ordering::Relaxed); + nested_worker.done(); + }); + ctrx.fetch_add(1, Ordering::Relaxed); + worker.done(); + }); + } + + wg.wait().await; + assert_eq!(ctr.load(Ordering::Relaxed), 10); +} + +#[::tokio::test] +async fn test_async_wait_group_from() { + let wg = AsyncWaitGroup::from(5); + for _ in 0..5 { + let t = wg.clone(); + tokio::spawn(async move { + t.done(); + }); + } + wg.wait().await; +} + +#[::tokio::test] +async fn test_sync_wait_group() { + let wg = AsyncWaitGroup::new(); + let ctr = Arc::new(AtomicUsize::new(0)); + + for _ in 0..5 { + let ctrx = ctr.clone(); + let wg = wg.add(1); + std::thread::spawn(move || { + std::thread::sleep(Duration::from_millis(50)); + ctrx.fetch_add(1, Ordering::Relaxed); + + wg.done(); + }); + } + wg.wait().await; + assert_eq!(ctr.load(Ordering::Relaxed), 5); +} + +#[::tokio::test] +async fn test_async_waitings() { + let wg = AsyncWaitGroup::new(); + wg.add(1); + wg.add(1); + assert_eq!(wg.waitings(), 2); +} + +#[::tokio::test(flavor = "multi_thread")] +async fn test_async_block_wait() { + let wg = AsyncWaitGroup::new(); + let t_wg = wg.add(1); + ::tokio::spawn(async move { + // do some time consuming task + t_wg.done(); + ::tokio::task::yield_now().await; + }); + + // wait other thread completes + wg.block_wait(); + + assert_eq!(wg.waitings(), 0); +} + +#[::tokio::test] +async fn test_wake_after_updating() { + let wg = AsyncWaitGroup::new(); + for _ in 0..100000 { + let worker = wg.add(1); + tokio::spawn(async move { + tokio::time::sleep(std::time::Duration::from_millis(10)).await; + let mut a = 0; + for _ in 0..1000 { + a += 1; + } + println!("{a}"); + tokio::time::sleep(std::time::Duration::from_millis(10)).await; + worker.done(); + }); + } + wg.wait().await; +} + +#[test] +fn test_clone_and_fmt() { + let awg = AsyncWaitGroup::new(); + let awg1 = awg.clone(); + awg1.add(3); + assert_eq!(format!("{:?}", awg), format!("{:?}", awg1)); +}