diff --git a/README.md b/README.md index e9bb6d6..39e298a 100644 --- a/README.md +++ b/README.md @@ -20,6 +20,7 @@ enumflags2 = "^0.6" ## Features - [x] Uses enums to represent individual flags—a set of flags is a separate type from a single flag. +- [x] Automatically chooses a free bit when you don't specify. - [x] Detects incorrect BitFlags at compile time. - [x] Has a similar API compared to the popular [bitflags](https://crates.io/crates/bitflags) crate. - [x] Does not expose the generated types explicity. The user interacts exclusively with `struct BitFlags;`. @@ -37,7 +38,7 @@ use enumflags2::{bitflags, make_bitflags, BitFlags}; enum Test { A = 0b0001, B = 0b0010, - C = 0b0100, + C, // unspecified variants pick unused bits automatically D = 0b1000, } diff --git a/enumflags_derive/src/lib.rs b/enumflags_derive/src/lib.rs index cfd760e..8e4c2d9 100644 --- a/enumflags_derive/src/lib.rs +++ b/enumflags_derive/src/lib.rs @@ -9,21 +9,28 @@ use syn::{ parse::{Parse, ParseStream}, parse_macro_input, spanned::Spanned, - Ident, Item, ItemEnum, Token, + Expr, Ident, Item, ItemEnum, Token, Variant, }; -#[derive(Debug)] -struct Flag { +struct Flag<'a> { name: Ident, span: Span, - value: FlagValue, + value: FlagValue<'a>, } -#[derive(Debug)] -enum FlagValue { +enum FlagValue<'a> { Literal(u128), Deferred, - Inferred, + Inferred(&'a mut Variant), +} + +impl FlagValue<'_> { + fn is_inferred(&self) -> bool { + match self { + FlagValue::Inferred(_) => true, + _ => false, + } + } } struct Parameters { @@ -54,9 +61,9 @@ pub fn bitflags_internal( input: proc_macro::TokenStream, ) -> proc_macro::TokenStream { let Parameters { default } = parse_macro_input!(attr as Parameters); - let ast = parse_macro_input!(input as Item); + let mut ast = parse_macro_input!(input as Item); let output = match ast { - Item::Enum(ref item_enum) => gen_enumflags(item_enum, default), + Item::Enum(ref mut item_enum) => gen_enumflags(item_enum, default), _ => Err(syn::Error::new_spanned( &ast, "#[bitflags] requires an enum", @@ -76,7 +83,6 @@ pub fn bitflags_internal( /// Try to evaluate the expression given. fn fold_expr(expr: &syn::Expr) -> Option { - use syn::Expr; match expr { Expr::Lit(ref expr_lit) => match expr_lit.lit { syn::Lit::Int(ref lit_int) => lit_int.base10_parse().ok(), @@ -98,8 +104,8 @@ fn fold_expr(expr: &syn::Expr) -> Option { } fn collect_flags<'a>( - variants: impl Iterator, -) -> Result, syn::Error> { + variants: impl Iterator, +) -> Result>, syn::Error> { variants .map(|variant| { // MSRV: Would this be cleaner with `matches!`? @@ -113,6 +119,8 @@ fn collect_flags<'a>( } } + let name = variant.ident.clone(); + let span = variant.span(); let value = if let Some(ref expr) = variant.discriminant { if let Some(n) = fold_expr(&expr.1) { FlagValue::Literal(n) @@ -120,18 +128,42 @@ fn collect_flags<'a>( FlagValue::Deferred } } else { - FlagValue::Inferred + FlagValue::Inferred(variant) }; - Ok(Flag { - name: variant.ident.clone(), - span: variant.span(), - value, - }) + Ok(Flag { name, span, value }) }) .collect() } +fn inferred_value(type_name: &Ident, previous_variants: &[Ident], repr: &Ident) -> Expr { + let tokens = if previous_variants.is_empty() { + quote!(1) + } else { + quote!(::enumflags2::_internal::next_bit( + #(#type_name::#previous_variants as u128)|* + ) as #repr) + }; + + syn::parse2(tokens).expect("couldn't parse inferred value") +} + +fn infer_values<'a>(flags: &mut [Flag], type_name: &Ident, repr: &Ident) { + let mut previous_variants: Vec = flags.iter() + .filter(|flag| !flag.value.is_inferred()) + .map(|flag| flag.name.clone()).collect(); + + for flag in flags { + match flag.value { + FlagValue::Inferred(ref mut variant) => { + variant.discriminant = Some((::default(), inferred_value(type_name, &previous_variants, repr))); + previous_variants.push(flag.name.clone()); + } + _ => {} + } + } +} + /// Given a list of attributes, find the `repr`, if any, and return the integer /// type specified. fn extract_repr(attrs: &[syn::Attribute]) -> Result, syn::Error> { @@ -210,10 +242,7 @@ fn check_flag(type_name: &Ident, flag: &Flag, bits: u8) -> Result Err(syn::Error::new( - flag.span, - "Please add an explicit discriminant", - )), + Inferred(_) => Ok(None), Deferred => { let variant_name = &flag.name; // MSRV: Use an unnamed constant (`const _: ...`). @@ -235,33 +264,34 @@ fn check_flag(type_name: &Ident, flag: &Flag, bits: u8) -> Result) -> Result { +fn gen_enumflags(ast: &mut ItemEnum, default: Vec) -> Result { let ident = &ast.ident; let span = Span::call_site(); - // for quote! interpolation - let variant_names = ast.variants.iter().map(|v| &v.ident).collect::>(); - let repeated_name = vec![&ident; ast.variants.len()]; - let ty = extract_repr(&ast.attrs)? + let repr = extract_repr(&ast.attrs)? .ok_or_else(|| syn::Error::new_spanned(&ident, "repr attribute missing. Add #[repr(u64)] or a similar attribute to specify the size of the bitfield."))?; - let bits = type_bits(&ty)?; + let bits = type_bits(&repr)?; - let variants = collect_flags(ast.variants.iter())?; + let mut variants = collect_flags(ast.variants.iter_mut())?; let deferred = variants .iter() .flat_map(|variant| check_flag(ident, variant, bits).transpose()) .collect::, _>>()?; + infer_values(&mut variants, ident, &repr); + if (bits as usize) < variants.len() { return Err(syn::Error::new_spanned( - &ty, + &repr, format!("Not enough bits for {} flags", variants.len()), )); } let std_path = quote_spanned!(span => ::enumflags2::_internal::core); + let variant_names = ast.variants.iter().map(|v| &v.ident).collect::>(); + let repeated_name = vec![&ident; ast.variants.len()]; Ok(quote_spanned! { span => @@ -303,15 +333,15 @@ fn gen_enumflags(ast: &ItemEnum, default: Vec) -> Result) -> Result"); fn bits(self) -> Self::Numeric { - self as #ty + self as #repr } } diff --git a/src/lib.rs b/src/lib.rs index a48eff8..197ced9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -13,7 +13,7 @@ //! enum Test { //! A = 0b0001, //! B = 0b0010, -//! C = 0b0100, +//! C, // unspecified variants pick unused bits automatically //! D = 0b1000, //! } //! @@ -259,6 +259,11 @@ pub mod _internal { impl AssertionHelper for [(); 0] { type Status = AssertionFailed; } + + pub const fn next_bit(x: u128) -> u128 { + // trailing_ones is beyond our MSRV + 1 << (!x).trailing_zeros() + } } // Internal debug formatting implementations diff --git a/test_suite/common.rs b/test_suite/common.rs index cf5a6c1..3052053 100644 --- a/test_suite/common.rs +++ b/test_suite/common.rs @@ -19,7 +19,7 @@ enum Test1 { E = 1 << 34, } -#[enumflags2::bitflags(default = B | C)] +#[bitflags(default = B | C)] #[derive(Copy, Clone, Debug)] #[repr(u8)] enum Default6 { @@ -129,3 +129,34 @@ fn module() { } } } + +#[test] +fn inferred_values() { + #[bitflags] + #[derive(Copy, Clone, Debug)] + #[repr(u8)] + enum Inferred { + Infer2, + SpecifiedA = 1, + Infer8, + SpecifiedB = 4, + } + + assert_eq!(Inferred::Infer2 as u8, 2); + assert_eq!(Inferred::Infer8 as u8, 8); + + #[bitflags] + #[derive(Copy, Clone, Debug)] + #[repr(u8)] + enum OnlyInferred { + Infer1, + Infer2, + Infer4, + Infer8, + } + + assert_eq!(OnlyInferred::Infer1 as u8, 1); + assert_eq!(OnlyInferred::Infer2 as u8, 2); + assert_eq!(OnlyInferred::Infer4 as u8, 4); + assert_eq!(OnlyInferred::Infer8 as u8, 8); +} diff --git a/test_suite/ui/missing_disciminant.rs b/test_suite/ui/missing_disciminant.rs deleted file mode 100644 index 9b79345..0000000 --- a/test_suite/ui/missing_disciminant.rs +++ /dev/null @@ -1,9 +0,0 @@ -#[enumflags2::bitflags] -#[repr(u8)] -#[derive(Copy, Clone)] -enum Foo { - OhNoTheresNoDiscriminant, - WhatWillTheMacroDo, -} - -fn main() {} diff --git a/test_suite/ui/missing_disciminant.stderr b/test_suite/ui/missing_disciminant.stderr deleted file mode 100644 index 241eb1b..0000000 --- a/test_suite/ui/missing_disciminant.stderr +++ /dev/null @@ -1,5 +0,0 @@ -error: Please add an explicit discriminant - --> $DIR/missing_disciminant.rs:5:5 - | -5 | OhNoTheresNoDiscriminant, - | ^^^^^^^^^^^^^^^^^^^^^^^^