From 74b8bb2627962bb55d87715045b3f1201bc7051d Mon Sep 17 00:00:00 2001 From: polazarus Date: Thu, 10 Aug 2023 16:07:16 +0200 Subject: [PATCH] improve bytes serde --- Cargo.toml | 5 +- src/bytes/serde.rs | 129 +++++++++++---------------------------------- 2 files changed, 36 insertions(+), 98 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index a6f152e..8b1516b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,6 +15,7 @@ rust-version = "1.64.0" [features] unstable = [] +serde = ["dep:serde", "dep:serde_bytes"] [[bench]] name = "competition" @@ -25,8 +26,10 @@ arcstr = "1.1.5" fastrand = "1.9.0" flexstr = "0.9.2" imstr = "0.2.0" -serde_test = "1.0.163" +serde_test = "1.0" serde = { version = "1.0.60", features = ["derive"] } +serde_json = "1.0" [dependencies] serde = { version = "1.0.60", optional = true } +serde_bytes = { version = "0.11", optional = true } diff --git a/src/bytes/serde.rs b/src/bytes/serde.rs index 94666c2..4153ebf 100644 --- a/src/bytes/serde.rs +++ b/src/bytes/serde.rs @@ -1,6 +1,5 @@ -use std::marker::PhantomData; +use std::borrow::Cow; -use serde::de::Visitor; use serde::{Deserialize, Serialize}; use super::HipByt; @@ -18,49 +17,6 @@ where } } -#[derive(Clone, Copy, Debug)] -struct BytesVisitor; - -impl<'de> Visitor<'de> for BytesVisitor { - type Value = Vec; - - fn visit_byte_buf(self, v: Vec) -> Result - where - E: serde::de::Error, - { - Ok(v) - } - - fn visit_bytes(self, v: &[u8]) -> Result - where - E: serde::de::Error, - { - Ok(v.into()) - } - - fn visit_borrowed_bytes(self, v: &'de [u8]) -> Result - where - E: serde::de::Error, - { - Ok(v.into()) - } - - fn visit_seq(self, mut seq: A) -> Result - where - A: serde::de::SeqAccess<'de>, - { - let mut v = seq.size_hint().map_or_else(Vec::new, Vec::with_capacity); - while let Some(e) = seq.next_element()? { - v.push(e); - } - Ok(v) - } - - fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { - write!(formatter, "bytes") - } -} - impl<'de, 'borrow, B> Deserialize<'de> for HipByt<'borrow, B> where B: Backend, @@ -69,61 +25,16 @@ where where D: serde::Deserializer<'de>, { - let v = deserializer.deserialize_byte_buf(BytesVisitor)?; + let v: Vec = serde_bytes::deserialize(deserializer)?; Ok(Self::from(v)) } } -#[derive(Clone, Copy, Debug)] -struct BorrowingByteVisitor(PhantomData); - -impl<'de, B> Visitor<'de> for BorrowingByteVisitor -where - B: Backend, -{ - type Value = HipByt<'de, B>; - - fn visit_byte_buf(self, v: Vec) -> Result - where - E: serde::de::Error, - { - Ok(v.into()) - } - - fn visit_bytes(self, v: &[u8]) -> Result - where - E: serde::de::Error, - { - Ok(v.into()) - } - - fn visit_borrowed_bytes(self, v: &'de [u8]) -> Result - where - E: serde::de::Error, - { - Ok(HipByt::borrowed(v)) - } - - fn visit_seq(self, mut seq: A) -> Result - where - A: serde::de::SeqAccess<'de>, - { - let mut v = seq.size_hint().map_or_else(Vec::new, Vec::with_capacity); - while let Some(e) = seq.next_element()? { - v.push(e); - } - Ok(v.into()) - } - - fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { - write!(formatter, "bytes") - } -} - /// Deserializes a `HipByt` as a borrow if possible. /// /// ```rust /// # use serde::Deserialize; +/// # use serde_json; /// use hipstr::bytes::HipByt; /// use hipstr::Local; /// #[derive(Deserialize)] @@ -131,14 +42,22 @@ where /// #[serde(borrow, deserialize_with = "hipstr::bytes::serde::borrowing_deserialize")] /// field: HipByt<'a, Local>, /// } -/// # fn main() {} +/// # fn main() { +/// let s: MyStruct = serde_json::from_str(r#"{"field": "abc"}"#).unwrap(); +/// assert!(s.field.is_borrowed()); +/// # } /// ``` -pub fn borrowing_deserialize<'de, D, B>(deserializer: D) -> Result, D::Error> +/// +/// # Errors +/// +/// Returns a deserializer if either the serialization is incorrect or an unexpected value is encountered. +pub fn borrowing_deserialize<'de: 'a, 'a, D, B>(deserializer: D) -> Result, D::Error> where D: serde::Deserializer<'de>, B: Backend, { - deserializer.deserialize_byte_buf(BorrowingByteVisitor(PhantomData)) + let cow: Cow<'de, [u8]> = serde_bytes::Deserialize::deserialize(deserializer)?; + Ok(HipByt::from(cow)) } #[cfg(test)] @@ -147,6 +66,7 @@ mod tests { assert_de_tokens, assert_de_tokens_error, assert_ser_tokens, assert_tokens, Token, }; + use crate::bytes::serde::borrowing_deserialize; use crate::HipByt; #[test] @@ -177,8 +97,23 @@ mod tests { #[test] fn test_de_error() { assert_de_tokens_error::( - &[Token::Str("")], - "invalid type: string \"\", expected bytes", + &[Token::F32(0.0)], + "invalid type: floating point `0`, expected byte array", ); } + + #[test] + fn test_serde_borrowing() { + use serde::de::Deserialize; + use serde_json::Value; + + use super::super::HipByt; + use crate::Local; + + let v = Value::from("abcdefghijklmnopqrstuvwxyz"); + let h1: HipByt<'_, Local> = borrowing_deserialize(&v).unwrap(); + let h2: HipByt<'_, Local> = Deserialize::deserialize(&v).unwrap(); + assert!(h1.is_borrowed()); + assert!(!h2.is_borrowed()); + } }