Skip to content

Commit

Permalink
Move visitors to a shared module
Browse files Browse the repository at this point in the history
  • Loading branch information
fjarri committed Aug 6, 2023
1 parent d7cccf4 commit a155af0
Show file tree
Hide file tree
Showing 4 changed files with 229 additions and 202 deletions.
15 changes: 6 additions & 9 deletions serdect/src/array.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//! Serialization primitives for arrays.
// Unfortunately, we currently cannot assert generically that we are serializing
// Unfortunately, we currently cannot tell `serde` in a uniform fashion that we are serializing
// a fixed-size byte array.
// See https://github.com/serde-rs/serde/issues/2120 for the discussion.
// Therefore we have to fall back to the slice methods,
Expand All @@ -13,7 +13,7 @@ use core::marker::PhantomData;

use serde::{Deserialize, Deserializer, Serialize, Serializer};

use crate::slice;
use crate::common::{self, ExactLength, SliceVisitor, StrIntoBufVisitor};

#[cfg(feature = "zeroize")]
use zeroize::Zeroize;
Expand All @@ -25,7 +25,7 @@ where
S: Serializer,
T: AsRef<[u8]>,
{
slice::serialize_hex_lower_or_bin(value, serializer)
common::serialize_hex_lower_or_bin(value, serializer)
}

/// Serialize the given type as upper case hex when using human-readable
Expand All @@ -35,7 +35,7 @@ where
S: Serializer,
T: AsRef<[u8]>,
{
slice::serialize_hex_upper_or_bin(value, serializer)
common::serialize_hex_upper_or_bin(value, serializer)
}

/// Deserialize from hex when using human-readable formats or binary if the
Expand All @@ -46,12 +46,9 @@ where
D: Deserializer<'de>,
{
if deserializer.is_human_readable() {
deserializer.deserialize_str(slice::StrVisitor::<slice::ExactLength>(buffer, PhantomData))
deserializer.deserialize_str(StrIntoBufVisitor::<ExactLength>(buffer, PhantomData))
} else {
deserializer.deserialize_byte_buf(slice::SliceVisitor::<slice::ExactLength>(
buffer,
PhantomData,
))
deserializer.deserialize_byte_buf(SliceVisitor::<ExactLength>(buffer, PhantomData))
}
}

Expand Down
211 changes: 211 additions & 0 deletions serdect/src/common.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
use core::fmt;
use core::marker::PhantomData;

use serde::{
de::{Error, Visitor},
Serialize, Serializer,
};

#[cfg(feature = "alloc")]
use ::alloc::vec::Vec;

pub(crate) fn serialize_hex<S, T, const UPPERCASE: bool>(
value: &T,
serializer: S,
) -> Result<S::Ok, S::Error>
where
S: Serializer,
T: AsRef<[u8]>,
{
#[cfg(feature = "alloc")]
if UPPERCASE {
return base16ct::upper::encode_string(value.as_ref()).serialize(serializer);
} else {
return base16ct::lower::encode_string(value.as_ref()).serialize(serializer);
}
#[cfg(not(feature = "alloc"))]
{
let _ = value;
let _ = serializer;
return Err(S::Error::custom(
"serializer is human readable, which requires the `alloc` crate feature",
));
}
}

pub(crate) fn serialize_hex_lower_or_bin<S, T>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
T: AsRef<[u8]>,
{
if serializer.is_human_readable() {
serialize_hex::<_, _, false>(value, serializer)
} else {
serializer.serialize_bytes(value.as_ref())
}
}

/// Serialize the given type as upper case hex when using human-readable
/// formats or binary if the format is binary.
pub(crate) fn serialize_hex_upper_or_bin<S, T>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
T: AsRef<[u8]>,
{
if serializer.is_human_readable() {
serialize_hex::<_, _, true>(value, serializer)
} else {
serializer.serialize_bytes(value.as_ref())
}
}

pub(crate) trait LengthCheck {
fn length_check(buffer_length: usize, data_length: usize) -> bool;
fn expecting(
formatter: &mut fmt::Formatter<'_>,
data_type: &str,
data_length: usize,
) -> fmt::Result;
}

pub(crate) struct ExactLength;

impl LengthCheck for ExactLength {
fn length_check(buffer_length: usize, data_length: usize) -> bool {
buffer_length == data_length
}
fn expecting(
formatter: &mut fmt::Formatter<'_>,
data_type: &str,
data_length: usize,
) -> fmt::Result {
write!(formatter, "{} of length {}", data_type, data_length)
}
}

