test.rs

  1use proc_macro::TokenStream;
  2use proc_macro2::Ident;
  3use quote::{format_ident, quote};
  4use std::mem;
  5use syn::{
  6    self, Expr, ExprLit, FnArg, ItemFn, Lit, Meta, MetaList, PathSegment, Token, Type,
  7    parse::{Parse, ParseStream},
  8    parse_quote,
  9    punctuated::Punctuated,
 10    spanned::Spanned,
 11};
 12
 13struct Args {
 14    seeds: Vec<u64>,
 15    max_retries: usize,
 16    max_iterations: usize,
 17    on_failure_fn_name: proc_macro2::TokenStream,
 18}
 19
 20impl Parse for Args {
 21    fn parse(input: ParseStream) -> Result<Self, syn::Error> {
 22        let mut seeds = Vec::<u64>::new();
 23        let mut max_retries = 0;
 24        let mut max_iterations = 1;
 25        let mut on_failure_fn_name = quote!(None);
 26
 27        let metas = Punctuated::<Meta, Token![,]>::parse_terminated(input)?;
 28
 29        for meta in metas {
 30            let ident = {
 31                let meta_path = match &meta {
 32                    Meta::NameValue(meta) => &meta.path,
 33                    Meta::List(list) => &list.path,
 34                    Meta::Path(path) => {
 35                        return Err(syn::Error::new(path.span(), "invalid path argument"));
 36                    }
 37                };
 38                let Some(ident) = meta_path.get_ident() else {
 39                    return Err(syn::Error::new(meta_path.span(), "unexpected path"));
 40                };
 41                ident.to_string()
 42            };
 43
 44            match (&meta, ident.as_str()) {
 45                (Meta::NameValue(meta), "retries") => {
 46                    max_retries = parse_usize_from_expr(&meta.value)?
 47                }
 48                (Meta::NameValue(meta), "iterations") => {
 49                    max_iterations = parse_usize_from_expr(&meta.value)?
 50                }
 51                (Meta::NameValue(meta), "on_failure") => {
 52                    let Expr::Lit(ExprLit {
 53                        lit: Lit::Str(name),
 54                        ..
 55                    }) = &meta.value
 56                    else {
 57                        return Err(syn::Error::new(
 58                            meta.value.span(),
 59                            "on_failure argument must be a string",
 60                        ));
 61                    };
 62                    let segments = name
 63                        .value()
 64                        .split("::")
 65                        .map(|part| PathSegment::from(Ident::new(part, name.span())))
 66                        .collect();
 67                    let path = syn::Path {
 68                        leading_colon: None,
 69                        segments,
 70                    };
 71                    on_failure_fn_name = quote!(Some(#path));
 72                }
 73                (Meta::NameValue(meta), "seed") => {
 74                    seeds = vec![parse_usize_from_expr(&meta.value)? as u64]
 75                }
 76                (Meta::List(list), "seeds") => seeds = parse_u64_array(&list)?,
 77                (Meta::Path(_), _) => {
 78                    return Err(syn::Error::new(meta.span(), "invalid path argument"));
 79                }
 80                (_, _) => {
 81                    return Err(syn::Error::new(meta.span(), "invalid argument name"));
 82                }
 83            }
 84        }
 85
 86        Ok(Args {
 87            seeds,
 88            max_retries,
 89            max_iterations: max_iterations,
 90            on_failure_fn_name,
 91        })
 92    }
 93}
 94
 95pub fn test(args: TokenStream, function: TokenStream) -> TokenStream {
 96    let args = syn::parse_macro_input!(args as Args);
 97    let mut inner_fn = match syn::parse::<ItemFn>(function) {
 98        Ok(f) => f,
 99        Err(err) => return error_to_stream(err),
100    };
101
102    let inner_fn_attributes = mem::take(&mut inner_fn.attrs);
103    let inner_fn_name = format_ident!("_{}", inner_fn.sig.ident);
104    let outer_fn_name = mem::replace(&mut inner_fn.sig.ident, inner_fn_name.clone());
105
106    let result = generate_test_function(
107        args,
108        inner_fn,
109        inner_fn_attributes,
110        inner_fn_name,
111        outer_fn_name,
112    );
113    match result {
114        Ok(tokens) => tokens,
115        Err(tokens) => tokens,
116    }
117}
118
119fn generate_test_function(
120    args: Args,
121    inner_fn: ItemFn,
122    inner_fn_attributes: Vec<syn::Attribute>,
123    inner_fn_name: Ident,
124    outer_fn_name: Ident,
125) -> Result<TokenStream, TokenStream> {
126    let seeds = &args.seeds;
127    let max_retries = args.max_retries;
128    let num_iterations = args.max_iterations;
129    let on_failure_fn_name = &args.on_failure_fn_name;
130    let seeds = quote!( #(#seeds),* );
131
132    let mut outer_fn: ItemFn = if inner_fn.sig.asyncness.is_some() {
133        // Pass to the test function the number of app contexts that it needs,
134        // based on its parameter list.
135        let mut cx_vars = proc_macro2::TokenStream::new();
136        let mut cx_teardowns = proc_macro2::TokenStream::new();
137        let mut inner_fn_args = proc_macro2::TokenStream::new();
138        for (ix, arg) in inner_fn.sig.inputs.iter().enumerate() {
139            if let FnArg::Typed(arg) = arg {
140                if let Type::Path(ty) = &*arg.ty {
141                    let last_segment = ty.path.segments.last();
142                    match last_segment.map(|s| s.ident.to_string()).as_deref() {
143                        Some("StdRng") => {
144                            inner_fn_args.extend(quote!(rand::SeedableRng::seed_from_u64(_seed),));
145                            continue;
146                        }
147                        Some("BackgroundExecutor") => {
148                            inner_fn_args.extend(quote!(gpui::BackgroundExecutor::new(
149                                std::sync::Arc::new(dispatcher.clone()),
150                            ),));
151                            continue;
152                        }
153                        _ => {}
154                    }
155                } else if let Type::Reference(ty) = &*arg.ty {
156                    if let Type::Path(ty) = &*ty.elem {
157                        let last_segment = ty.path.segments.last();
158                        if let Some("TestAppContext") =
159                            last_segment.map(|s| s.ident.to_string()).as_deref()
160                        {
161                            let cx_varname = format_ident!("cx_{}", ix);
162                            cx_vars.extend(quote!(
163                                let mut #cx_varname = gpui::TestAppContext::build(
164                                    dispatcher.clone(),
165                                    Some(stringify!(#outer_fn_name)),
166                                );
167                            ));
168                            cx_teardowns.extend(quote!(
169                                dispatcher.run_until_parked();
170                                #cx_varname.quit();
171                                dispatcher.run_until_parked();
172                            ));
173                            inner_fn_args.extend(quote!(&mut #cx_varname,));
174                            continue;
175                        }
176                    }
177                }
178            }
179
180            return Err(error_with_message("invalid function signature", arg));
181        }
182
183        parse_quote! {
184            #[test]
185            fn #outer_fn_name() {
186                #inner_fn
187
188                gpui::run_test(
189                    #num_iterations,
190                    &[#seeds],
191                    #max_retries,
192                    &mut |dispatcher, _seed| {
193                        let executor = gpui::BackgroundExecutor::new(std::sync::Arc::new(dispatcher.clone()));
194                        #cx_vars
195                        executor.block_test(#inner_fn_name(#inner_fn_args));
196                        #cx_teardowns
197                    },
198                    #on_failure_fn_name
199                );
200            }
201        }
202    } else {
203        // Pass to the test function the number of app contexts that it needs,
204        // based on its parameter list.
205        let mut cx_vars = proc_macro2::TokenStream::new();
206        let mut cx_teardowns = proc_macro2::TokenStream::new();
207        let mut inner_fn_args = proc_macro2::TokenStream::new();
208        for (ix, arg) in inner_fn.sig.inputs.iter().enumerate() {
209            if let FnArg::Typed(arg) = arg {
210                if let Type::Path(ty) = &*arg.ty {
211                    let last_segment = ty.path.segments.last();
212
213                    if let Some("StdRng") = last_segment.map(|s| s.ident.to_string()).as_deref() {
214                        inner_fn_args.extend(quote!(rand::SeedableRng::seed_from_u64(_seed),));
215                        continue;
216                    }
217                } else if let Type::Reference(ty) = &*arg.ty {
218                    if let Type::Path(ty) = &*ty.elem {
219                        let last_segment = ty.path.segments.last();
220                        match last_segment.map(|s| s.ident.to_string()).as_deref() {
221                            Some("App") => {
222                                let cx_varname = format_ident!("cx_{}", ix);
223                                let cx_varname_lock = format_ident!("cx_{}_lock", ix);
224                                cx_vars.extend(quote!(
225                                    let mut #cx_varname = gpui::TestAppContext::build(
226                                       dispatcher.clone(),
227                                       Some(stringify!(#outer_fn_name))
228                                    );
229                                    let mut #cx_varname_lock = #cx_varname.app.borrow_mut();
230                                ));
231                                inner_fn_args.extend(quote!(&mut #cx_varname_lock,));
232                                cx_teardowns.extend(quote!(
233                                    drop(#cx_varname_lock);
234                                    dispatcher.run_until_parked();
235                                    #cx_varname.update(|cx| { cx.quit() });
236                                    dispatcher.run_until_parked();
237                                ));
238                                continue;
239                            }
240                            Some("TestAppContext") => {
241                                let cx_varname = format_ident!("cx_{}", ix);
242                                cx_vars.extend(quote!(
243                                    let mut #cx_varname = gpui::TestAppContext::build(
244                                        dispatcher.clone(),
245                                        Some(stringify!(#outer_fn_name))
246                                    );
247                                ));
248                                cx_teardowns.extend(quote!(
249                                    dispatcher.run_until_parked();
250                                    #cx_varname.quit();
251                                    dispatcher.run_until_parked();
252                                ));
253                                inner_fn_args.extend(quote!(&mut #cx_varname,));
254                                continue;
255                            }
256                            _ => {}
257                        }
258                    }
259                }
260            }
261
262            return Err(error_with_message("invalid function signature", arg));
263        }
264
265        parse_quote! {
266            #[test]
267            fn #outer_fn_name() {
268                #inner_fn
269
270                gpui::run_test(
271                    #num_iterations,
272                    &[#seeds],
273                    #max_retries,
274                    &mut |dispatcher, _seed| {
275                        #cx_vars
276                        #inner_fn_name(#inner_fn_args);
277                        #cx_teardowns
278                    },
279                    #on_failure_fn_name,
280                );
281            }
282        }
283    };
284    outer_fn.attrs.extend(inner_fn_attributes);
285
286    Ok(TokenStream::from(quote!(#outer_fn)))
287}
288
289fn parse_usize_from_expr(expr: &Expr) -> Result<usize, syn::Error> {
290    let Expr::Lit(ExprLit {
291        lit: Lit::Int(int), ..
292    }) = expr
293    else {
294        return Err(syn::Error::new(expr.span(), "expected an integer"));
295    };
296    int.base10_parse()
297        .map_err(|_| syn::Error::new(int.span(), "failed to parse integer"))
298}
299
300fn parse_u64_array(meta_list: &MetaList) -> Result<Vec<u64>, syn::Error> {
301    let mut result = Vec::new();
302    let tokens = &meta_list.tokens;
303    let parser = |input: ParseStream| {
304        let exprs = Punctuated::<Expr, Token![,]>::parse_terminated(input)?;
305        for expr in exprs {
306            if let Expr::Lit(ExprLit {
307                lit: Lit::Int(int), ..
308            }) = expr
309            {
310                let value: usize = int.base10_parse()?;
311                result.push(value as u64);
312            } else {
313                return Err(syn::Error::new(expr.span(), "expected an integer"));
314            }
315        }
316        Ok(())
317    };
318    syn::parse::Parser::parse2(parser, tokens.clone())?;
319    Ok(result)
320}
321
322fn error_with_message(message: &str, spanned: impl Spanned) -> TokenStream {
323    error_to_stream(syn::Error::new(spanned.span(), message))
324}
325
326fn error_to_stream(err: syn::Error) -> TokenStream {
327    TokenStream::from(err.into_compile_error())
328}