gpui_macros.rs

  1use proc_macro::TokenStream;
  2use proc_macro2::Ident;
  3use quote::{format_ident, quote};
  4use std::mem;
  5use syn::{
  6    parse_macro_input, parse_quote, spanned::Spanned as _, AttributeArgs, DeriveInput, FnArg,
  7    GenericParam, Generics, ItemFn, Lit, Meta, NestedMeta, Type, WhereClause,
  8};
  9
 10#[proc_macro_attribute]
 11pub fn test(args: TokenStream, function: TokenStream) -> TokenStream {
 12    let mut namespace = format_ident!("gpui");
 13
 14    let args = syn::parse_macro_input!(args as AttributeArgs);
 15    let mut max_retries = 0;
 16    let mut num_iterations = 1;
 17    let mut starting_seed = 0;
 18    let mut detect_nondeterminism = false;
 19    let mut on_failure_fn_name = quote!(None);
 20
 21    for arg in args {
 22        match arg {
 23            NestedMeta::Meta(Meta::Path(name))
 24                if name.get_ident().map_or(false, |n| n == "self") =>
 25            {
 26                namespace = format_ident!("crate");
 27            }
 28            NestedMeta::Meta(Meta::NameValue(meta)) => {
 29                let key_name = meta.path.get_ident().map(|i| i.to_string());
 30                let result = (|| {
 31                    match key_name.as_deref() {
 32                        Some("detect_nondeterminism") => {
 33                            detect_nondeterminism = parse_bool(&meta.lit)?
 34                        }
 35                        Some("retries") => max_retries = parse_int(&meta.lit)?,
 36                        Some("iterations") => num_iterations = parse_int(&meta.lit)?,
 37                        Some("seed") => starting_seed = parse_int(&meta.lit)?,
 38                        Some("on_failure") => {
 39                            if let Lit::Str(name) = meta.lit {
 40                                let ident = Ident::new(&name.value(), name.span());
 41                                on_failure_fn_name = quote!(Some(#ident));
 42                            } else {
 43                                return Err(TokenStream::from(
 44                                    syn::Error::new(
 45                                        meta.lit.span(),
 46                                        "on_failure argument must be a string",
 47                                    )
 48                                    .into_compile_error(),
 49                                ));
 50                            }
 51                        }
 52                        _ => {
 53                            return Err(TokenStream::from(
 54                                syn::Error::new(meta.path.span(), "invalid argument")
 55                                    .into_compile_error(),
 56                            ))
 57                        }
 58                    }
 59                    Ok(())
 60                })();
 61
 62                if let Err(tokens) = result {
 63                    return tokens;
 64                }
 65            }
 66            other => {
 67                return TokenStream::from(
 68                    syn::Error::new_spanned(other, "invalid argument").into_compile_error(),
 69                )
 70            }
 71        }
 72    }
 73
 74    let mut inner_fn = parse_macro_input!(function as ItemFn);
 75    if max_retries > 0 && num_iterations > 1 {
 76        return TokenStream::from(
 77            syn::Error::new_spanned(inner_fn, "retries and randomized iterations can't be mixed")
 78                .into_compile_error(),
 79        );
 80    }
 81    let inner_fn_attributes = mem::take(&mut inner_fn.attrs);
 82    let inner_fn_name = format_ident!("_{}", inner_fn.sig.ident);
 83    let outer_fn_name = mem::replace(&mut inner_fn.sig.ident, inner_fn_name.clone());
 84
 85    let mut outer_fn: ItemFn = if inner_fn.sig.asyncness.is_some() {
 86        // Pass to the test function the number of app contexts that it needs,
 87        // based on its parameter list.
 88        let mut cx_vars = proc_macro2::TokenStream::new();
 89        let mut cx_teardowns = proc_macro2::TokenStream::new();
 90        let mut inner_fn_args = proc_macro2::TokenStream::new();
 91        for (ix, arg) in inner_fn.sig.inputs.iter().enumerate() {
 92            if let FnArg::Typed(arg) = arg {
 93                if let Type::Path(ty) = &*arg.ty {
 94                    let last_segment = ty.path.segments.last();
 95                    match last_segment.map(|s| s.ident.to_string()).as_deref() {
 96                        Some("StdRng") => {
 97                            inner_fn_args.extend(quote!(rand::SeedableRng::seed_from_u64(seed),));
 98                            continue;
 99                        }
100                        Some("Arc") => {
101                            if let syn::PathArguments::AngleBracketed(args) =
102                                &last_segment.unwrap().arguments
103                            {
104                                if let Some(syn::GenericArgument::Type(syn::Type::Path(ty))) =
105                                    args.args.last()
106                                {
107                                    let last_segment = ty.path.segments.last();
108                                    if let Some("Deterministic") =
109                                        last_segment.map(|s| s.ident.to_string()).as_deref()
110                                    {
111                                        inner_fn_args.extend(quote!(deterministic.clone(),));
112                                        continue;
113                                    }
114                                }
115                            }
116                        }
117                        _ => {}
118                    }
119                } else if let Type::Reference(ty) = &*arg.ty {
120                    if let Type::Path(ty) = &*ty.elem {
121                        let last_segment = ty.path.segments.last();
122                        if let Some("TestAppContext") =
123                            last_segment.map(|s| s.ident.to_string()).as_deref()
124                        {
125                            let first_entity_id = ix * 100_000;
126                            let cx_varname = format_ident!("cx_{}", ix);
127                            cx_vars.extend(quote!(
128                                let mut #cx_varname = #namespace::TestAppContext::new(
129                                    foreground_platform.clone(),
130                                    cx.platform().clone(),
131                                    deterministic.build_foreground(#ix),
132                                    deterministic.build_background(),
133                                    cx.font_cache().clone(),
134                                    cx.leak_detector(),
135                                    #first_entity_id,
136                                    stringify!(#outer_fn_name).to_string(),
137                                );
138                            ));
139                            cx_teardowns.extend(quote!(
140                                #cx_varname.remove_all_windows();
141                                deterministic.run_until_parked();
142                                #cx_varname.update(|cx| cx.clear_globals());
143                            ));
144                            inner_fn_args.extend(quote!(&mut #cx_varname,));
145                            continue;
146                        }
147                    }
148                }
149            }
150
151            return TokenStream::from(
152                syn::Error::new_spanned(arg, "invalid argument").into_compile_error(),
153            );
154        }
155
156        parse_quote! {
157            #[test]
158            fn #outer_fn_name() {
159                #inner_fn
160
161                #namespace::test::run_test(
162                    #num_iterations as u64,
163                    #starting_seed as u64,
164                    #max_retries,
165                    #detect_nondeterminism,
166                    &mut |cx, foreground_platform, deterministic, seed| {
167                        #cx_vars
168                        cx.foreground().run(#inner_fn_name(#inner_fn_args));
169                        #cx_teardowns
170                    },
171                    #on_failure_fn_name,
172                    stringify!(#outer_fn_name).to_string(),
173                );
174            }
175        }
176    } else {
177        // Pass to the test function the number of app contexts that it needs,
178        // based on its parameter list.
179        let mut cx_vars = proc_macro2::TokenStream::new();
180        let mut cx_teardowns = proc_macro2::TokenStream::new();
181        let mut inner_fn_args = proc_macro2::TokenStream::new();
182        for (ix, arg) in inner_fn.sig.inputs.iter().enumerate() {
183            if let FnArg::Typed(arg) = arg {
184                if let Type::Path(ty) = &*arg.ty {
185                    let last_segment = ty.path.segments.last();
186
187                    if let Some("StdRng") = last_segment.map(|s| s.ident.to_string()).as_deref() {
188                        inner_fn_args.extend(quote!(rand::SeedableRng::seed_from_u64(seed),));
189                        continue;
190                    }
191                } else if let Type::Reference(ty) = &*arg.ty {
192                    if let Type::Path(ty) = &*ty.elem {
193                        let last_segment = ty.path.segments.last();
194                        match last_segment.map(|s| s.ident.to_string()).as_deref() {
195                            Some("AppContext") => {
196                                inner_fn_args.extend(quote!(cx,));
197                                continue;
198                            }
199                            Some("TestAppContext") => {
200                                let first_entity_id = ix * 100_000;
201                                let cx_varname = format_ident!("cx_{}", ix);
202                                cx_vars.extend(quote!(
203                                    let mut #cx_varname = #namespace::TestAppContext::new(
204                                        foreground_platform.clone(),
205                                        cx.platform().clone(),
206                                        deterministic.build_foreground(#ix),
207                                        deterministic.build_background(),
208                                        cx.font_cache().clone(),
209                                        cx.leak_detector(),
210                                        #first_entity_id,
211                                        stringify!(#outer_fn_name).to_string(),
212                                    );
213                                ));
214                                cx_teardowns.extend(quote!(
215                                    #cx_varname.remove_all_windows();
216                                    deterministic.run_until_parked();
217                                    #cx_varname.update(|cx| cx.clear_globals());
218                                ));
219                                inner_fn_args.extend(quote!(&mut #cx_varname,));
220                                continue;
221                            }
222                            _ => {}
223                        }
224                    }
225                }
226            }
227
228            return TokenStream::from(
229                syn::Error::new_spanned(arg, "invalid argument").into_compile_error(),
230            );
231        }
232
233        parse_quote! {
234            #[test]
235            fn #outer_fn_name() {
236                #inner_fn
237
238                #namespace::test::run_test(
239                    #num_iterations as u64,
240                    #starting_seed as u64,
241                    #max_retries,
242                    #detect_nondeterminism,
243                    &mut |cx, foreground_platform, deterministic, seed| {
244                        #cx_vars
245                        #inner_fn_name(#inner_fn_args);
246                        #cx_teardowns
247                    },
248                    #on_failure_fn_name,
249                    stringify!(#outer_fn_name).to_string(),
250                );
251            }
252        }
253    };
254    outer_fn.attrs.extend(inner_fn_attributes);
255
256    TokenStream::from(quote!(#outer_fn))
257}
258
259fn parse_int(literal: &Lit) -> Result<usize, TokenStream> {
260    let result = if let Lit::Int(int) = &literal {
261        int.base10_parse()
262    } else {
263        Err(syn::Error::new(literal.span(), "must be an integer"))
264    };
265
266    result.map_err(|err| TokenStream::from(err.into_compile_error()))
267}
268
269fn parse_bool(literal: &Lit) -> Result<bool, TokenStream> {
270    let result = if let Lit::Bool(result) = &literal {
271        Ok(result.value)
272    } else {
273        Err(syn::Error::new(literal.span(), "must be a boolean"))
274    };
275
276    result.map_err(|err| TokenStream::from(err.into_compile_error()))
277}
278
279#[proc_macro_derive(Element)]
280pub fn element_derive(input: TokenStream) -> TokenStream {
281    let ast = parse_macro_input!(input as DeriveInput);
282    let type_name = ast.ident;
283
284    let placeholder_view_generics: Generics = parse_quote! { <V: 'static> };
285    let placeholder_view_type_name: Ident = parse_quote! { V };
286    let view_type_name: Ident;
287    let impl_generics: syn::ImplGenerics<'_>;
288    let type_generics: Option<syn::TypeGenerics<'_>>;
289    let where_clause: Option<&'_ WhereClause>;
290
291    match ast.generics.params.iter().find_map(|param| {
292        if let GenericParam::Type(type_param) = param {
293            Some(type_param.ident.clone())
294        } else {
295            None
296        }
297    }) {
298        Some(type_name) => {
299            view_type_name = type_name;
300            let generics = ast.generics.split_for_impl();
301            impl_generics = generics.0;
302            type_generics = Some(generics.1);
303            where_clause = generics.2;
304        }
305        _ => {
306            view_type_name = placeholder_view_type_name;
307            let generics = placeholder_view_generics.split_for_impl();
308            impl_generics = generics.0;
309            type_generics = None;
310            where_clause = generics.2;
311        }
312    }
313
314    let gen = quote! {
315        impl #impl_generics Element<#view_type_name> for #type_name #type_generics
316        #where_clause
317        {
318
319            type LayoutState = gpui::elements::AnyElement<V>;
320            type PaintState = ();
321
322            fn layout(
323                &mut self,
324                constraint: gpui::SizeConstraint,
325                view: &mut V,
326                cx: &mut gpui::LayoutContext<V>,
327            ) -> (gpui::geometry::vector::Vector2F, gpui::elements::AnyElement<V>) {
328                let mut element = self.render(view, cx).into_any();
329                let size = element.layout(constraint, view, cx);
330                (size, element)
331            }
332
333            fn paint(
334                &mut self,
335                scene: &mut gpui::SceneBuilder,
336                bounds: gpui::geometry::rect::RectF,
337                visible_bounds: gpui::geometry::rect::RectF,
338                element: &mut gpui::elements::AnyElement<V>,
339                view: &mut V,
340                cx: &mut gpui::PaintContext<V>,
341            ) {
342                element.paint(scene, bounds.origin(), visible_bounds, view, cx);
343            }
344
345            fn rect_for_text_range(
346                &self,
347                range_utf16: std::ops::Range<usize>,
348                _: gpui::geometry::rect::RectF,
349                _: gpui::geometry::rect::RectF,
350                element: &gpui::elements::AnyElement<V>,
351                _: &(),
352                view: &V,
353                cx: &gpui::ViewContext<V>,
354            ) -> Option<gpui::geometry::rect::RectF> {
355                element.rect_for_text_range(range_utf16, view, cx)
356            }
357
358            fn debug(
359                &self,
360                _: gpui::geometry::rect::RectF,
361                element: &gpui::elements::AnyElement<V>,
362                _: &(),
363                view: &V,
364                cx: &gpui::ViewContext<V>,
365            ) -> gpui::json::Value {
366                element.debug(view, cx)
367            }
368        }
369    };
370
371    gen.into()
372}