lib.rs

  1use proc_macro::TokenStream;
  2use quote::{format_ident, quote};
  3use std::mem;
  4use syn::{
  5    parse_macro_input, parse_quote, spanned::Spanned as _, AttributeArgs, ItemFn, Lit, Meta,
  6    NestedMeta,
  7};
  8
  9#[proc_macro_attribute]
 10pub fn test(args: TokenStream, function: TokenStream) -> TokenStream {
 11    let mut namespace = format_ident!("gpui");
 12
 13    let args = syn::parse_macro_input!(args as AttributeArgs);
 14    let mut max_retries = 0;
 15    let mut num_iterations = 1;
 16    let mut starting_seed = std::env::var("SEED")
 17        .map(|i| i.parse().expect("invalid `SEED`"))
 18        .ok();
 19
 20    for arg in args {
 21        match arg {
 22            NestedMeta::Meta(Meta::Path(name))
 23                if name.get_ident().map_or(false, |n| n == "self") =>
 24            {
 25                namespace = format_ident!("crate");
 26            }
 27            NestedMeta::Meta(Meta::NameValue(meta)) => {
 28                let key_name = meta.path.get_ident().map(|i| i.to_string());
 29                let result = (|| {
 30                    match key_name.as_ref().map(String::as_str) {
 31                        Some("retries") => max_retries = parse_int(&meta.lit)?,
 32                        Some("iterations") => {
 33                            if let Ok(iters) = std::env::var("ITERATIONS") {
 34                                num_iterations = iters.parse().expect("invalid `ITERATIONS`");
 35                            } else {
 36                                num_iterations = parse_int(&meta.lit)?;
 37                            }
 38                        }
 39                        Some("seed") => {
 40                            if starting_seed.is_none() {
 41                                starting_seed = Some(parse_int(&meta.lit)?);
 42                            }
 43                        }
 44                        _ => {
 45                            return Err(TokenStream::from(
 46                                syn::Error::new(meta.path.span(), "invalid argument")
 47                                    .into_compile_error(),
 48                            ))
 49                        }
 50                    }
 51                    Ok(())
 52                })();
 53
 54                if let Err(tokens) = result {
 55                    return tokens;
 56                }
 57            }
 58            other => {
 59                return TokenStream::from(
 60                    syn::Error::new_spanned(other, "invalid argument").into_compile_error(),
 61                )
 62            }
 63        }
 64    }
 65    let starting_seed = starting_seed.unwrap_or(0);
 66
 67    let mut inner_fn = parse_macro_input!(function as ItemFn);
 68    let inner_fn_attributes = mem::take(&mut inner_fn.attrs);
 69    let inner_fn_name = format_ident!("_{}", inner_fn.sig.ident);
 70    let outer_fn_name = mem::replace(&mut inner_fn.sig.ident, inner_fn_name.clone());
 71
 72    // Pass to the test function the number of app contexts that it needs,
 73    // based on its parameter list.
 74    let inner_fn_args = (0..inner_fn.sig.inputs.len())
 75        .map(|i| {
 76            let first_entity_id = i * 100_000;
 77            quote!(#namespace::TestAppContext::new(foreground.clone(), background.clone(), #first_entity_id),)
 78        })
 79        .collect::<proc_macro2::TokenStream>();
 80
 81    let mut outer_fn: ItemFn = if inner_fn.sig.asyncness.is_some() {
 82        parse_quote! {
 83            #[test]
 84            fn #outer_fn_name() {
 85                #inner_fn
 86
 87                let mut retries = 0;
 88                let mut i = 0;
 89                loop {
 90                    let seed = #starting_seed + i;
 91                    let result = std::panic::catch_unwind(|| {
 92                        let (foreground, background) = #namespace::executor::deterministic(seed as u64);
 93                        foreground.run(#inner_fn_name(#inner_fn_args));
 94                    });
 95
 96                    match result {
 97                        Ok(result) => {
 98                            retries = 0;
 99                            i += 1;
100                            if i == #num_iterations {
101                                return result
102                            }
103                        }
104                        Err(error) => {
105                            if retries < #max_retries {
106                                retries += 1;
107                                println!("retrying: attempt {}", retries);
108                            } else {
109                                if #num_iterations > 1 {
110                                    eprintln!("failing seed: {}", seed);
111                                }
112                                std::panic::resume_unwind(error);
113                            }
114                        }
115                    }
116                }
117            }
118        }
119    } else {
120        parse_quote! {
121            #[test]
122            fn #outer_fn_name() {
123                #inner_fn
124
125                if #max_retries > 0 {
126                    let mut retries = 0;
127                    loop {
128                        let result = std::panic::catch_unwind(|| {
129                            #namespace::App::test(|cx| {
130                                #inner_fn_name(cx);
131                            });
132                        });
133
134                        match result {
135                            Ok(result) => return result,
136                            Err(error) => {
137                                if retries < #max_retries {
138                                    retries += 1;
139                                    println!("retrying: attempt {}", retries);
140                                } else {
141                                    std::panic::resume_unwind(error);
142                                }
143                            }
144                        }
145                    }
146                } else {
147                    #namespace::App::test(|cx| {
148                        #inner_fn_name(cx);
149                    });
150                }
151            }
152        }
153    };
154    outer_fn.attrs.extend(inner_fn_attributes);
155
156    TokenStream::from(quote!(#outer_fn))
157}
158
159fn parse_int(literal: &Lit) -> Result<usize, TokenStream> {
160    let result = if let Lit::Int(int) = &literal {
161        int.base10_parse()
162    } else {
163        Err(syn::Error::new(literal.span(), "must be an integer"))
164    };
165
166    result.map_err(|err| TokenStream::from(err.into_compile_error()))
167}