test.rs

  1use proc_macro::TokenStream;
  2use proc_macro2::Ident;
  3use quote::{format_ident, quote};
  4use std::mem;
  5use syn::{
  6    self, Expr, ExprLit, FnArg, ItemFn, Lit, Meta, MetaList, PathSegment, Token, Type,
  7    parse::{Parse, ParseStream},
  8    parse_quote,
  9    punctuated::Punctuated,
 10    spanned::Spanned,
 11};
 12
 13struct Args {
 14    seeds: Vec<u64>,
 15    max_retries: usize,
 16    max_iterations: usize,
 17    on_failure_fn_name: proc_macro2::TokenStream,
 18}
 19
 20impl Parse for Args {
 21    fn parse(input: ParseStream) -> Result<Self, syn::Error> {
 22        let mut seeds = Vec::<u64>::new();
 23        let mut max_retries = 0;
 24        let mut max_iterations = 1;
 25        let mut on_failure_fn_name = quote!(None);
 26
 27        let metas = Punctuated::<Meta, Token![,]>::parse_terminated(input)?;
 28
 29        for meta in metas {
 30            let ident = {
 31                let meta_path = match &meta {
 32                    Meta::NameValue(meta) => &meta.path,
 33                    Meta::List(list) => &list.path,
 34                    Meta::Path(path) => {
 35                        return Err(syn::Error::new(path.span(), "invalid path argument"));
 36                    }
 37                };
 38                let Some(ident) = meta_path.get_ident() else {
 39                    return Err(syn::Error::new(meta_path.span(), "unexpected path"));
 40                };
 41                ident.to_string()
 42            };
 43
 44            match (&meta, ident.as_str()) {
 45                (Meta::NameValue(meta), "retries") => {
 46                    max_retries = parse_usize_from_expr(&meta.value)?
 47                }
 48                (Meta::NameValue(meta), "iterations") => {
 49                    max_iterations = parse_usize_from_expr(&meta.value)?
 50                }
 51                (Meta::NameValue(meta), "on_failure") => {
 52                    let Expr::Lit(ExprLit {
 53                        lit: Lit::Str(name),
 54                        ..
 55                    }) = &meta.value
 56                    else {
 57                        return Err(syn::Error::new(
 58                            meta.value.span(),
 59                            "on_failure argument must be a string",
 60                        ));
 61                    };
 62                    let segments = name
 63                        .value()
 64                        .split("::")
 65                        .map(|part| PathSegment::from(Ident::new(part, name.span())))
 66                        .collect();
 67                    let path = syn::Path {
 68                        leading_colon: None,
 69                        segments,
 70                    };
 71                    on_failure_fn_name = quote!(Some(#path));
 72                }
 73                (Meta::NameValue(meta), "seed") => {
 74                    seeds = vec![parse_usize_from_expr(&meta.value)? as u64]
 75                }
 76                (Meta::List(list), "seeds") => seeds = parse_u64_array(list)?,
 77                (Meta::Path(_), _) => {
 78                    return Err(syn::Error::new(meta.span(), "invalid path argument"));
 79                }
 80                (_, _) => {
 81                    return Err(syn::Error::new(meta.span(), "invalid argument name"));
 82                }
 83            }
 84        }
 85
 86        Ok(Args {
 87            seeds,
 88            max_retries,
 89            max_iterations,
 90            on_failure_fn_name,
 91        })
 92    }
 93}
 94
 95pub fn test(args: TokenStream, function: TokenStream) -> TokenStream {
 96    let args = syn::parse_macro_input!(args as Args);
 97    let mut inner_fn = match syn::parse::<ItemFn>(function) {
 98        Ok(f) => f,
 99        Err(err) => return error_to_stream(err),
100    };
101
102    let inner_fn_attributes = mem::take(&mut inner_fn.attrs);
103    let inner_fn_name = format_ident!("__{}", inner_fn.sig.ident);
104    let outer_fn_name = mem::replace(&mut inner_fn.sig.ident, inner_fn_name.clone());
105
106    let result = generate_test_function(
107        args,
108        inner_fn,
109        inner_fn_attributes,
110        inner_fn_name,
111        outer_fn_name,
112    );
113    match result {
114        Ok(tokens) => tokens,
115        Err(tokens) => tokens,
116    }
117}
118
119fn generate_test_function(
120    args: Args,
121    inner_fn: ItemFn,
122    inner_fn_attributes: Vec<syn::Attribute>,
123    inner_fn_name: Ident,
124    outer_fn_name: Ident,
125) -> Result<TokenStream, TokenStream> {
126    let seeds = &args.seeds;
127    let max_retries = args.max_retries;
128    let num_iterations = args.max_iterations;
129    let on_failure_fn_name = &args.on_failure_fn_name;
130    let seeds = quote!( #(#seeds),* );
131
132    let mut outer_fn: ItemFn = if inner_fn.sig.asyncness.is_some() {
133        // Pass to the test function the number of app contexts that it needs,
134        // based on its parameter list.
135        let mut cx_vars = proc_macro2::TokenStream::new();
136        let mut cx_teardowns = proc_macro2::TokenStream::new();
137        let mut inner_fn_args = proc_macro2::TokenStream::new();
138        for (ix, arg) in inner_fn.sig.inputs.iter().enumerate() {
139            if let FnArg::Typed(arg) = arg {
140                if let Type::Path(ty) = &*arg.ty {
141                    let last_segment = ty.path.segments.last();
142                    match last_segment.map(|s| s.ident.to_string()).as_deref() {
143                        Some("StdRng") => {
144                            inner_fn_args.extend(quote!(rand::SeedableRng::seed_from_u64(_seed),));
145                            continue;
146                        }
147                        Some("BackgroundExecutor") => {
148                            inner_fn_args.extend(quote!(gpui::BackgroundExecutor::new(
149                                std::sync::Arc::new(dispatcher.clone()),
150                            ),));
151                            continue;
152                        }
153                        _ => {}
154                    }
155                } else if let Type::Reference(ty) = &*arg.ty
156                    && let Type::Path(ty) = &*ty.elem
157                {
158                    let last_segment = ty.path.segments.last();
159                    if let Some("TestAppContext") =
160                        last_segment.map(|s| s.ident.to_string()).as_deref()
161                    {
162                        let cx_varname = format_ident!("cx_{}", ix);
163                        cx_vars.extend(quote!(
164                            let mut #cx_varname = gpui::TestAppContext::build(
165                                dispatcher.clone(),
166                                Some(stringify!(#outer_fn_name)),
167                            );
168                        ));
169                        cx_teardowns.extend(quote!(
170                            dispatcher.run_until_parked();
171                            #cx_varname.executor().forbid_parking();
172                            #cx_varname.quit();
173                            dispatcher.run_until_parked();
174                        ));
175                        inner_fn_args.extend(quote!(&mut #cx_varname,));
176                        continue;
177                    }
178                }
179            }
180
181            return Err(error_with_message("invalid function signature", arg));
182        }
183
184        parse_quote! {
185            #[test]
186            fn #outer_fn_name() {
187                #inner_fn
188
189                gpui::run_test(
190                    #num_iterations,
191                    &[#seeds],
192                    #max_retries,
193                    &mut |dispatcher, _seed| {
194                        let executor = gpui::BackgroundExecutor::new(std::sync::Arc::new(dispatcher.clone()));
195                        #cx_vars
196                        executor.block_test(#inner_fn_name(#inner_fn_args));
197                        #cx_teardowns
198                    },
199                    #on_failure_fn_name
200                );
201            }
202        }
203    } else {
204        // Pass to the test function the number of app contexts that it needs,
205        // based on its parameter list.
206        let mut cx_vars = proc_macro2::TokenStream::new();
207        let mut cx_teardowns = proc_macro2::TokenStream::new();
208        let mut inner_fn_args = proc_macro2::TokenStream::new();
209        for (ix, arg) in inner_fn.sig.inputs.iter().enumerate() {
210            if let FnArg::Typed(arg) = arg {
211                if let Type::Path(ty) = &*arg.ty {
212                    let last_segment = ty.path.segments.last();
213
214                    if let Some("StdRng") = last_segment.map(|s| s.ident.to_string()).as_deref() {
215                        inner_fn_args.extend(quote!(rand::SeedableRng::seed_from_u64(_seed),));
216                        continue;
217                    }
218                } else if let Type::Reference(ty) = &*arg.ty
219                    && let Type::Path(ty) = &*ty.elem
220                {
221                    let last_segment = ty.path.segments.last();
222                    match last_segment.map(|s| s.ident.to_string()).as_deref() {
223                        Some("App") => {
224                            let cx_varname = format_ident!("cx_{}", ix);
225                            let cx_varname_lock = format_ident!("cx_{}_lock", ix);
226                            cx_vars.extend(quote!(
227                                let mut #cx_varname = gpui::TestAppContext::build(
228                                   dispatcher.clone(),
229                                   Some(stringify!(#outer_fn_name))
230                                );
231                                let mut #cx_varname_lock = #cx_varname.app.borrow_mut();
232                            ));
233                            inner_fn_args.extend(quote!(&mut #cx_varname_lock,));
234                            cx_teardowns.extend(quote!(
235                                    drop(#cx_varname_lock);
236                                    dispatcher.run_until_parked();
237                                    #cx_varname.update(|cx| { cx.background_executor().forbid_parking(); cx.quit(); });
238                                    dispatcher.run_until_parked();
239                                ));
240                            continue;
241                        }
242                        Some("TestAppContext") => {
243                            let cx_varname = format_ident!("cx_{}", ix);
244                            cx_vars.extend(quote!(
245                                let mut #cx_varname = gpui::TestAppContext::build(
246                                    dispatcher.clone(),
247                                    Some(stringify!(#outer_fn_name))
248                                );
249                            ));
250                            cx_teardowns.extend(quote!(
251                                dispatcher.run_until_parked();
252                                #cx_varname.executor().forbid_parking();
253                                #cx_varname.quit();
254                                dispatcher.run_until_parked();
255                            ));
256                            inner_fn_args.extend(quote!(&mut #cx_varname,));
257                            continue;
258                        }
259                        _ => {}
260                    }
261                }
262            }
263
264            return Err(error_with_message("invalid function signature", arg));
265        }
266
267        parse_quote! {
268            #[test]
269            fn #outer_fn_name() {
270                #inner_fn
271
272                gpui::run_test(
273                    #num_iterations,
274                    &[#seeds],
275                    #max_retries,
276                    &mut |dispatcher, _seed| {
277                        #cx_vars
278                        #inner_fn_name(#inner_fn_args);
279                        #cx_teardowns
280                    },
281                    #on_failure_fn_name,
282                );
283            }
284        }
285    };
286    outer_fn.attrs.extend(inner_fn_attributes);
287
288    Ok(TokenStream::from(quote!(#outer_fn)))
289}
290
291fn parse_usize_from_expr(expr: &Expr) -> Result<usize, syn::Error> {
292    let Expr::Lit(ExprLit {
293        lit: Lit::Int(int), ..
294    }) = expr
295    else {
296        return Err(syn::Error::new(expr.span(), "expected an integer"));
297    };
298    int.base10_parse()
299        .map_err(|_| syn::Error::new(int.span(), "failed to parse integer"))
300}
301
302fn parse_u64_array(meta_list: &MetaList) -> Result<Vec<u64>, syn::Error> {
303    let mut result = Vec::new();
304    let tokens = &meta_list.tokens;
305    let parser = |input: ParseStream| {
306        let exprs = Punctuated::<Expr, Token![,]>::parse_terminated(input)?;
307        for expr in exprs {
308            if let Expr::Lit(ExprLit {
309                lit: Lit::Int(int), ..
310            }) = expr
311            {
312                let value: usize = int.base10_parse()?;
313                result.push(value as u64);
314            } else {
315                return Err(syn::Error::new(expr.span(), "expected an integer"));
316            }
317        }
318        Ok(())
319    };
320    syn::parse::Parser::parse2(parser, tokens.clone())?;
321    Ok(result)
322}
323
324fn error_with_message(message: &str, spanned: impl Spanned) -> TokenStream {
325    error_to_stream(syn::Error::new(spanned.span(), message))
326}
327
328fn error_to_stream(err: syn::Error) -> TokenStream {
329    TokenStream::from(err.into_compile_error())
330}