Skip to content

Commit

Permalink
Merge pull request #28 from sgued/byte-array
Browse files Browse the repository at this point in the history
Add support for `[u8; N]`
  • Loading branch information
dtolnay authored Dec 27, 2023
2 parents 51cc56b + 95d9f83 commit 5da9904
Show file tree
Hide file tree
Showing 8 changed files with 331 additions and 7 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ edition = "2018"
keywords = ["serde", "serialization", "no_std", "bytes"]
license = "MIT OR Apache-2.0"
repository = "https://github.com/serde-rs/bytes"
rust-version = "1.31"
rust-version = "1.53"

[features]
default = ["std"]
Expand Down
225 changes: 225 additions & 0 deletions src/bytearray.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
use crate::Bytes;
use core::borrow::{Borrow, BorrowMut};
use core::cmp::Ordering;
use core::convert::TryInto;
use core::fmt::{self, Debug};
use core::hash::{Hash, Hasher};
use core::ops::{Deref, DerefMut};

use serde::de::{Deserialize, Deserializer, Error, SeqAccess, Visitor};
use serde::ser::{Serialize, Serializer};

/// Wrapper around `[u8; N]` to serialize and deserialize efficiently.
///
/// ```
/// use std::collections::HashMap;
/// use std::io;
///
/// use serde_bytes::ByteArray;
///
/// fn deserialize_bytearrays() -> bincode::Result<()> {
/// let example_data = [
/// 2, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 116,
/// 119, 111, 1, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 111, 110, 101
/// ];
///
/// let map: HashMap<u32, ByteArray<3>> = bincode::deserialize(&example_data[..])?;
///
/// println!("{:?}", map);
///
/// Ok(())
/// }
/// #
/// # fn main() {
/// # deserialize_bytearrays().unwrap();
/// # }
/// ```
#[derive(Clone, Eq, Ord)]
pub struct ByteArray<const N: usize> {
bytes: [u8; N],
}

impl<const N: usize> ByteArray<N> {
/// Transform an [array](https://doc.rust-lang.org/stable/std/primitive.array.html) to the equivalent `ByteArray`
pub fn new(bytes: [u8; N]) -> Self {
Self { bytes }
}

/// Wrap existing bytes into a `ByteArray`
pub fn from<T: Into<[u8; N]>>(bytes: T) -> Self {
Self {
bytes: bytes.into(),
}
}

/// Unwraps the byte array underlying this `ByteArray`
pub fn into_array(self) -> [u8; N] {
self.bytes
}
}

impl<const N: usize> Debug for ByteArray<N> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
Debug::fmt(&self.bytes, f)
}
}

impl<const N: usize> AsRef<[u8; N]> for ByteArray<N> {
fn as_ref(&self) -> &[u8; N] {
&self.bytes
}
}
impl<const N: usize> AsMut<[u8; N]> for ByteArray<N> {
fn as_mut(&mut self) -> &mut [u8; N] {
&mut self.bytes
}
}

impl<const N: usize> Borrow<[u8; N]> for ByteArray<N> {
fn borrow(&self) -> &[u8; N] {
&self.bytes
}
}
impl<const N: usize> BorrowMut<[u8; N]> for ByteArray<N> {
fn borrow_mut(&mut self) -> &mut [u8; N] {
&mut self.bytes
}
}

impl<const N: usize> Deref for ByteArray<N> {
type Target = [u8; N];

fn deref(&self) -> &Self::Target {
&self.bytes
}
}

impl<const N: usize> DerefMut for ByteArray<N> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.bytes
}
}

impl<const N: usize> Borrow<Bytes> for ByteArray<N> {
fn borrow(&self) -> &Bytes {
Bytes::new(&self.bytes)
}
}

impl<const N: usize> BorrowMut<Bytes> for ByteArray<N> {
fn borrow_mut(&mut self) -> &mut Bytes {
unsafe { &mut *(&mut self.bytes as &mut [u8] as *mut [u8] as *mut Bytes) }
}
}

impl<Rhs, const N: usize> PartialEq<Rhs> for ByteArray<N>
where
Rhs: ?Sized + Borrow<[u8; N]>,
{
fn eq(&self, other: &Rhs) -> bool {
self.as_ref().eq(other.borrow())
}
}

impl<Rhs, const N: usize> PartialOrd<Rhs> for ByteArray<N>
where
Rhs: ?Sized + Borrow<[u8; N]>,
{
fn partial_cmp(&self, other: &Rhs) -> Option<Ordering> {
self.as_ref().partial_cmp(other.borrow())
}
}

impl<const N: usize> Hash for ByteArray<N> {
fn hash<H: Hasher>(&self, state: &mut H) {
self.bytes.hash(state);
}
}

impl<const N: usize> IntoIterator for ByteArray<N> {
type Item = u8;
type IntoIter = <[u8; N] as IntoIterator>::IntoIter;

fn into_iter(self) -> Self::IntoIter {
IntoIterator::into_iter(self.bytes)
}
}

