Skip to content

Commit

Permalink
Add custom error types support to the Decode and DecodeValue traits.
Browse files Browse the repository at this point in the history
  • Loading branch information
turbocool3r authored and admin committed Jan 23, 2024
1 parent 0f34bf8 commit 1ff2bf5
Show file tree
Hide file tree
Showing 58 changed files with 400 additions and 217 deletions.
2 changes: 2 additions & 0 deletions crmf/src/pop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,8 @@ impl<'a> ::der::Choice<'a> for EncKeyWithIdChoice<'a> {
}
}
impl<'a> ::der::Decode<'a> for EncKeyWithIdChoice<'a> {
type Error = ::der::Error;

fn decode<R: ::der::Reader<'a>>(reader: &mut R) -> ::der::Result<Self> {
let t = reader.peek_tag()?;
if t == <Utf8StringRef<'a> as ::der::FixedTag>::TAG {
Expand Down
2 changes: 2 additions & 0 deletions der/derive/src/choice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ impl DeriveChoice {
}

impl<#lifetime> ::der::Decode<#lifetime> for #ident<#lt_params> {
type Error = ::der::Error;

fn decode<R: ::der::Reader<#lifetime>>(reader: &mut R) -> ::der::Result<Self> {
use der::Reader as _;
match reader.peek_tag()? {
Expand Down
2 changes: 2 additions & 0 deletions der/derive/src/enumerated.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,8 @@ impl DeriveEnumerated {

quote! {
impl<#default_lifetime> ::der::DecodeValue<#default_lifetime> for #ident {
type Error = ::der::Error;

fn decode_value<R: ::der::Reader<#default_lifetime>>(
reader: &mut R,
header: ::der::Header
Expand Down
2 changes: 2 additions & 0 deletions der/derive/src/sequence.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ impl DeriveSequence {

quote! {
impl #impl_generics ::der::DecodeValue<#lifetime> for #ident #ty_generics #where_clause {
type Error = ::der::Error;

fn decode_value<R: ::der::Reader<#lifetime>>(
reader: &mut R,
header: ::der::Header,
Expand Down
58 changes: 34 additions & 24 deletions der/src/asn1/any.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

use crate::{
BytesRef, Choice, Decode, DecodeValue, DerOrd, EncodeValue, Error, ErrorKind, Header, Length,
Reader, Result, SliceReader, Tag, Tagged, ValueOrd, Writer,
Reader, SliceReader, Tag, Tagged, ValueOrd, Writer,
};
use core::cmp::Ordering;

Expand Down Expand Up @@ -40,7 +40,7 @@ impl<'a> AnyRef<'a> {
};

/// Create a new [`AnyRef`] from the provided [`Tag`] and DER bytes.
pub fn new(tag: Tag, bytes: &'a [u8]) -> Result<Self> {
pub fn new(tag: Tag, bytes: &'a [u8]) -> Result<Self, Error> {
let value = BytesRef::new(bytes).map_err(|_| ErrorKind::Length { tag })?;
Ok(Self { tag, value })
}
Expand All @@ -56,12 +56,12 @@ impl<'a> AnyRef<'a> {
}

/// Attempt to decode this [`AnyRef`] type into the inner value.
pub fn decode_as<T>(self) -> Result<T>
pub fn decode_as<T>(self) -> Result<T, <T as DecodeValue<'a>>::Error>
where
T: Choice<'a> + DecodeValue<'a>,
{
if !T::can_decode(self.tag) {
return Err(self.tag.unexpected_error(None));
return Err(self.tag.unexpected_error(None).into());
}

let header = Header {
Expand All @@ -71,7 +71,7 @@ impl<'a> AnyRef<'a> {

let mut decoder = SliceReader::new(self.value())?;
let result = T::decode_value(&mut decoder, header)?;
decoder.finish(result)
Ok(decoder.finish(result)?)
}

/// Is this value an ASN.1 `NULL` value?
Expand All @@ -81,14 +81,15 @@ impl<'a> AnyRef<'a> {

/// Attempt to decode this value an ASN.1 `SEQUENCE`, creating a new
/// nested reader and calling the provided argument with it.
pub fn sequence<F, T>(self, f: F) -> Result<T>
pub fn sequence<F, T, E>(self, f: F) -> Result<T, E>
where
F: FnOnce(&mut SliceReader<'a>) -> Result<T>,
F: FnOnce(&mut SliceReader<'a>) -> Result<T, E>,
E: From<Error>,
{
self.tag.assert_eq(Tag::Sequence)?;
let mut reader = SliceReader::new(self.value.as_slice())?;
let result = f(&mut reader)?;
reader.finish(result)
Ok(reader.finish(result)?)
}
}

Expand All @@ -99,14 +100,18 @@ impl<'a> Choice<'a> for AnyRef<'a> {
}

impl<'a> Decode<'a> for AnyRef<'a> {
fn decode<R: Reader<'a>>(reader: &mut R) -> Result<AnyRef<'a>> {
type Error = Error;

fn decode<R: Reader<'a>>(reader: &mut R) -> Result<AnyRef<'a>, Error> {
let header = Header::decode(reader)?;
Self::decode_value(reader, header)
}
}

impl<'a> DecodeValue<'a> for AnyRef<'a> {
fn decode_value<R: Reader<'a>>(reader: &mut R, header: Header) -> Result<Self> {
type Error = Error;

fn decode_value<R: Reader<'a>>(reader: &mut R, header: Header) -> Result<Self, Error> {
Ok(Self {
tag: header.tag,
value: BytesRef::decode_value(reader, header)?,
Expand All @@ -115,11 +120,11 @@ impl<'a> DecodeValue<'a> for AnyRef<'a> {
}

impl EncodeValue for AnyRef<'_> {
fn value_len(&self) -> Result<Length> {
fn value_len(&self) -> Result<Length, Error> {
Ok(self.value.len())
}

fn encode_value(&self, writer: &mut impl Writer) -> Result<()> {
fn encode_value(&self, writer: &mut impl Writer) -> Result<(), Error> {
writer.write(self.value())
}
}
Expand All @@ -131,7 +136,7 @@ impl Tagged for AnyRef<'_> {
}

impl ValueOrd for AnyRef<'_> {
fn value_cmp(&self, other: &Self) -> Result<Ordering> {
fn value_cmp(&self, other: &Self) -> Result<Ordering, Error> {
self.value.der_cmp(&other.value)
}
}
Expand All @@ -145,7 +150,7 @@ impl<'a> From<AnyRef<'a>> for BytesRef<'a> {
impl<'a> TryFrom<&'a [u8]> for AnyRef<'a> {
type Error = Error;

fn try_from(bytes: &'a [u8]) -> Result<AnyRef<'a>> {
fn try_from(bytes: &'a [u8]) -> Result<AnyRef<'a>, Error> {
AnyRef::from_der(bytes)
}
}
Expand Down Expand Up @@ -175,7 +180,7 @@ mod allocating {

impl Any {
/// Create a new [`Any`] from the provided [`Tag`] and DER bytes.
pub fn new(tag: Tag, bytes: impl Into<Box<[u8]>>) -> Result<Self> {
pub fn new(tag: Tag, bytes: impl Into<Box<[u8]>>) -> Result<Self, Error> {
let value = BytesOwned::new(bytes)?;

// Ensure the tag and value are a valid `AnyRef`.
Expand All @@ -189,15 +194,15 @@ mod allocating {
}

/// Attempt to decode this [`Any`] type into the inner value.
pub fn decode_as<'a, T>(&'a self) -> Result<T>
pub fn decode_as<'a, T>(&'a self) -> Result<T, <T as DecodeValue<'a>>::Error>
where
T: Choice<'a> + DecodeValue<'a>,
{
AnyRef::from(self).decode_as()
}

/// Encode the provided type as an [`Any`] value.
pub fn encode_from<T>(msg: &T) -> Result<Self>
pub fn encode_from<T>(msg: &T) -> Result<Self, Error>
where
T: Tagged + EncodeValue,
{
Expand All @@ -211,9 +216,10 @@ mod allocating {

/// Attempt to decode this value an ASN.1 `SEQUENCE`, creating a new
/// nested reader and calling the provided argument with it.
pub fn sequence<'a, F, T>(&'a self, f: F) -> Result<T>
pub fn sequence<'a, F, T, E>(&'a self, f: F) -> Result<T, E>
where
F: FnOnce(&mut SliceReader<'a>) -> Result<T>,
F: FnOnce(&mut SliceReader<'a>) -> Result<T, E>,
E: From<Error>,
{
AnyRef::from(self).sequence(f)
}
Expand All @@ -234,25 +240,29 @@ mod allocating {
}

impl<'a> Decode<'a> for Any {
fn decode<R: Reader<'a>>(reader: &mut R) -> Result<Self> {
type Error = Error;

fn decode<R: Reader<'a>>(reader: &mut R) -> Result<Self, Error> {
let header = Header::decode(reader)?;
Self::decode_value(reader, header)
}
}

impl<'a> DecodeValue<'a> for Any {
fn decode_value<R: Reader<'a>>(reader: &mut R, header: Header) -> Result<Self> {
type Error = Error;

fn decode_value<R: Reader<'a>>(reader: &mut R, header: Header) -> Result<Self, Error> {
let value = reader.read_vec(header.length)?;
Self::new(header.tag, value)
}
}

impl EncodeValue for Any {
fn value_len(&self) -> Result<Length> {
fn value_len(&self) -> Result<Length, Error> {
Ok(self.value.len())
}

fn encode_value(&self, writer: &mut impl Writer) -> Result<()> {
fn encode_value(&self, writer: &mut impl Writer) -> Result<(), Error> {
writer.write(self.value.as_slice())
}
}
Expand All @@ -271,7 +281,7 @@ mod allocating {
}

impl ValueOrd for Any {
fn value_cmp(&self, other: &Self) -> Result<Ordering> {
fn value_cmp(&self, other: &Self) -> Result<Ordering, Error> {
self.value.der_cmp(&other.value)
}
}
Expand Down
6 changes: 6 additions & 0 deletions der/src/asn1/bit_string.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,8 @@ impl<'a> BitStringRef<'a> {
impl_any_conversions!(BitStringRef<'a>, 'a);

impl<'a> DecodeValue<'a> for BitStringRef<'a> {
type Error = Error;

fn decode_value<R: Reader<'a>>(reader: &mut R, header: Header) -> Result<Self> {
let header = Header {
tag: header.tag,
Expand Down Expand Up @@ -309,6 +311,8 @@ mod allocating {
impl_any_conversions!(BitString);

impl<'a> DecodeValue<'a> for BitString {
type Error = Error;

fn decode_value<R: Reader<'a>>(reader: &mut R, header: Header) -> Result<Self> {
let inner_len = (header.length - Length::ONE)?;
let unused_bits = reader.read_byte()?;
Expand Down Expand Up @@ -442,6 +446,8 @@ where
T::Type: From<bool>,
T::Type: core::ops::Shl<usize, Output = T::Type>,
{
type Error = Error;

fn decode_value<R: Reader<'a>>(reader: &mut R, header: Header) -> Result<Self> {
let position = reader.position();
let bits = BitStringRef::decode_value(reader, header)?;
Expand Down
2 changes: 2 additions & 0 deletions der/src/asn1/bmp_string.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ impl AsRef<[u8]> for BmpString {
}

impl<'a> DecodeValue<'a> for BmpString {
type Error = Error;

fn decode_value<R: Reader<'a>>(reader: &mut R, header: Header) -> Result<Self> {
Self::from_ucs2(reader.read_vec(header.length)?)
}
Expand Down
2 changes: 2 additions & 0 deletions der/src/asn1/boolean.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ const TRUE_OCTET: u8 = 0b11111111;
const FALSE_OCTET: u8 = 0b00000000;

impl<'a> DecodeValue<'a> for bool {
type Error = Error;

fn decode_value<R: Reader<'a>>(reader: &mut R, header: Header) -> Result<Self> {
if header.length != Length::ONE {
return Err(reader.error(ErrorKind::Length { tag: Self::TAG }));
Expand Down
Loading

0 comments on commit 1ff2bf5

Please sign in to comment.