lib.rs

  1use proc_macro::TokenStream;
  2use quote::{format_ident, quote};
  3use std::mem;
  4use syn::{
  5    parse_macro_input, parse_quote, spanned::Spanned as _, AttributeArgs, ItemFn, Lit, Meta,
  6    NestedMeta,
  7};
  8
  9#[proc_macro_attribute]
 10pub fn test(args: TokenStream, function: TokenStream) -> TokenStream {
 11    let mut namespace = format_ident!("gpui");
 12
 13    let args = syn::parse_macro_input!(args as AttributeArgs);
 14    let mut max_retries = 0;
 15    let mut num_iterations = 1;
 16    let mut starting_seed = 0;
 17    for arg in args {
 18        match arg {
 19            NestedMeta::Meta(Meta::Path(name))
 20                if name.get_ident().map_or(false, |n| n == "self") =>
 21            {
 22                namespace = format_ident!("crate");
 23            }
 24            NestedMeta::Meta(Meta::NameValue(meta)) => {
 25                let key_name = meta.path.get_ident().map(|i| i.to_string());
 26                let variable = match key_name.as_ref().map(String::as_str) {
 27                    Some("retries") => &mut max_retries,
 28                    Some("iterations") => &mut num_iterations,
 29                    Some("seed") => &mut starting_seed,
 30                    _ => {
 31                        return TokenStream::from(
 32                            syn::Error::new(meta.path.span(), "invalid argument")
 33                                .into_compile_error(),
 34                        )
 35                    }
 36                };
 37
 38                *variable = match parse_int(&meta.lit) {
 39                    Ok(value) => value,
 40                    Err(error) => {
 41                        return TokenStream::from(error.into_compile_error());
 42                    }
 43                };
 44            }
 45            other => {
 46                return TokenStream::from(
 47                    syn::Error::new_spanned(other, "invalid argument").into_compile_error(),
 48                )
 49            }
 50        }
 51    }
 52
 53    let mut inner_fn = parse_macro_input!(function as ItemFn);
 54    let inner_fn_attributes = mem::take(&mut inner_fn.attrs);
 55    let inner_fn_name = format_ident!("_{}", inner_fn.sig.ident);
 56    let outer_fn_name = mem::replace(&mut inner_fn.sig.ident, inner_fn_name.clone());
 57
 58    // Pass to the test function the number of app contexts that it needs,
 59    // based on its parameter list.
 60    let inner_fn_args = (0..inner_fn.sig.inputs.len())
 61        .map(|i| {
 62            let first_entity_id = i * 100_000;
 63            quote!(#namespace::TestAppContext::new(foreground.clone(), background.clone(), #first_entity_id),)
 64        })
 65        .collect::<proc_macro2::TokenStream>();
 66
 67    let mut outer_fn: ItemFn = if inner_fn.sig.asyncness.is_some() {
 68        parse_quote! {
 69            #[test]
 70            fn #outer_fn_name() {
 71                #inner_fn
 72
 73                let mut retries = 0;
 74                let mut i = 0;
 75                loop {
 76                    let seed = #starting_seed + i;
 77                    let result = std::panic::catch_unwind(|| {
 78                        let (foreground, background) = #namespace::executor::deterministic(seed as u64);
 79                        foreground.run(#inner_fn_name(#inner_fn_args));
 80                    });
 81
 82                    match result {
 83                        Ok(result) => {
 84                            retries = 0;
 85                            i += 1;
 86                            if i == #num_iterations {
 87                                return result
 88                            }
 89                        }
 90                        Err(error) => {
 91                            if retries < #max_retries {
 92                                retries += 1;
 93                                println!("retrying: attempt {}", retries);
 94                            } else {
 95                                if #num_iterations > 1 {
 96                                    eprintln!("failing seed: {}", seed);
 97                                }
 98                                std::panic::resume_unwind(error);
 99                            }
100                        }
101                    }
102                }
103            }
104        }
105    } else {
106        parse_quote! {
107            #[test]
108            fn #outer_fn_name() {
109                #inner_fn
110
111                if #max_retries > 0 {
112                    let mut retries = 0;
113                    loop {
114                        let result = std::panic::catch_unwind(|| {
115                            #namespace::App::test(|cx| {
116                                #inner_fn_name(cx);
117                            });
118                        });
119
120                        match result {
121                            Ok(result) => return result,
122                            Err(error) => {
123                                if retries < #max_retries {
124                                    retries += 1;
125                                    println!("retrying: attempt {}", retries);
126                                } else {
127                                    std::panic::resume_unwind(error);
128                                }
129                            }
130                        }
131                    }
132                } else {
133                    #namespace::App::test(|cx| {
134                        #inner_fn_name(cx);
135                    });
136                }
137            }
138        }
139    };
140    outer_fn.attrs.extend(inner_fn_attributes);
141
142    TokenStream::from(quote!(#outer_fn))
143}
144
145fn parse_int(literal: &Lit) -> syn::Result<usize> {
146    if let Lit::Int(int) = &literal {
147        int.base10_parse()
148    } else {
149        Err(syn::Error::new(literal.span(), "must be an integer"))
150    }
151}