From 1caeb97bcab0a2bdf34dbc92718d1cb3b25a4503 Mon Sep 17 00:00:00 2001 From: Konrad Kohbrok Date: Wed, 15 Nov 2023 10:07:40 +0100 Subject: [PATCH] tls_codec: feature for conditional deserialization derivation (#1214) --- .github/workflows/tls_codec.yml | 1 + tls_codec/Cargo.toml | 14 +- tls_codec/derive/Cargo.toml | 3 +- tls_codec/derive/src/lib.rs | 256 ++++++++++++++++++++++++++----- tls_codec/derive/tests/decode.rs | 46 ++++++ 5 files changed, 279 insertions(+), 41 deletions(-) diff --git a/.github/workflows/tls_codec.yml b/.github/workflows/tls_codec.yml index d4142bcef..12c259b3e 100644 --- a/.github/workflows/tls_codec.yml +++ b/.github/workflows/tls_codec.yml @@ -71,3 +71,4 @@ jobs: - uses: RustCrypto/actions/cargo-hack-install@master - run: cargo hack test --feature-powerset - run: cargo hack test -p tls_codec_derive --feature-powerset --test encode\* --test decode\* + - run: cargo hack test -p tls_codec_derive --feature-powerset --doc diff --git a/tls_codec/Cargo.toml b/tls_codec/Cargo.toml index a94c289ab..863309d99 100644 --- a/tls_codec/Cargo.toml +++ b/tls_codec/Cargo.toml @@ -27,12 +27,16 @@ criterion = { version = "0.5", default-features = false } regex = "1.8" [features] -default = [ "std" ] -arbitrary = [ "std", "dep:arbitrary" ] -derive = [ "tls_codec_derive" ] -serde = [ "std", "dep:serde" ] +default = ["std"] +arbitrary = ["std", "dep:arbitrary"] +derive = ["tls_codec_derive"] +serde = ["std", "dep:serde"] mls = [] # In MLS variable length vectors are limited compared to QUIC. -std = [ "tls_codec_derive?/std" ] +std = ["tls_codec_derive?/std"] +conditional_deserialization = [ + "derive", + "tls_codec_derive/conditional_deserialization", +] [[bench]] name = "tls_vec" diff --git a/tls_codec/derive/Cargo.toml b/tls_codec/derive/Cargo.toml index 6e3c10396..834096147 100644 --- a/tls_codec/derive/Cargo.toml +++ b/tls_codec/derive/Cargo.toml @@ -23,5 +23,6 @@ tls_codec = { path = "../" } trybuild = "1" [features] -default = [ "std" ] +default = ["std"] +conditional_deserialization = ["syn/full"] std = [] diff --git a/tls_codec/derive/src/lib.rs b/tls_codec/derive/src/lib.rs index 07abfad98..5966a44f2 100644 --- a/tls_codec/derive/src/lib.rs +++ b/tls_codec/derive/src/lib.rs @@ -1,21 +1,35 @@ //! # Derive macros for traits in `tls_codec` //! +//! Derive macros can be used to automatically implement the +//! [`Serialize`](../tls_codec::Serialize), +//! [`SerializeBytes`](../tls_codec::SerializeBytes), +//! [`Deserialize`](../tls_codec::Deserialize), +//! [`DeserializeBytes`](../tls_codec::DeserializeBytes), and +//! [`Size`](../tls_codec::Size) traits for structs and enums. Note that the +//! functions of the [`Serialize`](../tls_codec::Serialize) and +//! [`Deserialize`](../tls_codec::Deserialize) traits (and thus the +//! corresponding derive macros) require the `"std"` feature to work. +//! //! ## Warning //! -//! The derive macros support deriving the `tls_codec` traits for enumerations and the resulting -//! serialized format complies with [the "variants" section of the TLS RFC](https://datatracker.ietf.org/doc/html/rfc8446#section-3.8). -//! However support is limited to enumerations that are serialized with their discriminant -//! immediately followed by the variant data. If this is not appropriate (e.g. the format requires -//! other fields between the discriminant and variant data), the `tls_codec` traits can be -//! implemented manually. +//! The derive macros support deriving the `tls_codec` traits for enumerations +//! and the resulting serialized format complies with [the "variants" section of +//! the TLS RFC](https://datatracker.ietf.org/doc/html/rfc8446#section-3.8). +//! However support is limited to enumerations that are serialized with their +//! discriminant immediately followed by the variant data. If this is not +//! appropriate (e.g. the format requires other fields between the discriminant +//! and variant data), the `tls_codec` traits can be implemented manually. //! //! ## Parsing unknown values -//! In many cases it is necessary to deserialize structs with unknown values, e.g. -//! when receiving unknown TLS extensions. -//! In this case the deserialize function returns an `Error::UnknownValue` with -//! a `u64` value of the unknown type. +//! +//! In many cases it is necessary to deserialize structs with unknown values, +//! e.g. when receiving unknown TLS extensions. In this case the deserialize +//! function returns an `Error::UnknownValue` with a `u64` value of the unknown +//! type. //! //! ``` +//! # #[cfg(feature = "std")] +//! # { //! use tls_codec_derive::{TlsDeserialize, TlsSerialize, TlsSize}; //! //! #[derive(TlsDeserialize, TlsSerialize, TlsSize)] @@ -31,27 +45,35 @@ //! let deserialized = TypeWithUnknowns::tls_deserialize_exact(incoming); //! assert!(matches!(deserialized, Err(Error::UnknownValue(3)))); //! } +//! # } //! ``` //! //! ## Available attributes //! +//! Attributes can be used to control serialization and deserialization on a +//! per-field basis. +//! //! ### with //! //! ```text //! #[tls_codec(with = "prefix")] //! ``` //! -//! This attribute may be applied to a struct field. It indicates that deriving any of the -//! `tls_codec` traits for the containing struct calls the following functions: +//! This attribute may be applied to a struct field. It indicates that deriving +//! any of the `tls_codec` traits for the containing struct calls the following +//! functions: //! - `prefix::tls_deserialize` when deriving `Deserialize` //! - `prefix::tls_serialize` when deriving `Serialize` //! - `prefix::tls_serialized_len` when deriving `Size` //! -//! `prefix` can be a path to a module, type or trait where the functions are defined. +//! `prefix` can be a path to a module, type or trait where the functions are +//! defined. //! //! Their expected signatures match the corresponding methods in the traits. //! //! ``` +//! # #[cfg(feature = "std")] +//! # { //! use tls_codec_derive::{TlsSerialize, TlsSize}; //! //! #[derive(TlsSerialize, TlsSize)] @@ -72,6 +94,7 @@ //! TlsByteSliceU32(v).tls_serialize(writer) //! } //! } +//! # } //! ``` //! //! ### discriminant @@ -81,27 +104,34 @@ //! #[tls_codec(discriminant = "path::to::const::or::enum::Variant")] //! ``` //! -//! This attribute may be applied to an enum variant to specify the discriminant to use when -//! serializing it. If all variants are units (e.g. they do not have any data), this attribute -//! must not be used and the desired discriminants should be assigned to the variants using -//! standard Rust syntax (`Variant = Discriminant`). +//! This attribute may be applied to an enum variant to specify the discriminant +//! to use when serializing it. If all variants are units (e.g. they do not have +//! any data), this attribute must not be used and the desired discriminants +//! should be assigned to the variants using standard Rust syntax (`Variant = +//! Discriminant`). //! -//! For enumerations with non-unit variants, if no variant has this attribute, the serialization -//! discriminants will start from zero. If this attribute is used on a variant and the following -//! variant does not have it, its discriminant will be equal to the previous variant discriminant -//! plus 1. This behavior is referred to as "implicit discriminants". +//! For enumerations with non-unit variants, if no variant has this attribute, +//! the serialization discriminants will start from zero. If this attribute is +//! used on a variant and the following variant does not have it, its +//! discriminant will be equal to the previous variant discriminant plus 1. This +//! behavior is referred to as "implicit discriminants". //! -//! You can also provide paths that lead to `const` definitions or enum Variants. The important -//! thing is that any of those path expressions must resolve to something that can be coerced to -//! the `#[repr(enum_repr)]` of the enum. Please note that there are checks performed at compile -//! time to check if the provided value fits within the bounds of the `enum_repr` to avoid misuse. +//! You can also provide paths that lead to `const` definitions or enum +//! Variants. The important thing is that any of those path expressions must +//! resolve to something that can be coerced to the `#[repr(enum_repr)]` of the +//! enum. Please note that there are checks performed at compile time to check +//! if the provided value fits within the bounds of the `enum_repr` to avoid +//! misuse. //! -//! Note: When using paths *once* in your enum discriminants, as we do not have enough information -//! to deduce the next implicit discriminant (the constant expressions those paths resolve is only -//! evaluated at a later compilation stage than macros), you will be forced to use explicit -//! discriminants for all the other Variants of your enum. +//! Note: When using paths *once* in your enum discriminants, as we do not have +//! enough information to deduce the next implicit discriminant (the constant +//! expressions those paths resolve is only evaluated at a later compilation +//! stage than macros), you will be forced to use explicit discriminants for all +//! the other Variants of your enum. //! //! ``` +//! # #[cfg(feature = "std")] +//! # { //! use tls_codec_derive::{TlsSerialize, TlsSize}; //! //! const CONST_DISCRIMINANT: u8 = 5; @@ -130,7 +160,7 @@ //! #[tls_codec(discriminant = "CONST_DISCRIMINANT")] //! StaticConstant(u8), //! } -//! +//! # } //! ``` //! //! ### skip @@ -139,13 +169,16 @@ //! #[tls_codec(skip)] //! ``` //! -//! This attribute may be applied to a struct field to specify that it should be skipped. Skipping -//! means that the field at hand will neither be serialized into TLS bytes nor deserialized from TLS -//! bytes. For deserialization, it is required to populate the field with a known value. Thus, when -//! `skip` is used, the field type needs to implement the [Default] trait so it can be populated -//! with a default value. +//! This attribute may be applied to a struct field to specify that it should be +//! skipped. Skipping means that the field at hand will neither be serialized +//! into TLS bytes nor deserialized from TLS bytes. For deserialization, it is +//! required to populate the field with a known value. Thus, when `skip` is +//! used, the field type needs to implement the [Default] trait so it can be +//! populated with a default value. //! //! ``` +//! # #[cfg(feature = "std")] +//! # { //! use tls_codec_derive::{TlsSerialize, TlsDeserialize, TlsSize}; //! //! struct CustomStruct; @@ -163,6 +196,43 @@ //! b: CustomStruct, //! c: u8, //! } +//! # } +//! ``` +//! +//! ## Conditional deserialization via the `conditionally_deserializable` attribute macro +//! +//! In some cases, it can be useful to have two variants of a struct, where one +//! is deserializable and one isn't. For example, the deserializable variant of +//! the struct could represent an unverified message, where only verification +//! produces the verified variant. Further processing could then be restricted +//! to the undeserializable struct variant. +//! +//! A pattern like this can be created via the `conditionally_deserializable` +//! attribute macro (requires the `conditional_deserialization` feature flag). +//! +//! The macro adds a boolean const generic to the struct and creates two +//! aliases, one for the deserializable variant (with a "`Deserializable`" +//! prefix) and one for the undeserializable one (with an "`Undeserializable`" +//! prefix). +//! +//! ``` +//! # #[cfg(all(feature = "conditional_deserialization", feature = "std"))] +//! # { +//! use tls_codec::{Serialize, Deserialize}; +//! use tls_codec_derive::{TlsSerialize, TlsSize, conditionally_deserializable}; +//! +//! #[conditionally_deserializable] +//! #[derive(TlsSize, TlsSerialize, PartialEq, Debug)] +//! struct ExampleStruct { +//! a: u8, +//! b: u16, +//! } +//! +//! let undeserializable_struct = UndeserializableExampleStruct { a: 1, b: 2 }; +//! let serialized = undeserializable_struct.tls_serialize_detached().unwrap(); +//! let deserializable_struct = +//! DeserializableExampleStruct::tls_deserialize(&mut serialized.as_slice()).unwrap(); +//! # } //! ``` extern crate proc_macro; @@ -176,6 +246,9 @@ use syn::{ Expr, ExprLit, ExprPath, Field, Generics, Ident, Lit, Member, Meta, Result, Token, Type, }; +#[cfg(feature = "conditional_deserialization")] +use syn::{parse_quote, ConstParam, ImplGenerics, ItemStruct, TypeGenerics}; + /// Attribute name to identify attributes to be processed by derive-macros in this crate. const ATTR_IDENT: &str = "tls_codec"; @@ -895,6 +968,27 @@ fn impl_serialize(parsed_ast: TlsStruct, svariant: SerializeVariant) -> TokenStr } } +#[cfg(feature = "conditional_deserialization")] +fn restrict_conditional_generic( + impl_generics: ImplGenerics, + ty_generics: TypeGenerics, + deserializable: bool, +) -> (TokenStream2, TokenStream2) { + let impl_generics = quote! { #impl_generics } + .to_string() + .replace(" const IS_DESERIALIZABLE : bool ", "") + .replace("<>", "") + .parse() + .unwrap(); + let deserializable_string = if deserializable { "true" } else { "false" }; + let ty_generics = quote! { #ty_generics } + .to_string() + .replace("IS_DESERIALIZABLE", deserializable_string) + .parse() + .unwrap(); + (impl_generics, ty_generics) +} + #[allow(unused_variables)] fn impl_deserialize(parsed_ast: TlsStruct) -> TokenStream2 { match parsed_ast { @@ -914,6 +1008,9 @@ fn impl_deserialize(parsed_ast: TlsStruct) -> TokenStream2 { .map(|p| p.for_trait("Deserialize")) .collect::>(); let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); + #[cfg(feature = "conditional_deserialization")] + let (impl_generics, ty_generics) = + restrict_conditional_generic(impl_generics, ty_generics, true); quote! { impl #impl_generics tls_codec::Deserialize for #ident #ty_generics #where_clause { #[cfg(feature = "std")] @@ -1003,6 +1100,9 @@ fn impl_deserialize_bytes(parsed_ast: TlsStruct) -> TokenStream2 { .map(|p| p.for_trait("DeserializeBytes")) .collect::>(); let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); + #[cfg(feature = "conditional_deserialization")] + let (impl_generics, ty_generics) = + restrict_conditional_generic(impl_generics, ty_generics, true); quote! { impl #impl_generics tls_codec::DeserializeBytes for #ident #ty_generics #where_clause { fn tls_deserialize_bytes(bytes: &[u8]) -> core::result::Result<(Self, &[u8]), tls_codec::Error> { @@ -1102,3 +1202,89 @@ fn partition_skipped( (members_skip, member_prefixes_skip), ) } + +/// The `conditionally_deserializable` attribute macro takes as input either +/// `Bytes` or `Reader` and does the following: +/// * Add a boolean const generic to the struct indicating if the variant of the +/// struct is deserializable or not. +/// * Depending on the input derive either the `TlsDeserialize` or +/// `TlsDeserializeBytes` trait for the deserializable variant +/// * Create type aliases for the deserializable and undeserializable variant of +/// the struct, where the alias is the name of the struct prefixed with +/// `Deserializable` or `Undeserializable` respectively. +/// +/// The `conditionally_deserializable` attribute macro is only available if the +/// `conditional_deserialization` feature is enabled. +/// +#[cfg_attr( + feature = "conditional_deserialization", + doc = r##" +```compile_fail +use tls_codec_derive::{TlsSerialize, TlsDeserialize, TlsSize, conditionally_deserializable}; + +#[conditionally_deserializable(Bytes)] +#[derive(TlsDeserialize, TlsSerialize, TlsSize)] +struct ExampleStruct { + pub a: u16, +} + +impl UndeserializableExampleStruct { + #[cfg(feature = "conditional_deserialization")] + fn deserialize(bytes: &[u8]) -> Result { + Self::tls_deserialize_exact(bytes) + } +} +``` +"## +)] +#[cfg(feature = "conditional_deserialization")] +#[proc_macro_attribute] +pub fn conditionally_deserializable( + _input: TokenStream, + annotated_item: TokenStream, +) -> TokenStream { + let annotated_item = parse_macro_input!(annotated_item as ItemStruct); + impl_conditionally_deserializable(annotated_item).into() +} + +#[cfg(feature = "conditional_deserialization")] +fn impl_conditionally_deserializable(mut annotated_item: ItemStruct) -> TokenStream2 { + let deserializable_const_generic: ConstParam = parse_quote! {const IS_DESERIALIZABLE: bool}; + // Add the DESERIALIZABLE const generic to the struct + annotated_item + .generics + .params + .push(deserializable_const_generic.into()); + // Derive both TlsDeserialize and TlsDeserializeBytes + let deserialize_bytes_implementation = + impl_deserialize_bytes(parse_ast(annotated_item.clone().into()).unwrap()); + let deserialize_implementation = + impl_deserialize(parse_ast(annotated_item.clone().into()).unwrap()); + let (impl_generics, ty_generics, _) = annotated_item.generics.split_for_impl(); + // Patch generics for use by the type aliases + let (_deserializable_impl_generics, deserializable_ty_generics) = + restrict_conditional_generic(impl_generics.clone(), ty_generics.clone(), true); + let (_undeserializable_impl_generics, undeserializable_ty_generics) = + restrict_conditional_generic(impl_generics.clone(), ty_generics.clone(), false); + let annotated_item_ident = annotated_item.ident.clone(); + // Create Alias Idents by adding prefixes + let deserializable_ident = Ident::new( + &format!("Deserializable{}", annotated_item_ident), + Span::call_site(), + ); + let undeserializable_ident = Ident::new( + &format!("Undeserializable{}", annotated_item_ident), + Span::call_site(), + ); + let annotated_item_visibility = annotated_item.vis.clone(); + quote! { + #annotated_item + + #annotated_item_visibility type #undeserializable_ident = #annotated_item_ident #undeserializable_ty_generics; + #annotated_item_visibility type #deserializable_ident = #annotated_item_ident #deserializable_ty_generics; + + #deserialize_implementation + + #deserialize_bytes_implementation + } +} diff --git a/tls_codec/derive/tests/decode.rs b/tls_codec/derive/tests/decode.rs index 93d5fa430..d0907b25c 100644 --- a/tls_codec/derive/tests/decode.rs +++ b/tls_codec/derive/tests/decode.rs @@ -527,3 +527,49 @@ fn type_with_unknowns() { let deserialized = TypeWithUnknowns::tls_deserialize_exact(incoming); assert!(matches!(deserialized, Err(Error::UnknownValue(3)))); } + +#[cfg(feature = "conditional_deserialization")] +mod conditional_deserialization { + use tls_codec::{Deserialize, Serialize}; + use tls_codec_derive::{conditionally_deserializable, TlsSerialize, TlsSize}; + + #[test] + fn conditionally_deserializable_struct() { + #[conditionally_deserializable] + #[derive(TlsSize, TlsSerialize, PartialEq, Debug)] + struct ExampleStruct { + a: u8, + b: u16, + } + + let undeserializable_struct = UndeserializableExampleStruct { a: 1, b: 2 }; + let serialized = undeserializable_struct.tls_serialize_detached().unwrap(); + let deserializable_struct = + DeserializableExampleStruct::tls_deserialize(&mut serialized.as_slice()).unwrap(); + assert_eq!(deserializable_struct.a, undeserializable_struct.a); + assert_eq!(deserializable_struct.b, undeserializable_struct.b); + + #[conditionally_deserializable] + #[derive(TlsSize, TlsSerialize, PartialEq, Debug)] + struct SecondExampleStruct { + a: u8, + b: u16, + } + } + + #[test] + fn conditional_deserializable_struct_bytes() { + #[conditionally_deserializable] + #[derive(TlsSize, TlsSerialize, PartialEq, Debug)] + struct ExampleStruct { + a: u8, + b: u16, + } + let undeserializable_struct = UndeserializableExampleStruct { a: 1, b: 2 }; + let serialized = undeserializable_struct.tls_serialize_detached().unwrap(); + let deserializable_struct = + DeserializableExampleStruct::tls_deserialize_exact(&mut &*serialized).unwrap(); + assert_eq!(deserializable_struct.a, undeserializable_struct.a); + assert_eq!(deserializable_struct.b, undeserializable_struct.b); + } +}