Skip to content

Commit

Permalink
feat: support default Read::read_exact
Browse files Browse the repository at this point in the history
  • Loading branch information
ethe committed Oct 9, 2024
1 parent 919b298 commit 39e5d29
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 29 deletions.
79 changes: 50 additions & 29 deletions fusio/src/buf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,14 @@ impl IoBufMut for Vec<u8> {
}

unsafe fn to_buf_mut_nocopy(self) -> BufMut {
BufMut(BufMutInner::Vec(self))
BufMut {
layout: BufMutInner::Vec(self),
start: 0,
}
}

unsafe fn recover_from_buf_mut(buf: BufMut) -> Self {
match buf.0 {
match buf.layout {
BufMutInner::Vec(vec) => vec,
_ => unreachable!(),
}
Expand Down Expand Up @@ -159,14 +162,17 @@ impl IoBufMut for &mut [u8] {
}

unsafe fn to_buf_mut_nocopy(self) -> BufMut {
BufMut(BufMutInner::Slice {
ptr: self.as_mut_ptr(),
len: self.len(),
})
BufMut {
layout: BufMutInner::Slice {
ptr: self.as_mut_ptr(),
len: self.len(),
},
start: 0,
}
}

unsafe fn recover_from_buf_mut(buf: BufMut) -> Self {
match buf.0 {
match buf.layout {
BufMutInner::Slice { ptr, len } => std::slice::from_raw_parts_mut(ptr, len),
_ => unreachable!(),
}
Expand Down Expand Up @@ -265,11 +271,14 @@ impl IoBufMut for bytes::BytesMut {
}

unsafe fn to_buf_mut_nocopy(self) -> BufMut {
BufMut(BufMutInner::BytesMut(self))
BufMut {
layout: BufMutInner::BytesMut(self),
start: 0,
}
}

unsafe fn recover_from_buf_mut(buf: BufMut) -> Self {
match buf.0 {
match buf.layout {
BufMutInner::BytesMut(bytes) => bytes,
_ => unreachable!(),
}
Expand Down Expand Up @@ -339,12 +348,21 @@ impl IoBuf for Buf {
}
}

pub struct BufMut(BufMutInner);
pub struct BufMut {
layout: BufMutInner,
start: usize,
}

impl BufMut {
pub(crate) fn set_start(&mut self, start: usize) {
self.start = start;
}
}

#[cfg(not(feature = "no-send"))]
unsafe impl Send for BufMut {}

enum BufMutInner {
pub(crate) enum BufMutInner {
#[allow(unused)]
Slice {
ptr: *mut u8,
Expand All @@ -357,37 +375,37 @@ enum BufMutInner {

impl IoBuf for BufMut {
fn as_ptr(&self) -> *const u8 {
match &self.0 {
BufMutInner::Slice { ptr, .. } => *ptr,
BufMutInner::Vec(vec) => vec.as_ptr(),
match &self.layout {
BufMutInner::Slice { ptr, .. } => unsafe { (*ptr).add(self.start) },
BufMutInner::Vec(vec) => vec[self.start..].as_ptr(),
#[cfg(feature = "bytes")]
BufMutInner::BytesMut(bytes) => bytes.as_ptr(),
}
}

fn bytes_init(&self) -> usize {
match &self.0 {
BufMutInner::Slice { len, .. } => *len,
BufMutInner::Vec(vec) => vec.len(),
match &self.layout {
BufMutInner::Slice { len, .. } => *len - self.start,
BufMutInner::Vec(vec) => vec.len() - self.start,
#[cfg(feature = "bytes")]
BufMutInner::BytesMut(bytes) => bytes.len(),
BufMutInner::BytesMut(bytes) => bytes.len() - self.start,
}
}

#[cfg(feature = "bytes")]
fn as_bytes(&self) -> bytes::Bytes {
match &self.0 {
match &self.layout {
BufMutInner::Slice { ptr, len } => {
bytes::Bytes::copy_from_slice(unsafe { std::slice::from_raw_parts(*ptr, *len) })
}
BufMutInner::Vec(vec) => bytes::Bytes::copy_from_slice(vec),
BufMutInner::Vec(vec) => bytes::Bytes::copy_from_slice(&vec[self.start..]),
#[cfg(feature = "bytes")]
BufMutInner::BytesMut(bytes) => bytes.clone().freeze(),
}
}

unsafe fn to_buf_nocopy(self) -> Buf {
match self.0 {
match self.layout {
BufMutInner::Slice { ptr, len } => Buf(BufInner::Slice { ptr, len }),
BufMutInner::Vec(vec) => Buf(BufInner::Vec(vec)),
#[cfg(feature = "bytes")]
Expand All @@ -397,22 +415,25 @@ impl IoBuf for BufMut {

unsafe fn recover_from_buf(buf: Buf) -> Self {
match buf.0 {
BufInner::Slice { ptr, len } => BufMut(BufMutInner::Slice {
ptr: ptr as *mut _,
len,
}),
BufInner::Vec(vec) => BufMut(BufMutInner::Vec(vec)),
BufInner::Slice { .. } => unreachable!(),
BufInner::Vec(vec) => BufMut {
layout: BufMutInner::Vec(vec),
start: 0,
},
#[cfg(feature = "bytes")]
BufInner::Bytes(_) => unreachable!(),
#[cfg(feature = "bytes")]
BufInner::BytesMut(bytes) => BufMut(BufMutInner::BytesMut(bytes)),
BufInner::BytesMut(bytes) => BufMut {
layout: BufMutInner::BytesMut(bytes),
start: 0,
},
}
}
}

impl IoBufMut for BufMut {
fn set_init(&mut self, init: usize) {
match &mut self.0 {
match &mut self.layout {
BufMutInner::Slice { .. } => {}
BufMutInner::Vec(vec) => vec.set_init(init),
#[cfg(feature = "bytes")]
Expand All @@ -421,7 +442,7 @@ impl IoBufMut for BufMut {
}

fn as_mut_ptr(&mut self) -> *mut u8 {
match &mut self.0 {
match &mut self.layout {
BufMutInner::Slice { ptr, .. } => *ptr,
BufMutInner::Vec(vec) => vec.as_mut_ptr(),
#[cfg(feature = "bytes")]
Expand Down
56 changes: 56 additions & 0 deletions fusio/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,42 @@ pub trait Read: MaybeSend + MaybeSync {
buf: B,
) -> impl Future<Output = (Result<u64, Error>, B)> + MaybeSend;

fn read_exact<B: IoBufMut>(
&mut self,
mut buf: B,
) -> impl Future<Output = (Result<(), Error>, B)> + MaybeSend {
async move {
let len = buf.bytes_init() as u64;
let mut read = 0;

while read < len {
let mut buf_mut = unsafe { buf.to_buf_mut_nocopy() };
buf_mut.set_start(read as usize);
let (result, buf_mut) = self.read(buf_mut).await;
buf = unsafe { B::recover_from_buf_mut(buf_mut) };

match result {
Ok(0) => {
return (
Err(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
"failed to fill whole buffer",
)
.into()),
buf,
)
}
Ok(n) => {
read += n;
}
Err(Error::Io(ref e)) if e.kind() == std::io::ErrorKind::Interrupted => {}
Err(e) => return (Err(e), buf),
}
}
(Ok(()), buf)
}
}

fn read_to_end(
&mut self,
buf: Vec<u8>,
Expand Down Expand Up @@ -336,6 +372,26 @@ mod tests {
test_local_fs(TokioFs).await.unwrap();
}

#[cfg(feature = "tokio")]
#[tokio::test]
async fn test_read_exact() {
use tempfile::tempfile;
use tokio::fs::File;

let mut file = File::from_std(tempfile().unwrap());
let (result, _) = file.write_all(&b"hello, world"[..]).await;
result.unwrap();
file.seek(0).await.unwrap();
let (result, buf) = file.read_exact(vec![0u8; 5]).await;
result.unwrap();
assert_eq!(buf.as_slice(), b"hello");
let (result, _) = file.read_exact(vec![0u8; 8]).await;
assert!(result.is_err());
if let Error::Io(e) = result.unwrap_err() {
assert_eq!(e.kind(), std::io::ErrorKind::UnexpectedEof);
}
}

#[cfg(feature = "monoio")]
#[monoio::test]
async fn test_monoio() {
Expand Down

0 comments on commit 39e5d29

Please sign in to comment.