From b495ba2d1f4c4ba56eba8fc832c2ac23adcf0b36 Mon Sep 17 00:00:00 2001 From: Franziskus Kiefer Date: Sun, 19 Nov 2023 10:27:23 +0100 Subject: [PATCH] tls_codec: Don't require std feature for consumers when deriving (#1262) --- Cargo.lock | 4 +- tls_codec/CHANGELOG.md | 1 + tls_codec/Cargo.toml | 4 +- tls_codec/derive/Cargo.toml | 2 +- tls_codec/derive/src/lib.rs | 122 +++++++++++++++++++++--------------- 5 files changed, 77 insertions(+), 56 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 766872578..945e2c72e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1589,7 +1589,7 @@ dependencies = [ [[package]] name = "tls_codec" -version = "0.4.0-pre.1" +version = "0.4.0-pre.2" dependencies = [ "anstyle", "anstyle-parse", @@ -1605,7 +1605,7 @@ dependencies = [ [[package]] name = "tls_codec_derive" -version = "0.4.0-pre.1" +version = "0.4.0-pre.2" dependencies = [ "proc-macro2", "quote", diff --git a/tls_codec/CHANGELOG.md b/tls_codec/CHANGELOG.md index d1f2d4dd5..0b6bc3fc5 100644 --- a/tls_codec/CHANGELOG.md +++ b/tls_codec/CHANGELOG.md @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed - [#1251](https://github.com/RustCrypto/formats/pull/1251): Add `_bytes` suffix to function names in the `DeserializeBytes` trait to avoid collisions with function names in the `Deserialize` trait +- [#1135](https://github.com/RustCrypto/formats/pull/1135): `no_std` support for the derive crate. This requires the `std` feature to be enabled when using derive with `Serialize` and `Deserialize`. ### Removed diff --git a/tls_codec/Cargo.toml b/tls_codec/Cargo.toml index 8b99494f6..217e69166 100644 --- a/tls_codec/Cargo.toml +++ b/tls_codec/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tls_codec" -version = "0.4.0-pre.1" +version = "0.4.0-pre.2" authors = ["RustCrypto Developers"] license = "Apache-2.0 OR MIT" documentation = "https://docs.rs/tls_codec/" @@ -18,7 +18,7 @@ zeroize = { version = "1.7", default-features = false, features = [ # optional dependencies arbitrary = { version = "1.3", features = ["derive"], optional = true } -tls_codec_derive = { version = "=0.4.0-pre.1", path = "./derive", optional = true } +tls_codec_derive = { version = "=0.4.0-pre.2", path = "./derive", optional = true } serde = { version = "1.0.184", features = ["derive"], optional = true } [dev-dependencies] diff --git a/tls_codec/derive/Cargo.toml b/tls_codec/derive/Cargo.toml index f7196d748..1fca2ab83 100644 --- a/tls_codec/derive/Cargo.toml +++ b/tls_codec/derive/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "tls_codec_derive" -version = "0.4.0-pre.1" +version = "0.4.0-pre.2" authors = ["RustCrypto Developers"] license = "Apache-2.0 OR MIT" documentation = "https://docs.rs/tls_codec_derive/" diff --git a/tls_codec/derive/src/lib.rs b/tls_codec/derive/src/lib.rs index 5966a44f2..1c4d7e339 100644 --- a/tls_codec/derive/src/lib.rs +++ b/tls_codec/derive/src/lib.rs @@ -814,34 +814,39 @@ fn impl_serialize(parsed_ast: TlsStruct, svariant: SerializeVariant) -> TokenStr match svariant { SerializeVariant::Write => { - quote! { - impl #impl_generics tls_codec::Serialize for #ident #ty_generics #where_clause { - #[cfg(feature = "std")] - fn tls_serialize(&self, writer: &mut W) -> core::result::Result { - let mut written = 0usize; - #( - written += #prefixes::tls_serialize(&self.#members, writer)?; - )* - if cfg!(debug_assertions) { - let expected_written = tls_codec::Size::tls_serialized_len(&self); - debug_assert_eq!(written, expected_written, "Expected to serialize {} bytes but only {} were generated.", expected_written, written); - if written != expected_written { - Err(tls_codec::Error::EncodingError(format!("Expected to serialize {} bytes but only {} were generated.", expected_written, written))) + if cfg!(feature = "std") { + quote! { + impl #impl_generics tls_codec::Serialize for #ident #ty_generics #where_clause { + fn tls_serialize(&self, writer: &mut W) -> core::result::Result { + let mut written = 0usize; + #( + written += #prefixes::tls_serialize(&self.#members, writer)?; + )* + if cfg!(debug_assertions) { + let expected_written = tls_codec::Size::tls_serialized_len(&self); + debug_assert_eq!(written, expected_written, "Expected to serialize {} bytes but only {} were generated.", expected_written, written); + if written != expected_written { + Err(tls_codec::Error::EncodingError(format!("Expected to serialize {} bytes but only {} were generated.", expected_written, written))) + } else { + Ok(written) + } } else { Ok(written) } - } else { - Ok(written) } } - } - impl #impl_generics tls_codec::Serialize for &#ident #ty_generics #where_clause { - #[cfg(feature = "std")] - fn tls_serialize(&self, writer: &mut W) -> core::result::Result { - tls_codec::Serialize::tls_serialize(*self, writer) + impl #impl_generics tls_codec::Serialize for &#ident #ty_generics #where_clause { + fn tls_serialize(&self, writer: &mut W) -> core::result::Result { + tls_codec::Serialize::tls_serialize(*self, writer) + } } } + } else { + quote! { + impl #impl_generics tls_codec::Serialize for #ident #ty_generics #where_clause {} + impl #impl_generics tls_codec::Serialize for &#ident #ty_generics #where_clause {} + } } } SerializeVariant::Bytes => { @@ -926,23 +931,28 @@ fn impl_serialize(parsed_ast: TlsStruct, svariant: SerializeVariant) -> TokenStr match svariant { SerializeVariant::Write => { - quote! { - impl #impl_generics tls_codec::Serialize for #ident #ty_generics #where_clause { - #[cfg(feature = "std")] - fn tls_serialize(&self, writer: &mut W) -> core::result::Result { - #discriminant_constants - match self { - #(#arms)* + if cfg!(feature = "std") { + quote! { + impl #impl_generics tls_codec::Serialize for #ident #ty_generics #where_clause { + fn tls_serialize(&self, writer: &mut W) -> core::result::Result { + #discriminant_constants + match self { + #(#arms)* + } } } - } - impl #impl_generics tls_codec::Serialize for &#ident #ty_generics #where_clause { - #[cfg(feature = "std")] - fn tls_serialize(&self, writer: &mut W) -> core::result::Result { - tls_codec::Serialize::tls_serialize(*self, writer) + impl #impl_generics tls_codec::Serialize for &#ident #ty_generics #where_clause { + fn tls_serialize(&self, writer: &mut W) -> core::result::Result { + tls_codec::Serialize::tls_serialize(*self, writer) + } } } + } else { + quote! { + impl #impl_generics tls_codec::Serialize for #ident #ty_generics #where_clause {} + impl #impl_generics tls_codec::Serialize for &#ident #ty_generics #where_clause {} + } } } SerializeVariant::Bytes => { @@ -1011,16 +1021,21 @@ fn impl_deserialize(parsed_ast: TlsStruct) -> TokenStream2 { #[cfg(feature = "conditional_deserialization")] let (impl_generics, ty_generics) = restrict_conditional_generic(impl_generics, ty_generics, true); - quote! { - impl #impl_generics tls_codec::Deserialize for #ident #ty_generics #where_clause { - #[cfg(feature = "std")] - fn tls_deserialize(bytes: &mut R) -> core::result::Result { - Ok(Self { - #(#members: #prefixes::tls_deserialize(bytes)?,)* - #(#members_default: Default::default(),)* - }) + if cfg!(feature = "std") { + quote! { + impl #impl_generics tls_codec::Deserialize for #ident #ty_generics #where_clause { + fn tls_deserialize(bytes: &mut R) -> core::result::Result { + Ok(Self { + #(#members: #prefixes::tls_deserialize(bytes)?,)* + #(#members_default: Default::default(),)* + }) + } } } + } else { + quote! { + impl #impl_generics tls_codec::Deserialize for #ident #ty_generics #where_clause {} + } } } TlsStruct::Enum(Enum { @@ -1050,20 +1065,25 @@ fn impl_deserialize(parsed_ast: TlsStruct) -> TokenStream2 { }) .collect::>(); let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); - quote! { - impl #impl_generics tls_codec::Deserialize for #ident #ty_generics #where_clause { - #[cfg(feature = "std")] - fn tls_deserialize(bytes: &mut R) -> core::result::Result { - #discriminant_constants - let discriminant = <#repr as tls_codec::Deserialize>::tls_deserialize(bytes)?; - match discriminant { - #(#arms)* - _ => { - Err(tls_codec::Error::UnknownValue(discriminant.into())) - }, + if cfg!(feature = "std") { + quote! { + impl #impl_generics tls_codec::Deserialize for #ident #ty_generics #where_clause { + fn tls_deserialize(bytes: &mut R) -> core::result::Result { + #discriminant_constants + let discriminant = <#repr as tls_codec::Deserialize>::tls_deserialize(bytes)?; + match discriminant { + #(#arms)* + _ => { + Err(tls_codec::Error::UnknownValue(discriminant.into())) + }, + } } } } + } else { + quote! { + impl #impl_generics tls_codec::Deserialize for #ident #ty_generics #where_clause {} + } } } }