test.rs

  1use proc_macro::TokenStream;
  2use proc_macro2::Ident;
  3use quote::{format_ident, quote};
  4use std::mem;
  5use syn::{
  6    self, Expr, ExprLit, FnArg, ItemFn, Lit, Meta, MetaList, PathSegment, Token, Type,
  7    parse::{Parse, ParseStream},
  8    parse_quote,
  9    punctuated::Punctuated,
 10    spanned::Spanned,
 11};
 12
 13struct Args {
 14    seeds: Vec<u64>,
 15    max_retries: usize,
 16    max_iterations: usize,
 17    on_failure_fn_name: proc_macro2::TokenStream,
 18}
 19
 20impl Parse for Args {
 21    fn parse(input: ParseStream) -> Result<Self, syn::Error> {
 22        let mut seeds = Vec::<u64>::new();
 23        let mut max_retries = 0;
 24        let mut max_iterations = 1;
 25        let mut on_failure_fn_name = quote!(None);
 26
 27        let metas = Punctuated::<Meta, Token![,]>::parse_terminated(input)?;
 28
 29        for meta in metas {
 30            let ident = {
 31                let meta_path = match &meta {
 32                    Meta::NameValue(meta) => &meta.path,
 33                    Meta::List(list) => &list.path,
 34                    Meta::Path(path) => {
 35                        return Err(syn::Error::new(path.span(), "invalid path argument"));
 36                    }
 37                };
 38                let Some(ident) = meta_path.get_ident() else {
 39                    return Err(syn::Error::new(meta_path.span(), "unexpected path"));
 40                };
 41                ident.to_string()
 42            };
 43
 44            match (&meta, ident.as_str()) {
 45                (Meta::NameValue(meta), "retries") => {
 46                    max_retries = parse_usize_from_expr(&meta.value)?
 47                }
 48                (Meta::NameValue(meta), "iterations") => {
 49                    max_iterations = parse_usize_from_expr(&meta.value)?
 50                }
 51                (Meta::NameValue(meta), "on_failure") => {
 52                    let Expr::Lit(ExprLit {
 53                        lit: Lit::Str(name),
 54                        ..
 55                    }) = &meta.value
 56                    else {
 57                        return Err(syn::Error::new(
 58                            meta.value.span(),
 59                            "on_failure argument must be a string",
 60                        ));
 61                    };
 62                    let segments = name
 63                        .value()
 64                        .split("::")
 65                        .map(|part| PathSegment::from(Ident::new(part, name.span())))
 66                        .collect();
 67                    let path = syn::Path {
 68                        leading_colon: None,
 69                        segments,
 70                    };
 71                    on_failure_fn_name = quote!(Some(#path));
 72                }
 73                (Meta::NameValue(meta), "seed") => {
 74                    seeds = vec![parse_usize_from_expr(&meta.value)? as u64]
 75                }
 76                (Meta::List(list), "seeds") => seeds = parse_u64_array(list)?,
 77                (Meta::Path(_), _) => {
 78                    return Err(syn::Error::new(meta.span(), "invalid path argument"));
 79                }
 80                (_, _) => {
 81                    return Err(syn::Error::new(meta.span(), "invalid argument name"));
 82                }
 83            }
 84        }
 85
 86        Ok(Args {
 87            seeds,
 88            max_retries,
 89            max_iterations,
 90            on_failure_fn_name,
 91        })
 92    }
 93}
 94
 95pub fn test(args: TokenStream, function: TokenStream) -> TokenStream {
 96    let args = syn::parse_macro_input!(args as Args);
 97    let mut inner_fn = match syn::parse::<ItemFn>(function) {
 98        Ok(f) => f,
 99        Err(err) => return error_to_stream(err),
100    };
101
102    let inner_fn_attributes = mem::take(&mut inner_fn.attrs);
103    let inner_fn_name = format_ident!("__{}", inner_fn.sig.ident);
104    let outer_fn_name = mem::replace(&mut inner_fn.sig.ident, inner_fn_name.clone());
105
106    let result = generate_test_function(
107        args,
108        inner_fn,
109        inner_fn_attributes,
110        inner_fn_name,
111        outer_fn_name,
112    );
113    match result {
114        Ok(tokens) => tokens,
115        Err(tokens) => tokens,
116    }
117}
118
119fn generate_test_function(
120    args: Args,
121    inner_fn: ItemFn,
122    inner_fn_attributes: Vec<syn::Attribute>,
123    inner_fn_name: Ident,
124    outer_fn_name: Ident,
125) -> Result<TokenStream, TokenStream> {
126    let seeds = &args.seeds;
127    let max_retries = args.max_retries;
128    let num_iterations = args.max_iterations;
129    let on_failure_fn_name = &args.on_failure_fn_name;
130    let seeds = quote!( #(#seeds),* );
131
132    let mut outer_fn: ItemFn = if inner_fn.sig.asyncness.is_some() {
133        // Pass to the test function the number of app contexts that it needs,
134        // based on its parameter list.
135        let mut cx_vars = proc_macro2::TokenStream::new();
136        let mut cx_teardowns = proc_macro2::TokenStream::new();
137        let mut inner_fn_args = proc_macro2::TokenStream::new();
138        for (ix, arg) in inner_fn.sig.inputs.iter().enumerate() {
139            if let FnArg::Typed(arg) = arg {
140                if let Type::Path(ty) = &*arg.ty {
141                    let last_segment = ty.path.segments.last();
142                    match last_segment.map(|s| s.ident.to_string()).as_deref() {
143                        Some("StdRng") => {
144                            inner_fn_args.extend(quote!(rand::SeedableRng::seed_from_u64(_seed),));
145                            continue;
146                        }
147                        Some("BackgroundExecutor") => {
148                            inner_fn_args.extend(quote!(gpui::BackgroundExecutor::new(
149                                std::sync::Arc::new(dispatcher.clone()),
150                            ),));
151                            continue;
152                        }
153                        _ => {}
154                    }
155                } else if let Type::Reference(ty) = &*arg.ty
156                    && let Type::Path(ty) = &*ty.elem
157                {
158                    let last_segment = ty.path.segments.last();
159                    if let Some("TestAppContext") =
160                        last_segment.map(|s| s.ident.to_string()).as_deref()
161                    {
162                        let cx_varname = format_ident!("cx_{}", ix);
163                        cx_vars.extend(quote!(
164                            let mut #cx_varname = gpui::TestAppContext::build(
165                                dispatcher.clone(),
166                                Some(stringify!(#outer_fn_name)),
167                            );
168                            let _entity_refcounts = #cx_varname.app.borrow().ref_counts_drop_handle();
169                        ));
170                        cx_teardowns.extend(quote!(
171                            #cx_varname.run_until_parked();
172                            #cx_varname.update(|cx| { cx.background_executor().forbid_parking(); cx.quit(); });
173                            #cx_varname.run_until_parked();
174                            drop(#cx_varname);
175                        ));
176                        inner_fn_args.extend(quote!(&mut #cx_varname,));
177                        continue;
178                    }
179                }
180            }
181
182            return Err(error_with_message("invalid function signature", arg));
183        }
184
185        parse_quote! {
186            #[test]
187            fn #outer_fn_name() {
188                #inner_fn
189
190                gpui::run_test(
191                    #num_iterations,
192                    &[#seeds],
193                    #max_retries,
194                    &mut |dispatcher, _seed| {
195                        let exec = std::sync::Arc::new(dispatcher.clone());
196                        #cx_vars
197                        gpui::ForegroundExecutor::new(exec.clone()).block_test(#inner_fn_name(#inner_fn_args));
198                        drop(exec);
199                        #cx_teardowns
200                        // Ideally we would only drop cancelled tasks, that way we could detect leaks due to task <-> entity
201                        // cycles as cancelled tasks will be dropped properly once the runnable gets run again
202                        //
203                        // async-task does not give us the power to do this just yet though
204                        dispatcher.drain_tasks();
205                        drop(dispatcher);
206                    },
207                    #on_failure_fn_name
208                );
209            }
210        }
211    } else {
212        // Pass to the test function the number of app contexts that it needs,
213        // based on its parameter list.
214        let mut cx_vars = proc_macro2::TokenStream::new();
215        let mut cx_teardowns = proc_macro2::TokenStream::new();
216        let mut inner_fn_args = proc_macro2::TokenStream::new();
217        for (ix, arg) in inner_fn.sig.inputs.iter().enumerate() {
218            if let FnArg::Typed(arg) = arg {
219                if let Type::Path(ty) = &*arg.ty {
220                    let last_segment = ty.path.segments.last();
221
222                    if let Some("StdRng") = last_segment.map(|s| s.ident.to_string()).as_deref() {
223                        inner_fn_args.extend(quote!(rand::SeedableRng::seed_from_u64(_seed),));
224                        continue;
225                    }
226                } else if let Type::Reference(ty) = &*arg.ty
227                    && let Type::Path(ty) = &*ty.elem
228                {
229                    let last_segment = ty.path.segments.last();
230                    match last_segment.map(|s| s.ident.to_string()).as_deref() {
231                        Some("App") => {
232                            let cx_varname = format_ident!("cx_{}", ix);
233                            let cx_varname_lock = format_ident!("cx_{}_lock", ix);
234                            cx_vars.extend(quote!(
235                                let mut #cx_varname = gpui::TestAppContext::build(
236                                   dispatcher.clone(),
237                                   Some(stringify!(#outer_fn_name))
238                                );
239                                let mut #cx_varname_lock = #cx_varname.app.borrow_mut();
240                                let _entity_refcounts = #cx_varname_lock.ref_counts_drop_handle();
241                            ));
242                            inner_fn_args.extend(quote!(&mut #cx_varname_lock,));
243                            cx_teardowns.extend(quote!(
244                                    drop(#cx_varname_lock);
245                                    #cx_varname.run_until_parked();
246                                    #cx_varname.update(|cx| { cx.background_executor().forbid_parking(); cx.quit(); });
247                                    #cx_varname.run_until_parked();
248                                    drop(#cx_varname);
249                                ));
250                            continue;
251                        }
252                        Some("TestAppContext") => {
253                            let cx_varname = format_ident!("cx_{}", ix);
254                            cx_vars.extend(quote!(
255                                let mut #cx_varname = gpui::TestAppContext::build(
256                                    dispatcher.clone(),
257                                    Some(stringify!(#outer_fn_name))
258                                );
259                                let _entity_refcounts = #cx_varname.app.borrow().ref_counts_drop_handle();
260                            ));
261                            cx_teardowns.extend(quote!(
262                                #cx_varname.run_until_parked();
263                                #cx_varname.update(|cx| { cx.background_executor().forbid_parking(); cx.quit(); });
264                                #cx_varname.run_until_parked();
265                                drop(#cx_varname);
266                            ));
267                            inner_fn_args.extend(quote!(&mut #cx_varname,));
268                            continue;
269                        }
270                        _ => {}
271                    }
272                }
273            }
274
275            return Err(error_with_message("invalid function signature", arg));
276        }
277
278        parse_quote! {
279            #[test]
280            fn #outer_fn_name() {
281                #inner_fn
282
283                gpui::run_test(
284                    #num_iterations,
285                    &[#seeds],
286                    #max_retries,
287                    &mut |dispatcher, _seed| {
288                        #cx_vars
289                        #inner_fn_name(#inner_fn_args);
290                        #cx_teardowns
291                        // Ideally we would only drop cancelled tasks, that way we could detect leaks due to task <-> entity
292                        // cycles as cancelled tasks will be dropped properly once they runnable gets run again
293                        //
294                        // async-task does not give us the power to do this just yet though
295                        dispatcher.drain_tasks();
296                        drop(dispatcher);
297                    },
298                    #on_failure_fn_name,
299                );
300            }
301        }
302    };
303    outer_fn.attrs.extend(inner_fn_attributes);
304
305    Ok(TokenStream::from(quote!(#outer_fn)))
306}
307
308fn parse_usize_from_expr(expr: &Expr) -> Result<usize, syn::Error> {
309    let Expr::Lit(ExprLit {
310        lit: Lit::Int(int), ..
311    }) = expr
312    else {
313        return Err(syn::Error::new(expr.span(), "expected an integer"));
314    };
315    int.base10_parse()
316        .map_err(|_| syn::Error::new(int.span(), "failed to parse integer"))
317}
318
319fn parse_u64_array(meta_list: &MetaList) -> Result<Vec<u64>, syn::Error> {
320    let mut result = Vec::new();
321    let tokens = &meta_list.tokens;
322    let parser = |input: ParseStream| {
323        let exprs = Punctuated::<Expr, Token![,]>::parse_terminated(input)?;
324        for expr in exprs {
325            if let Expr::Lit(ExprLit {
326                lit: Lit::Int(int), ..
327            }) = expr
328            {
329                let value: usize = int.base10_parse()?;
330                result.push(value as u64);
331            } else {
332                return Err(syn::Error::new(expr.span(), "expected an integer"));
333            }
334        }
335        Ok(())
336    };
337    syn::parse::Parser::parse2(parser, tokens.clone())?;
338    Ok(result)
339}
340
341fn error_with_message(message: &str, spanned: impl Spanned) -> TokenStream {
342    error_to_stream(syn::Error::new(spanned.span(), message))
343}
344
345fn error_to_stream(err: syn::Error) -> TokenStream {
346    TokenStream::from(err.into_compile_error())
347}