gpui_macros.rs

  1use proc_macro::TokenStream;
  2use proc_macro2::Ident;
  3use quote::{format_ident, quote};
  4use std::mem;
  5use syn::{
  6    parse_macro_input, parse_quote, spanned::Spanned as _, AttributeArgs, DeriveInput, FnArg,
  7    GenericParam, Generics, ItemFn, Lit, Meta, NestedMeta, Type, WhereClause,
  8};
  9
 10#[proc_macro_attribute]
 11pub fn test(args: TokenStream, function: TokenStream) -> TokenStream {
 12    let mut namespace = format_ident!("gpui");
 13
 14    let args = syn::parse_macro_input!(args as AttributeArgs);
 15    let mut max_retries = 0;
 16    let mut num_iterations = 1;
 17    let mut starting_seed = 0;
 18    let mut detect_nondeterminism = false;
 19    let mut on_failure_fn_name = quote!(None);
 20
 21    for arg in args {
 22        match arg {
 23            NestedMeta::Meta(Meta::Path(name))
 24                if name.get_ident().map_or(false, |n| n == "self") =>
 25            {
 26                namespace = format_ident!("crate");
 27            }
 28            NestedMeta::Meta(Meta::NameValue(meta)) => {
 29                let key_name = meta.path.get_ident().map(|i| i.to_string());
 30                let result = (|| {
 31                    match key_name.as_deref() {
 32                        Some("detect_nondeterminism") => {
 33                            detect_nondeterminism = parse_bool(&meta.lit)?
 34                        }
 35                        Some("retries") => max_retries = parse_int(&meta.lit)?,
 36                        Some("iterations") => num_iterations = parse_int(&meta.lit)?,
 37                        Some("seed") => starting_seed = parse_int(&meta.lit)?,
 38                        Some("on_failure") => {
 39                            if let Lit::Str(name) = meta.lit {
 40                                let mut path = syn::Path {
 41                                    leading_colon: None,
 42                                    segments: Default::default(),
 43                                };
 44                                for part in name.value().split("::") {
 45                                    path.segments.push(Ident::new(part, name.span()).into());
 46                                }
 47                                on_failure_fn_name = quote!(Some(#path));
 48                            } else {
 49                                return Err(TokenStream::from(
 50                                    syn::Error::new(
 51                                        meta.lit.span(),
 52                                        "on_failure argument must be a string",
 53                                    )
 54                                    .into_compile_error(),
 55                                ));
 56                            }
 57                        }
 58                        _ => {
 59                            return Err(TokenStream::from(
 60                                syn::Error::new(meta.path.span(), "invalid argument")
 61                                    .into_compile_error(),
 62                            ))
 63                        }
 64                    }
 65                    Ok(())
 66                })();
 67
 68                if let Err(tokens) = result {
 69                    return tokens;
 70                }
 71            }
 72            other => {
 73                return TokenStream::from(
 74                    syn::Error::new_spanned(other, "invalid argument").into_compile_error(),
 75                )
 76            }
 77        }
 78    }
 79
 80    let mut inner_fn = parse_macro_input!(function as ItemFn);
 81    if max_retries > 0 && num_iterations > 1 {
 82        return TokenStream::from(
 83            syn::Error::new_spanned(inner_fn, "retries and randomized iterations can't be mixed")
 84                .into_compile_error(),
 85        );
 86    }
 87    let inner_fn_attributes = mem::take(&mut inner_fn.attrs);
 88    let inner_fn_name = format_ident!("_{}", inner_fn.sig.ident);
 89    let outer_fn_name = mem::replace(&mut inner_fn.sig.ident, inner_fn_name.clone());
 90
 91    let mut outer_fn: ItemFn = if inner_fn.sig.asyncness.is_some() {
 92        // Pass to the test function the number of app contexts that it needs,
 93        // based on its parameter list.
 94        let mut cx_vars = proc_macro2::TokenStream::new();
 95        let mut cx_teardowns = proc_macro2::TokenStream::new();
 96        let mut inner_fn_args = proc_macro2::TokenStream::new();
 97        for (ix, arg) in inner_fn.sig.inputs.iter().enumerate() {
 98            if let FnArg::Typed(arg) = arg {
 99                if let Type::Path(ty) = &*arg.ty {
100                    let last_segment = ty.path.segments.last();
101                    match last_segment.map(|s| s.ident.to_string()).as_deref() {
102                        Some("StdRng") => {
103                            inner_fn_args.extend(quote!(rand::SeedableRng::seed_from_u64(seed),));
104                            continue;
105                        }
106                        Some("Arc") => {
107                            if let syn::PathArguments::AngleBracketed(args) =
108                                &last_segment.unwrap().arguments
109                            {
110                                if let Some(syn::GenericArgument::Type(syn::Type::Path(ty))) =
111                                    args.args.last()
112                                {
113                                    let last_segment = ty.path.segments.last();
114                                    if let Some("Deterministic") =
115                                        last_segment.map(|s| s.ident.to_string()).as_deref()
116                                    {
117                                        inner_fn_args.extend(quote!(deterministic.clone(),));
118                                        continue;
119                                    }
120                                }
121                            }
122                        }
123                        _ => {}
124                    }
125                } else if let Type::Reference(ty) = &*arg.ty {
126                    if let Type::Path(ty) = &*ty.elem {
127                        let last_segment = ty.path.segments.last();
128                        if let Some("TestAppContext") =
129                            last_segment.map(|s| s.ident.to_string()).as_deref()
130                        {
131                            let first_entity_id = ix * 100_000;
132                            let cx_varname = format_ident!("cx_{}", ix);
133                            cx_vars.extend(quote!(
134                                let mut #cx_varname = #namespace::TestAppContext::new(
135                                    foreground_platform.clone(),
136                                    cx.platform().clone(),
137                                    deterministic.build_foreground(#ix),
138                                    deterministic.build_background(),
139                                    cx.font_cache().clone(),
140                                    cx.leak_detector(),
141                                    #first_entity_id,
142                                    stringify!(#outer_fn_name).to_string(),
143                                );
144                            ));
145                            cx_teardowns.extend(quote!(
146                                #cx_varname.remove_all_windows();
147                                deterministic.run_until_parked();
148                                #cx_varname.update(|cx| cx.clear_globals());
149                            ));
150                            inner_fn_args.extend(quote!(&mut #cx_varname,));
151                            continue;
152                        }
153                    }
154                }
155            }
156
157            return TokenStream::from(
158                syn::Error::new_spanned(arg, "invalid argument").into_compile_error(),
159            );
160        }
161
162        parse_quote! {
163            #[test]
164            fn #outer_fn_name() {
165                #inner_fn
166
167                #namespace::test::run_test(
168                    #num_iterations as u64,
169                    #starting_seed as u64,
170                    #max_retries,
171                    #detect_nondeterminism,
172                    &mut |cx, foreground_platform, deterministic, seed| {
173                        // some of the macro contents do not use all variables, silence the warnings
174                        let _ = (&cx, &foreground_platform, &deterministic, &seed);
175                        #cx_vars
176                        cx.foreground().run(#inner_fn_name(#inner_fn_args));
177                        #cx_teardowns
178                    },
179                    #on_failure_fn_name,
180                    stringify!(#outer_fn_name).to_string(),
181                );
182            }
183        }
184    } else {
185        // Pass to the test function the number of app contexts that it needs,
186        // based on its parameter list.
187        let mut cx_vars = proc_macro2::TokenStream::new();
188        let mut cx_teardowns = proc_macro2::TokenStream::new();
189        let mut inner_fn_args = proc_macro2::TokenStream::new();
190        for (ix, arg) in inner_fn.sig.inputs.iter().enumerate() {
191            if let FnArg::Typed(arg) = arg {
192                if let Type::Path(ty) = &*arg.ty {
193                    let last_segment = ty.path.segments.last();
194
195                    if let Some("StdRng") = last_segment.map(|s| s.ident.to_string()).as_deref() {
196                        inner_fn_args.extend(quote!(rand::SeedableRng::seed_from_u64(seed),));
197                        continue;
198                    }
199                } else if let Type::Reference(ty) = &*arg.ty {
200                    if let Type::Path(ty) = &*ty.elem {
201                        let last_segment = ty.path.segments.last();
202                        match last_segment.map(|s| s.ident.to_string()).as_deref() {
203                            Some("AppContext") => {
204                                inner_fn_args.extend(quote!(cx,));
205                                continue;
206                            }
207                            Some("TestAppContext") => {
208                                let first_entity_id = ix * 100_000;
209                                let cx_varname = format_ident!("cx_{}", ix);
210                                cx_vars.extend(quote!(
211                                    let mut #cx_varname = #namespace::TestAppContext::new(
212                                        foreground_platform.clone(),
213                                        cx.platform().clone(),
214                                        deterministic.build_foreground(#ix),
215                                        deterministic.build_background(),
216                                        cx.font_cache().clone(),
217                                        cx.leak_detector(),
218                                        #first_entity_id,
219                                        stringify!(#outer_fn_name).to_string(),
220                                    );
221                                ));
222                                cx_teardowns.extend(quote!(
223                                    #cx_varname.remove_all_windows();
224                                    deterministic.run_until_parked();
225                                    #cx_varname.update(|cx| cx.clear_globals());
226                                ));
227                                inner_fn_args.extend(quote!(&mut #cx_varname,));
228                                continue;
229                            }
230                            _ => {}
231                        }
232                    }
233                }
234            }
235
236            return TokenStream::from(
237                syn::Error::new_spanned(arg, "invalid argument").into_compile_error(),
238            );
239        }
240
241        parse_quote! {
242            #[test]
243            fn #outer_fn_name() {
244                #inner_fn
245
246                #namespace::test::run_test(
247                    #num_iterations as u64,
248                    #starting_seed as u64,
249                    #max_retries,
250                    #detect_nondeterminism,
251                    &mut |cx, foreground_platform, deterministic, seed| {
252                        // some of the macro contents do not use all variables, silence the warnings
253                        let _ = (&cx, &foreground_platform, &deterministic, &seed);
254                        #cx_vars
255                        #inner_fn_name(#inner_fn_args);
256                        #cx_teardowns
257                    },
258                    #on_failure_fn_name,
259                    stringify!(#outer_fn_name).to_string(),
260                );
261            }
262        }
263    };
264    outer_fn.attrs.extend(inner_fn_attributes);
265
266    TokenStream::from(quote!(#outer_fn))
267}
268
269fn parse_int(literal: &Lit) -> Result<usize, TokenStream> {
270    let result = if let Lit::Int(int) = &literal {
271        int.base10_parse()
272    } else {
273        Err(syn::Error::new(literal.span(), "must be an integer"))
274    };
275
276    result.map_err(|err| TokenStream::from(err.into_compile_error()))
277}
278
279fn parse_bool(literal: &Lit) -> Result<bool, TokenStream> {
280    let result = if let Lit::Bool(result) = &literal {
281        Ok(result.value)
282    } else {
283        Err(syn::Error::new(literal.span(), "must be a boolean"))
284    };
285
286    result.map_err(|err| TokenStream::from(err.into_compile_error()))
287}
288
289#[proc_macro_derive(Element)]
290pub fn element_derive(input: TokenStream) -> TokenStream {
291    let ast = parse_macro_input!(input as DeriveInput);
292    let type_name = ast.ident;
293
294    let placeholder_view_generics: Generics = parse_quote! { <V: 'static> };
295    let placeholder_view_type_name: Ident = parse_quote! { V };
296    let view_type_name: Ident;
297    let impl_generics: syn::ImplGenerics<'_>;
298    let type_generics: Option<syn::TypeGenerics<'_>>;
299    let where_clause: Option<&'_ WhereClause>;
300
301    match ast.generics.params.iter().find_map(|param| {
302        if let GenericParam::Type(type_param) = param {
303            Some(type_param.ident.clone())
304        } else {
305            None
306        }
307    }) {
308        Some(type_name) => {
309            view_type_name = type_name;
310            let generics = ast.generics.split_for_impl();
311            impl_generics = generics.0;
312            type_generics = Some(generics.1);
313            where_clause = generics.2;
314        }
315        _ => {
316            view_type_name = placeholder_view_type_name;
317            let generics = placeholder_view_generics.split_for_impl();
318            impl_generics = generics.0;
319            type_generics = None;
320            where_clause = generics.2;
321        }
322    }
323
324    let gen = quote! {
325        impl #impl_generics Element<#view_type_name> for #type_name #type_generics
326        #where_clause
327        {
328
329            type LayoutState = gpui::elements::AnyElement<V>;
330            type PaintState = ();
331
332            fn layout(
333                &mut self,
334                constraint: gpui::SizeConstraint,
335                view: &mut V,
336                cx: &mut gpui::ViewContext<V>,
337            ) -> (gpui::geometry::vector::Vector2F, gpui::elements::AnyElement<V>) {
338                let mut element = self.render(view, cx).into_any();
339                let size = element.layout(constraint, view, cx);
340                (size, element)
341            }
342
343            fn paint(
344                &mut self,
345                bounds: gpui::geometry::rect::RectF,
346                visible_bounds: gpui::geometry::rect::RectF,
347                element: &mut gpui::elements::AnyElement<V>,
348                view: &mut V,
349                cx: &mut gpui::ViewContext<V>,
350            ) {
351                element.paint(bounds.origin(), visible_bounds, view, cx);
352            }
353
354            fn rect_for_text_range(
355                &self,
356                range_utf16: std::ops::Range<usize>,
357                _: gpui::geometry::rect::RectF,
358                _: gpui::geometry::rect::RectF,
359                element: &gpui::elements::AnyElement<V>,
360                _: &(),
361                view: &V,
362                cx: &gpui::ViewContext<V>,
363            ) -> Option<gpui::geometry::rect::RectF> {
364                element.rect_for_text_range(range_utf16, view, cx)
365            }
366
367            fn debug(
368                &self,
369                _: gpui::geometry::rect::RectF,
370                element: &gpui::elements::AnyElement<V>,
371                _: &(),
372                view: &V,
373                cx: &gpui::ViewContext<V>,
374            ) -> gpui::json::Value {
375                element.debug(view, cx)
376            }
377        }
378    };
379
380    gen.into()
381}