lib.rs

  1use proc_macro::TokenStream;
  2use quote::{format_ident, quote};
  3use std::mem;
  4use syn::{
  5    parse_macro_input, parse_quote, AttributeArgs, ItemFn, Lit, Meta, MetaNameValue, NestedMeta,
  6};
  7
  8#[proc_macro_attribute]
  9pub fn test(args: TokenStream, function: TokenStream) -> TokenStream {
 10    let mut namespace = format_ident!("gpui");
 11
 12    let args = syn::parse_macro_input!(args as AttributeArgs);
 13    let mut max_retries = 0;
 14    for arg in args {
 15        match arg {
 16            NestedMeta::Meta(Meta::Path(name))
 17                if name.get_ident().map_or(false, |n| n == "self") =>
 18            {
 19                namespace = format_ident!("crate");
 20            }
 21            NestedMeta::Meta(Meta::NameValue(meta)) => {
 22                if let Some(result) = parse_retries(&meta) {
 23                    match result {
 24                        Ok(retries) => max_retries = retries,
 25                        Err(error) => return TokenStream::from(error.into_compile_error()),
 26                    }
 27                }
 28            }
 29            other => {
 30                return TokenStream::from(
 31                    syn::Error::new_spanned(other, "invalid argument").into_compile_error(),
 32                )
 33            }
 34        }
 35    }
 36
 37    let mut inner_fn = parse_macro_input!(function as ItemFn);
 38    let inner_fn_attributes = mem::take(&mut inner_fn.attrs);
 39    let inner_fn_name = format_ident!("_{}", inner_fn.sig.ident);
 40    let outer_fn_name = mem::replace(&mut inner_fn.sig.ident, inner_fn_name.clone());
 41    let mut outer_fn: ItemFn = if inner_fn.sig.asyncness.is_some() {
 42        parse_quote! {
 43            #[test]
 44            fn #outer_fn_name() {
 45                #inner_fn
 46
 47                if #max_retries > 0 {
 48                    let mut retries = 0;
 49                    loop {
 50                        let result = std::panic::catch_unwind(|| {
 51                            #namespace::App::test_async(move |cx| async {
 52                                #inner_fn_name(cx).await;
 53                            });
 54                        });
 55
 56                        match result {
 57                            Ok(result) => return result,
 58                            Err(error) => {
 59                                if retries < #max_retries {
 60                                    retries += 1;
 61                                    println!("retrying: attempt {}", retries);
 62                                } else {
 63                                    std::panic::resume_unwind(error);
 64                                }
 65                            }
 66                        }
 67                    }
 68                } else {
 69                    #namespace::App::test_async(move |cx| async {
 70                        #inner_fn_name(cx).await;
 71                    });
 72                }
 73            }
 74        }
 75    } else {
 76        parse_quote! {
 77            #[test]
 78            fn #outer_fn_name() {
 79                #inner_fn
 80
 81                if #max_retries > 0 {
 82                    let mut retries = 0;
 83                    loop {
 84                        let result = std::panic::catch_unwind(|| {
 85                            #namespace::App::test(|cx| {
 86                                #inner_fn_name(cx);
 87                            });
 88                        });
 89
 90                        match result {
 91                            Ok(result) => return result,
 92                            Err(error) => {
 93                                if retries < #max_retries {
 94                                    retries += 1;
 95                                    println!("retrying: attempt {}", retries);
 96                                } else {
 97                                    std::panic::resume_unwind(error);
 98                                }
 99                            }
100                        }
101                    }
102                } else {
103                    #namespace::App::test(|cx| {
104                        #inner_fn_name(cx);
105                    });
106                }
107            }
108        }
109    };
110    outer_fn.attrs.extend(inner_fn_attributes);
111
112    TokenStream::from(quote!(#outer_fn))
113}
114
115fn parse_retries(meta: &MetaNameValue) -> Option<syn::Result<usize>> {
116    let ident = meta.path.get_ident();
117    if ident.map_or(false, |n| n == "retries") {
118        if let Lit::Int(int) = &meta.lit {
119            Some(int.base10_parse())
120        } else {
121            Some(Err(syn::Error::new(
122                meta.lit.span(),
123                "retries mut be an integer",
124            )))
125        }
126    } else {
127        None
128    }
129}