gpui_macros.rs

  1use proc_macro::TokenStream;
  2use proc_macro2::Ident;
  3use quote::{format_ident, quote};
  4use std::mem;
  5use syn::{
  6    parse_macro_input, parse_quote, spanned::Spanned as _, AttributeArgs, DeriveInput, FnArg,
  7    GenericParam, Generics, ItemFn, Lit, Meta, NestedMeta, Type, WhereClause,
  8};
  9
 10#[proc_macro_attribute]
 11pub fn test(args: TokenStream, function: TokenStream) -> TokenStream {
 12    let mut namespace = format_ident!("gpui");
 13
 14    let args = syn::parse_macro_input!(args as AttributeArgs);
 15    let mut max_retries = 0;
 16    let mut num_iterations = 1;
 17    let mut starting_seed = 0;
 18    let mut detect_nondeterminism = false;
 19    let mut on_failure_fn_name = quote!(None);
 20
 21    for arg in args {
 22        match arg {
 23            NestedMeta::Meta(Meta::Path(name))
 24                if name.get_ident().map_or(false, |n| n == "self") =>
 25            {
 26                namespace = format_ident!("crate");
 27            }
 28            NestedMeta::Meta(Meta::NameValue(meta)) => {
 29                let key_name = meta.path.get_ident().map(|i| i.to_string());
 30                let result = (|| {
 31                    match key_name.as_deref() {
 32                        Some("detect_nondeterminism") => {
 33                            detect_nondeterminism = parse_bool(&meta.lit)?
 34                        }
 35                        Some("retries") => max_retries = parse_int(&meta.lit)?,
 36                        Some("iterations") => num_iterations = parse_int(&meta.lit)?,
 37                        Some("seed") => starting_seed = parse_int(&meta.lit)?,
 38                        Some("on_failure") => {
 39                            if let Lit::Str(name) = meta.lit {
 40                                let mut path = syn::Path {
 41                                    leading_colon: None,
 42                                    segments: Default::default(),
 43                                };
 44                                for part in name.value().split("::") {
 45                                    path.segments.push(Ident::new(part, name.span()).into());
 46                                }
 47                                on_failure_fn_name = quote!(Some(#path));
 48                            } else {
 49                                return Err(TokenStream::from(
 50                                    syn::Error::new(
 51                                        meta.lit.span(),
 52                                        "on_failure argument must be a string",
 53                                    )
 54                                    .into_compile_error(),
 55                                ));
 56                            }
 57                        }
 58                        _ => {
 59                            return Err(TokenStream::from(
 60                                syn::Error::new(meta.path.span(), "invalid argument")
 61                                    .into_compile_error(),
 62                            ))
 63                        }
 64                    }
 65                    Ok(())
 66                })();
 67
 68                if let Err(tokens) = result {
 69                    return tokens;
 70                }
 71            }
 72            other => {
 73                return TokenStream::from(
 74                    syn::Error::new_spanned(other, "invalid argument").into_compile_error(),
 75                )
 76            }
 77        }
 78    }
 79
 80    let mut inner_fn = parse_macro_input!(function as ItemFn);
 81    if max_retries > 0 && num_iterations > 1 {
 82        return TokenStream::from(
 83            syn::Error::new_spanned(inner_fn, "retries and randomized iterations can't be mixed")
 84                .into_compile_error(),
 85        );
 86    }
 87    let inner_fn_attributes = mem::take(&mut inner_fn.attrs);
 88    let inner_fn_name = format_ident!("_{}", inner_fn.sig.ident);
 89    let outer_fn_name = mem::replace(&mut inner_fn.sig.ident, inner_fn_name.clone());
 90
 91    let mut outer_fn: ItemFn = if inner_fn.sig.asyncness.is_some() {
 92        // Pass to the test function the number of app contexts that it needs,
 93        // based on its parameter list.
 94        let mut cx_vars = proc_macro2::TokenStream::new();
 95        let mut cx_teardowns = proc_macro2::TokenStream::new();
 96        let mut inner_fn_args = proc_macro2::TokenStream::new();
 97        for (ix, arg) in inner_fn.sig.inputs.iter().enumerate() {
 98            if let FnArg::Typed(arg) = arg {
 99                if let Type::Path(ty) = &*arg.ty {
100                    let last_segment = ty.path.segments.last();
101                    match last_segment.map(|s| s.ident.to_string()).as_deref() {
102                        Some("StdRng") => {
103                            inner_fn_args.extend(quote!(rand::SeedableRng::seed_from_u64(seed),));
104                            continue;
105                        }
106                        Some("Arc") => {
107                            if let syn::PathArguments::AngleBracketed(args) =
108                                &last_segment.unwrap().arguments
109                            {
110                                if let Some(syn::GenericArgument::Type(syn::Type::Path(ty))) =
111                                    args.args.last()
112                                {
113                                    let last_segment = ty.path.segments.last();
114                                    if let Some("Deterministic") =
115                                        last_segment.map(|s| s.ident.to_string()).as_deref()
116                                    {
117                                        inner_fn_args.extend(quote!(deterministic.clone(),));
118                                        continue;
119                                    }
120                                }
121                            }
122                        }
123                        _ => {}
124                    }
125                } else if let Type::Reference(ty) = &*arg.ty {
126                    if let Type::Path(ty) = &*ty.elem {
127                        let last_segment = ty.path.segments.last();
128                        if let Some("TestAppContext") =
129                            last_segment.map(|s| s.ident.to_string()).as_deref()
130                        {
131                            let first_entity_id = ix * 100_000;
132                            let cx_varname = format_ident!("cx_{}", ix);
133                            cx_vars.extend(quote!(
134                                let mut #cx_varname = #namespace::TestAppContext::new(
135                                    foreground_platform.clone(),
136                                    cx.platform().clone(),
137                                    deterministic.build_foreground(#ix),
138                                    deterministic.build_background(),
139                                    cx.font_cache().clone(),
140                                    cx.leak_detector(),
141                                    #first_entity_id,
142                                    stringify!(#outer_fn_name).to_string(),
143                                );
144                            ));
145                            cx_teardowns.extend(quote!(
146                                #cx_varname.remove_all_windows();
147                                deterministic.run_until_parked();
148                                #cx_varname.update(|cx| cx.clear_globals());
149                            ));
150                            inner_fn_args.extend(quote!(&mut #cx_varname,));
151                            continue;
152                        }
153                    }
154                }
155            }
156
157            return TokenStream::from(
158                syn::Error::new_spanned(arg, "invalid argument").into_compile_error(),
159            );
160        }
161
162        parse_quote! {
163            #[test]
164            fn #outer_fn_name() {
165                #inner_fn
166
167                #namespace::test::run_test(
168                    #num_iterations as u64,
169                    #starting_seed as u64,
170                    #max_retries,
171                    #detect_nondeterminism,
172                    &mut |cx, foreground_platform, deterministic, seed| {
173                        #cx_vars
174                        cx.foreground().run(#inner_fn_name(#inner_fn_args));
175                        #cx_teardowns
176                    },
177                    #on_failure_fn_name,
178                    stringify!(#outer_fn_name).to_string(),
179                );
180            }
181        }
182    } else {
183        // Pass to the test function the number of app contexts that it needs,
184        // based on its parameter list.
185        let mut cx_vars = proc_macro2::TokenStream::new();
186        let mut cx_teardowns = proc_macro2::TokenStream::new();
187        let mut inner_fn_args = proc_macro2::TokenStream::new();
188        for (ix, arg) in inner_fn.sig.inputs.iter().enumerate() {
189            if let FnArg::Typed(arg) = arg {
190                if let Type::Path(ty) = &*arg.ty {
191                    let last_segment = ty.path.segments.last();
192
193                    if let Some("StdRng") = last_segment.map(|s| s.ident.to_string()).as_deref() {
194                        inner_fn_args.extend(quote!(rand::SeedableRng::seed_from_u64(seed),));
195                        continue;
196                    }
197                } else if let Type::Reference(ty) = &*arg.ty {
198                    if let Type::Path(ty) = &*ty.elem {
199                        let last_segment = ty.path.segments.last();
200                        match last_segment.map(|s| s.ident.to_string()).as_deref() {
201                            Some("AppContext") => {
202                                inner_fn_args.extend(quote!(cx,));
203                                continue;
204                            }
205                            Some("TestAppContext") => {
206                                let first_entity_id = ix * 100_000;
207                                let cx_varname = format_ident!("cx_{}", ix);
208                                cx_vars.extend(quote!(
209                                    let mut #cx_varname = #namespace::TestAppContext::new(
210                                        foreground_platform.clone(),
211                                        cx.platform().clone(),
212                                        deterministic.build_foreground(#ix),
213                                        deterministic.build_background(),
214                                        cx.font_cache().clone(),
215                                        cx.leak_detector(),
216                                        #first_entity_id,
217                                        stringify!(#outer_fn_name).to_string(),
218                                    );
219                                ));
220                                cx_teardowns.extend(quote!(
221                                    #cx_varname.remove_all_windows();
222                                    deterministic.run_until_parked();
223                                    #cx_varname.update(|cx| cx.clear_globals());
224                                ));
225                                inner_fn_args.extend(quote!(&mut #cx_varname,));
226                                continue;
227                            }
228                            _ => {}
229                        }
230                    }
231                }
232            }
233
234            return TokenStream::from(
235                syn::Error::new_spanned(arg, "invalid argument").into_compile_error(),
236            );
237        }
238
239        parse_quote! {
240            #[test]
241            fn #outer_fn_name() {
242                #inner_fn
243
244                #namespace::test::run_test(
245                    #num_iterations as u64,
246                    #starting_seed as u64,
247                    #max_retries,
248                    #detect_nondeterminism,
249                    &mut |cx, foreground_platform, deterministic, seed| {
250                        #cx_vars
251                        #inner_fn_name(#inner_fn_args);
252                        #cx_teardowns
253                    },
254                    #on_failure_fn_name,
255                    stringify!(#outer_fn_name).to_string(),
256                );
257            }
258        }
259    };
260    outer_fn.attrs.extend(inner_fn_attributes);
261
262    TokenStream::from(quote!(#outer_fn))
263}
264
265fn parse_int(literal: &Lit) -> Result<usize, TokenStream> {
266    let result = if let Lit::Int(int) = &literal {
267        int.base10_parse()
268    } else {
269        Err(syn::Error::new(literal.span(), "must be an integer"))
270    };
271
272    result.map_err(|err| TokenStream::from(err.into_compile_error()))
273}
274
275fn parse_bool(literal: &Lit) -> Result<bool, TokenStream> {
276    let result = if let Lit::Bool(result) = &literal {
277        Ok(result.value)
278    } else {
279        Err(syn::Error::new(literal.span(), "must be a boolean"))
280    };
281
282    result.map_err(|err| TokenStream::from(err.into_compile_error()))
283}
284
285#[proc_macro_derive(Element)]
286pub fn element_derive(input: TokenStream) -> TokenStream {
287    let ast = parse_macro_input!(input as DeriveInput);
288    let type_name = ast.ident;
289
290    let placeholder_view_generics: Generics = parse_quote! { <V: 'static> };
291    let placeholder_view_type_name: Ident = parse_quote! { V };
292    let view_type_name: Ident;
293    let impl_generics: syn::ImplGenerics<'_>;
294    let type_generics: Option<syn::TypeGenerics<'_>>;
295    let where_clause: Option<&'_ WhereClause>;
296
297    match ast.generics.params.iter().find_map(|param| {
298        if let GenericParam::Type(type_param) = param {
299            Some(type_param.ident.clone())
300        } else {
301            None
302        }
303    }) {
304        Some(type_name) => {
305            view_type_name = type_name;
306            let generics = ast.generics.split_for_impl();
307            impl_generics = generics.0;
308            type_generics = Some(generics.1);
309            where_clause = generics.2;
310        }
311        _ => {
312            view_type_name = placeholder_view_type_name;
313            let generics = placeholder_view_generics.split_for_impl();
314            impl_generics = generics.0;
315            type_generics = None;
316            where_clause = generics.2;
317        }
318    }
319
320    let gen = quote! {
321        impl #impl_generics Element<#view_type_name> for #type_name #type_generics
322        #where_clause
323        {
324
325            type LayoutState = gpui::elements::AnyElement<V>;
326            type PaintState = ();
327
328            fn layout(
329                &mut self,
330                constraint: gpui::SizeConstraint,
331                view: &mut V,
332                cx: &mut gpui::ViewContext<V>,
333            ) -> (gpui::geometry::vector::Vector2F, gpui::elements::AnyElement<V>) {
334                let mut element = self.render(view, cx).into_any();
335                let size = element.layout(constraint, view, cx);
336                (size, element)
337            }
338
339            fn paint(
340                &mut self,
341                bounds: gpui::geometry::rect::RectF,
342                visible_bounds: gpui::geometry::rect::RectF,
343                element: &mut gpui::elements::AnyElement<V>,
344                view: &mut V,
345                cx: &mut gpui::ViewContext<V>,
346            ) {
347                element.paint(bounds.origin(), visible_bounds, view, cx);
348            }
349
350            fn rect_for_text_range(
351                &self,
352                range_utf16: std::ops::Range<usize>,
353                _: gpui::geometry::rect::RectF,
354                _: gpui::geometry::rect::RectF,
355                element: &gpui::elements::AnyElement<V>,
356                _: &(),
357                view: &V,
358                cx: &gpui::ViewContext<V>,
359            ) -> Option<gpui::geometry::rect::RectF> {
360                element.rect_for_text_range(range_utf16, view, cx)
361            }
362
363            fn debug(
364                &self,
365                _: gpui::geometry::rect::RectF,
366                element: &gpui::elements::AnyElement<V>,
367                _: &(),
368                view: &V,
369                cx: &gpui::ViewContext<V>,
370            ) -> gpui::json::Value {
371                element.debug(view, cx)
372            }
373        }
374    };
375
376    gen.into()
377}