Skip to content

Commit

Permalink
Added serde impls for Array; gated by serde
Browse files Browse the repository at this point in the history
  • Loading branch information
rozbb committed Nov 1, 2023
1 parent 6c39ebb commit e0f2b7d
Show file tree
Hide file tree
Showing 4 changed files with 206 additions and 2 deletions.
16 changes: 14 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 9 additions & 0 deletions hybrid-array/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,12 @@ rust-version = "1.65"

[dependencies]
typenum = "1.17"
serde = { version = "1", optional = true }

[dev-dependencies]
bincode = "1.3"
serde_json = "1"

[features]
default = []
serde = ["dep:serde"]
135 changes: 135 additions & 0 deletions hybrid-array/src/impl_serde.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
// This file was modified from
// https://github.com/fizyk20/generic-array/blob/0e2a03714b05bb7a737a677f8df77d6360d19c99/src/impl_serde.rs

use crate::{Array, ArraySize};
use core::fmt;
use core::marker::PhantomData;
use serde::de::{self, SeqAccess, Visitor};
use serde::{ser::SerializeTuple, Deserialize, Deserializer, Serialize, Serializer};

impl<T, N: ArraySize> Serialize for Array<T, N>
where
T: Serialize,
{
#[inline]
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut tup = serializer.serialize_tuple(N::USIZE)?;
for el in self {
tup.serialize_element(el)?;
}

tup.end()
}
}

struct GAVisitor<T, N> {
_t: PhantomData<T>,
_n: PhantomData<N>,
}

// to avoid extra computation when testing for extra elements in the sequence
struct Dummy;
impl<'de> Deserialize<'de> for Dummy {
fn deserialize<D>(_deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
Ok(Dummy)
}
}

impl<'de, T, N: ArraySize> Visitor<'de> for GAVisitor<T, N>
where
T: Deserialize<'de>,
{
type Value = Array<T, N>;

fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(formatter, "struct Array<T, U{}>", N::USIZE)
}

fn visit_seq<A>(self, mut seq: A) -> Result<Array<T, N>, A::Error>
where
A: SeqAccess<'de>,
{
// Check the length in advance
match seq.size_hint() {
Some(n) if n != N::USIZE => {
return Err(de::Error::invalid_length(n, &self));
}
_ => {}
}

// Deserialize the array
let arr = Array::try_from_fn(|idx| {
let next_elem_opt = seq.next_element()?;
next_elem_opt.ok_or(de::Error::invalid_length(idx, &self))
});

// If there's a value allegedly remaining, and deserializing it doesn't fail, then that's a
// length mismatch error
if seq.size_hint() != Some(0) && seq.next_element::<Dummy>()?.is_some() {
Err(de::Error::invalid_length(N::USIZE + 1, &self))
} else {
arr
}
}
}

impl<'de, T, N: ArraySize> Deserialize<'de> for Array<T, N>
where
T: Deserialize<'de>,
{
fn deserialize<D>(deserializer: D) -> Result<Array<T, N>, D::Error>
where
D: Deserializer<'de>,
{
let visitor = GAVisitor {
_t: PhantomData,
_n: PhantomData,
};
deserializer.deserialize_tuple(N::USIZE, visitor)
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_serialize() {
let array = Array::<u8, typenum::U2>::default();
let serialized = bincode::serialize(&array);
assert!(serialized.is_ok());
}

#[test]
fn test_deserialize() {
let mut array = Array::<u8, typenum::U2>::default();
array[0] = 1;
array[1] = 2;
let serialized = bincode::serialize(&array).unwrap();
let deserialized = bincode::deserialize::<Array<u8, typenum::U2>>(&serialized);
assert!(deserialized.is_ok());
let array = deserialized.unwrap();
assert_eq!(array[0], 1);
assert_eq!(array[1], 2);
}

#[test]
fn test_serialized_size() {
let array = Array::<u8, typenum::U1>::default();
let size = bincode::serialized_size(&array).unwrap();
assert_eq!(size, 1);
}

#[test]
#[should_panic]
fn test_too_many() {
let serialized = "[1, 2, 3, 4, 5]";
let _ = serde_json::from_str::<Array<u8, typenum::U4>>(serialized).unwrap();
}
}
48 changes: 48 additions & 0 deletions hybrid-array/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ use typenum::{Diff, Sum, Unsigned};

mod impls;

#[cfg(feature = "serde")]
mod impl_serde;

/// Hybrid typenum-based and const generic array type.
///
/// Provides the flexibility of typenum-based expressions while also
Expand All @@ -64,6 +67,14 @@ where
Self(ArrayExt::from_fn(cb))
}

/// Create array where each array element `T` is returned by the `cb` call.
pub fn try_from_fn<F, E>(cb: F) -> Result<Self, E>
where
F: FnMut(usize) -> Result<T, E>,
{
ArrayExt::try_from_fn(cb).map(Self)
}

/// Create array from a slice.
pub fn from_slice(slice: &[T]) -> Result<Self, TryFromSliceError>
where
Expand Down Expand Up @@ -570,6 +581,12 @@ pub trait ArrayExt<T>: Sized {
where
F: FnMut(usize) -> T;

/// Try to create an array using the given callback function for each element. Returns an error
/// if any one of the calls errors
fn try_from_fn<F, E>(cb: F) -> Result<Self, E>
where
F: FnMut(usize) -> Result<T, E>;

/// Create array from a slice, returning [`TryFromSliceError`] if the slice
/// length does not match the array length.
fn from_slice(slice: &[T]) -> Result<Self, TryFromSliceError>
Expand All @@ -591,6 +608,37 @@ impl<T, const N: usize> ArrayExt<T> for [T; N] {
})
}

fn try_from_fn<F, E>(mut cb: F) -> Result<Self, E>
where
F: FnMut(usize) -> Result<T, E>,
{
// TODO: Replace this entire function with array::try_map once it stabilizes
// https://doc.rust-lang.org/std/primitive.array.html#method.try_map

// Make an uninitialized array. We will populate it element-by-element
let mut arr: [MaybeUninit<T>; N] = unsafe { MaybeUninit::uninit().assume_init() };

// Dropping a `MaybeUninit` does nothing, so if there is a panic during this loop,
// we have a memory leak, but there is no memory safety issue.
for (idx, elem) in arr.iter_mut().enumerate() {
// Run the callback. On success, write it to the array. On error, return immediately
match cb(idx) {
Ok(val) => {
elem.write(val);
}
Err(e) => {
return Err(e);
}
}
}

// If we've made it this far, all the elements have been written. Convert the uninitialized
// array to an initialized array
// TODO: Replace this map with MaybeUninit::array_assume_init() once it stabilizes
let arr = arr.map(|elem: MaybeUninit<T>| unsafe { elem.assume_init() });
Ok(arr)
}

fn from_slice(slice: &[T]) -> Result<Self, TryFromSliceError>
where
T: Copy,
Expand Down

0 comments on commit e0f2b7d

Please sign in to comment.