From 7c85617810e0af6410e104028511abfa6fc62433 Mon Sep 17 00:00:00 2001 From: Tony Arcieri Date: Sat, 17 Aug 2024 12:45:42 -0600 Subject: [PATCH] [WIP] der: refactor `Reader::read_nested` Gets rid of a provided method along with the `NestedReader` type, in order to fix #1228. Instead, each reader must now implement its own strategy for nested reading. --- der/src/bytes_ref.rs | 10 ++++ der/src/lib.rs | 2 +- der/src/reader.rs | 24 ++++------ der/src/reader/nested.rs | 100 --------------------------------------- der/src/reader/pem.rs | 29 +++++++++--- der/src/reader/slice.rs | 22 +++++++++ 6 files changed, 64 insertions(+), 123 deletions(-) delete mode 100644 der/src/reader/nested.rs diff --git a/der/src/bytes_ref.rs b/der/src/bytes_ref.rs index 7d384ba41..ce4fb800c 100644 --- a/der/src/bytes_ref.rs +++ b/der/src/bytes_ref.rs @@ -49,6 +49,16 @@ impl<'a> BytesRef<'a> { pub fn is_empty(self) -> bool { self.len() == Length::ZERO } + + /// Get a prefix of a [`BytesRef`] of the given length. + pub fn prefix(self, length: Length) -> Result { + let inner = self + .as_slice() + .get(..usize::try_from(length)?) + .ok_or_else(|| Error::incomplete(self.length))?; + + Ok(Self { length, inner }) + } } impl AsRef<[u8]> for BytesRef<'_> { diff --git a/der/src/lib.rs b/der/src/lib.rs index ce341bb99..441cd2f5d 100644 --- a/der/src/lib.rs +++ b/der/src/lib.rs @@ -367,7 +367,7 @@ pub use crate::{ header::Header, length::{IndefiniteLength, Length}, ord::{DerOrd, ValueOrd}, - reader::{nested::NestedReader, slice::SliceReader, Reader}, + reader::{slice::SliceReader, Reader}, tag::{Class, FixedTag, Tag, TagMode, TagNumber, Tagged}, writer::{slice::SliceWriter, Writer}, }; diff --git a/der/src/reader.rs b/der/src/reader.rs index 10576be29..1227d7389 100644 --- a/der/src/reader.rs +++ b/der/src/reader.rs @@ -1,12 +1,9 @@ //! Reader trait. -pub(crate) mod nested; #[cfg(feature = "pem")] pub(crate) mod pem; pub(crate) mod slice; -pub(crate) use nested::NestedReader; - use crate::{ asn1::ContextSpecific, Decode, DecodeValue, Encode, EncodingRules, Error, ErrorKind, FixedTag, Header, Length, Tag, TagMode, TagNumber, @@ -35,6 +32,12 @@ pub trait Reader<'r>: Sized { /// Get the position within the buffer. fn position(&self) -> Length; + /// Read nested data of the given length. + fn read_nested(&mut self, len: Length, f: F) -> Result + where + E: From, + F: FnOnce(&mut Self) -> Result; + /// Attempt to read data borrowed directly from the input as a slice, /// updating the internal cursor position. /// @@ -130,17 +133,6 @@ pub trait Reader<'r>: Sized { Ok(buf) } - /// Read nested data of the given length. - fn read_nested<'n, T, F, E>(&'n mut self, len: Length, f: F) -> Result - where - F: FnOnce(&mut NestedReader<'n, Self>) -> Result, - E: From, - { - let mut reader = NestedReader::new(self, len)?; - let ret = f(&mut reader)?; - Ok(reader.finish(ret)?) - } - /// Read a byte vector of the given length. #[cfg(feature = "alloc")] fn read_vec(&mut self, len: Length) -> Result, Error> { @@ -157,9 +149,9 @@ pub trait Reader<'r>: Sized { /// Read an ASN.1 `SEQUENCE`, creating a nested [`Reader`] for the body and /// calling the provided closure with it. - fn sequence<'n, F, T, E>(&'n mut self, f: F) -> Result + fn sequence(&mut self, f: F) -> Result where - F: FnOnce(&mut NestedReader<'n, Self>) -> Result, + F: FnOnce(&mut Self) -> Result, E: From, { let header = Header::decode(self)?; diff --git a/der/src/reader/nested.rs b/der/src/reader/nested.rs deleted file mode 100644 index 77c6cdc45..000000000 --- a/der/src/reader/nested.rs +++ /dev/null @@ -1,100 +0,0 @@ -//! Reader type for consuming nested TLV records within a DER document. - -use crate::{reader::Reader, EncodingRules, Error, ErrorKind, Header, Length, Result}; - -/// Reader type used by [`Reader::read_nested`]. -pub struct NestedReader<'i, R> { - /// Inner reader type. - inner: &'i mut R, - - /// Nested input length. - input_len: Length, - - /// Position within the nested input. - position: Length, -} - -impl<'i, 'r, R: Reader<'r>> NestedReader<'i, R> { - /// Create a new nested reader which can read the given [`Length`]. - pub(crate) fn new(inner: &'i mut R, len: Length) -> Result { - if len <= inner.remaining_len() { - Ok(Self { - inner, - input_len: len, - position: Length::ZERO, - }) - } else { - Err(ErrorKind::Incomplete { - expected_len: (inner.offset() + len)?, - actual_len: (inner.offset() + inner.remaining_len())?, - } - .at(inner.offset())) - } - } - - /// Move the position cursor the given length, returning an error if there - /// isn't enough remaining data in the nested input. - fn advance_position(&mut self, len: Length) -> Result<()> { - let new_position = (self.position + len)?; - - if new_position <= self.input_len { - self.position = new_position; - Ok(()) - } else { - Err(ErrorKind::Incomplete { - expected_len: (self.inner.offset() + len)?, - actual_len: (self.inner.offset() + self.remaining_len())?, - } - .at(self.inner.offset())) - } - } -} - -impl<'i, 'r, R: Reader<'r>> Reader<'r> for NestedReader<'i, R> { - fn encoding_rules(&self) -> EncodingRules { - self.inner.encoding_rules() - } - - fn input_len(&self) -> Length { - self.input_len - } - - fn peek_byte(&self) -> Option { - if self.is_finished() { - None - } else { - self.inner.peek_byte() - } - } - - fn peek_header(&self) -> Result
{ - if self.is_finished() { - Err(Error::incomplete(self.offset())) - } else { - // TODO(tarcieri): handle peeking past nested length - self.inner.peek_header() - } - } - - fn position(&self) -> Length { - self.position - } - - fn read_slice(&mut self, len: Length) -> Result<&'r [u8]> { - self.advance_position(len)?; - self.inner.read_slice(len) - } - - fn error(&mut self, kind: ErrorKind) -> Error { - self.inner.error(kind) - } - - fn offset(&self) -> Length { - self.inner.offset() - } - - fn read_into<'o>(&mut self, out: &'o mut [u8]) -> Result<&'o [u8]> { - self.advance_position(Length::try_from(out.len())?)?; - self.inner.read_into(out) - } -} diff --git a/der/src/reader/pem.rs b/der/src/reader/pem.rs index 9e552246f..9f2a22f01 100644 --- a/der/src/reader/pem.rs +++ b/der/src/reader/pem.rs @@ -1,7 +1,7 @@ //! Streaming PEM reader. use super::Reader; -use crate::{Decode, EncodingRules, ErrorKind, Header, Length, Result}; +use crate::{Decode, EncodingRules, Error, ErrorKind, Header, Length}; use pem_rfc7468::Decoder; /// `Reader` type which decodes PEM on-the-fly. @@ -28,7 +28,7 @@ impl<'i> PemReader<'i> { /// Create a new PEM reader which decodes data on-the-fly. /// /// Uses the default 64-character line wrapping. - pub fn new(pem: &'i [u8]) -> Result { + pub fn new(pem: &'i [u8]) -> crate::Result { let decoder = Decoder::new(pem)?; let input_len = Length::try_from(decoder.remaining_len())?; @@ -50,7 +50,7 @@ impl<'i> PemReader<'i> { /// output buffer. /// /// Attempts to fill the entire buffer, returning an error if there is not enough data. - pub fn peek_into(&self, buf: &mut [u8]) -> Result<()> { + pub fn peek_into(&self, buf: &mut [u8]) -> crate::Result<()> { self.decoder.clone().decode(buf)?; Ok(()) } @@ -72,7 +72,7 @@ impl<'i> Reader<'i> for PemReader<'i> { self.peek_into(&mut byte).ok().map(|_| byte[0]) } - fn peek_header(&self) -> Result
{ + fn peek_header(&self) -> crate::Result
{ Header::decode(&mut self.clone()) } @@ -80,12 +80,29 @@ impl<'i> Reader<'i> for PemReader<'i> { self.position } - fn read_slice(&mut self, _len: Length) -> Result<&'i [u8]> { + fn read_nested(&mut self, len: Length, f: F) -> Result + where + F: FnOnce(&mut Self) -> Result, + E: From, + { + let nested_input_len = (self.position + len)?; + if nested_input_len > self.input_len { + return Err(Error::incomplete(self.input_len).into()); + } + + let orig_input_len = self.input_len; + self.input_len = nested_input_len; + let ret = f(self); + self.input_len = orig_input_len; + ret + } + + fn read_slice(&mut self, _len: Length) -> crate::Result<&'i [u8]> { // Can't borrow from PEM because it requires decoding Err(ErrorKind::Reader.into()) } - fn read_into<'o>(&mut self, buf: &'o mut [u8]) -> Result<&'o [u8]> { + fn read_into<'o>(&mut self, buf: &'o mut [u8]) -> crate::Result<&'o [u8]> { let bytes = self.decoder.decode(buf)?; self.position = (self.position + bytes.len())?; diff --git a/der/src/reader/slice.rs b/der/src/reader/slice.rs index a9e1cabe7..132cb6df3 100644 --- a/der/src/reader/slice.rs +++ b/der/src/reader/slice.rs @@ -91,6 +91,28 @@ impl<'a> Reader<'a> for SliceReader<'a> { self.position } + /// Read nested data of the given length. + fn read_nested(&mut self, len: Length, f: F) -> Result + where + F: FnOnce(&mut Self) -> Result, + E: From, + { + let prefix_len = (self.position + len)?; + let mut nested_reader = self.clone(); + nested_reader.bytes = self.bytes.prefix(prefix_len)?; + + let ret = f(&mut nested_reader); + self.position = nested_reader.position; + self.failed = nested_reader.failed; + + ret.and_then(|value| { + nested_reader.finish(value).map_err(|e| { + self.failed = true; + e.into() + }) + }) + } + fn read_slice(&mut self, len: Length) -> Result<&'a [u8], Error> { if self.is_failed() { return Err(self.error(ErrorKind::Failed));