test.rs

  1use proc_macro::TokenStream;
  2use proc_macro2::Ident;
  3use quote::{format_ident, quote};
  4use std::mem;
  5use syn::{
  6    parse_macro_input, parse_quote, spanned::Spanned as _, AttributeArgs, FnArg, ItemFn, Lit, Meta,
  7    NestedMeta, Type,
  8};
  9
 10pub fn test(args: TokenStream, function: TokenStream) -> TokenStream {
 11    let args = syn::parse_macro_input!(args as AttributeArgs);
 12    let mut max_retries = 0;
 13    let mut num_iterations = 1;
 14    let mut on_failure_fn_name = quote!(None);
 15
 16    for arg in args {
 17        match arg {
 18            NestedMeta::Meta(Meta::NameValue(meta)) => {
 19                let key_name = meta.path.get_ident().map(|i| i.to_string());
 20                let result = (|| {
 21                    match key_name.as_deref() {
 22                        Some("retries") => max_retries = parse_int(&meta.lit)?,
 23                        Some("iterations") => num_iterations = parse_int(&meta.lit)?,
 24                        Some("on_failure") => {
 25                            if let Lit::Str(name) = meta.lit {
 26                                let mut path = syn::Path {
 27                                    leading_colon: None,
 28                                    segments: Default::default(),
 29                                };
 30                                for part in name.value().split("::") {
 31                                    path.segments.push(Ident::new(part, name.span()).into());
 32                                }
 33                                on_failure_fn_name = quote!(Some(#path));
 34                            } else {
 35                                return Err(TokenStream::from(
 36                                    syn::Error::new(
 37                                        meta.lit.span(),
 38                                        "on_failure argument must be a string",
 39                                    )
 40                                    .into_compile_error(),
 41                                ));
 42                            }
 43                        }
 44                        _ => {
 45                            return Err(TokenStream::from(
 46                                syn::Error::new(meta.path.span(), "invalid argument")
 47                                    .into_compile_error(),
 48                            ))
 49                        }
 50                    }
 51                    Ok(())
 52                })();
 53
 54                if let Err(tokens) = result {
 55                    return tokens;
 56                }
 57            }
 58            other => {
 59                return TokenStream::from(
 60                    syn::Error::new_spanned(other, "invalid argument").into_compile_error(),
 61                )
 62            }
 63        }
 64    }
 65
 66    let mut inner_fn = parse_macro_input!(function as ItemFn);
 67    if max_retries > 0 && num_iterations > 1 {
 68        return TokenStream::from(
 69            syn::Error::new_spanned(inner_fn, "retries and randomized iterations can't be mixed")
 70                .into_compile_error(),
 71        );
 72    }
 73    let inner_fn_attributes = mem::take(&mut inner_fn.attrs);
 74    let inner_fn_name = format_ident!("_{}", inner_fn.sig.ident);
 75    let outer_fn_name = mem::replace(&mut inner_fn.sig.ident, inner_fn_name.clone());
 76
 77    let mut outer_fn: ItemFn = if inner_fn.sig.asyncness.is_some() {
 78        // Pass to the test function the number of app contexts that it needs,
 79        // based on its parameter list.
 80        let mut cx_vars = proc_macro2::TokenStream::new();
 81        let mut cx_teardowns = proc_macro2::TokenStream::new();
 82        let mut inner_fn_args = proc_macro2::TokenStream::new();
 83        for (ix, arg) in inner_fn.sig.inputs.iter().enumerate() {
 84            if let FnArg::Typed(arg) = arg {
 85                if let Type::Path(ty) = &*arg.ty {
 86                    let last_segment = ty.path.segments.last();
 87                    match last_segment.map(|s| s.ident.to_string()).as_deref() {
 88                        Some("StdRng") => {
 89                            inner_fn_args.extend(quote!(rand::SeedableRng::seed_from_u64(_seed),));
 90                            continue;
 91                        }
 92                        Some("BackgroundExecutor") => {
 93                            inner_fn_args.extend(quote!(gpui::BackgroundExecutor::new(
 94                                std::sync::Arc::new(dispatcher.clone()),
 95                            ),));
 96                            continue;
 97                        }
 98                        _ => {}
 99                    }
100                } else if let Type::Reference(ty) = &*arg.ty {
101                    if let Type::Path(ty) = &*ty.elem {
102                        let last_segment = ty.path.segments.last();
103                        if let Some("TestAppContext") =
104                            last_segment.map(|s| s.ident.to_string()).as_deref()
105                        {
106                            let cx_varname = format_ident!("cx_{}", ix);
107                            cx_vars.extend(quote!(
108                                let mut #cx_varname = gpui::TestAppContext::new(
109                                    dispatcher.clone()
110                                );
111                            ));
112                            cx_teardowns.extend(quote!(
113                                dispatcher.run_until_parked();
114                                #cx_varname.quit();
115                                dispatcher.run_until_parked();
116                            ));
117                            inner_fn_args.extend(quote!(&mut #cx_varname,));
118                            continue;
119                        }
120                    }
121                }
122            }
123
124            return TokenStream::from(
125                syn::Error::new_spanned(arg, "invalid argument").into_compile_error(),
126            );
127        }
128
129        parse_quote! {
130            #[test]
131            fn #outer_fn_name() {
132                #inner_fn
133
134                gpui::run_test(
135                    #num_iterations as u64,
136                    #max_retries,
137                    &mut |dispatcher, _seed| {
138                        let executor = gpui::BackgroundExecutor::new(std::sync::Arc::new(dispatcher.clone()));
139                        #cx_vars
140                        executor.block_test(#inner_fn_name(#inner_fn_args));
141                        #cx_teardowns
142                    },
143                    #on_failure_fn_name,
144                    stringify!(#outer_fn_name).to_string(),
145                );
146            }
147        }
148    } else {
149        // Pass to the test function the number of app contexts that it needs,
150        // based on its parameter list.
151        let mut cx_vars = proc_macro2::TokenStream::new();
152        let mut cx_teardowns = proc_macro2::TokenStream::new();
153        let mut inner_fn_args = proc_macro2::TokenStream::new();
154        for (ix, arg) in inner_fn.sig.inputs.iter().enumerate() {
155            if let FnArg::Typed(arg) = arg {
156                if let Type::Path(ty) = &*arg.ty {
157                    let last_segment = ty.path.segments.last();
158
159                    if let Some("StdRng") = last_segment.map(|s| s.ident.to_string()).as_deref() {
160                        inner_fn_args.extend(quote!(rand::SeedableRng::seed_from_u64(_seed),));
161                        continue;
162                    }
163                } else if let Type::Reference(ty) = &*arg.ty {
164                    if let Type::Path(ty) = &*ty.elem {
165                        let last_segment = ty.path.segments.last();
166                        match last_segment.map(|s| s.ident.to_string()).as_deref() {
167                            Some("AppContext") => {
168                                let cx_varname = format_ident!("cx_{}", ix);
169                                let cx_varname_lock = format_ident!("cx_{}_lock", ix);
170                                cx_vars.extend(quote!(
171                                    let mut #cx_varname = gpui::TestAppContext::new(
172                                       dispatcher.clone()
173                                    );
174                                    let mut #cx_varname_lock = #cx_varname.app.borrow_mut();
175                                ));
176                                inner_fn_args.extend(quote!(&mut #cx_varname_lock,));
177                                cx_teardowns.extend(quote!(
178                                    drop(#cx_varname_lock);
179                                    dispatcher.run_until_parked();
180                                    #cx_varname.update(|cx| { cx.quit() });
181                                    dispatcher.run_until_parked();
182                                ));
183                                continue;
184                            }
185                            Some("TestAppContext") => {
186                                let cx_varname = format_ident!("cx_{}", ix);
187                                cx_vars.extend(quote!(
188                                    let mut #cx_varname = gpui::TestAppContext::new(
189                                        dispatcher.clone()
190                                    );
191                                ));
192                                cx_teardowns.extend(quote!(
193                                    dispatcher.run_until_parked();
194                                    #cx_varname.quit();
195                                    dispatcher.run_until_parked();
196                                ));
197                                inner_fn_args.extend(quote!(&mut #cx_varname,));
198                                continue;
199                            }
200                            _ => {}
201                        }
202                    }
203                }
204            }
205
206            return TokenStream::from(
207                syn::Error::new_spanned(arg, "invalid argument").into_compile_error(),
208            );
209        }
210
211        parse_quote! {
212            #[test]
213            fn #outer_fn_name() {
214                #inner_fn
215
216                gpui::run_test(
217                    #num_iterations as u64,
218                    #max_retries,
219                    &mut |dispatcher, _seed| {
220                        #cx_vars
221                        #inner_fn_name(#inner_fn_args);
222                        #cx_teardowns
223                    },
224                    #on_failure_fn_name,
225                    stringify!(#outer_fn_name).to_string(),
226                );
227            }
228        }
229    };
230    outer_fn.attrs.extend(inner_fn_attributes);
231
232    TokenStream::from(quote!(#outer_fn))
233}
234
235fn parse_int(literal: &Lit) -> Result<usize, TokenStream> {
236    let result = if let Lit::Int(int) = &literal {
237        int.base10_parse()
238    } else {
239        Err(syn::Error::new(literal.span(), "must be an integer"))
240    };
241
242    result.map_err(|err| TokenStream::from(err.into_compile_error()))
243}