test.rs

  1use proc_macro::TokenStream;
  2use proc_macro2::Ident;
  3use quote::{format_ident, quote};
  4use std::mem;
  5use syn::{
  6    parse_macro_input, parse_quote, spanned::Spanned as _, AttributeArgs, FnArg, ItemFn, Lit, Meta,
  7    NestedMeta, Type,
  8};
  9
 10pub fn test(args: TokenStream, function: TokenStream) -> TokenStream {
 11    let args = syn::parse_macro_input!(args as AttributeArgs);
 12    let mut max_retries = 0;
 13    let mut num_iterations = 1;
 14    let mut on_failure_fn_name = quote!(None);
 15
 16    for arg in args {
 17        match arg {
 18            NestedMeta::Meta(Meta::NameValue(meta)) => {
 19                let key_name = meta.path.get_ident().map(|i| i.to_string());
 20                let result = (|| {
 21                    match key_name.as_deref() {
 22                        Some("retries") => max_retries = parse_int(&meta.lit)?,
 23                        Some("iterations") => num_iterations = parse_int(&meta.lit)?,
 24                        Some("on_failure") => {
 25                            if let Lit::Str(name) = meta.lit {
 26                                let mut path = syn::Path {
 27                                    leading_colon: None,
 28                                    segments: Default::default(),
 29                                };
 30                                for part in name.value().split("::") {
 31                                    path.segments.push(Ident::new(part, name.span()).into());
 32                                }
 33                                on_failure_fn_name = quote!(Some(#path));
 34                            } else {
 35                                return Err(TokenStream::from(
 36                                    syn::Error::new(
 37                                        meta.lit.span(),
 38                                        "on_failure argument must be a string",
 39                                    )
 40                                    .into_compile_error(),
 41                                ));
 42                            }
 43                        }
 44                        _ => {
 45                            return Err(TokenStream::from(
 46                                syn::Error::new(meta.path.span(), "invalid argument")
 47                                    .into_compile_error(),
 48                            ))
 49                        }
 50                    }
 51                    Ok(())
 52                })();
 53
 54                if let Err(tokens) = result {
 55                    return tokens;
 56                }
 57            }
 58            other => {
 59                return TokenStream::from(
 60                    syn::Error::new_spanned(other, "invalid argument").into_compile_error(),
 61                )
 62            }
 63        }
 64    }
 65
 66    let mut inner_fn = parse_macro_input!(function as ItemFn);
 67    if max_retries > 0 && num_iterations > 1 {
 68        return TokenStream::from(
 69            syn::Error::new_spanned(inner_fn, "retries and randomized iterations can't be mixed")
 70                .into_compile_error(),
 71        );
 72    }
 73    let inner_fn_attributes = mem::take(&mut inner_fn.attrs);
 74    let inner_fn_name = format_ident!("_{}", inner_fn.sig.ident);
 75    let outer_fn_name = mem::replace(&mut inner_fn.sig.ident, inner_fn_name.clone());
 76
 77    let mut outer_fn: ItemFn = if inner_fn.sig.asyncness.is_some() {
 78        // Pass to the test function the number of app contexts that it needs,
 79        // based on its parameter list.
 80        let mut cx_vars = proc_macro2::TokenStream::new();
 81        let mut cx_teardowns = proc_macro2::TokenStream::new();
 82        let mut inner_fn_args = proc_macro2::TokenStream::new();
 83        for (ix, arg) in inner_fn.sig.inputs.iter().enumerate() {
 84            if let FnArg::Typed(arg) = arg {
 85                if let Type::Path(ty) = &*arg.ty {
 86                    let last_segment = ty.path.segments.last();
 87                    match last_segment.map(|s| s.ident.to_string()).as_deref() {
 88                        Some("StdRng") => {
 89                            inner_fn_args.extend(quote!(rand::SeedableRng::seed_from_u64(_seed),));
 90                            continue;
 91                        }
 92                        Some("Executor") => {
 93                            inner_fn_args.extend(quote!(gpui2::Executor::new(
 94                                std::sync::Arc::new(dispatcher.clone())
 95                            ),));
 96                            continue;
 97                        }
 98                        _ => {}
 99                    }
100                } else if let Type::Reference(ty) = &*arg.ty {
101                    if let Type::Path(ty) = &*ty.elem {
102                        let last_segment = ty.path.segments.last();
103                        if let Some("TestAppContext") =
104                            last_segment.map(|s| s.ident.to_string()).as_deref()
105                        {
106                            let cx_varname = format_ident!("cx_{}", ix);
107                            cx_vars.extend(quote!(
108                                let mut #cx_varname = gpui2::TestAppContext::new(
109                                    dispatcher.clone()
110                                );
111                            ));
112                            cx_teardowns.extend(quote!(
113                                #cx_varname.quit();
114                                dispatcher.run_until_parked();
115                            ));
116                            inner_fn_args.extend(quote!(&mut #cx_varname,));
117                            continue;
118                        }
119                    }
120                }
121            }
122
123            return TokenStream::from(
124                syn::Error::new_spanned(arg, "invalid argument").into_compile_error(),
125            );
126        }
127
128        parse_quote! {
129            #[test]
130            fn #outer_fn_name() {
131                #inner_fn
132
133                gpui2::run_test(
134                    #num_iterations as u64,
135                    #max_retries,
136                    &mut |dispatcher, _seed| {
137                        let executor = gpui2::Executor::new(std::sync::Arc::new(dispatcher.clone()));
138                        #cx_vars
139                        executor.block(#inner_fn_name(#inner_fn_args));
140                        #cx_teardowns
141                    },
142                    #on_failure_fn_name,
143                    stringify!(#outer_fn_name).to_string(),
144                );
145            }
146        }
147    } else {
148        // Pass to the test function the number of app contexts that it needs,
149        // based on its parameter list.
150        let mut cx_vars = proc_macro2::TokenStream::new();
151        let mut cx_teardowns = proc_macro2::TokenStream::new();
152        let mut inner_fn_args = proc_macro2::TokenStream::new();
153        for (ix, arg) in inner_fn.sig.inputs.iter().enumerate() {
154            if let FnArg::Typed(arg) = arg {
155                if let Type::Path(ty) = &*arg.ty {
156                    let last_segment = ty.path.segments.last();
157
158                    if let Some("StdRng") = last_segment.map(|s| s.ident.to_string()).as_deref() {
159                        inner_fn_args.extend(quote!(rand::SeedableRng::seed_from_u64(_seed),));
160                        continue;
161                    }
162                } else if let Type::Reference(ty) = &*arg.ty {
163                    if let Type::Path(ty) = &*ty.elem {
164                        let last_segment = ty.path.segments.last();
165                        match last_segment.map(|s| s.ident.to_string()).as_deref() {
166                            Some("AppContext") => {
167                                let cx_varname = format_ident!("cx_{}", ix);
168                                let cx_varname_lock = format_ident!("cx_{}_lock", ix);
169                                cx_vars.extend(quote!(
170                                    let mut #cx_varname = gpui2::TestAppContext::new(
171                                       dispatcher.clone()
172                                    );
173                                    let mut #cx_varname_lock = #cx_varname.app.lock();
174                                ));
175                                inner_fn_args.extend(quote!(&mut #cx_varname_lock,));
176                                cx_teardowns.extend(quote!(
177                                    #cx_varname_lock.quit();
178                                    dispatcher.run_until_parked();
179                                ));
180                                continue;
181                            }
182                            Some("TestAppContext") => {
183                                let cx_varname = format_ident!("cx_{}", ix);
184                                cx_vars.extend(quote!(
185                                    let mut #cx_varname = gpui2::TestAppContext::new(
186                                        dispatcher.clone()
187                                    );
188                                ));
189                                cx_teardowns.extend(quote!(
190                                    #cx_varname.quit();
191                                    dispatcher.run_until_parked();
192                                ));
193                                inner_fn_args.extend(quote!(&mut #cx_varname,));
194                                continue;
195                            }
196                            _ => {}
197                        }
198                    }
199                }
200            }
201
202            return TokenStream::from(
203                syn::Error::new_spanned(arg, "invalid argument").into_compile_error(),
204            );
205        }
206
207        parse_quote! {
208            #[test]
209            fn #outer_fn_name() {
210                #inner_fn
211
212                gpui2::run_test(
213                    #num_iterations as u64,
214                    #max_retries,
215                    &mut |dispatcher, _seed| {
216                        #cx_vars
217                        #inner_fn_name(#inner_fn_args);
218                        #cx_teardowns
219                    },
220                    #on_failure_fn_name,
221                    stringify!(#outer_fn_name).to_string(),
222                );
223            }
224        }
225    };
226    outer_fn.attrs.extend(inner_fn_attributes);
227
228    TokenStream::from(quote!(#outer_fn))
229}
230
231fn parse_int(literal: &Lit) -> Result<usize, TokenStream> {
232    let result = if let Lit::Int(int) = &literal {
233        int.base10_parse()
234    } else {
235        Err(syn::Error::new(literal.span(), "must be an integer"))
236    };
237
238    result.map_err(|err| TokenStream::from(err.into_compile_error()))
239}