Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
187 changes: 127 additions & 60 deletions derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,15 @@ fn expand_derive_arbitrary(input: syn::DeriveInput) -> Result<TokenStream> {
let (lifetime_without_bounds, lifetime_with_bounds) =
build_arbitrary_lifetime(input.generics.clone());

// This won't be used if `needs_recursive_count` ends up false.
let recursive_count = syn::Ident::new(
&format!("RECURSIVE_COUNT_{}", input.ident),
Span::call_site(),
);

let arbitrary_method =
let (arbitrary_method, needs_recursive_count) =
gen_arbitrary_method(&input, lifetime_without_bounds.clone(), &recursive_count)?;
let size_hint_method = gen_size_hint_method(&input)?;
let size_hint_method = gen_size_hint_method(&input, needs_recursive_count)?;
let name = input.ident;

// Apply user-supplied bounds or automatic `T: ArbitraryBounds`.
Expand All @@ -56,17 +57,25 @@ fn expand_derive_arbitrary(input: syn::DeriveInput) -> Result<TokenStream> {
// Build TypeGenerics and WhereClause without a lifetime
let (_, ty_generics, where_clause) = generics.split_for_impl();

Ok(quote! {
const _: () = {
let recursive_count = needs_recursive_count.then(|| {
Some(quote! {
::std::thread_local! {
#[allow(non_upper_case_globals)]
static #recursive_count: ::core::cell::Cell<u32> = const {
::core::cell::Cell::new(0)
};
}
})
});

Ok(quote! {
const _: () = {
#recursive_count

#[automatically_derived]
impl #impl_generics arbitrary::Arbitrary<#lifetime_without_bounds> for #name #ty_generics #where_clause {
impl #impl_generics arbitrary::Arbitrary<#lifetime_without_bounds>
for #name #ty_generics #where_clause
{
#arbitrary_method
#size_hint_method
}
Expand Down Expand Up @@ -149,10 +158,7 @@ fn add_trait_bounds(mut generics: Generics, lifetime: LifetimeParam) -> Generics
generics
}

fn with_recursive_count_guard(
recursive_count: &syn::Ident,
expr: impl quote::ToTokens,
) -> impl quote::ToTokens {
fn with_recursive_count_guard(recursive_count: &syn::Ident, expr: TokenStream) -> TokenStream {
quote! {
let guard_against_recursion = u.is_empty();
if guard_against_recursion {
Expand Down Expand Up @@ -181,7 +187,7 @@ fn gen_arbitrary_method(
input: &DeriveInput,
lifetime: LifetimeParam,
recursive_count: &syn::Ident,
) -> Result<TokenStream> {
) -> Result<(TokenStream, bool)> {
fn arbitrary_structlike(
fields: &Fields,
ident: &syn::Ident,
Expand Down Expand Up @@ -219,28 +225,36 @@ fn gen_arbitrary_method(
recursive_count: &syn::Ident,
unstructured: TokenStream,
variants: &[TokenStream],
) -> impl quote::ToTokens {
needs_recursive_count: bool,
) -> TokenStream {
let count = variants.len() as u64;
with_recursive_count_guard(
recursive_count,
quote! {
// Use a multiply + shift to generate a ranged random number
// with slight bias. For details, see:
// https://lemire.me/blog/2016/06/30/fast-random-shuffling
Ok(match (u64::from(<u32 as arbitrary::Arbitrary>::arbitrary(#unstructured)?) * #count) >> 32 {
#(#variants,)*
_ => unreachable!()
})
},
)

let do_variants = quote! {
// Use a multiply + shift to generate a ranged random number
// with slight bias. For details, see:
// https://lemire.me/blog/2016/06/30/fast-random-shuffling
Ok(match (
u64::from(<u32 as arbitrary::Arbitrary>::arbitrary(#unstructured)?) * #count
) >> 32
{
#(#variants,)*
_ => unreachable!()
})
};

if needs_recursive_count {
with_recursive_count_guard(recursive_count, do_variants)
} else {
do_variants
}
}

fn arbitrary_enum(
DataEnum { variants, .. }: &DataEnum,
enum_name: &Ident,
lifetime: LifetimeParam,
recursive_count: &syn::Ident,
) -> Result<TokenStream> {
) -> Result<(TokenStream, bool)> {
let filtered_variants = variants.iter().filter(not_skipped);

// Check attributes of all variants:
Expand All @@ -254,11 +268,16 @@ fn gen_arbitrary_method(
.map(|(index, variant)| (index as u64, variant));

// Construct `match`-arms for the `arbitrary` method.
let mut needs_recursive_count = false;
let variants = enumerated_variants
.clone()
.map(|(index, Variant { fields, ident, .. })| {
construct(fields, |_, field| gen_constructor_for_field(field))
.map(|ctor| arbitrary_variant(index, enum_name, ident, ctor))
construct(fields, |_, field| gen_constructor_for_field(field)).map(|ctor| {
if !ctor.is_empty() {
needs_recursive_count = true;
}
arbitrary_variant(index, enum_name, ident, ctor)
})
})
.collect::<Result<Vec<TokenStream>>>()?;

Expand All @@ -277,34 +296,56 @@ fn gen_arbitrary_method(
(!variants.is_empty())
.then(|| {
// TODO: Improve dealing with `u` vs. `&mut u`.
let arbitrary = arbitrary_enum_method(recursive_count, quote! { u }, &variants);
let arbitrary_take_rest = arbitrary_enum_method(recursive_count, quote! { &mut u }, &variants_take_rest);

quote! {
fn arbitrary(u: &mut arbitrary::Unstructured<#lifetime>) -> arbitrary::Result<Self> {
#arbitrary
}
let arbitrary = arbitrary_enum_method(
recursive_count,
quote! { u },
&variants,
needs_recursive_count,
);
let arbitrary_take_rest = arbitrary_enum_method(
recursive_count,
quote! { &mut u },
&variants_take_rest,
needs_recursive_count,
);

(
quote! {
fn arbitrary(u: &mut arbitrary::Unstructured<#lifetime>)
-> arbitrary::Result<Self>
{
#arbitrary
}

fn arbitrary_take_rest(mut u: arbitrary::Unstructured<#lifetime>) -> arbitrary::Result<Self> {
#arbitrary_take_rest
}
}
fn arbitrary_take_rest(mut u: arbitrary::Unstructured<#lifetime>)
-> arbitrary::Result<Self>
{
#arbitrary_take_rest
}
},
needs_recursive_count,
)
})
.ok_or_else(|| {
Error::new_spanned(
enum_name,
"Enum must have at least one variant, that is not skipped",
)
})
.ok_or_else(|| Error::new_spanned(
enum_name,
"Enum must have at least one variant, that is not skipped"
))
}

let ident = &input.ident;
let needs_recursive_count = true;
match &input.data {
Data::Struct(data) => arbitrary_structlike(&data.fields, ident, lifetime, recursive_count),
Data::Struct(data) => arbitrary_structlike(&data.fields, ident, lifetime, recursive_count)
.map(|ts| (ts, needs_recursive_count)),
Data::Union(data) => arbitrary_structlike(
&Fields::Named(data.fields.clone()),
ident,
lifetime,
recursive_count,
),
)
.map(|ts| (ts, needs_recursive_count)),
Data::Enum(data) => arbitrary_enum(data, ident, lifetime, recursive_count),
}
}
Expand Down Expand Up @@ -357,7 +398,7 @@ fn construct_take_rest(fields: &Fields) -> Result<TokenStream> {
})
}

fn gen_size_hint_method(input: &DeriveInput) -> Result<TokenStream> {
fn gen_size_hint_method(input: &DeriveInput, needs_recursive_count: bool) -> Result<TokenStream> {
let size_hint_fields = |fields: &Fields| {
fields
.iter()
Expand All @@ -372,9 +413,9 @@ fn gen_size_hint_method(input: &DeriveInput) -> Result<TokenStream> {
quote! { <#ty as arbitrary::Arbitrary>::try_size_hint(depth) }
}

// Note that in this case it's hard to determine what size_hint must be, so size_of::<T>() is
// just an educated guess, although it's gonna be inaccurate for dynamically
// allocated types (Vec, HashMap, etc.).
// Note that in this case it's hard to determine what size_hint must be, so
// size_of::<T>() is just an educated guess, although it's gonna be
// inaccurate for dynamically allocated types (Vec, HashMap, etc.).
FieldConstructor::With(_) => {
quote! { Ok((::core::mem::size_of::<#ty>(), None)) }
}
Expand All @@ -391,6 +432,7 @@ fn gen_size_hint_method(input: &DeriveInput) -> Result<TokenStream> {
})
};
let size_hint_structlike = |fields: &Fields| {
assert!(needs_recursive_count);
size_hint_fields(fields).map(|hint| {
quote! {
#[inline]
Expand All @@ -399,7 +441,12 @@ fn gen_size_hint_method(input: &DeriveInput) -> Result<TokenStream> {
}

#[inline]
fn try_size_hint(depth: usize) -> ::core::result::Result<(usize, ::core::option::Option<usize>), arbitrary::MaxRecursionReached> {
fn try_size_hint(depth: usize)
-> ::core::result::Result<
(usize, ::core::option::Option<usize>),
arbitrary::MaxRecursionReached,
>
{
arbitrary::size_hint::try_recursion_guard(depth, |depth| #hint)
}
}
Expand All @@ -413,24 +460,44 @@ fn gen_size_hint_method(input: &DeriveInput) -> Result<TokenStream> {
.iter()
.filter(not_skipped)
.map(|Variant { fields, .. }| {
if !needs_recursive_count {
assert!(fields.is_empty());
}
// The attributes of all variants are checked in `gen_arbitrary_method` above
// and can therefore assume that they are valid.
// and can therefore assume that they are valid.
size_hint_fields(fields)
})
.collect::<Result<Vec<TokenStream>>>()
.map(|variants| {
quote! {
fn size_hint(depth: usize) -> (usize, ::core::option::Option<usize>) {
Self::try_size_hint(depth).unwrap_or_default()
if needs_recursive_count {
// The enum might be recursive: `try_size_hint` is the primary one, and
// `size_hint` is defined in terms of it.
quote! {
fn size_hint(depth: usize) -> (usize, ::core::option::Option<usize>) {
Self::try_size_hint(depth).unwrap_or_default()
}
#[inline]
fn try_size_hint(depth: usize)
-> ::core::result::Result<
(usize, ::core::option::Option<usize>),
arbitrary::MaxRecursionReached,
>
{
Ok(arbitrary::size_hint::and(
<u32 as arbitrary::Arbitrary>::size_hint(depth),
arbitrary::size_hint::try_recursion_guard(depth, |depth| {
Ok(arbitrary::size_hint::or_all(&[ #( #variants? ),* ]))
})?,
))
}
}
#[inline]
fn try_size_hint(depth: usize) -> ::core::result::Result<(usize, ::core::option::Option<usize>), arbitrary::MaxRecursionReached> {
Ok(arbitrary::size_hint::and(
<u32 as arbitrary::Arbitrary>::try_size_hint(depth)?,
arbitrary::size_hint::try_recursion_guard(depth, |depth| {
Ok(arbitrary::size_hint::or_all(&[ #( #variants? ),* ]))
})?,
))
} else {
// The enum is guaranteed non-recursive, i.e. fieldless: `size_hint` is the
// primary one, and the default `try_size_hint` is good enough.
quote! {
fn size_hint(depth: usize) -> (usize, ::core::option::Option<usize>) {
<u32 as arbitrary::Arbitrary>::size_hint(depth)
}
}
}
}),
Expand Down