lib.rs

  1use proc_macro::TokenStream;
  2use quote::{format_ident, quote};
  3use std::mem;
  4use syn::{
  5    parse_macro_input, parse_quote, AttributeArgs, ItemFn, Lit, Meta, MetaNameValue, NestedMeta,
  6};
  7
  8#[proc_macro_attribute]
  9pub fn test(args: TokenStream, function: TokenStream) -> TokenStream {
 10    let mut namespace = format_ident!("gpui");
 11
 12    let args = syn::parse_macro_input!(args as AttributeArgs);
 13    let mut max_retries = 0;
 14    for arg in args {
 15        match arg {
 16            NestedMeta::Meta(Meta::Path(name))
 17                if name.get_ident().map_or(false, |n| n == "self") =>
 18            {
 19                namespace = format_ident!("crate");
 20            }
 21            NestedMeta::Meta(Meta::NameValue(meta)) => {
 22                if let Some(result) = parse_retries(&meta) {
 23                    match result {
 24                        Ok(retries) => max_retries = retries,
 25                        Err(error) => return TokenStream::from(error.into_compile_error()),
 26                    }
 27                }
 28            }
 29            other => {
 30                return TokenStream::from(
 31                    syn::Error::new_spanned(other, "invalid argument").into_compile_error(),
 32                )
 33            }
 34        }
 35    }
 36
 37    let mut inner_fn = parse_macro_input!(function as ItemFn);
 38    let inner_fn_attributes = mem::take(&mut inner_fn.attrs);
 39    let inner_fn_name = format_ident!("_{}", inner_fn.sig.ident);
 40    let outer_fn_name = mem::replace(&mut inner_fn.sig.ident, inner_fn_name.clone());
 41
 42    // Pass to the test function the number of app contexts that it needs,
 43    // based on its parameter list.
 44    let inner_fn_args = (0..inner_fn.sig.inputs.len())
 45        .map(|i| {
 46            let first_entity_id = i * 100_000;
 47            quote!(#namespace::TestAppContext::new(foreground.clone(), #first_entity_id),)
 48        })
 49        .collect::<proc_macro2::TokenStream>();
 50
 51    let mut outer_fn: ItemFn = if inner_fn.sig.asyncness.is_some() {
 52        parse_quote! {
 53            #[test]
 54            fn #outer_fn_name() {
 55                #inner_fn
 56
 57                if #max_retries > 0 {
 58                    let mut retries = 0;
 59                    loop {
 60                        let result = std::panic::catch_unwind(|| {
 61                            let foreground = ::std::rc::Rc::new(#namespace::executor::Foreground::test());
 62                            #namespace::block_on(foreground.run(#inner_fn_name(#inner_fn_args)));
 63                        });
 64
 65                        match result {
 66                            Ok(result) => return result,
 67                            Err(error) => {
 68                                if retries < #max_retries {
 69                                    retries += 1;
 70                                    println!("retrying: attempt {}", retries);
 71                                } else {
 72                                    std::panic::resume_unwind(error);
 73                                }
 74                            }
 75                        }
 76                    }
 77                } else {
 78                    let foreground = ::std::rc::Rc::new(#namespace::executor::Foreground::test());
 79                    #namespace::block_on(foreground.run(#inner_fn_name(#inner_fn_args)));
 80                }
 81            }
 82        }
 83    } else {
 84        parse_quote! {
 85            #[test]
 86            fn #outer_fn_name() {
 87                #inner_fn
 88
 89                if #max_retries > 0 {
 90                    let mut retries = 0;
 91                    loop {
 92                        let result = std::panic::catch_unwind(|| {
 93                            #namespace::App::test(|cx| {
 94                                #inner_fn_name(cx);
 95                            });
 96                        });
 97
 98                        match result {
 99                            Ok(result) => return result,
100                            Err(error) => {
101                                if retries < #max_retries {
102                                    retries += 1;
103                                    println!("retrying: attempt {}", retries);
104                                } else {
105                                    std::panic::resume_unwind(error);
106                                }
107                            }
108                        }
109                    }
110                } else {
111                    #namespace::App::test(|cx| {
112                        #inner_fn_name(cx);
113                    });
114                }
115            }
116        }
117    };
118    outer_fn.attrs.extend(inner_fn_attributes);
119
120    TokenStream::from(quote!(#outer_fn))
121}
122
123fn parse_retries(meta: &MetaNameValue) -> Option<syn::Result<usize>> {
124    let ident = meta.path.get_ident();
125    if ident.map_or(false, |n| n == "retries") {
126        if let Lit::Int(int) = &meta.lit {
127            Some(int.base10_parse())
128        } else {
129            Some(Err(syn::Error::new(
130                meta.lit.span(),
131                "retries mut be an integer",
132            )))
133        }
134    } else {
135        None
136    }
137}