test.rs

  1use proc_macro::TokenStream;
  2use proc_macro2::Ident;
  3use quote::{format_ident, quote};
  4use std::mem;
  5use syn::{
  6    parse_macro_input, parse_quote, spanned::Spanned as _, AttributeArgs, FnArg, ItemFn, Lit, Meta,
  7    NestedMeta, Type,
  8};
  9
 10pub fn test(args: TokenStream, function: TokenStream) -> TokenStream {
 11    let args = syn::parse_macro_input!(args as AttributeArgs);
 12    let mut max_retries = 0;
 13    let mut num_iterations = 1;
 14    let mut on_failure_fn_name = quote!(None);
 15
 16    for arg in args {
 17        match arg {
 18            NestedMeta::Meta(Meta::NameValue(meta)) => {
 19                let key_name = meta.path.get_ident().map(|i| i.to_string());
 20                let result = (|| {
 21                    match key_name.as_deref() {
 22                        Some("retries") => max_retries = parse_int(&meta.lit)?,
 23                        Some("iterations") => num_iterations = parse_int(&meta.lit)?,
 24                        Some("on_failure") => {
 25                            if let Lit::Str(name) = meta.lit {
 26                                let mut path = syn::Path {
 27                                    leading_colon: None,
 28                                    segments: Default::default(),
 29                                };
 30                                for part in name.value().split("::") {
 31                                    path.segments.push(Ident::new(part, name.span()).into());
 32                                }
 33                                on_failure_fn_name = quote!(Some(#path));
 34                            } else {
 35                                return Err(TokenStream::from(
 36                                    syn::Error::new(
 37                                        meta.lit.span(),
 38                                        "on_failure argument must be a string",
 39                                    )
 40                                    .into_compile_error(),
 41                                ));
 42                            }
 43                        }
 44                        _ => {
 45                            return Err(TokenStream::from(
 46                                syn::Error::new(meta.path.span(), "invalid argument")
 47                                    .into_compile_error(),
 48                            ))
 49                        }
 50                    }
 51                    Ok(())
 52                })();
 53
 54                if let Err(tokens) = result {
 55                    return tokens;
 56                }
 57            }
 58            other => {
 59                return TokenStream::from(
 60                    syn::Error::new_spanned(other, "invalid argument").into_compile_error(),
 61                )
 62            }
 63        }
 64    }
 65
 66    let mut inner_fn = parse_macro_input!(function as ItemFn);
 67    if max_retries > 0 && num_iterations > 1 {
 68        return TokenStream::from(
 69            syn::Error::new_spanned(inner_fn, "retries and randomized iterations can't be mixed")
 70                .into_compile_error(),
 71        );
 72    }
 73    let inner_fn_attributes = mem::take(&mut inner_fn.attrs);
 74    let inner_fn_name = format_ident!("_{}", inner_fn.sig.ident);
 75    let outer_fn_name = mem::replace(&mut inner_fn.sig.ident, inner_fn_name.clone());
 76
 77    let mut outer_fn: ItemFn = if inner_fn.sig.asyncness.is_some() {
 78        // Pass to the test function the number of app contexts that it needs,
 79        // based on its parameter list.
 80        let mut cx_vars = proc_macro2::TokenStream::new();
 81        let mut cx_teardowns = proc_macro2::TokenStream::new();
 82        let mut inner_fn_args = proc_macro2::TokenStream::new();
 83        for (ix, arg) in inner_fn.sig.inputs.iter().enumerate() {
 84            if let FnArg::Typed(arg) = arg {
 85                if let Type::Path(ty) = &*arg.ty {
 86                    let last_segment = ty.path.segments.last();
 87                    match last_segment.map(|s| s.ident.to_string()).as_deref() {
 88                        Some("StdRng") => {
 89                            inner_fn_args.extend(quote!(rand::SeedableRng::seed_from_u64(_seed),));
 90                            continue;
 91                        }
 92                        Some("BackgroundExecutor") => {
 93                            inner_fn_args.extend(quote!(gpui::BackgroundExecutor::new(
 94                                std::sync::Arc::new(dispatcher.clone()),
 95                            ),));
 96                            continue;
 97                        }
 98                        _ => {}
 99                    }
100                } else if let Type::Reference(ty) = &*arg.ty {
101                    if let Type::Path(ty) = &*ty.elem {
102                        let last_segment = ty.path.segments.last();
103                        if let Some("TestAppContext") =
104                            last_segment.map(|s| s.ident.to_string()).as_deref()
105                        {
106                            let cx_varname = format_ident!("cx_{}", ix);
107                            cx_vars.extend(quote!(
108                                let mut #cx_varname = gpui::TestAppContext::new(
109                                    dispatcher.clone(),
110                                    Some(stringify!(#outer_fn_name)),
111                                );
112                            ));
113                            cx_teardowns.extend(quote!(
114                                dispatcher.run_until_parked();
115                                #cx_varname.quit();
116                                dispatcher.run_until_parked();
117                            ));
118                            inner_fn_args.extend(quote!(&mut #cx_varname,));
119                            continue;
120                        }
121                    }
122                }
123            }
124
125            return TokenStream::from(
126                syn::Error::new_spanned(arg, "invalid argument").into_compile_error(),
127            );
128        }
129
130        parse_quote! {
131            #[test]
132            fn #outer_fn_name() {
133                #inner_fn
134
135                gpui::run_test(
136                    #num_iterations as u64,
137                    #max_retries,
138                    &mut |dispatcher, _seed| {
139                        let executor = gpui::BackgroundExecutor::new(std::sync::Arc::new(dispatcher.clone()));
140                        #cx_vars
141                        executor.block_test(#inner_fn_name(#inner_fn_args));
142                        #cx_teardowns
143                    },
144                    #on_failure_fn_name
145                );
146            }
147        }
148    } else {
149        // Pass to the test function the number of app contexts that it needs,
150        // based on its parameter list.
151        let mut cx_vars = proc_macro2::TokenStream::new();
152        let mut cx_teardowns = proc_macro2::TokenStream::new();
153        let mut inner_fn_args = proc_macro2::TokenStream::new();
154        for (ix, arg) in inner_fn.sig.inputs.iter().enumerate() {
155            if let FnArg::Typed(arg) = arg {
156                if let Type::Path(ty) = &*arg.ty {
157                    let last_segment = ty.path.segments.last();
158
159                    if let Some("StdRng") = last_segment.map(|s| s.ident.to_string()).as_deref() {
160                        inner_fn_args.extend(quote!(rand::SeedableRng::seed_from_u64(_seed),));
161                        continue;
162                    }
163                } else if let Type::Reference(ty) = &*arg.ty {
164                    if let Type::Path(ty) = &*ty.elem {
165                        let last_segment = ty.path.segments.last();
166                        match last_segment.map(|s| s.ident.to_string()).as_deref() {
167                            Some("App") => {
168                                let cx_varname = format_ident!("cx_{}", ix);
169                                let cx_varname_lock = format_ident!("cx_{}_lock", ix);
170                                cx_vars.extend(quote!(
171                                    let mut #cx_varname = gpui::TestAppContext::new(
172                                       dispatcher.clone(),
173                                       Some(stringify!(#outer_fn_name))
174                                    );
175                                    let mut #cx_varname_lock = #cx_varname.app.borrow_mut();
176                                ));
177                                inner_fn_args.extend(quote!(&mut #cx_varname_lock,));
178                                cx_teardowns.extend(quote!(
179                                    drop(#cx_varname_lock);
180                                    dispatcher.run_until_parked();
181                                    #cx_varname.update(|cx| { cx.quit() });
182                                    dispatcher.run_until_parked();
183                                ));
184                                continue;
185                            }
186                            Some("TestAppContext") => {
187                                let cx_varname = format_ident!("cx_{}", ix);
188                                cx_vars.extend(quote!(
189                                    let mut #cx_varname = gpui::TestAppContext::new(
190                                        dispatcher.clone(),
191                                        Some(stringify!(#outer_fn_name))
192                                    );
193                                ));
194                                cx_teardowns.extend(quote!(
195                                    dispatcher.run_until_parked();
196                                    #cx_varname.quit();
197                                    dispatcher.run_until_parked();
198                                ));
199                                inner_fn_args.extend(quote!(&mut #cx_varname,));
200                                continue;
201                            }
202                            _ => {}
203                        }
204                    }
205                }
206            }
207
208            return TokenStream::from(
209                syn::Error::new_spanned(arg, "invalid argument").into_compile_error(),
210            );
211        }
212
213        parse_quote! {
214            #[test]
215            fn #outer_fn_name() {
216                #inner_fn
217
218                gpui::run_test(
219                    #num_iterations as u64,
220                    #max_retries,
221                    &mut |dispatcher, _seed| {
222                        #cx_vars
223                        #inner_fn_name(#inner_fn_args);
224                        #cx_teardowns
225                    },
226                    #on_failure_fn_name,
227                );
228            }
229        }
230    };
231    outer_fn.attrs.extend(inner_fn_attributes);
232
233    TokenStream::from(quote!(#outer_fn))
234}
235
236fn parse_int(literal: &Lit) -> Result<usize, TokenStream> {
237    let result = if let Lit::Int(int) = &literal {
238        int.base10_parse()
239    } else {
240        Err(syn::Error::new(literal.span(), "must be an integer"))
241    };
242
243    result.map_err(|err| TokenStream::from(err.into_compile_error()))
244}