pub(crate) struct UpperBound;

impl LengthCheck for UpperBound {
fn length_check(buffer_length: usize, data_length: usize) -> bool {
buffer_length >= data_length
}
fn expecting(
formatter: &mut fmt::Formatter<'_>,
data_type: &str,
data_length: usize,
) -> fmt::Result {
write!(
formatter,
"{} with a maximum length of {}",
data_type, data_length
)
}
}

pub(crate) struct StrIntoBufVisitor<'b, T: LengthCheck>(pub &'b mut [u8], pub PhantomData<T>);

impl<'de, 'b, T: LengthCheck> Visitor<'de> for StrIntoBufVisitor<'b, T> {
type Value = ();

fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
T::expecting(formatter, "a string", self.0.len() * 2)
}

fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: Error,
{
if !T::length_check(self.0.len() * 2, v.len()) {
return Err(Error::invalid_length(v.len(), &self));
}
// TODO: Map `base16ct::Error::InvalidLength` to `Error::invalid_length`.
base16ct::mixed::decode(v, self.0)
.map(|_| ())
.map_err(E::custom)
}
}

#[cfg(feature = "alloc")]
pub(crate) struct StrIntoVecVisitor;

impl<'de> Visitor<'de> for StrIntoVecVisitor {
type Value = Vec<u8>;

fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(formatter, "a string")
}

fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: Error,
{
base16ct::mixed::decode_vec(v).map_err(E::custom)
}
}

pub(crate) struct SliceVisitor<'b, T: LengthCheck>(pub &'b mut [u8], pub PhantomData<T>);

impl<'de, 'b, T: LengthCheck> Visitor<'de> for SliceVisitor<'b, T> {
type Value = ();

fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
T::expecting(formatter, "an array", self.0.len())
}

fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
where
E: Error,
{
// Workaround for
// https://github.com/rust-lang/rfcs/blob/b1de05846d9bc5591d753f611ab8ee84a01fa500/text/2094-nll.md#problem-case-3-conditional-control-flow-across-functions
if T::length_check(self.0.len(), v.len()) {
let buffer = &mut self.0[..v.len()];
buffer.copy_from_slice(v);
return Ok(());
}

Err(E::invalid_length(v.len(), &self))
}

#[cfg(feature = "alloc")]
fn visit_byte_buf<E>(self, mut v: Vec<u8>) -> Result<Self::Value, E>
where
E: Error,
{
// Workaround for
// https://github.com/rust-lang/rfcs/blob/b1de05846d9bc5591d753f611ab8ee84a01fa500/text/2094-nll.md#problem-case-3-conditional-control-flow-across-functions
if T::length_check(self.0.len(), v.len()) {
let buffer = &mut self.0[..v.len()];
buffer.swap_with_slice(&mut v);
return Ok(());
}

Err(E::invalid_length(v.len(), &self))
}
}

#[cfg(feature = "alloc")]
pub(crate) struct VecVisitor;

#[cfg(feature = "alloc")]
impl<'de> Visitor<'de> for VecVisitor {
type Value = Vec<u8>;

fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(formatter, "a bytestring")
}

fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
where
E: Error,
{
Ok(v.into())
}

fn visit_byte_buf<E>(self, v: Vec<u8>) -> Result<Self::Value, E>
where
E: Error,
{
Ok(v)
}
}
30 changes: 1 addition & 29 deletions serdect/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,35 +131,7 @@
extern crate alloc;

pub mod array;
mod common;
pub mod slice;

pub use serde;

use serde::Serializer;

#[cfg(not(feature = "alloc"))]
use serde::ser::Error;

#[cfg(feature = "alloc")]
use serde::Serialize;

fn serialize_hex<S, T, const UPPERCASE: bool>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
T: AsRef<[u8]>,
{
#[cfg(feature = "alloc")]
if UPPERCASE {
return base16ct::upper::encode_string(value.as_ref()).serialize(serializer);
} else {
return base16ct::lower::encode_string(value.as_ref()).serialize(serializer);
}
#[cfg(not(feature = "alloc"))]
{
let _ = value;
let _ = serializer;
return Err(S::Error::custom(
"serializer is human readable, which requires the `alloc` crate feature",
));
}
}
Loading

0 comments on commit a155af0

Please sign in to comment.