diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..4fffb2f --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +/target +/Cargo.lock diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..fde16e7 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "ruchei" +version = "0.1.0" +edition = "2021" + +[dependencies] +futures-util = { version = "0.3.30", features = ["sink"] } +pin-project = "1" + +[dev-dependencies] +async-std = { version = "1.12.0", features = ["attributes"] } +async-tungstenite = { version = "0.24.0", features = ["async-std-runtime"] } diff --git a/examples/ws-buffered.rs b/examples/ws-buffered.rs new file mode 100644 index 0000000..12830e9 --- /dev/null +++ b/examples/ws-buffered.rs @@ -0,0 +1,18 @@ +use async_std::net::TcpListener; +use futures_util::StreamExt; +use ruchei::{concurrent::ConcurrentExt, echo::EchoExt, fanout_buffered::MulticastBuffered}; + +#[async_std::main] +async fn main() { + let streams = TcpListener::bind("127.0.0.1:8080").await.unwrap(); + streams + .incoming() + .filter_map(|r| async { r.ok() }) + .map(async_tungstenite::accept_async) + .concurrent() + .filter_map(|r| async { r.ok() }) + .multicast_buffered(|_| {}) + .echo() + .await + .unwrap(); +} diff --git a/examples/ws-bufferless.rs b/examples/ws-bufferless.rs new file mode 100644 index 0000000..7e202f8 --- /dev/null +++ b/examples/ws-bufferless.rs @@ -0,0 +1,18 @@ +use async_std::net::TcpListener; +use futures_util::StreamExt; +use ruchei::{concurrent::ConcurrentExt, echo::EchoExt, fanout_bufferless::MulticastBufferless}; + +#[async_std::main] +async fn main() { + let streams = TcpListener::bind("127.0.0.1:8080").await.unwrap(); + streams + .incoming() + .filter_map(|r| async { r.ok() }) + .map(async_tungstenite::accept_async) + .concurrent() + .filter_map(|r| async { r.ok() }) + .multicast_bufferless(|_| {}) + .echo() + .await + .unwrap(); +} diff --git a/examples/ws-client.rs b/examples/ws-client.rs new file mode 100644 index 0000000..66e3980 --- /dev/null +++ b/examples/ws-client.rs @@ -0,0 +1,12 @@ +use ruchei::echo::EchoExt; + +#[async_std::main] +async fn main() { + async_tungstenite::async_std::connect_async("ws://127.0.0.1:8080/") + .await + .unwrap() + .0 + .echo() + .await + .unwrap(); +} diff --git a/src/callback.rs b/src/callback.rs new file mode 100644 index 0000000..36b143b --- /dev/null +++ b/src/callback.rs @@ -0,0 +1,9 @@ +pub trait Callback: Clone { + fn on_close(&self, error: Option); +} + +impl)> Callback for F { + fn on_close(&self, error: Option) { + self(error) + } +} diff --git a/src/concurrent.rs b/src/concurrent.rs new file mode 100644 index 0000000..77b447e --- /dev/null +++ b/src/concurrent.rs @@ -0,0 +1,55 @@ +use std::{ + pin::Pin, + task::{Context, Poll}, +}; + +use futures_util::{ + stream::{Fuse, FuturesUnordered}, + Future, Stream, StreamExt, +}; + +#[pin_project::pin_project] +pub struct Concurrent { + #[pin] + streams: Fuse, + #[pin] + futures: FuturesUnordered, +} + +impl> Stream for Concurrent { + type Item = Fut::Output; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut this = self.project(); + while let Poll::Ready(Some(future)) = this.streams.as_mut().poll_next(cx) { + this.futures.push(future) + } + match this.futures.poll_next(cx) { + Poll::Ready(None) if !this.streams.is_done() => Poll::Pending, + poll => poll, + } + } +} + +impl> From for Concurrent { + fn from(streams: R) -> Self { + Self { + streams: streams.fuse(), + futures: Default::default(), + } + } +} + +pub trait ConcurrentExt: Sized { + type Fut; + + fn concurrent(self) -> Concurrent; +} + +impl> ConcurrentExt for R { + type Fut = Fut; + + fn concurrent(self) -> Concurrent { + self.into() + } +} diff --git a/src/echo.rs b/src/echo.rs new file mode 100644 index 0000000..6fa1f8c --- /dev/null +++ b/src/echo.rs @@ -0,0 +1,77 @@ +use std::{ + collections::VecDeque, + pin::Pin, + task::{Context, Poll}, +}; + +use futures_util::{stream::Fuse, Future, Sink, Stream, StreamExt}; +use pin_project::pin_project; + +#[pin_project] +pub struct Echo { + #[pin] + stream: Fuse, + queue: VecDeque, + item: Option, + started: bool, +} + +impl> + Sink> Future for Echo { + type Output = Result<(), E>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let mut this = self.project(); + while let Poll::Ready(Some(t)) = this.stream.as_mut().poll_next(cx)? { + this.queue.push_back(t); + } + loop { + match this.item.take() { + Some(item) => match this.stream.as_mut().poll_ready(cx)? { + Poll::Ready(()) => { + this.stream.as_mut().start_send(item)?; + *this.started = true; + } + Poll::Pending => { + *this.item = Some(item); + break; + } + }, + None => match this.queue.pop_front() { + Some(item) => *this.item = Some(item), + None => { + break; + } + }, + } + } + if *this.started && this.stream.as_mut().poll_flush(cx)?.is_ready() { + *this.started = false; + } + Poll::Pending + } +} + +impl>> From for Echo { + fn from(stream: S) -> Self { + Self { + stream: stream.fuse(), + queue: Default::default(), + item: None, + started: false, + } + } +} + +pub trait EchoExt: Sized { + type T; + + fn echo(self) -> Echo; +} + +impl>> EchoExt for S { + type T = T; + + fn echo(self) -> Echo { + self.into() + } +} diff --git a/src/fanout_buffered.rs b/src/fanout_buffered.rs new file mode 100644 index 0000000..89d0109 --- /dev/null +++ b/src/fanout_buffered.rs @@ -0,0 +1,330 @@ +use std::{ + convert::Infallible, + pin::Pin, + sync::Arc, + task::{Context, Poll}, +}; + +use futures_util::{ + future::FusedFuture, + lock::{Mutex, OwnedMutexGuard, OwnedMutexLockFuture}, + ready, + stream::{Fuse, FuturesUnordered, SelectAll}, + Future, Sink, Stream, StreamExt, +}; +use pin_project::pin_project; + +use crate::{callback::Callback, owned_close::OwnedClose}; + +#[derive(Clone)] +struct Done(Arc>); + +struct Node(Out, Done, List); + +struct List(Arc>>>); + +impl Drop for List { + fn drop(&mut self) { + loop { + let node = { + let Some(mut guard) = self.0.try_lock() else { + break; + }; + let Some(node) = guard.take() else { + break; + }; + node + }; + *self = node.2; + } + } +} + +#[derive(Default)] +enum State { + #[default] + Flushed, + Readying(Out, Done), + Started(Done), +} + +impl State { + fn take(&mut self) -> Self { + std::mem::take(self) + } +} + +#[pin_project] +struct Unicast { + #[pin] + stream: S, + #[pin] + list: OwnedMutexLockFuture>>, + state: State, + callback: F, +} + +impl> + Sink, F: Callback> + Unicast +{ + fn poll_list(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { + let mut this = self.project(); + if !this.list.is_terminated() { + match this.list.as_mut().poll(cx) { + Poll::Ready(guard) => match guard.as_ref() { + Some(Node(out, done, list)) => { + *this.state = State::Readying(out.clone(), done.clone()); + *this.list = list.0.clone().lock_owned(); + Poll::Ready(()) + } + None => Poll::Pending, + }, + Poll::Pending => Poll::Pending, + } + } else { + Poll::Pending + } + } + + fn poll_send( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + out: Out, + done: Done, + ) -> Poll> { + let mut this = self.project(); + match this.stream.as_mut().poll_ready(cx)? { + Poll::Ready(()) => { + this.stream.start_send(out)?; + *this.state = State::Started(done); + Poll::Ready(Ok(())) + } + Poll::Pending => { + *this.state = State::Readying(out, done); + Poll::Pending + } + } + } + + fn state(self: Pin<&mut Self>) -> &mut State { + self.project().state + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>, done: Done) -> Poll> { + let this = self.project(); + match this.stream.poll_flush(cx)? { + Poll::Ready(()) => Poll::Ready(Ok(())), + Poll::Pending => { + *this.state = State::Started(done); + Poll::Pending + } + } + } + + fn pre_poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + loop { + match self.as_mut().state().take() { + State::Flushed => ready!(self.as_mut().poll_list(cx)), + State::Readying(out, done) => ready!(self.as_mut().poll_send(cx, out, done))?, + State::Started(done) => ready!(self.as_mut().poll_flush(cx, done))?, + } + } + } + + fn poll_inner(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll>> { + let _ = self.as_mut().pre_poll(cx)?; + self.project().stream.poll_next(cx) + } +} + +impl> + Sink, F: Callback> + Stream for Unicast +{ + type Item = In; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.as_mut().poll_inner(cx) { + Poll::Ready(Some(Ok(out))) => Poll::Ready(Some(out)), + Poll::Ready(Some(Err(e))) => { + self.callback.on_close(Some(e)); + Poll::Ready(None) + } + Poll::Ready(None) => { + self.callback.on_close(None); + Poll::Ready(None) + } + Poll::Pending => Poll::Pending, + } + } +} + +#[pin_project] +pub struct Multicast { + #[pin] + streams: Fuse, + #[pin] + select: SelectAll>, + #[pin] + closing: FuturesUnordered>, + #[pin] + done: OwnedMutexLockFuture<()>, + list_guard: OwnedMutexGuard>>, + list_mutex: Arc>>>, + callback: F, +} + +impl< + In, + Out: Clone, + E, + S: Unpin + Stream> + Sink, + F: Callback, + R: Stream, + > Multicast +{ + fn poll_next_infallible(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut this = self.project(); + while let Poll::Ready(Some(stream)) = this.streams.as_mut().poll_next(cx) { + this.select.push(Unicast { + stream, + list: this.list_mutex.clone().lock_owned(), + state: Default::default(), + callback: this.callback.clone(), + }); + } + match this.select.poll_next(cx) { + Poll::Ready(None) if !this.streams.is_done() => Poll::Pending, + poll => poll, + } + } +} + +impl< + In, + Out: Clone, + E, + S: Unpin + Stream> + Sink, + F: Callback, + R: Stream, + > Stream for Multicast +{ + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.poll_next_infallible(cx).map(|o| o.map(Ok)) + } +} + +impl< + In, + Out: Clone, + E, + S: Unpin + Stream> + Sink, + F: Callback, + R: Stream, + > Sink for Multicast +{ + type Error = Infallible; + + fn poll_ready(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn start_send(self: Pin<&mut Self>, item: Out) -> Result<(), Self::Error> { + let mut this = self.project(); + let list_mutex = Arc::new(Mutex::new(None)); + let list = List(list_mutex.clone()); + let done_mutex = Arc::new(Mutex::new(())); + let list_guard = list_mutex.clone().try_lock_owned().unwrap(); + **this.list_guard = Some(Node( + item, + Done(Arc::new(done_mutex.clone().try_lock_owned().unwrap())), + list, + )); + *this.list_guard = list_guard; + *this.done = done_mutex.lock_owned(); + *this.list_mutex = list_mutex; + Ok(()) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let future = self.project().done; + if future.is_terminated() || future.poll(cx).is_ready() { + Poll::Ready(Ok(())) + } else { + Poll::Pending + } + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut this = self.project(); + if !this.select.is_empty() { + for unicast in std::mem::take(this.select.get_mut()) { + this.closing.push(unicast.stream.into()) + } + } + loop { + match this.closing.as_mut().poll_next(cx) { + Poll::Ready(Some(Ok(()))) => this.callback.on_close(None), + Poll::Ready(Some(Err(e))) => this.callback.on_close(Some(e)), + Poll::Ready(None) => break Poll::Ready(Ok(())), + Poll::Pending => break Poll::Pending, + } + } + } +} + +impl< + In, + Out: Clone, + E, + S: Unpin + Stream> + Sink, + F: Callback, + R: Stream, + > Multicast +{ + pub fn new(streams: R, callback: F) -> Self { + let list_mutex = Arc::new(Mutex::new(None)); + let list_guard = list_mutex.clone().try_lock_owned().unwrap(); + Self { + streams: streams.fuse(), + select: Default::default(), + closing: Default::default(), + done: Arc::new(Mutex::new(())).lock_owned(), + list_guard, + list_mutex, + callback, + } + } +} + +pub trait MulticastBuffered: Sized { + type S; + + type E; + + fn multicast_buffered>( + self, + callback: F, + ) -> Multicast; +} + +impl< + In, + Out: Clone, + E, + S: Unpin + Stream> + Sink, + R: Stream, + > MulticastBuffered for R +{ + type S = S; + + type E = E; + + fn multicast_buffered>( + self, + callback: F, + ) -> Multicast { + Multicast::new(self, callback) + } +} diff --git a/src/fanout_bufferless.rs b/src/fanout_bufferless.rs new file mode 100644 index 0000000..5475cae --- /dev/null +++ b/src/fanout_bufferless.rs @@ -0,0 +1,297 @@ +use std::{ + convert::Infallible, + pin::Pin, + task::{Context, Poll, Waker}, +}; + +use futures_util::{ + stream::{Fuse, FuturesUnordered, SelectAll}, + Future, Sink, Stream, StreamExt, +}; +use pin_project::pin_project; + +use crate::{ + callback::Callback, + owned_close::OwnedClose, + wait_all::{Completable, CompleteOne, WaitMany}, +}; + +#[pin_project] +struct Unicast { + #[pin] + stream: S, + waker: Option, + readying: CompleteOne, + flushing: CompleteOne, + ready: bool, + started: Option, + callback: F, +} + +impl Unicast { + fn wake(&mut self) { + if let Some(waker) = self.waker.take() { + waker.wake(); + } + } +} + +impl> + Sink, F: Callback> Stream + for Unicast +{ + type Item = In; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut this = self.project(); + if this.readying.pending() { + match this.stream.as_mut().poll_ready(cx) { + Poll::Ready(Ok(())) => { + *this.ready = true; + this.readying.complete() + } + Poll::Ready(Err(e)) => { + this.callback.on_close(Some(e)); + return Poll::Ready(None); + } + Poll::Pending => {} + } + } + if *this.ready { + if let Some(out) = this.started.take() { + match this.stream.as_mut().start_send(out) { + Ok(()) => { + *this.ready = false; + } + Err(e) => { + this.callback.on_close(Some(e)); + return Poll::Ready(None); + } + } + } + } + if this.flushing.pending() { + match this.stream.as_mut().poll_flush(cx) { + Poll::Ready(Ok(())) => this.flushing.complete(), + Poll::Ready(Err(e)) => { + this.callback.on_close(Some(e)); + return Poll::Ready(None); + } + Poll::Pending => {} + } + } + *this.waker = Some(cx.waker().clone()); + this.stream.poll_next(cx).map(|o| match o { + Some(Ok(item)) => Some(item), + Some(Err(e)) => { + this.callback.on_close(Some(e)); + None + } + None => { + this.callback.on_close(None); + None + } + }) + } +} + +#[pin_project] +pub struct Multicast { + #[pin] + streams: Fuse, + #[pin] + select: SelectAll>, + #[pin] + readying: WaitMany, + #[pin] + flushing: WaitMany, + #[pin] + closing: FuturesUnordered>, + ready: bool, + flushed: bool, + callback: F, +} + +impl< + In, + Out, + E, + S: Unpin + Stream> + Sink, + F: Callback, + R: Stream, + > Multicast +{ + fn poll_next_infallible(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut this = self.project(); + while let Poll::Ready(Some(stream)) = this.streams.as_mut().poll_next(cx) { + this.select.push(Unicast { + stream, + waker: None, + readying: Default::default(), + flushing: Default::default(), + ready: false, + started: None, + callback: this.callback.clone(), + }); + } + match this.select.poll_next(cx) { + Poll::Ready(None) if !this.streams.is_done() => Poll::Pending, + poll => poll, + } + } +} + +impl< + In, + Out, + E, + S: Unpin + Stream> + Sink, + F: Callback, + R: Stream, + > Stream for Multicast +{ + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.poll_next_infallible(cx).map(|o| o.map(Ok)) + } +} + +impl< + In, + Out: Clone, + E, + S: Unpin + Stream> + Sink, + F: Callback, + R: Stream, + > Sink for Multicast +{ + type Error = Infallible; + + fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut this = self.project(); + loop { + match this.readying.as_mut().poll(cx) { + Poll::Ready(()) => { + if *this.ready { + *this.ready = false; + break Poll::Ready(Ok(())); + } else { + *this.ready = true; + let completable = Completable::default(); + for unicast in this.select.iter_mut() { + unicast.readying.completable(completable.clone()); + unicast.wake(); + } + this.readying.completable(completable); + } + } + Poll::Pending => break Poll::Pending, + } + } + } + + fn start_send(self: Pin<&mut Self>, item: Out) -> Result<(), Self::Error> { + let mut this = self.project(); + for unicast in this.select.iter_mut() { + if unicast.ready { + unicast.started = Some(item.clone()); + unicast.wake(); + } + } + Ok(()) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut this = self.project(); + loop { + match this.flushing.as_mut().poll(cx) { + Poll::Ready(()) => { + if *this.flushed { + *this.flushed = false; + break Poll::Ready(Ok(())); + } else { + *this.flushed = true; + let completable = Completable::default(); + for unicast in this.select.iter_mut() { + unicast.flushing.completable(completable.clone()); + unicast.wake(); + } + this.flushing.completable(completable); + } + } + Poll::Pending => break Poll::Pending, + } + } + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut this = self.project(); + if !this.select.is_empty() { + for unicast in std::mem::take(this.select.get_mut()) { + this.closing.push(unicast.stream.into()) + } + } + loop { + match this.closing.as_mut().poll_next(cx) { + Poll::Ready(Some(Ok(()))) => this.callback.on_close(None), + Poll::Ready(Some(Err(e))) => this.callback.on_close(Some(e)), + Poll::Ready(None) => break Poll::Ready(Ok(())), + Poll::Pending => break Poll::Pending, + } + } + } +} + +impl< + In, + Out, + E, + S: Unpin + Stream> + Sink, + F: Callback, + R: Stream, + > Multicast +{ + pub fn new(streams: R, callback: F) -> Self { + Self { + streams: streams.fuse(), + select: Default::default(), + readying: Default::default(), + flushing: Default::default(), + closing: Default::default(), + ready: false, + flushed: false, + callback, + } + } +} + +pub trait MulticastBufferless: Sized { + type S; + + type E; + + fn multicast_bufferless>( + self, + callback: F, + ) -> Multicast; +} + +impl< + In, + Out, + E, + S: Unpin + Stream> + Sink, + R: Stream, + > MulticastBufferless for R +{ + type S = S; + + type E = E; + + fn multicast_bufferless>( + self, + callback: F, + ) -> Multicast { + Multicast::new(self, callback) + } +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..ae55b0a --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,7 @@ +mod callback; +pub mod concurrent; +pub mod echo; +pub mod fanout_buffered; +pub mod fanout_bufferless; +mod owned_close; +mod wait_all; diff --git a/src/owned_close.rs b/src/owned_close.rs new file mode 100644 index 0000000..02e0a76 --- /dev/null +++ b/src/owned_close.rs @@ -0,0 +1,32 @@ +use std::{ + marker::PhantomData, + pin::Pin, + task::{Context, Poll}, +}; + +use futures_util::{Future, Sink}; +use pin_project::pin_project; + +#[pin_project] +pub struct OwnedClose { + #[pin] + sink: S, + _out: PhantomData, +} + +impl> Future for OwnedClose { + type Output = Result<(), E>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + self.project().sink.poll_close(cx) + } +} + +impl From for OwnedClose { + fn from(sink: S) -> Self { + Self { + sink, + _out: PhantomData, + } + } +} diff --git a/src/wait_all.rs b/src/wait_all.rs new file mode 100644 index 0000000..1dad91d --- /dev/null +++ b/src/wait_all.rs @@ -0,0 +1,83 @@ +use std::{ + pin::Pin, + sync::{Arc, Mutex, Weak}, + task::{Context, Poll, Waker}, +}; + +use futures_util::Future; + +#[derive(Debug, Clone, Default)] +pub struct Completable { + waker: Arc>>, +} + +impl Completable { + fn complete(self) { + if let Some(waker) = Arc::into_inner(self.waker) { + if let Ok(mut waker) = waker.try_lock() { + if let Some(waker) = waker.take() { + waker.wake(); + } + } + } + } +} + +#[derive(Debug, Clone, Default)] +pub struct CompleteOne { + completable: Option, +} + +impl CompleteOne { + pub fn pending(&self) -> bool { + self.completable.is_some() + } + + pub fn complete(&mut self) { + if let Some(completable) = self.completable.take() { + completable.complete(); + } + } + + pub fn completable(&mut self, completable: Completable) { + self.completable = Some(completable); + } +} + +impl Drop for CompleteOne { + fn drop(&mut self) { + self.complete(); + } +} + +#[derive(Default)] +pub struct WaitMany { + waker: Weak>>, +} + +impl WaitMany { + pub fn completable(&mut self, completable: Completable) { + self.waker = Arc::downgrade(&completable.waker); + } +} + +impl Future for WaitMany { + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + { + let Some(waker) = self.waker.upgrade() else { + return Poll::Ready(()); + }; + let Ok(mut waker) = waker.try_lock() else { + return Poll::Ready(()); + }; + *waker = Some(cx.waker().clone()); + } + if self.waker.strong_count() == 0 { + Poll::Ready(()) + } else { + Poll::Pending + } + } +}