Skip to content

Commit

Permalink
Pick discriminants automatically when unspecified
Browse files Browse the repository at this point in the history
Closes #21
  • Loading branch information
meithecatte committed Feb 24, 2021
1 parent 88cd25e commit 70f22ac
Show file tree
Hide file tree
Showing 6 changed files with 104 additions and 51 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ enumflags2 = "^0.6"
## Features

- [x] Uses enums to represent individual flags—a set of flags is a separate type from a single flag.
- [x] Automatically chooses a free bit when you don't specify.
- [x] Detects incorrect BitFlags at compile time.
- [x] Has a similar API compared to the popular [bitflags](https://crates.io/crates/bitflags) crate.
- [x] Does not expose the generated types explicity. The user interacts exclusively with `struct BitFlags<Enum>;`.
Expand All @@ -37,7 +38,7 @@ use enumflags2::{bitflags, make_bitflags, BitFlags};
enum Test {
A = 0b0001,
B = 0b0010,
C = 0b0100,
C, // unspecified variants pick unused bits automatically
D = 0b1000,
}

Expand Down
98 changes: 64 additions & 34 deletions enumflags_derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,28 @@ use syn::{
parse::{Parse, ParseStream},
parse_macro_input,
spanned::Spanned,
Ident, Item, ItemEnum, Token,
Expr, Ident, Item, ItemEnum, Token, Variant,
};

#[derive(Debug)]
struct Flag {
struct Flag<'a> {
name: Ident,
span: Span,
value: FlagValue,
value: FlagValue<'a>,
}

#[derive(Debug)]
enum FlagValue {
enum FlagValue<'a> {
Literal(u128),
Deferred,
Inferred,
Inferred(&'a mut Variant),
}

impl FlagValue<'_> {
fn is_inferred(&self) -> bool {
match self {
FlagValue::Inferred(_) => true,
_ => false,
}
}
}

struct Parameters {
Expand Down Expand Up @@ -54,9 +61,9 @@ pub fn bitflags_internal(
input: proc_macro::TokenStream,
) -> proc_macro::TokenStream {
let Parameters { default } = parse_macro_input!(attr as Parameters);
let ast = parse_macro_input!(input as Item);
let mut ast = parse_macro_input!(input as Item);
let output = match ast {
Item::Enum(ref item_enum) => gen_enumflags(item_enum, default),
Item::Enum(ref mut item_enum) => gen_enumflags(item_enum, default),
_ => Err(syn::Error::new_spanned(
&ast,
"#[bitflags] requires an enum",
Expand All @@ -76,7 +83,6 @@ pub fn bitflags_internal(

/// Try to evaluate the expression given.
fn fold_expr(expr: &syn::Expr) -> Option<u128> {
use syn::Expr;
match expr {
Expr::Lit(ref expr_lit) => match expr_lit.lit {
syn::Lit::Int(ref lit_int) => lit_int.base10_parse().ok(),
Expand All @@ -98,8 +104,8 @@ fn fold_expr(expr: &syn::Expr) -> Option<u128> {
}

fn collect_flags<'a>(
variants: impl Iterator<Item = &'a syn::Variant>,
) -> Result<Vec<Flag>, syn::Error> {
variants: impl Iterator<Item = &'a mut Variant>,
) -> Result<Vec<Flag<'a>>, syn::Error> {
variants
.map(|variant| {
// MSRV: Would this be cleaner with `matches!`?
Expand All @@ -113,25 +119,51 @@ fn collect_flags<'a>(
}
}

let name = variant.ident.clone();
let span = variant.span();
let value = if let Some(ref expr) = variant.discriminant {
if let Some(n) = fold_expr(&expr.1) {
FlagValue::Literal(n)
} else {
FlagValue::Deferred
}
} else {
FlagValue::Inferred
FlagValue::Inferred(variant)
};

Ok(Flag {
name: variant.ident.clone(),
span: variant.span(),
value,
})
Ok(Flag { name, span, value })
})
.collect()
}

fn inferred_value(type_name: &Ident, previous_variants: &[Ident], repr: &Ident) -> Expr {
let tokens = if previous_variants.is_empty() {
quote!(1)
} else {
quote!(::enumflags2::_internal::next_bit(
#(#type_name::#previous_variants as u128)|*
) as #repr)
};

syn::parse2(tokens).expect("couldn't parse inferred value")
}

fn infer_values<'a>(flags: &mut [Flag], type_name: &Ident, repr: &Ident) {
let mut previous_variants: Vec<Ident> = flags.iter()
.filter(|flag| !flag.value.is_inferred())
.map(|flag| flag.name.clone()).collect();

for flag in flags {
match flag.value {
FlagValue::Inferred(ref mut variant) => {
variant.discriminant = Some((<Token![=]>::default(), inferred_value(type_name, &previous_variants, repr)));
previous_variants.push(flag.name.clone());
}
_ => {}
}
}
}

/// Given a list of attributes, find the `repr`, if any, and return the integer
/// type specified.
fn extract_repr(attrs: &[syn::Attribute]) -> Result<Option<Ident>, syn::Error> {
Expand Down Expand Up @@ -210,10 +242,7 @@ fn check_flag(type_name: &Ident, flag: &Flag, bits: u8) -> Result<Option<TokenSt
Ok(None)
}
}
Inferred => Err(syn::Error::new(
flag.span,
"Please add an explicit discriminant",
)),
Inferred(_) => Ok(None),
Deferred => {
let variant_name = &flag.name;
// MSRV: Use an unnamed constant (`const _: ...`).
Expand All @@ -235,33 +264,34 @@ fn check_flag(type_name: &Ident, flag: &Flag, bits: u8) -> Result<Option<TokenSt
}
}

fn gen_enumflags(ast: &ItemEnum, default: Vec<Ident>) -> Result<TokenStream, syn::Error> {
fn gen_enumflags(ast: &mut ItemEnum, default: Vec<Ident>) -> Result<TokenStream, syn::Error> {
let ident = &ast.ident;

let span = Span::call_site();
// for quote! interpolation
let variant_names = ast.variants.iter().map(|v| &v.ident).collect::<Vec<_>>();
let repeated_name = vec![&ident; ast.variants.len()];

let ty = extract_repr(&ast.attrs)?
let repr = extract_repr(&ast.attrs)?
.ok_or_else(|| syn::Error::new_spanned(&ident,
"repr attribute missing. Add #[repr(u64)] or a similar attribute to specify the size of the bitfield."))?;
let bits = type_bits(&ty)?;
let bits = type_bits(&repr)?;

let variants = collect_flags(ast.variants.iter())?;
let mut variants = collect_flags(ast.variants.iter_mut())?;
let deferred = variants
.iter()
.flat_map(|variant| check_flag(ident, variant, bits).transpose())
.collect::<Result<Vec<_>, _>>()?;

infer_values(&mut variants, ident, &repr);

if (bits as usize) < variants.len() {
return Err(syn::Error::new_spanned(
&ty,
&repr,
format!("Not enough bits for {} flags", variants.len()),
));
}

let std_path = quote_spanned!(span => ::enumflags2::_internal::core);
let variant_names = ast.variants.iter().map(|v| &v.ident).collect::<Vec<_>>();
let repeated_name = vec![&ident; ast.variants.len()];

Ok(quote_spanned! {
span =>
Expand Down Expand Up @@ -303,15 +333,15 @@ fn gen_enumflags(ast: &ItemEnum, default: Vec<Ident>) -> Result<TokenStream, syn
}

impl ::enumflags2::_internal::RawBitFlags for #ident {
type Numeric = #ty;
type Numeric = #repr;

const EMPTY: Self::Numeric = 0;

const DEFAULT: Self::Numeric =
0 #(| (#repeated_name::#default as #ty))*;
0 #(| (#repeated_name::#default as #repr))*;

const ALL_BITS: Self::Numeric =
0 #(| (#repeated_name::#variant_names as #ty))*;
0 #(| (#repeated_name::#variant_names as #repr))*;

const FLAG_LIST: &'static [Self] =
&[#(#repeated_name::#variant_names),*];
Expand All @@ -320,7 +350,7 @@ fn gen_enumflags(ast: &ItemEnum, default: Vec<Ident>) -> Result<TokenStream, syn
concat!("BitFlags<", stringify!(#ident), ">");

fn bits(self) -> Self::Numeric {
self as #ty
self as #repr
}
}

Expand Down
7 changes: 6 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
//! enum Test {
//! A = 0b0001,
//! B = 0b0010,
//! C = 0b0100,
//! C, // unspecified variants pick unused bits automatically
//! D = 0b1000,
//! }
//!
Expand Down Expand Up @@ -259,6 +259,11 @@ pub mod _internal {
impl AssertionHelper for [(); 0] {
type Status = AssertionFailed;
}

pub const fn next_bit(x: u128) -> u128 {
// trailing_ones is beyond our MSRV
1 << (!x).trailing_zeros()
}
}

// Internal debug formatting implementations
Expand Down
33 changes: 32 additions & 1 deletion test_suite/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ enum Test1 {
E = 1 << 34,
}

#[enumflags2::bitflags(default = B | C)]
#[bitflags(default = B | C)]
#[derive(Copy, Clone, Debug)]
#[repr(u8)]
enum Default6 {
Expand Down Expand Up @@ -129,3 +129,34 @@ fn module() {
}
}
}

#[test]
fn inferred_values() {
#[bitflags]
#[derive(Copy, Clone, Debug)]
#[repr(u8)]
enum Inferred {
Infer2,
SpecifiedA = 1,
Infer8,
SpecifiedB = 4,
}

assert_eq!(Inferred::Infer2 as u8, 2);
assert_eq!(Inferred::Infer8 as u8, 8);

#[bitflags]
#[derive(Copy, Clone, Debug)]
#[repr(u8)]
enum OnlyInferred {
Infer1,
Infer2,
Infer4,
Infer8,
}

assert_eq!(OnlyInferred::Infer1 as u8, 1);
assert_eq!(OnlyInferred::Infer2 as u8, 2);
assert_eq!(OnlyInferred::Infer4 as u8, 4);
assert_eq!(OnlyInferred::Infer8 as u8, 8);
}
9 changes: 0 additions & 9 deletions test_suite/ui/missing_disciminant.rs

This file was deleted.

5 changes: 0 additions & 5 deletions test_suite/ui/missing_disciminant.stderr

This file was deleted.

1 comment on commit 70f22ac

@wycats
Copy link

@wycats wycats commented on 70f22ac Feb 24, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🚀

Please sign in to comment.