impl<'a, const N: usize> IntoIterator for &'a ByteArray<N> {
type Item = &'a u8;
type IntoIter = <&'a [u8; N] as IntoIterator>::IntoIter;

fn into_iter(self) -> Self::IntoIter {
self.bytes.iter()
}
}

impl<'a, const N: usize> IntoIterator for &'a mut ByteArray<N> {
type Item = &'a mut u8;
type IntoIter = <&'a mut [u8; N] as IntoIterator>::IntoIter;

fn into_iter(self) -> Self::IntoIter {
self.bytes.iter_mut()
}
}

impl<const N: usize> Serialize for ByteArray<N> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_bytes(&self.bytes)
}
}

struct ByteArrayVisitor<const N: usize>;

impl<'de, const N: usize> Visitor<'de> for ByteArrayVisitor<N> {
type Value = ByteArray<N>;

fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
write!(formatter, "a byte array of length {}", N)
}

fn visit_seq<V>(self, mut seq: V) -> Result<ByteArray<N>, V::Error>
where
V: SeqAccess<'de>,
{
let mut bytes = [0; N];

for (idx, byte) in bytes.iter_mut().enumerate() {
*byte = seq
.next_element()?
.ok_or_else(|| V::Error::invalid_length(idx, &self))?;
}

Ok(ByteArray::from(bytes))
}

fn visit_bytes<E>(self, v: &[u8]) -> Result<ByteArray<N>, E>
where
E: Error,
{
Ok(ByteArray {
bytes: v
.try_into()
.map_err(|_| E::invalid_length(v.len(), &self))?,
})
}

fn visit_str<E>(self, v: &str) -> Result<ByteArray<N>, E>
where
E: Error,
{
self.visit_bytes(v.as_bytes())
}
}

impl<'de, const N: usize> Deserialize<'de> for ByteArray<N> {
fn deserialize<D>(deserializer: D) -> Result<ByteArray<N>, D::Error>
where
D: Deserializer<'de>,
{
deserializer.deserialize_bytes(ByteArrayVisitor::<N>)
}
}
22 changes: 21 additions & 1 deletion src/de.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::Bytes;
use crate::{ByteArray, Bytes};
use core::fmt;
use core::marker::PhantomData;
use serde::de::{Error, Visitor};
Expand Down Expand Up @@ -63,6 +63,26 @@ impl<'de: 'a, 'a> Deserialize<'de> for &'a Bytes {
}
}

impl<'de, const N: usize> Deserialize<'de> for [u8; N] {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let arr: ByteArray<N> = serde::Deserialize::deserialize(deserializer)?;
Ok(*arr)
}
}

impl<'de, const N: usize> Deserialize<'de> for ByteArray<N> {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
// Via the serde::Deserialize impl for ByteArray
serde::Deserialize::deserialize(deserializer)
}
}

#[cfg(any(feature = "std", feature = "alloc"))]
impl<'de> Deserialize<'de> for ByteBuf {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
Expand Down
11 changes: 11 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
//!
//! #[serde(with = "serde_bytes")]
//! byte_buf: Vec<u8>,
//!
//! #[serde(with = "serde_bytes")]
//! byte_array: [u8; 314],
//! }
//! ```
Expand All @@ -37,6 +40,7 @@
clippy::needless_doctest_main
)]

mod bytearray;
mod bytes;
mod de;
mod ser;
Expand All @@ -52,6 +56,7 @@ use serde::Deserializer;

use serde::Serializer;

pub use crate::bytearray::ByteArray;
pub use crate::bytes::Bytes;
pub use crate::de::Deserialize;
pub use crate::ser::Serialize;
Expand All @@ -77,6 +82,9 @@ pub use crate::bytebuf::ByteBuf;
///
/// #[serde(with = "serde_bytes")]
/// byte_buf: Vec<u8>,
///
/// #[serde(with = "serde_bytes")]
/// byte_array: [u8; 314],
/// }
/// ```
pub fn serialize<T, S>(bytes: &T, serializer: S) -> Result<S::Ok, S::Error>
Expand All @@ -102,6 +110,9 @@ where
/// struct Packet {
/// #[serde(with = "serde_bytes")]
/// payload: Vec<u8>,
///
/// #[serde(with = "serde_bytes")]
/// byte_array: [u8; 314],
/// }
/// ```
#[cfg(any(feature = "std", feature = "alloc"))]
Expand Down
20 changes: 19 additions & 1 deletion src/ser.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::Bytes;
use crate::{ByteArray, Bytes};
use serde::Serializer;

#[cfg(any(feature = "std", feature = "alloc"))]
Expand Down Expand Up @@ -51,6 +51,24 @@ impl Serialize for Bytes {
}
}

impl<const N: usize> Serialize for [u8; N] {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_bytes(self)
}
}

impl<const N: usize> Serialize for ByteArray<N> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_bytes(&**self)
}
}

#[cfg(any(feature = "std", feature = "alloc"))]
impl Serialize for ByteBuf {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
Expand Down
Loading

0 comments on commit 5da9904

Please sign in to comment.