lib.rs

  1use proc_macro::TokenStream;
  2use quote::{format_ident, quote};
  3use std::mem;
  4use syn::{
  5    parse_macro_input, parse_quote, AttributeArgs, ItemFn, Lit, Meta, MetaNameValue, NestedMeta,
  6};
  7
  8#[proc_macro_attribute]
  9pub fn test(args: TokenStream, function: TokenStream) -> TokenStream {
 10    let mut namespace = format_ident!("gpui");
 11
 12    let args = syn::parse_macro_input!(args as AttributeArgs);
 13    let mut max_retries = 0;
 14    let mut iterations = 1;
 15    for arg in args {
 16        match arg {
 17            NestedMeta::Meta(Meta::Path(name))
 18                if name.get_ident().map_or(false, |n| n == "self") =>
 19            {
 20                namespace = format_ident!("crate");
 21            }
 22            NestedMeta::Meta(Meta::NameValue(meta)) => {
 23                if let Some(result) = parse_int_meta(&meta, "retries") {
 24                    match result {
 25                        Ok(value) => max_retries = value,
 26                        Err(error) => return TokenStream::from(error.into_compile_error()),
 27                    }
 28                } else if let Some(result) = parse_int_meta(&meta, "iterations") {
 29                    match result {
 30                        Ok(value) => iterations = value,
 31                        Err(error) => return TokenStream::from(error.into_compile_error()),
 32                    }
 33                }
 34            }
 35            other => {
 36                return TokenStream::from(
 37                    syn::Error::new_spanned(other, "invalid argument").into_compile_error(),
 38                )
 39            }
 40        }
 41    }
 42
 43    let mut inner_fn = parse_macro_input!(function as ItemFn);
 44    let inner_fn_attributes = mem::take(&mut inner_fn.attrs);
 45    let inner_fn_name = format_ident!("_{}", inner_fn.sig.ident);
 46    let outer_fn_name = mem::replace(&mut inner_fn.sig.ident, inner_fn_name.clone());
 47
 48    // Pass to the test function the number of app contexts that it needs,
 49    // based on its parameter list.
 50    let inner_fn_args = (0..inner_fn.sig.inputs.len())
 51        .map(|i| {
 52            let first_entity_id = i * 100_000;
 53            quote!(#namespace::TestAppContext::new(foreground.clone(), background.clone(), #first_entity_id),)
 54        })
 55        .collect::<proc_macro2::TokenStream>();
 56
 57    let mut outer_fn: ItemFn = if inner_fn.sig.asyncness.is_some() {
 58        parse_quote! {
 59            #[test]
 60            fn #outer_fn_name() {
 61                #inner_fn
 62
 63                let mut retries = 0;
 64                let mut seed = 0;
 65                loop {
 66                    let result = std::panic::catch_unwind(|| {
 67                        let (foreground, background) = #namespace::executor::deterministic(seed as u64);
 68                        foreground.run(#inner_fn_name(#inner_fn_args));
 69                    });
 70
 71                    match result {
 72                        Ok(result) => {
 73                            seed += 1;
 74                            retries = 0;
 75                            if seed == #iterations {
 76                                return result
 77                            }
 78                        }
 79                        Err(error) => {
 80                            if retries < #max_retries {
 81                                retries += 1;
 82                                println!("retrying: attempt {}", retries);
 83                            } else {
 84                                if #iterations > 1 {
 85                                    eprintln!("failing seed: {}", seed);
 86                                }
 87                                std::panic::resume_unwind(error);
 88                            }
 89                        }
 90                    }
 91                }
 92            }
 93        }
 94    } else {
 95        parse_quote! {
 96            #[test]
 97            fn #outer_fn_name() {
 98                #inner_fn
 99
100                if #max_retries > 0 {
101                    let mut retries = 0;
102                    loop {
103                        let result = std::panic::catch_unwind(|| {
104                            #namespace::App::test(|cx| {
105                                #inner_fn_name(cx);
106                            });
107                        });
108
109                        match result {
110                            Ok(result) => return result,
111                            Err(error) => {
112                                if retries < #max_retries {
113                                    retries += 1;
114                                    println!("retrying: attempt {}", retries);
115                                } else {
116                                    std::panic::resume_unwind(error);
117                                }
118                            }
119                        }
120                    }
121                } else {
122                    #namespace::App::test(|cx| {
123                        #inner_fn_name(cx);
124                    });
125                }
126            }
127        }
128    };
129    outer_fn.attrs.extend(inner_fn_attributes);
130
131    TokenStream::from(quote!(#outer_fn))
132}
133
134fn parse_int_meta(meta: &MetaNameValue, name: &str) -> Option<syn::Result<usize>> {
135    let ident = meta.path.get_ident();
136    if ident.map_or(false, |n| n == name) {
137        if let Lit::Int(int) = &meta.lit {
138            Some(int.base10_parse())
139        } else {
140            Some(Err(syn::Error::new(
141                meta.lit.span(),
142                format!("{} mut be an integer", name),
143            )))
144        }
145    } else {
146        None
147    }
148}