From 39e5d292fb6a67f8c034df9a0f77755ea2c93f55 Mon Sep 17 00:00:00 2001 From: Gwo Tzu-Hsing Date: Wed, 9 Oct 2024 10:52:07 +0800 Subject: [PATCH] feat: support default `Read::read_exact` --- fusio/src/buf.rs | 79 ++++++++++++++++++++++++++++++------------------ fusio/src/lib.rs | 56 ++++++++++++++++++++++++++++++++++ 2 files changed, 106 insertions(+), 29 deletions(-) diff --git a/fusio/src/buf.rs b/fusio/src/buf.rs index f7f452a..5784f86 100644 --- a/fusio/src/buf.rs +++ b/fusio/src/buf.rs @@ -89,11 +89,14 @@ impl IoBufMut for Vec { } unsafe fn to_buf_mut_nocopy(self) -> BufMut { - BufMut(BufMutInner::Vec(self)) + BufMut { + layout: BufMutInner::Vec(self), + start: 0, + } } unsafe fn recover_from_buf_mut(buf: BufMut) -> Self { - match buf.0 { + match buf.layout { BufMutInner::Vec(vec) => vec, _ => unreachable!(), } @@ -159,14 +162,17 @@ impl IoBufMut for &mut [u8] { } unsafe fn to_buf_mut_nocopy(self) -> BufMut { - BufMut(BufMutInner::Slice { - ptr: self.as_mut_ptr(), - len: self.len(), - }) + BufMut { + layout: BufMutInner::Slice { + ptr: self.as_mut_ptr(), + len: self.len(), + }, + start: 0, + } } unsafe fn recover_from_buf_mut(buf: BufMut) -> Self { - match buf.0 { + match buf.layout { BufMutInner::Slice { ptr, len } => std::slice::from_raw_parts_mut(ptr, len), _ => unreachable!(), } @@ -265,11 +271,14 @@ impl IoBufMut for bytes::BytesMut { } unsafe fn to_buf_mut_nocopy(self) -> BufMut { - BufMut(BufMutInner::BytesMut(self)) + BufMut { + layout: BufMutInner::BytesMut(self), + start: 0, + } } unsafe fn recover_from_buf_mut(buf: BufMut) -> Self { - match buf.0 { + match buf.layout { BufMutInner::BytesMut(bytes) => bytes, _ => unreachable!(), } @@ -339,12 +348,21 @@ impl IoBuf for Buf { } } -pub struct BufMut(BufMutInner); +pub struct BufMut { + layout: BufMutInner, + start: usize, +} + +impl BufMut { + pub(crate) fn set_start(&mut self, start: usize) { + self.start = start; + } +} #[cfg(not(feature = "no-send"))] unsafe impl Send for BufMut {} -enum BufMutInner { +pub(crate) enum BufMutInner { #[allow(unused)] Slice { ptr: *mut u8, @@ -357,37 +375,37 @@ enum BufMutInner { impl IoBuf for BufMut { fn as_ptr(&self) -> *const u8 { - match &self.0 { - BufMutInner::Slice { ptr, .. } => *ptr, - BufMutInner::Vec(vec) => vec.as_ptr(), + match &self.layout { + BufMutInner::Slice { ptr, .. } => unsafe { (*ptr).add(self.start) }, + BufMutInner::Vec(vec) => vec[self.start..].as_ptr(), #[cfg(feature = "bytes")] BufMutInner::BytesMut(bytes) => bytes.as_ptr(), } } fn bytes_init(&self) -> usize { - match &self.0 { - BufMutInner::Slice { len, .. } => *len, - BufMutInner::Vec(vec) => vec.len(), + match &self.layout { + BufMutInner::Slice { len, .. } => *len - self.start, + BufMutInner::Vec(vec) => vec.len() - self.start, #[cfg(feature = "bytes")] - BufMutInner::BytesMut(bytes) => bytes.len(), + BufMutInner::BytesMut(bytes) => bytes.len() - self.start, } } #[cfg(feature = "bytes")] fn as_bytes(&self) -> bytes::Bytes { - match &self.0 { + match &self.layout { BufMutInner::Slice { ptr, len } => { bytes::Bytes::copy_from_slice(unsafe { std::slice::from_raw_parts(*ptr, *len) }) } - BufMutInner::Vec(vec) => bytes::Bytes::copy_from_slice(vec), + BufMutInner::Vec(vec) => bytes::Bytes::copy_from_slice(&vec[self.start..]), #[cfg(feature = "bytes")] BufMutInner::BytesMut(bytes) => bytes.clone().freeze(), } } unsafe fn to_buf_nocopy(self) -> Buf { - match self.0 { + match self.layout { BufMutInner::Slice { ptr, len } => Buf(BufInner::Slice { ptr, len }), BufMutInner::Vec(vec) => Buf(BufInner::Vec(vec)), #[cfg(feature = "bytes")] @@ -397,22 +415,25 @@ impl IoBuf for BufMut { unsafe fn recover_from_buf(buf: Buf) -> Self { match buf.0 { - BufInner::Slice { ptr, len } => BufMut(BufMutInner::Slice { - ptr: ptr as *mut _, - len, - }), - BufInner::Vec(vec) => BufMut(BufMutInner::Vec(vec)), + BufInner::Slice { .. } => unreachable!(), + BufInner::Vec(vec) => BufMut { + layout: BufMutInner::Vec(vec), + start: 0, + }, #[cfg(feature = "bytes")] BufInner::Bytes(_) => unreachable!(), #[cfg(feature = "bytes")] - BufInner::BytesMut(bytes) => BufMut(BufMutInner::BytesMut(bytes)), + BufInner::BytesMut(bytes) => BufMut { + layout: BufMutInner::BytesMut(bytes), + start: 0, + }, } } } impl IoBufMut for BufMut { fn set_init(&mut self, init: usize) { - match &mut self.0 { + match &mut self.layout { BufMutInner::Slice { .. } => {} BufMutInner::Vec(vec) => vec.set_init(init), #[cfg(feature = "bytes")] @@ -421,7 +442,7 @@ impl IoBufMut for BufMut { } fn as_mut_ptr(&mut self) -> *mut u8 { - match &mut self.0 { + match &mut self.layout { BufMutInner::Slice { ptr, .. } => *ptr, BufMutInner::Vec(vec) => vec.as_mut_ptr(), #[cfg(feature = "bytes")] diff --git a/fusio/src/lib.rs b/fusio/src/lib.rs index 1de65cd..105eb9e 100644 --- a/fusio/src/lib.rs +++ b/fusio/src/lib.rs @@ -66,6 +66,42 @@ pub trait Read: MaybeSend + MaybeSync { buf: B, ) -> impl Future, B)> + MaybeSend; + fn read_exact( + &mut self, + mut buf: B, + ) -> impl Future, B)> + MaybeSend { + async move { + let len = buf.bytes_init() as u64; + let mut read = 0; + + while read < len { + let mut buf_mut = unsafe { buf.to_buf_mut_nocopy() }; + buf_mut.set_start(read as usize); + let (result, buf_mut) = self.read(buf_mut).await; + buf = unsafe { B::recover_from_buf_mut(buf_mut) }; + + match result { + Ok(0) => { + return ( + Err(std::io::Error::new( + std::io::ErrorKind::UnexpectedEof, + "failed to fill whole buffer", + ) + .into()), + buf, + ) + } + Ok(n) => { + read += n; + } + Err(Error::Io(ref e)) if e.kind() == std::io::ErrorKind::Interrupted => {} + Err(e) => return (Err(e), buf), + } + } + (Ok(()), buf) + } + } + fn read_to_end( &mut self, buf: Vec, @@ -336,6 +372,26 @@ mod tests { test_local_fs(TokioFs).await.unwrap(); } + #[cfg(feature = "tokio")] + #[tokio::test] + async fn test_read_exact() { + use tempfile::tempfile; + use tokio::fs::File; + + let mut file = File::from_std(tempfile().unwrap()); + let (result, _) = file.write_all(&b"hello, world"[..]).await; + result.unwrap(); + file.seek(0).await.unwrap(); + let (result, buf) = file.read_exact(vec![0u8; 5]).await; + result.unwrap(); + assert_eq!(buf.as_slice(), b"hello"); + let (result, _) = file.read_exact(vec![0u8; 8]).await; + assert!(result.is_err()); + if let Error::Io(e) = result.unwrap_err() { + assert_eq!(e.kind(), std::io::ErrorKind::UnexpectedEof); + } + } + #[cfg(feature = "monoio")] #[monoio::test] async fn test_monoio() {