diff --git a/futures-util/src/io/mod.rs b/futures-util/src/io/mod.rs index 4f474f757d..6a91ba785e 100644 --- a/futures-util/src/io/mod.rs +++ b/futures-util/src/io/mod.rs @@ -91,6 +91,9 @@ pub use self::into_sink::IntoSink; mod lines; pub use self::lines::Lines; +mod pipe; +pub use self::pipe::{pipe, PipeReader, PipeWriter}; + mod read; pub use self::read::Read; diff --git a/futures-util/src/io/pipe.rs b/futures-util/src/io/pipe.rs new file mode 100644 index 0000000000..20a2eb1e4e --- /dev/null +++ b/futures-util/src/io/pipe.rs @@ -0,0 +1,302 @@ +use core::pin::Pin; +use core::ptr::copy_nonoverlapping; +use core::slice; +use core::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; + +use alloc::boxed::Box; +use alloc::sync::Arc; +use futures_core::ready; +use futures_core::task::{Context, Poll, Waker}; +use futures_io::{AsyncBufRead, AsyncRead, AsyncWrite, Error, ErrorKind, Result}; + +use crate::task::AtomicWaker; + +/// Create a unidirectional bounded pipe for data transfer between asynchronous tasks. +/// +/// The internal buffer size is given by `buffer`, which must be non zero. The [`PipeWriter`] +/// returned implements the [`AsyncWrite`] trait, while [`PipeReader`] implements [`AsyncRead`]. +/// +/// # Panics +/// +/// Panics when `buffer` is zero. +#[must_use] +pub fn pipe(buffer: usize) -> (PipeWriter, PipeReader) { + assert!(buffer != 0, "pipe buffer size must be non zero and not usize::MAX"); + // If it is `usize::MAX`, the allocation must fail anyway since Rust forbids allocations larger + // than `isize::MAX as usize`. This counts as OOM thus no need to state explicitly. + let len = buffer.saturating_add(1); + let ptr = Box::into_raw(alloc::vec![0u8; len].into_boxed_slice()); + let inner = Arc::new(Shared { + len, + buffer: ptr.cast(), + write_pos: AtomicUsize::new(0), + read_pos: AtomicUsize::new(0), + writer_waker: AtomicWaker::new(), + reader_waker: AtomicWaker::new(), + closed: AtomicBool::new(false), + }); + (PipeWriter { inner: inner.clone() }, PipeReader { inner }) +} + +// `read_pos..write_pos` (loop around, same below) contains the buffered content. +// `write_pos..(read_pos-1)` is the empty space for further data. +// Note that index `read_pos-1` is left vacant so that `read_pos == write_pos` if and only if +// the buffer is empty. +// +// Invariants, at any time: +// 1. `read_pos` and `buffer[read_pos..write_pos]` is owned by the read-end. +// Read-end can increment `read_pos` in that range to transfer +// a portion of buffer to the write-end. +// 2. `write_pos` and `buffer[writer_pos..(read_pos-1)]` is owned by the write-end. +// Write-end can increment `write_pos` in that range to transfer +// a portion of buffer to the read-end. +// 3. Read-end can only park (returning Pending) when it observed `read_pos == write_pos` after +// registered the waker. +// 4. Write-end can only park when it observed `write_pos == read_pos-1` after +// registered the waker. +#[derive(Debug)] +struct Shared { + len: usize, + buffer: *mut u8, + read_pos: AtomicUsize, + write_pos: AtomicUsize, + reader_waker: AtomicWaker, + writer_waker: AtomicWaker, + closed: AtomicBool, +} + +unsafe impl Send for Shared {} +unsafe impl Sync for Shared {} + +impl Drop for Shared { + fn drop(&mut self) { + unsafe { + drop(Box::from_raw(slice::from_raw_parts_mut(self.buffer, self.len))); + } + } +} + +impl Shared { + fn poll_read_ready(&self, waker: &Waker) -> Poll> { + // Only mutable by us reader. No synchronization for load. + let data_start = self.read_pos.load(Ordering::Relaxed); + // "Acquire" the bytes for read. + let mut data_end = self.write_pos.load(Ordering::Acquire); + // Fast path. + if data_start == data_end { + // Implicit "Acquite" `write_pos` below. + self.reader_waker.register(waker); + // Double check for readiness. + data_end = self.write_pos.load(Ordering::Acquire); + if data_start == data_end { + // Already "acquire"d by `reader_waker`. + if self.closed.load(Ordering::Relaxed) { + return Poll::Ready(Ok((0, 0))); + } + return Poll::Pending; + } + } + Poll::Ready(Ok((data_start, data_end))) + } + + unsafe fn commit_read(&self, new_read_pos: usize) { + // "Release" the bytes just read. + self.read_pos.store(new_read_pos, Ordering::Release); + // Implicit "Release" the `read_pos` change. + self.writer_waker.wake(); + } + + fn poll_write_ready(&self, waker: &Waker) -> Poll> { + // Only mutable by us writer. No synchronization for load. + let write_start = self.write_pos.load(Ordering::Relaxed); + // "Acquire" the bytes for write. + let mut write_end = + self.read_pos.load(Ordering::Acquire).checked_sub(1).unwrap_or(self.len - 1); + if write_start == write_end { + // Implicit "Acquite" `read_pos` below. + self.writer_waker.register(waker); + // Double check for writeness. + write_end = + self.read_pos.load(Ordering::Acquire).checked_sub(1).unwrap_or(self.len - 1); + if write_start == write_end { + // Already "acquire"d by `reader_waker`. + if self.closed.load(Ordering::Relaxed) { + return Poll::Ready(Err(Error::new(ErrorKind::BrokenPipe, "pipe closed"))); + } + return Poll::Pending; + } + } + Poll::Ready(Ok((write_start, write_end))) + } + + unsafe fn commit_write(&self, new_write_pos: usize) { + // "Release" the bytes just written. + self.write_pos.store(new_write_pos, Ordering::Release); + // Implicit "Release" the `write_pos` change. + self.reader_waker.wake(); + } +} + +/// The write end of a bounded pipe. +/// +/// This value is created by the [`pipe`] function. +#[derive(Debug)] +pub struct PipeWriter { + inner: Arc, +} + +impl Drop for PipeWriter { + fn drop(&mut self) { + self.inner.closed.store(true, Ordering::Relaxed); + // "Release" `closed`. + self.inner.reader_waker.wake(); + } +} + +impl AsyncWrite for PipeWriter { + fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { + if buf.is_empty() { + return Poll::Ready(Ok(0)); + } + + let inner = &*self.inner; + + let (write_start, write_end) = ready!(inner.poll_write_ready(cx.waker()))?; + + let written = if write_start <= write_end { + let written = buf.len().min(write_end - write_start); + // SAFETY: `buffer[write_pos..read_pos-1]` is owned by us writer. + unsafe { + copy_nonoverlapping(buf.as_ptr(), inner.buffer.add(write_start), written); + } + written + } else { + let written1 = buf.len().min(inner.len - write_start); + let written2 = (buf.len() - written1).min(write_end); + // SAFETY: `buffer[write_pos..]` and `buffer[..read_pos-1]` are owned by us writer. + unsafe { + copy_nonoverlapping(buf.as_ptr(), inner.buffer.add(write_start), written1); + copy_nonoverlapping(buf.as_ptr().add(written1), inner.buffer, written2); + } + written1 + written2 + }; + + let mut new_write_pos = write_start + written; + if new_write_pos >= inner.len { + new_write_pos -= inner.len; + } + + unsafe { + inner.commit_write(new_write_pos); + } + + Poll::Ready(Ok(written)) + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } +} + +/// The read end of a bounded pipe. +/// +/// This value is created by the [`pipe`] function. +#[derive(Debug)] +pub struct PipeReader { + inner: Arc, +} + +impl Drop for PipeReader { + fn drop(&mut self) { + self.inner.closed.store(true, Ordering::Relaxed); + // "Release" `closed`. + self.inner.writer_waker.wake(); + } +} + +impl AsyncRead for PipeReader { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + if buf.is_empty() { + return Poll::Ready(Ok(0)); + } + + let inner = &*self.inner; + + let (data_start, data_end) = ready!(inner.poll_read_ready(cx.waker()))?; + + let read = if data_start <= data_end { + let read = buf.len().min(data_end - data_start); + // SAFETY: `buffer[read_pos..write_pos]` are owned by us reader. + unsafe { + copy_nonoverlapping(inner.buffer.add(data_start), buf.as_mut_ptr(), read); + } + read + } else { + let read1 = buf.len().min(inner.len - data_start); + let read2 = (buf.len() - read1).min(data_end); + // SAFETY: `buffer[read_pos..]` and `buffer[..write_pos]` are owned by us reader. + unsafe { + copy_nonoverlapping(inner.buffer.add(data_start), buf.as_mut_ptr(), read1); + copy_nonoverlapping(inner.buffer, buf.as_mut_ptr().add(read1), read2); + } + read1 + read2 + }; + + let mut new_read_pos = data_start + read; + if new_read_pos >= inner.len { + new_read_pos -= inner.len; + } + + unsafe { + self.inner.commit_read(new_read_pos); + } + + Poll::Ready(Ok(read)) + } +} + +impl AsyncBufRead for PipeReader { + fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let inner = &*self.inner; + let (data_start, mut data_end) = ready!(inner.poll_read_ready(cx.waker()))?; + if data_end < data_start { + data_end = inner.len; + } + // SAFETY: `buffer[read_pos..]` is owned by us reader. + let data = + unsafe { slice::from_raw_parts(inner.buffer.add(data_start), data_end - data_start) }; + Poll::Ready(Ok(data)) + } + + fn consume(self: Pin<&mut Self>, amt: usize) { + let inner = &*self.inner; + // Only mutable by us reader. No synchronization for load. + let data_start = inner.read_pos.load(Ordering::Relaxed); + // Can only go forward since the last `poll_fill_buf` in the same thread. + // Does not need to be up-to-date. + let data_end = inner.write_pos.load(Ordering::Relaxed); + + let len = if data_start <= data_end { + data_end - data_start + } else { + data_end + inner.len - data_start + }; + assert!(amt <= len, "invalid advance"); + + let mut new_read_pos = data_start + amt; + if new_read_pos >= inner.len { + new_read_pos -= inner.len; + } + unsafe { + inner.commit_read(new_read_pos); + } + } +} diff --git a/futures/tests/io_pipe.rs b/futures/tests/io_pipe.rs new file mode 100644 index 0000000000..9f08edbeb0 --- /dev/null +++ b/futures/tests/io_pipe.rs @@ -0,0 +1,205 @@ +use futures::future::FutureExt; +use futures::task::Poll; +use futures_core::task::Context; +use futures_executor::block_on; +use futures_io::ErrorKind; +use futures_test::future::FutureTestExt; +use futures_test::task::{new_count_waker, panic_context}; +use futures_util::io::{pipe, PipeReader, PipeWriter}; +use futures_util::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt}; +use static_assertions::assert_impl_all; + +trait PollExt { + fn expect_pending(self); + fn expect_ready(self) -> T; +} + +impl PollExt for Poll { + #[track_caller] + fn expect_pending(self) { + assert!(self.is_pending()); + } + + #[track_caller] + fn expect_ready(self) -> T { + match self { + Poll::Ready(v) => v, + Poll::Pending => panic!("should be ready"), + } + } +} + +// They have only `Pin<&mut Self>` methods. `&Self` can do nothing. Thus Sync. +assert_impl_all!(PipeReader: Send, Sync, Unpin); +assert_impl_all!(PipeWriter: Send, Sync, Unpin); + +#[test] +fn small_write_nonblocking() { + let (mut w, mut r) = pipe(8); + let mut cx = panic_context(); + for _ in 0..10 { + let mut buf = [0u8; 10]; + assert_eq!(w.write(b"12345").poll_unpin(&mut cx).expect_ready().unwrap(), 5); + assert_eq!(r.read(&mut buf).poll_unpin(&mut cx).expect_ready().unwrap(), 5); + assert_eq!(&buf[..5], b"12345"); + } +} + +#[test] +fn big_write_nonblocking() { + let (mut w, mut r) = pipe(7); + let mut cx = panic_context(); + for _ in 0..10 { + let mut buf = [0u8; 10]; + assert_eq!(w.write(b"1234567890").poll_unpin(&mut cx).expect_ready().unwrap(), 7); + assert_eq!(r.read(&mut buf).poll_unpin(&mut cx).expect_ready().unwrap(), 7); + assert_eq!(&buf[..7], b"1234567"); + } +} + +#[test] +fn reader_blocked() { + let (mut w, mut r) = pipe(8); + let (waker, cnt) = new_count_waker(); + let mut cx = Context::from_waker(&waker); + + let mut buf = [0u8; 5]; + r.read(&mut buf).poll_unpin(&mut cx).expect_pending(); + assert_eq!(cnt.get(), 0); + assert_eq!(w.write(b"12345").poll_unpin(&mut cx).expect_ready().unwrap(), 5); + assert_eq!(cnt.get(), 1); + assert_eq!(r.read(&mut buf[..3]).poll_unpin(&mut cx).expect_ready().unwrap(), 3); + assert_eq!(&buf[..3], b"123"); + assert_eq!(cnt.get(), 1); + assert_eq!(r.read(&mut buf[..3]).poll_unpin(&mut cx).expect_ready().unwrap(), 2); + assert_eq!(&buf[..2], b"45"); + assert_eq!(cnt.get(), 1); + r.read(&mut buf[..3]).poll_unpin(&mut cx).expect_pending(); + assert_eq!(cnt.get(), 1); +} + +#[test] +fn writer_blocked() { + let (mut w, mut r) = pipe(7); + let (waker, cnt) = new_count_waker(); + let mut cx = Context::from_waker(&waker); + let mut buf = [0u8; 10]; + + assert_eq!(w.write(b"12345").poll_unpin(&mut cx).expect_ready().unwrap(), 5); + assert_eq!(w.write(b"67890").poll_unpin(&mut cx).expect_ready().unwrap(), 2); + assert_eq!(cnt.get(), 0); + w.write(b"xxx").poll_unpin(&mut cx).expect_pending(); + assert_eq!(cnt.get(), 0); + assert_eq!(r.read(&mut buf[..3]).poll_unpin(&mut cx).expect_ready().unwrap(), 3); + assert_eq!(&buf[..3], b"123"); + assert_eq!(cnt.get(), 1); + assert_eq!(w.write(b"abcde").poll_unpin(&mut cx).expect_ready().unwrap(), 3); + assert_eq!(cnt.get(), 1); + w.write(b"xxx").poll_unpin(&mut cx).expect_pending(); + assert_eq!(cnt.get(), 1); + assert_eq!(r.read(&mut buf).poll_unpin(&mut cx).expect_ready().unwrap(), 7); + assert_eq!(&buf[..7], b"4567abc"); + assert_eq!(cnt.get(), 2); +} + +#[test] +fn reader_closed_notify_writer() { + let (mut w, r) = pipe(4); + let (waker, cnt) = new_count_waker(); + let mut cx = Context::from_waker(&waker); + + assert_eq!(cnt.get(), 0); + assert_eq!(w.write(b"1234").poll_unpin(&mut cx).expect_ready().unwrap(), 4); + w.write(b"xxx").poll_unpin(&mut cx).expect_pending(); + assert_eq!(cnt.get(), 0); + drop(r); + assert_eq!(cnt.get(), 1); + + assert_eq!( + w.write(b"xxx").poll_unpin(&mut cx).expect_ready().unwrap_err().kind(), + ErrorKind::BrokenPipe + ); +} + +#[test] +fn writer_closed_notify_reader() { + let (w, mut r) = pipe(4); + let (waker, cnt) = new_count_waker(); + let mut cx = Context::from_waker(&waker); + let mut buf = [0u8; 10]; + + assert_eq!(cnt.get(), 0); + r.read(&mut buf).poll_unpin(&mut cx).expect_pending(); + assert_eq!(cnt.get(), 0); + drop(w); + assert_eq!(cnt.get(), 1); + + assert_eq!(r.read(&mut [0u8; 10]).poll_unpin(&mut cx).expect_ready().unwrap(), 0); +} + +#[test] +fn writer_closed_with_data() { + let (mut w, mut r) = pipe(4); + let mut cx = panic_context(); + let mut buf = [0u8; 10]; + + assert_eq!(w.write(b"1234").poll_unpin(&mut cx).expect_ready().unwrap(), 4); + drop(w); + assert_eq!(r.read(&mut buf).poll_unpin(&mut cx).expect_ready().unwrap(), 4); + assert_eq!(&buf[..4], b"1234"); + assert_eq!(r.read(&mut [0u8; 10]).poll_unpin(&mut cx).expect_ready().unwrap(), 0); +} + +#[test] +fn smoke() { + let (mut w, mut r) = pipe(128); + let data = "hello world".repeat(1024); + + let reader = std::thread::spawn(|| { + block_on(async move { + let mut buf = String::new(); + r.read_to_string(&mut buf).interleave_pending().await.unwrap(); + buf + }) + }); + + let writer = std::thread::spawn({ + let data = data.clone(); + || { + block_on(async move { + w.write_all(data.as_bytes()).interleave_pending().await.unwrap(); + }); + } + }); + + writer.join().unwrap(); + let ret = reader.join().unwrap(); + assert_eq!(ret, data); +} + +#[test] +fn smoke_bufread() { + let (mut w, mut r) = pipe(128); + let data = "hello world\n".repeat(1024); + + let reader = std::thread::spawn(|| { + block_on(async move { + let mut buf = String::new(); + while r.read_line(&mut buf).await.unwrap() != 0 {} + buf + }) + }); + + let writer = std::thread::spawn({ + let data = data.clone(); + || { + block_on(async move { + w.write_all(data.as_bytes()).interleave_pending().await.unwrap(); + }); + } + }); + + writer.join().unwrap(); + let ret = reader.join().unwrap(); + assert_eq!(ret, data); +}