diff --git a/src/bit_vec.rs b/src/bit_vec.rs index 20d65120..c40a3b1a 100644 --- a/src/bit_vec.rs +++ b/src/bit_vec.rs @@ -73,6 +73,19 @@ impl Decode for BitVec { Ok(result) }) } + + fn skip(input: &mut I) -> Result<(), Error> { + let Compact(bits) = >::decode(input)?; + let len = bitvec::mem::elts::(bits as usize); + + // Attempt to get the fixed size and check for overflow + if let Some(size) = T::encoded_fixed_size().and_then(|size| size.checked_mul(len)) { + input.skip(size) + } else { + // Fallback when there is no fixed size or on overflow + Result::from_iter((0..len).map(|_| T::skip(input))) + } + } } impl DecodeWithMemTracking for BitVec {} @@ -89,6 +102,10 @@ impl Decode for BitBox { fn decode(input: &mut I) -> Result { Ok(BitVec::::decode(input)?.into()) } + + fn skip(input: &mut I) -> Result<(), Error> { + BitVec::::skip(input) + } } impl DecodeWithMemTracking for BitBox {} @@ -145,6 +162,10 @@ mod tests { let elements = bitvec::mem::elts::(v.len()); let compact_len = Compact::compact_len(&(v.len() as u32)); assert_eq!(compact_len + elements, encoded.len(), "{}", v); + + let input = &mut &encoded[..]; + BitVec::::skip(input).unwrap(); + assert_eq!(input.remaining_len().unwrap(), Some(0)); } } @@ -157,6 +178,10 @@ mod tests { let elements = bitvec::mem::elts::(v.len()); let compact_len = Compact::compact_len(&(v.len() as u32)); assert_eq!(compact_len + elements * 2, encoded.len(), "{}", v); + + let input = &mut &encoded[..]; + BitVec::::skip(input).unwrap(); + assert_eq!(input.remaining_len().unwrap(), Some(0)); } } @@ -169,6 +194,10 @@ mod tests { let elements = bitvec::mem::elts::(v.len()); let compact_len = Compact::compact_len(&(v.len() as u32)); assert_eq!(compact_len + elements * 4, encoded.len(), "{}", v); + + let input = &mut &encoded[..]; + BitVec::::skip(input).unwrap(); + assert_eq!(input.remaining_len().unwrap(), Some(0)); } } @@ -181,6 +210,10 @@ mod tests { let elements = bitvec::mem::elts::(v.len()); let compact_len = Compact::compact_len(&(v.len() as u32)); assert_eq!(compact_len + elements * 8, encoded.len(), "{}", v); + + let input = &mut &encoded[..]; + BitVec::::skip(input).unwrap(); + assert_eq!(input.remaining_len().unwrap(), Some(0)); } } @@ -201,6 +234,10 @@ mod tests { let encoded = bb.encode(); let decoded = BitBox::::decode(&mut &encoded[..]).unwrap(); assert_eq!(bb, decoded); + + let input = &mut &encoded[..]; + BitBox::::skip(input).unwrap(); + assert_eq!(input.remaining_len().unwrap(), Some(0)); } #[test] diff --git a/src/codec.rs b/src/codec.rs index 1dce353a..857e9c9a 100644 --- a/src/codec.rs +++ b/src/codec.rs @@ -74,8 +74,25 @@ pub trait Input { Ok(buf[0]) } + /// Skip the exact number of bytes in the input. + /// + /// Note that the default implementation does an actual read and discards the bytes. + /// When possible, an implementation should provide a specialized implementation. + fn skip(&mut self, len: usize) -> Result<(), Error> { + let mut buf = vec![0u8; len.min(MAX_PREALLOCATION)]; + + let mut remains = len; + while remains > MAX_PREALLOCATION { + self.read(&mut buf[..])?; + remains -= buf.len(); + } + + self.read(&mut buf[..remains])?; + Ok(()) + } + /// Descend into nested reference when decoding. - /// This is called when decoding a new refence-based instance, + /// This is called when decoding a new reference-based instance, /// such as `Vec` or `Box`. Currently, all such types are /// allocated on the heap. fn descend_ref(&mut self) -> Result<(), Error> { @@ -124,6 +141,14 @@ impl<'a> Input for &'a [u8] { *self = &self[len..]; Ok(()) } + + fn skip(&mut self, len: usize) -> Result<(), Error> { + if len > self.len() { + return Err("Not enough data to skip".into()); + } + *self = &self[len..]; + Ok(()) + } } #[cfg(feature = "std")] @@ -299,7 +324,7 @@ pub trait Decode: Sized { #[doc(hidden)] const TYPE_INFO: TypeInfo = TypeInfo::Unknown; - /// Attempt to deserialise the value from input. + /// Attempt to deserialize the value from input. fn decode(input: &mut I) -> Result; /// Attempt to deserialize the value from input into a pre-allocated piece of memory. @@ -324,12 +349,17 @@ pub trait Decode: Sized { unsafe { Ok(DecodeFinished::assert_decoding_finished()) } } - /// Attempt to skip the encoded value from input. + /// Attempt to skip the encoded value from input without validating it. /// - /// The default implementation of this function is just calling [`Decode::decode`]. + /// The default implementation of this function is skipping the fixed encoded size + /// if it is known. Otherwise, it is just calling [`Decode::decode`]. /// When possible, an implementation should provide a specialized implementation. fn skip(input: &mut I) -> Result<(), Error> { - Self::decode(input).map(|_| ()) + if let Some(size) = Self::encoded_fixed_size() { + input.skip(size) + } else { + Self::decode(input).map(|_| ()) + } } /// Returns the fixed encoded size of the type. @@ -347,12 +377,12 @@ pub trait Decode: Sized { pub trait Codec: Decode + Encode {} impl Codec for S {} -/// Trait that bound `EncodeLike` along with `Encode`. Usefull for generic being used in function +/// Trait that bound `EncodeLike` along with `Encode`. Useful for generic being used in function /// with `EncodeLike` parameters. pub trait FullEncode: Encode + EncodeLike {} impl FullEncode for S {} -/// Trait that bound `EncodeLike` along with `Codec`. Usefull for generic being used in function +/// Trait that bound `EncodeLike` along with `Codec`. Useful for generic being used in function /// with `EncodeLike` parameters. pub trait FullCodec: Decode + FullEncode {} impl FullCodec for S {} @@ -442,6 +472,15 @@ impl Input for BytesCursor { Ok(()) } + fn skip(&mut self, len: usize) -> Result<(), Error> { + if len > self.bytes.len() - self.position { + return Err("Not enough data to skip".into()); + } + + self.position += len; + Ok(()) + } + fn scale_internal_decode_bytes(&mut self) -> Result { let length = >::decode(self)?.0 as usize; @@ -488,6 +527,10 @@ impl Decode for bytes::Bytes { fn decode(input: &mut I) -> Result { input.scale_internal_decode_bytes() } + + fn skip(input: &mut I) -> Result<(), Error> { + Vec::::skip(input) + } } #[cfg(feature = "bytes")] @@ -635,6 +678,19 @@ where fn decode(input: &mut I) -> Result { Self::decode_wrapped(input) } + + fn skip(input: &mut I) -> Result<(), Error> { + input.descend_ref()?; + + T::skip(input)?; + + input.ascend_ref(); + Ok(()) + } + + fn encoded_fixed_size() -> Option { + T::encoded_fixed_size() + } } /// A macro that matches on a [`TypeInfo`] and expands a given macro per variant. @@ -713,6 +769,17 @@ impl Decode for Result { _ => Err("unexpected first byte decoding Result".into()), } } + + fn skip(input: &mut I) -> Result<(), Error> { + match input + .read_byte() + .map_err(|e| e.chain("Could not result variant byte for `Result`"))? + { + 0 => T::skip(input).map_err(|e| e.chain("Could not skip `Result::Ok(T)`")), + 1 => E::skip(input).map_err(|e| e.chain("Could not skip `Result::Error(E)`")), + _ => Err("unexpected first byte decoding Result".into()), + } + } } impl DecodeWithMemTracking for Result {} @@ -752,6 +819,10 @@ impl Decode for OptionBool { _ => Err("unexpected first byte decoding OptionBool".into()), } } + + fn encoded_fixed_size() -> Option { + Some(1) + } } impl DecodeWithMemTracking for OptionBool {} @@ -790,6 +861,21 @@ impl Decode for Option { _ => Err("unexpected first byte decoding Option".into()), } } + + fn skip(input: &mut I) -> Result<(), Error> { + match input + .read_byte() + .map_err(|e| e.chain("Could not decode variant byte for `Option`"))? + { + 0 => Ok(()), + 1 => T::skip(input).map_err(|e| e.chain("Could not skip `Option::Some(T)`")), + _ => Err("unexpected first byte decoding Option".into()), + } + } + + fn encoded_fixed_size() -> Option { + Some(T::encoded_fixed_size()? + 1) + } } impl DecodeWithMemTracking for Option {} @@ -822,6 +908,10 @@ macro_rules! impl_for_non_zero { Self::new(Decode::decode(input)?) .ok_or_else(|| Error::from("cannot create non-zero number from 0")) } + + fn encoded_fixed_size() -> Option { + Some(mem::size_of::<$name>()) + } } impl DecodeWithMemTracking for $name {} @@ -829,6 +919,19 @@ macro_rules! impl_for_non_zero { } } +impl_for_non_zero! { + NonZeroI8, + NonZeroI16, + NonZeroI32, + NonZeroI64, + NonZeroI128, + NonZeroU8, + NonZeroU16, + NonZeroU32, + NonZeroU64, + NonZeroU128, +} + /// Encode the slice without prepending the len. /// /// This is equivalent to encoding all the element one by one, but it is optimized for some types. @@ -867,19 +970,6 @@ pub(crate) fn encode_slice_no_len(slice: &[T], de } } -impl_for_non_zero! { - NonZeroI8, - NonZeroI16, - NonZeroI32, - NonZeroI64, - NonZeroI128, - NonZeroU8, - NonZeroU16, - NonZeroU32, - NonZeroU64, - NonZeroU128, -} - impl Encode for [T; N] { fn size_hint(&self) -> usize { mem::size_of::() * N @@ -918,7 +1008,7 @@ impl Decode for [T; N] { ) -> Result { let is_primitive = match ::TYPE_INFO { | TypeInfo::U8 | TypeInfo::I8 => true, - | TypeInfo::U16 | + TypeInfo::U16 | TypeInfo::I16 | TypeInfo::U32 | TypeInfo::I32 | @@ -1012,15 +1102,10 @@ impl Decode for [T; N] { } fn skip(input: &mut I) -> Result<(), Error> { - if Self::encoded_fixed_size().is_some() { - // Should skip the bytes, but Input does not support skip. - for _ in 0..N { - T::skip(input)?; - } - } else { - Self::decode(input)?; + match Self::encoded_fixed_size() { + Some(len) => input.skip(len), + None => Result::from_iter((0..N).map(|_| T::skip(input))), } - Ok(()) } fn encoded_fixed_size() -> Option { @@ -1057,6 +1142,14 @@ where fn decode(input: &mut I) -> Result { Ok(Cow::Owned(Decode::decode(input)?)) } + + fn skip(input: &mut I) -> Result<(), Error> { + T::Owned::skip(input) + } + + fn encoded_fixed_size() -> Option { + T::Owned::encoded_fixed_size() + } } impl<'a, T: ToOwned + DecodeWithMemTracking> DecodeWithMemTracking for Cow<'a, T> where @@ -1074,6 +1167,10 @@ impl Decode for PhantomData { fn decode(_input: &mut I) -> Result { Ok(PhantomData) } + + fn encoded_fixed_size() -> Option { + Some(0) + } } impl DecodeWithMemTracking for PhantomData where PhantomData: Decode {} @@ -1082,6 +1179,10 @@ impl Decode for String { fn decode(input: &mut I) -> Result { Self::from_utf8(Vec::decode(input)?).map_err(|_| "Invalid utf8 sequence".into()) } + + fn skip(input: &mut I) -> Result<(), Error> { + Vec::::skip(input) + } } impl DecodeWithMemTracking for String {} @@ -1227,6 +1328,24 @@ impl Decode for Vec { >::decode(input) .and_then(move |Compact(len)| decode_vec_with_len(input, len as usize)) } + + fn skip(input: &mut I) -> Result<(), Error> { + let Compact(len) = >::decode(input)?; + + input.descend_ref()?; + + // Attempt to get the fixed size and check for overflow + if let Some(size) = T::encoded_fixed_size().and_then(|size| size.checked_mul(len as usize)) + { + input.skip(size) + } else { + // Fallback when there is no fixed size or on overflow + Result::from_iter((0..len).map(|_| T::skip(input))) + }?; + + input.ascend_ref(); + Ok(()) + } } impl DecodeWithMemTracking for Vec {} @@ -1276,6 +1395,10 @@ impl Decode for BTreeMap { result }) } + + fn skip(input: &mut I) -> Result<(), Error> { + Vec::<(K, V)>::skip(input) + } } impl DecodeWithMemTracking for BTreeMap where @@ -1298,6 +1421,10 @@ impl Decode for BTreeSet { result }) } + + fn skip(input: &mut I) -> Result<(), Error> { + Vec::::skip(input) + } } impl DecodeWithMemTracking for BTreeSet where BTreeSet: Decode {} @@ -1322,6 +1449,10 @@ impl Decode for LinkedList { result }) } + + fn skip(input: &mut I) -> Result<(), Error> { + Vec::::skip(input) + } } impl DecodeWithMemTracking for LinkedList where LinkedList: Decode {} @@ -1335,6 +1466,10 @@ impl Decode for BinaryHeap { fn decode(input: &mut I) -> Result { Ok(Vec::decode(input)?.into()) } + + fn skip(input: &mut I) -> Result<(), Error> { + Vec::::skip(input) + } } impl DecodeWithMemTracking for BinaryHeap where BinaryHeap: Decode {} @@ -1362,6 +1497,10 @@ impl Decode for VecDeque { fn decode(input: &mut I) -> Result { Ok(>::decode(input)?.into()) } + + fn skip(input: &mut I) -> Result<(), Error> { + Vec::::skip(input) + } } impl DecodeWithMemTracking for VecDeque {} @@ -1384,6 +1523,10 @@ impl Decode for () { fn decode(_: &mut I) -> Result<(), Error> { Ok(()) } + + fn encoded_fixed_size() -> Option { + Some(0) + } } macro_rules! impl_len { @@ -1429,6 +1572,14 @@ macro_rules! tuple_impl { Ok($one) => Ok(($one,)), } } + + fn skip(input: &mut I) -> Result<(), Error> { + $one::skip(input) + } + + fn encoded_fixed_size() -> Option { + $one::encoded_fixed_size() + } } impl<$one: DecodeLength> DecodeLength for ($one,) { @@ -1474,6 +1625,16 @@ macro_rules! tuple_impl { },)+ )) } + + fn skip(input: &mut INPUT) -> Result<(), super::Error> { + $first::skip(input)?; + $($rest::skip(input)?;)+ + Ok(()) + } + + fn encoded_fixed_size() -> Option { + Some( $first::encoded_fixed_size()? $( + $rest::encoded_fixed_size()? )+) + } } impl<$first: EncodeLike<$fextra>, $fextra: Encode, @@ -1572,6 +1733,10 @@ macro_rules! impl_one_byte { fn decode(input: &mut I) -> Result { Ok(input.read_byte()? as $t) } + + fn encoded_fixed_size() -> Option { + Some(1) + } } impl DecodeWithMemTracking for $t {} @@ -1634,6 +1799,10 @@ impl Decode for Duration { Ok(Duration::new(secs, nanos)) } } + + fn encoded_fixed_size() -> Option { + <(u64, u32)>::encoded_fixed_size() + } } impl DecodeWithMemTracking for Duration {} @@ -1662,6 +1831,14 @@ where <(T, T)>::decode(input).map_err(|e| e.chain("Could not decode `Range`"))?; Ok(Range { start, end }) } + + fn skip(input: &mut I) -> Result<(), Error> { + <(T, T)>::skip(input) + } + + fn encoded_fixed_size() -> Option { + <(T, T)>::encoded_fixed_size() + } } impl DecodeWithMemTracking for Range {} @@ -1688,6 +1865,14 @@ where <(T, T)>::decode(input).map_err(|e| e.chain("Could not decode `RangeInclusive`"))?; Ok(RangeInclusive::new(start, end)) } + + fn skip(input: &mut I) -> Result<(), Error> { + <(T, T)>::skip(input) + } + + fn encoded_fixed_size() -> Option { + <(T, T)>::encoded_fixed_size() + } } impl DecodeWithMemTracking for RangeInclusive {} @@ -1967,6 +2152,16 @@ mod tests { assert_eq!(io_reader.read_byte(), Err("io error: UnexpectedEof".into())); } + #[test] + fn io_reader_skip() { + let mut io_reader = IoReader(std::io::Cursor::new(&[1u8, 2, 3, 4][..])); + + io_reader.skip(0).unwrap(); + io_reader.skip(2).unwrap(); + assert_eq!(io_reader.read_byte().unwrap(), 3); + assert_eq!(io_reader.skip(2), Err("io error: UnexpectedEof".into())); + } + #[test] fn shared_references_implement_encode() { Arc::new(10u32).encode(); @@ -2196,4 +2391,401 @@ mod tests { assert_eq!(range_inclusive.encode(), range_inclusive_bytes); assert_eq!(RangeInclusive::decode(&mut &range_inclusive_bytes[..]), Ok(range_inclusive)); } + + #[test] + fn input_skip() { + struct MyInput(Vec); + impl Input for MyInput { + fn remaining_len(&mut self) -> Result, Error> { + Ok(None) + } + fn read(&mut self, into: &mut [u8]) -> Result<(), Error> { + let i = &mut &self.0[..]; + let res = i.read(into); + self.0 = i.to_vec(); + res + } + } + + let mut input = MyInput(vec![1, 2, 3, 4, 5, 6]); + input.skip(2).unwrap(); + assert_eq!(input.read_byte().unwrap(), 3); + input.skip(1).unwrap(); + assert_eq!(input.read_byte().unwrap(), 5); + input.skip(2).unwrap_err(); + + let mut input = MyInput((0..MAX_PREALLOCATION * 2).map(|i| i as u8).collect()); + input.skip(MAX_PREALLOCATION + 1).unwrap(); + assert_eq!(input.read_byte().unwrap(), (MAX_PREALLOCATION + 1) as u8); + input.skip(1).unwrap(); + assert_eq!(input.read_byte().unwrap(), (MAX_PREALLOCATION + 3) as u8); + } + + #[test] + fn u8_slice_skip() { + let mut input = &[1, 2, 3, 4, 5, 6][..]; + input.skip(2).unwrap(); + assert_eq!(input.read_byte().unwrap(), 3); + input.skip(1).unwrap(); + assert_eq!(input.read_byte().unwrap(), 5); + input.skip(2).unwrap_err(); + } + + #[test] + #[cfg(feature = "bytes")] + fn bytes_cursor_skip() { + let mut input = + BytesCursor { bytes: bytes::Bytes::from_static(&[1, 2, 3, 4, 5, 6]), position: 0 }; + + input.skip(2).unwrap(); + assert_eq!(input.read_byte().unwrap(), 3); + input.skip(1).unwrap(); + assert_eq!(input.read_byte().unwrap(), 5); + input.skip(2).unwrap_err(); + } + + #[test] + #[cfg(feature = "bytes")] + fn bytes_skip() { + let mut input = &(bytes::Bytes::from_static(&[1, 2]), 3u8).encode()[..]; + + bytes::Bytes::skip(&mut input).unwrap(); + assert_eq!(u8::decode(&mut input).unwrap(), 3); + } + + #[test] + fn skip_and_fixed_len_for_wrapper_type() { + let mut input = &(1u8, 2u8).encode()[..]; + + Arc::::skip(&mut input).unwrap(); + assert_eq!(u8::decode(&mut input).unwrap(), 2); + + assert_eq!(Arc::::encoded_fixed_size().unwrap(), 1); + } + + #[test] + fn skip_result() { + type R = Result; + let mut input = &(R::Ok(1u8), 2u8).encode()[..]; + + R::skip(&mut input).unwrap(); + assert_eq!(u8::decode(&mut input).unwrap(), 2); + + let mut input = &(R::Err(1u16), 2u8).encode()[..]; + + R::skip(&mut input).unwrap(); + assert_eq!(u8::decode(&mut input).unwrap(), 2); + } + + #[test] + fn skip_and_encoded_len_optionbool() { + assert_eq!( + OptionBool(Some(true)).encoded_size(), + OptionBool::encoded_fixed_size().unwrap() + ); + assert_eq!( + OptionBool(Some(false)).encoded_size(), + OptionBool::encoded_fixed_size().unwrap() + ); + assert_eq!(OptionBool(None).encoded_size(), OptionBool::encoded_fixed_size().unwrap()); + + let mut input = &(OptionBool(None), 2u8).encode()[..]; + + OptionBool::skip(&mut input).unwrap(); + assert_eq!(u8::decode(&mut input).unwrap(), 2); + } + + #[test] + fn skip_and_encoded_len_option() { + assert_eq!(Option::::encoded_fixed_size(), Some(2)); + assert_eq!(Option::>::encoded_fixed_size(), None); + + let mut input = &(Some(1u8), 2u8).encode()[..]; + + Option::::skip(&mut input).unwrap(); + assert_eq!(u8::decode(&mut input).unwrap(), 2); + + let mut input = &(Option::::None, 2u8).encode()[..]; + + Option::::skip(&mut input).unwrap(); + assert_eq!(u8::decode(&mut input).unwrap(), 2); + + let mut input = &(Some(vec![1u8, 2, 3]), 2u8).encode()[..]; + + Option::>::skip(&mut input).unwrap(); + assert_eq!(u8::decode(&mut input).unwrap(), 2); + } + + // Array + #[test] + fn skip_and_encoded_len_array() { + assert_eq!(<[u8; 32]>::encoded_fixed_size(), Some(32)); + assert_eq!(<[u8; 0]>::encoded_fixed_size(), Some(0)); + assert_eq!(<[Vec; 32]>::encoded_fixed_size(), None); + + let mut input = &([1u8; 32], 2u8).encode()[..]; + + <[u8; 32]>::skip(&mut input).unwrap(); + assert_eq!(u8::decode(&mut input).unwrap(), 2); + + let mut input = &([vec![1u8, 2, 3], vec![1, 2]], 2u8).encode()[..]; + + <[Vec; 2]>::skip(&mut input).unwrap(); + assert_eq!(u8::decode(&mut input).unwrap(), 2); + } + + // Cow + #[test] + fn skip_and_encoded_len_cow() { + assert_eq!(Cow::<[u8]>::encoded_fixed_size(), None); + assert_eq!(Cow::::encoded_fixed_size(), None); + + let mut input = &(Cow::<[u8]>::Borrowed(&[1u8, 2, 3]), 2u8).encode()[..]; + + Cow::<[u8]>::skip(&mut input).unwrap(); + assert_eq!(u8::decode(&mut input).unwrap(), 2); + + let mut input = &(Cow::::Borrowed("123"), 2u8).encode()[..]; + + Cow::::skip(&mut input).unwrap(); + assert_eq!(u8::decode(&mut input).unwrap(), 2); + } + + // PhantomData + #[test] + fn skip_and_encoded_len_phantomdata() { + assert_eq!(PhantomData::::encoded_fixed_size(), Some(0)); + + let mut input = &(PhantomData::, 2u8).encode()[..]; + + PhantomData::::skip(&mut input).unwrap(); + assert_eq!(u8::decode(&mut input).unwrap(), 2); + } + + // String + #[test] + fn skip_and_encoded_len_string() { + assert_eq!(String::encoded_fixed_size(), None); + + let mut input = &(String::from("123"), 2u8).encode()[..]; + + String::skip(&mut input).unwrap(); + assert_eq!(u8::decode(&mut input).unwrap(), 2); + } + + // Vec (vec of u8 and vec of vec) + #[test] + fn skip_and_encoded_len_vec() { + assert_eq!(Vec::::encoded_fixed_size(), None); + assert_eq!(Vec::>::encoded_fixed_size(), None); + + let mut input = &(vec![1u8, 2, 3], 2u8).encode()[..]; + + Vec::::skip(&mut input).unwrap(); + assert_eq!(u8::decode(&mut input).unwrap(), 2); + + let mut input = &(vec![vec![1u8, 2, 3], vec![1, 2]], 2u8).encode()[..]; + + Vec::>::skip(&mut input).unwrap(); + assert_eq!(u8::decode(&mut input).unwrap(), 2); + } + + // BTreeMap + #[test] + fn skip_and_encoded_len_btreemap() { + assert_eq!(BTreeMap::::encoded_fixed_size(), None); + assert_eq!(BTreeMap::>::encoded_fixed_size(), None); + + let mut input = &(BTreeMap::::new(), 2u8).encode()[..]; + + BTreeMap::::skip(&mut input).unwrap(); + assert_eq!(u8::decode(&mut input).unwrap(), 2); + + let mut input = &(BTreeMap::>::new(), 2u8).encode()[..]; + + BTreeMap::>::skip(&mut input).unwrap(); + assert_eq!(u8::decode(&mut input).unwrap(), 2); + } + + // BTreeSet + #[test] + fn skip_and_encoded_len_btreeset() { + assert_eq!(BTreeSet::::encoded_fixed_size(), None); + assert_eq!(BTreeSet::>::encoded_fixed_size(), None); + + let mut input = &(BTreeSet::::new(), 2u8).encode()[..]; + + BTreeSet::::skip(&mut input).unwrap(); + assert_eq!(u8::decode(&mut input).unwrap(), 2); + + let mut input = &(BTreeSet::>::new(), 2u8).encode()[..]; + + BTreeSet::>::skip(&mut input).unwrap(); + assert_eq!(u8::decode(&mut input).unwrap(), 2); + } + + // BinaryHeap + #[test] + fn skip_and_encoded_len_binaryheap() { + assert_eq!(BinaryHeap::::encoded_fixed_size(), None); + assert_eq!(BinaryHeap::>::encoded_fixed_size(), None); + + let mut input = &(BinaryHeap::::new(), 2u8).encode()[..]; + + BinaryHeap::::skip(&mut input).unwrap(); + assert_eq!(u8::decode(&mut input).unwrap(), 2); + + let mut input = &(BinaryHeap::>::new(), 2u8).encode()[..]; + + BinaryHeap::>::skip(&mut input).unwrap(); + assert_eq!(u8::decode(&mut input).unwrap(), 2); + } + + // VecDeque + #[test] + fn skip_and_encoded_len_vecdeque() { + assert_eq!(VecDeque::::encoded_fixed_size(), None); + assert_eq!(VecDeque::>::encoded_fixed_size(), None); + + let mut input = &(VecDeque::::new(), 2u8).encode()[..]; + + VecDeque::::skip(&mut input).unwrap(); + assert_eq!(u8::decode(&mut input).unwrap(), 2); + + let mut input = &(VecDeque::>::new(), 2u8).encode()[..]; + + VecDeque::>::skip(&mut input).unwrap(); + assert_eq!(u8::decode(&mut input).unwrap(), 2); + } + + // () + #[test] + fn skip_and_encoded_len_unit() { + assert_eq!(<()>::encoded_fixed_size(), Some(0)); + + let mut input = &((), 2u8).encode()[..]; + + <()>::skip(&mut input).unwrap(); + assert_eq!(u8::decode(&mut input).unwrap(), 2); + } + + // tuple + #[test] + fn skip_and_encoded_len_tuple() { + assert_eq!(<(u8, u8)>::encoded_fixed_size(), Some(2)); + assert_eq!(<(u8, Vec)>::encoded_fixed_size(), None); + + let mut input = &((1u8, 101u8), 2u8).encode()[..]; + + <(u8, u8)>::skip(&mut input).unwrap(); + assert_eq!(u8::decode(&mut input).unwrap(), 2); + + let mut input = &((1u8, vec![1u8, 2, 3]), 2u8).encode()[..]; + + <(u8, Vec)>::skip(&mut input).unwrap(); + assert_eq!(u8::decode(&mut input).unwrap(), 2); + } + + // Duration + #[test] + fn skip_and_encoded_len_duration() { + assert_eq!(Duration::encoded_fixed_size(), Some(12)); + + let mut input = &(Duration::new(1, 2), 2u8).encode()[..]; + + Duration::skip(&mut input).unwrap(); + assert_eq!(u8::decode(&mut input).unwrap(), 2); + } + + // Range + #[test] + fn skip_and_encoded_len_range() { + assert_eq!(Range::::encoded_fixed_size(), Some(2)); + + let mut input = &(Range { start: 1u8, end: 2 }, 2u8).encode()[..]; + + as Decode>::skip(&mut input).unwrap(); + assert_eq!(u8::decode(&mut input).unwrap(), 2); + } + + // Range inclusive + #[test] + fn skip_and_encoded_len_range_inclusive() { + assert_eq!(RangeInclusive::::encoded_fixed_size(), Some(2)); + + let mut input = &(RangeInclusive::new(1u8, 2), 2u8).encode()[..]; + + as Decode>::skip(&mut input).unwrap(); + assert_eq!(u8::decode(&mut input).unwrap(), 2); + } + + #[test] + fn descend_ascend_when_skipping() { + struct TestingDepthTrackingInput<'a, I> { + input: &'a mut I, + depth: u32, + max_depth: u32, + } + + impl<'a, I: Input> Input for TestingDepthTrackingInput<'a, I> { + fn remaining_len(&mut self) -> Result, Error> { + self.input.remaining_len() + } + + fn read(&mut self, into: &mut [u8]) -> Result<(), Error> { + self.input.read(into) + } + + fn read_byte(&mut self) -> Result { + self.input.read_byte() + } + + fn skip(&mut self, len: usize) -> Result<(), Error> { + self.input.skip(len) + } + + fn descend_ref(&mut self) -> Result<(), Error> { + self.input.descend_ref()?; + self.depth += 1; + if self.depth > self.max_depth { + Err("Depth limit reached".into()) + } else { + Ok(()) + } + } + + fn ascend_ref(&mut self) { + self.input.ascend_ref(); + self.depth -= 1; + } + + fn on_before_alloc_mem(&mut self, size: usize) -> Result<(), Error> { + self.input.on_before_alloc_mem(size) + } + } + + // Wrapper type + let input = (MyWrapper(Compact(3u32)), 2u8).encode(); + + let mut input_limited_1 = + TestingDepthTrackingInput { input: &mut &input[..], depth: 0, max_depth: 0 }; + MyWrapper::skip(&mut input_limited_1).unwrap_err(); + + let mut input_limited_2 = + TestingDepthTrackingInput { input: &mut &input[..], depth: 0, max_depth: 1 }; + MyWrapper::skip(&mut input_limited_2).unwrap(); + assert_eq!(u8::decode(&mut input_limited_2).unwrap(), 2); + + // Vec type + let input = (vec![1u8, 2, 3, 4, 5, 6], 2u8).encode(); + + let mut input_limited_1 = + TestingDepthTrackingInput { input: &mut &input[..], depth: 0, max_depth: 0 }; + Vec::::skip(&mut input_limited_1).unwrap_err(); + + let mut input_limited_2 = + TestingDepthTrackingInput { input: &mut &input[..], depth: 0, max_depth: 1 }; + Vec::::skip(&mut input_limited_2).unwrap(); + assert_eq!(u8::decode(&mut input_limited_2).unwrap(), 2); + } } diff --git a/src/compact.rs b/src/compact.rs index ffcd2ae4..4281ace2 100644 --- a/src/compact.rs +++ b/src/compact.rs @@ -75,6 +75,14 @@ impl<'a, T: 'a + Input> Input for PrefixInput<'a, T> { _ => self.input.read(buffer), } } + + fn skip(&mut self, len: usize) -> Result<(), Error> { + if len == 0 { + return Ok(()); + } + + self.input.skip(len - self.prefix.take().is_some() as usize) + } } /// Something that can return the compact encoded length for a given value. @@ -173,6 +181,10 @@ where let as_ = Compact::::decode(input)?; Ok(Compact(::decode_from(as_.0)?)) } + + fn skip(input: &mut I) -> Result<(), Error> { + Compact::::skip(input) + } } impl DecodeWithMemTracking for Compact @@ -406,7 +418,7 @@ impl<'a> Encode for CompactRef<'a, u64> { let bytes_needed = 8 - self.0.leading_zeros() / 8; assert!( bytes_needed >= 4, - "Previous match arm matches anyting less than 2^30; qed" + "Previous match arm matches anything less than 2^30; qed" ); dest.push_byte(0b11 + ((bytes_needed - 4) << 2) as u8); let mut v = *self.0; @@ -487,6 +499,10 @@ impl Decode for Compact<()> { fn decode(_input: &mut I) -> Result { Ok(Compact(())) } + + fn encoded_fixed_size() -> Option { + Some(0) + } } impl DecodeWithMemTracking for Compact<()> {} @@ -497,6 +513,16 @@ const U32_OUT_OF_RANGE: &str = "out of range decoding Compact"; const U64_OUT_OF_RANGE: &str = "out of range decoding Compact"; const U128_OUT_OF_RANGE: &str = "out of range decoding Compact"; +fn skip_compact(input: &mut I) -> Result<(), Error> { + let prefix = input.read_byte()?; + match prefix % 4 { + 1 => input.skip(1), + 2 => input.skip(3), + 3 => input.skip(((prefix >> 2) + 4) as usize), + _ => Ok(()), + } +} + impl Decode for Compact { fn decode(input: &mut I) -> Result { let prefix = input.read_byte()?; @@ -513,6 +539,10 @@ impl Decode for Compact { _ => return Err("unexpected prefix decoding Compact".into()), })) } + + fn skip(input: &mut I) -> Result<(), Error> { + skip_compact(input) + } } impl DecodeWithMemTracking for Compact {} @@ -541,6 +571,10 @@ impl Decode for Compact { _ => return Err("unexpected prefix decoding Compact".into()), })) } + + fn skip(input: &mut I) -> Result<(), Error> { + skip_compact(input) + } } impl DecodeWithMemTracking for Compact {} @@ -583,6 +617,10 @@ impl Decode for Compact { _ => unreachable!(), })) } + + fn skip(input: &mut I) -> Result<(), Error> { + skip_compact(input) + } } impl DecodeWithMemTracking for Compact {} @@ -641,6 +679,10 @@ impl Decode for Compact { _ => unreachable!(), })) } + + fn skip(input: &mut I) -> Result<(), Error> { + skip_compact(input) + } } impl DecodeWithMemTracking for Compact {} @@ -707,6 +749,10 @@ impl Decode for Compact { _ => unreachable!(), })) } + + fn skip(input: &mut I) -> Result<(), Error> { + skip_compact(input) + } } impl DecodeWithMemTracking for Compact {} @@ -725,6 +771,17 @@ mod tests { assert_eq!(input.read_byte(), Ok(1)); } + #[test] + fn prefix_input_skip() { + let mut input = PrefixInput { prefix: Some(1), input: &mut &vec![2, 3, 4][..] }; + assert_eq!(input.remaining_len(), Ok(Some(4))); + input.skip(0).unwrap(); + assert_eq!(input.remaining_len(), Ok(Some(4))); + input.skip(2).unwrap(); + assert_eq!(input.remaining_len(), Ok(Some(2))); + assert_eq!(input.read_byte(), Ok(3)); + } + #[test] fn compact_128_encoding_works() { let tests = [ @@ -1063,25 +1120,54 @@ mod tests { } macro_rules! quick_check_roundtrip { - ( $( $ty:ty : $test:ident ),* ) => { + ( $( $ty:ty : $test1:ident, $test2:ident ),* ) => { $( quickcheck::quickcheck! { - fn $test(v: $ty) -> bool { + fn $test1(v: $ty) -> bool { let encoded = Compact(v).encode(); let deencoded = >::decode(&mut &encoded[..]).unwrap().0; v == deencoded } } + + quickcheck::quickcheck! { + fn $test2(v: $ty) -> bool { + let mut encoded = &(Compact(v), 23u8).encode()[..]; + >::skip(&mut encoded).unwrap(); + let deencoded = u8::decode(&mut encoded).unwrap(); + + 23u8 == deencoded + } + } )* } } quick_check_roundtrip! { - u8: u8_roundtrip, - u16: u16_roundtrip, - u32 : u32_roundtrip, - u64 : u64_roundtrip, - u128 : u128_roundtrip + u8: u8_roundtrip, u8_skip_roundtrip, + u16: u16_roundtrip, u16_skip_roundtrip, + u32 : u32_roundtrip, u32_skip_roundtrip, + u64 : u64_roundtrip, u64_skip_roundtrip, + u128 : u128_roundtrip, u128_skip_roundtrip + } + + #[test] + fn skip_prefix_input() { + let mut input = PrefixInput { prefix: Some(1), input: &mut &vec![2, 3, 4][..] }; + assert_eq!(input.remaining_len(), Ok(Some(4))); + input.skip(0).unwrap(); + assert_eq!(input.remaining_len(), Ok(Some(4))); + input.skip(2).unwrap(); + assert_eq!(input.remaining_len(), Ok(Some(2))); + assert_eq!(input.read_byte(), Ok(3)); + + let mut input = PrefixInput { prefix: None, input: &mut &vec![2, 3, 4][..] }; + assert_eq!(input.remaining_len(), Ok(Some(3))); + input.skip(0).unwrap(); + assert_eq!(input.remaining_len(), Ok(Some(3))); + input.skip(2).unwrap(); + assert_eq!(input.remaining_len(), Ok(Some(1))); + assert_eq!(input.read_byte(), Ok(4)); } } diff --git a/src/depth_limit.rs b/src/depth_limit.rs index 8b3a7edf..28d0193c 100644 --- a/src/depth_limit.rs +++ b/src/depth_limit.rs @@ -50,6 +50,10 @@ impl<'a, I: Input> Input for DepthTrackingInput<'a, I> { self.input.read_byte() } + fn skip(&mut self, len: usize) -> Result<(), Error> { + self.input.skip(len) + } + fn descend_ref(&mut self) -> Result<(), Error> { self.input.descend_ref()?; self.depth += 1; diff --git a/src/encode_like.rs b/src/encode_like.rs index 8280b85d..dfd9e699 100644 --- a/src/encode_like.rs +++ b/src/encode_like.rs @@ -28,7 +28,7 @@ use crate::codec::Encode; /// } /// /// fn main() { -/// // Just pass the a reference to the normal tuple. +/// // Just pass a reference to the normal tuple. /// encode_like::<(u32, u32), _>(&(1u32, 2u32)); /// // Pass a tuple of references /// encode_like::<(u32, u32), _>(&(&1u32, &2u32)); @@ -39,9 +39,9 @@ use crate::codec::Encode; /// /// # Warning /// -/// The relation is not symetric, `T` implements `EncodeLike` does not mean `U` has same +/// The relation is not symmetric, `T` implements `EncodeLike` does not mean `U` has same /// representation as `T`. -/// For instance we could imaging a non zero integer to be encoded to the same representation as +/// For instance, we could imagine a non-zero integer to be encoded to the same representation as /// the said integer but not the other way around. /// /// # Limitation diff --git a/src/generic_array.rs b/src/generic_array.rs index 18dffceb..697754be 100644 --- a/src/generic_array.rs +++ b/src/generic_array.rs @@ -37,6 +37,17 @@ impl> Decode for generic_array::Gene None => Err("array length does not match definition".into()), } } + + fn skip(input: &mut I) -> Result<(), Error> { + match Self::encoded_fixed_size() { + Some(len) => input.skip(len), + None => Result::from_iter((0..L::to_usize()).map(|_| T::skip(input))), + } + } + + fn encoded_fixed_size() -> Option { + Some(T::encoded_fixed_size()? * L::to_usize()) + } } #[cfg(test)] @@ -62,4 +73,27 @@ mod tests { let encoded = test.encode(); assert_eq!(test, GenericArray::::decode(&mut &encoded[..]).unwrap()); } + + #[test] + fn skip_generic_array() { + let test = arr![u8; 3, 4, 5]; + let mut encoded = &(test, 23u8).encode()[..]; + GenericArray::::skip(&mut encoded).unwrap(); + assert_eq!(u8::decode(&mut encoded).unwrap(), 23); + + let test = arr![u16; 3, 4, 5, 6, 7, 8, 0]; + let mut encoded = &(test, 23u8).encode()[..]; + GenericArray::::skip(&mut encoded).unwrap(); + assert_eq!(u8::decode(&mut encoded).unwrap(), 23); + + let test = arr![u32; 3, 4, 5, 0, 1]; + let mut encoded = &(test, 23u8).encode()[..]; + GenericArray::::skip(&mut encoded).unwrap(); + assert_eq!(u8::decode(&mut encoded).unwrap(), 23); + + let test = arr![u64; 3]; + let mut encoded = &(test, 23u8).encode()[..]; + GenericArray::::skip(&mut encoded).unwrap(); + assert_eq!(u8::decode(&mut encoded).unwrap(), 23); + } }