gpui_macros.rs

  1use proc_macro::TokenStream;
  2use proc_macro2::Ident;
  3use quote::{format_ident, quote};
  4use std::mem;
  5use syn::{
  6    parse_macro_input, parse_quote, spanned::Spanned as _, AttributeArgs, FnArg, ItemFn, Lit, Meta,
  7    NestedMeta, Type,
  8};
  9
 10#[proc_macro_attribute]
 11pub fn test(args: TokenStream, function: TokenStream) -> TokenStream {
 12    let mut namespace = format_ident!("gpui");
 13
 14    let args = syn::parse_macro_input!(args as AttributeArgs);
 15    let mut max_retries = 0;
 16    let mut num_iterations = 1;
 17    let mut starting_seed = 0;
 18    let mut detect_nondeterminism = false;
 19    let mut on_failure_fn_name = quote!(None);
 20
 21    for arg in args {
 22        match arg {
 23            NestedMeta::Meta(Meta::Path(name))
 24                if name.get_ident().map_or(false, |n| n == "self") =>
 25            {
 26                namespace = format_ident!("crate");
 27            }
 28            NestedMeta::Meta(Meta::NameValue(meta)) => {
 29                let key_name = meta.path.get_ident().map(|i| i.to_string());
 30                let result = (|| {
 31                    match key_name.as_deref() {
 32                        Some("detect_nondeterminism") => {
 33                            detect_nondeterminism = parse_bool(&meta.lit)?
 34                        }
 35                        Some("retries") => max_retries = parse_int(&meta.lit)?,
 36                        Some("iterations") => num_iterations = parse_int(&meta.lit)?,
 37                        Some("seed") => starting_seed = parse_int(&meta.lit)?,
 38                        Some("on_failure") => {
 39                            if let Lit::Str(name) = meta.lit {
 40                                let ident = Ident::new(&name.value(), name.span());
 41                                on_failure_fn_name = quote!(Some(#ident));
 42                            } else {
 43                                return Err(TokenStream::from(
 44                                    syn::Error::new(
 45                                        meta.lit.span(),
 46                                        "on_failure argument must be a string",
 47                                    )
 48                                    .into_compile_error(),
 49                                ));
 50                            }
 51                        }
 52                        _ => {
 53                            return Err(TokenStream::from(
 54                                syn::Error::new(meta.path.span(), "invalid argument")
 55                                    .into_compile_error(),
 56                            ))
 57                        }
 58                    }
 59                    Ok(())
 60                })();
 61
 62                if let Err(tokens) = result {
 63                    return tokens;
 64                }
 65            }
 66            other => {
 67                return TokenStream::from(
 68                    syn::Error::new_spanned(other, "invalid argument").into_compile_error(),
 69                )
 70            }
 71        }
 72    }
 73
 74    let mut inner_fn = parse_macro_input!(function as ItemFn);
 75    if max_retries > 0 && num_iterations > 1 {
 76        return TokenStream::from(
 77            syn::Error::new_spanned(inner_fn, "retries and randomized iterations can't be mixed")
 78                .into_compile_error(),
 79        );
 80    }
 81    let inner_fn_attributes = mem::take(&mut inner_fn.attrs);
 82    let inner_fn_name = format_ident!("_{}", inner_fn.sig.ident);
 83    let outer_fn_name = mem::replace(&mut inner_fn.sig.ident, inner_fn_name.clone());
 84
 85    let mut outer_fn: ItemFn = if inner_fn.sig.asyncness.is_some() {
 86        // Pass to the test function the number of app contexts that it needs,
 87        // based on its parameter list.
 88        let mut cx_vars = proc_macro2::TokenStream::new();
 89        let mut cx_teardowns = proc_macro2::TokenStream::new();
 90        let mut inner_fn_args = proc_macro2::TokenStream::new();
 91        for (ix, arg) in inner_fn.sig.inputs.iter().enumerate() {
 92            if let FnArg::Typed(arg) = arg {
 93                if let Type::Path(ty) = &*arg.ty {
 94                    let last_segment = ty.path.segments.last();
 95                    match last_segment.map(|s| s.ident.to_string()).as_deref() {
 96                        Some("StdRng") => {
 97                            inner_fn_args.extend(quote!(rand::SeedableRng::seed_from_u64(seed),));
 98                            continue;
 99                        }
100                        Some("Arc") => {
101                            if let syn::PathArguments::AngleBracketed(args) =
102                                &last_segment.unwrap().arguments
103                            {
104                                if let Some(syn::GenericArgument::Type(syn::Type::Path(ty))) =
105                                    args.args.last()
106                                {
107                                    let last_segment = ty.path.segments.last();
108                                    if let Some("Deterministic") =
109                                        last_segment.map(|s| s.ident.to_string()).as_deref()
110                                    {
111                                        inner_fn_args.extend(quote!(deterministic.clone(),));
112                                        continue;
113                                    }
114                                }
115                            }
116                        }
117                        _ => {}
118                    }
119                } else if let Type::Reference(ty) = &*arg.ty {
120                    if let Type::Path(ty) = &*ty.elem {
121                        let last_segment = ty.path.segments.last();
122                        if let Some("TestAppContext") =
123                            last_segment.map(|s| s.ident.to_string()).as_deref()
124                        {
125                            let first_entity_id = ix * 100_000;
126                            let cx_varname = format_ident!("cx_{}", ix);
127                            cx_vars.extend(quote!(
128                                let mut #cx_varname = #namespace::TestAppContext::new(
129                                    foreground_platform.clone(),
130                                    cx.platform().clone(),
131                                    deterministic.build_foreground(#ix),
132                                    deterministic.build_background(),
133                                    cx.font_cache().clone(),
134                                    cx.leak_detector(),
135                                    #first_entity_id,
136                                    stringify!(#outer_fn_name).to_string(),
137                                );
138                            ));
139                            cx_teardowns.extend(quote!(
140                                #cx_varname.update(|cx| cx.remove_all_windows());
141                                deterministic.run_until_parked();
142                                #cx_varname.update(|cx| cx.clear_globals());
143                            ));
144                            inner_fn_args.extend(quote!(&mut #cx_varname,));
145                            continue;
146                        }
147                    }
148                }
149            }
150
151            return TokenStream::from(
152                syn::Error::new_spanned(arg, "invalid argument").into_compile_error(),
153            );
154        }
155
156        parse_quote! {
157            #[test]
158            fn #outer_fn_name() {
159                #inner_fn
160
161                #namespace::test::run_test(
162                    #num_iterations as u64,
163                    #starting_seed as u64,
164                    #max_retries,
165                    #detect_nondeterminism,
166                    &mut |cx, foreground_platform, deterministic, seed| {
167                        #cx_vars
168                        cx.foreground().run(#inner_fn_name(#inner_fn_args));
169                        #cx_teardowns
170                    },
171                    #on_failure_fn_name,
172                    stringify!(#outer_fn_name).to_string(),
173                );
174            }
175        }
176    } else {
177        let mut inner_fn_args = proc_macro2::TokenStream::new();
178        for arg in inner_fn.sig.inputs.iter() {
179            if let FnArg::Typed(arg) = arg {
180                if let Type::Path(ty) = &*arg.ty {
181                    let last_segment = ty.path.segments.last();
182
183                    if let Some("StdRng") = last_segment.map(|s| s.ident.to_string()).as_deref() {
184                        inner_fn_args.extend(quote!(rand::SeedableRng::seed_from_u64(seed),));
185                    }
186                } else {
187                    inner_fn_args.extend(quote!(cx,));
188                }
189            } else {
190                return TokenStream::from(
191                    syn::Error::new_spanned(arg, "invalid argument").into_compile_error(),
192                );
193            }
194        }
195
196        parse_quote! {
197            #[test]
198            fn #outer_fn_name() {
199                #inner_fn
200
201                #namespace::test::run_test(
202                    #num_iterations as u64,
203                    #starting_seed as u64,
204                    #max_retries,
205                    #detect_nondeterminism,
206                    &mut |cx, _, _, seed| #inner_fn_name(#inner_fn_args),
207                    #on_failure_fn_name,
208                    stringify!(#outer_fn_name).to_string(),
209                );
210            }
211        }
212    };
213    outer_fn.attrs.extend(inner_fn_attributes);
214
215    TokenStream::from(quote!(#outer_fn))
216}
217
218fn parse_int(literal: &Lit) -> Result<usize, TokenStream> {
219    let result = if let Lit::Int(int) = &literal {
220        int.base10_parse()
221    } else {
222        Err(syn::Error::new(literal.span(), "must be an integer"))
223    };
224
225    result.map_err(|err| TokenStream::from(err.into_compile_error()))
226}
227
228fn parse_bool(literal: &Lit) -> Result<bool, TokenStream> {
229    let result = if let Lit::Bool(result) = &literal {
230        Ok(result.value)
231    } else {
232        Err(syn::Error::new(literal.span(), "must be a boolean"))
233    };
234
235    result.map_err(|err| TokenStream::from(err.into_compile_error()))
236}