1use proc_macro::TokenStream;
2use proc_macro2::Ident;
3use quote::{format_ident, quote};
4use std::mem;
5use syn::{
6 parse_macro_input, parse_quote, spanned::Spanned as _, AttributeArgs, DeriveInput, FnArg,
7 GenericParam, Generics, ItemFn, Lit, Meta, NestedMeta, Type, WhereClause,
8};
9
10#[proc_macro_attribute]
11pub fn test(args: TokenStream, function: TokenStream) -> TokenStream {
12 let mut namespace = format_ident!("gpui");
13
14 let args = syn::parse_macro_input!(args as AttributeArgs);
15 let mut max_retries = 0;
16 let mut num_iterations = 1;
17 let mut starting_seed = 0;
18 let mut detect_nondeterminism = false;
19 let mut on_failure_fn_name = quote!(None);
20
21 for arg in args {
22 match arg {
23 NestedMeta::Meta(Meta::Path(name))
24 if name.get_ident().map_or(false, |n| n == "self") =>
25 {
26 namespace = format_ident!("crate");
27 }
28 NestedMeta::Meta(Meta::NameValue(meta)) => {
29 let key_name = meta.path.get_ident().map(|i| i.to_string());
30 let result = (|| {
31 match key_name.as_deref() {
32 Some("detect_nondeterminism") => {
33 detect_nondeterminism = parse_bool(&meta.lit)?
34 }
35 Some("retries") => max_retries = parse_int(&meta.lit)?,
36 Some("iterations") => num_iterations = parse_int(&meta.lit)?,
37 Some("seed") => starting_seed = parse_int(&meta.lit)?,
38 Some("on_failure") => {
39 if let Lit::Str(name) = meta.lit {
40 let mut path = syn::Path {
41 leading_colon: None,
42 segments: Default::default(),
43 };
44 for part in name.value().split("::") {
45 path.segments.push(Ident::new(part, name.span()).into());
46 }
47 on_failure_fn_name = quote!(Some(#path));
48 } else {
49 return Err(TokenStream::from(
50 syn::Error::new(
51 meta.lit.span(),
52 "on_failure argument must be a string",
53 )
54 .into_compile_error(),
55 ));
56 }
57 }
58 _ => {
59 return Err(TokenStream::from(
60 syn::Error::new(meta.path.span(), "invalid argument")
61 .into_compile_error(),
62 ))
63 }
64 }
65 Ok(())
66 })();
67
68 if let Err(tokens) = result {
69 return tokens;
70 }
71 }
72 other => {
73 return TokenStream::from(
74 syn::Error::new_spanned(other, "invalid argument").into_compile_error(),
75 )
76 }
77 }
78 }
79
80 let mut inner_fn = parse_macro_input!(function as ItemFn);
81 if max_retries > 0 && num_iterations > 1 {
82 return TokenStream::from(
83 syn::Error::new_spanned(inner_fn, "retries and randomized iterations can't be mixed")
84 .into_compile_error(),
85 );
86 }
87 let inner_fn_attributes = mem::take(&mut inner_fn.attrs);
88 let inner_fn_name = format_ident!("_{}", inner_fn.sig.ident);
89 let outer_fn_name = mem::replace(&mut inner_fn.sig.ident, inner_fn_name.clone());
90
91 let mut outer_fn: ItemFn = if inner_fn.sig.asyncness.is_some() {
92 // Pass to the test function the number of app contexts that it needs,
93 // based on its parameter list.
94 let mut cx_vars = proc_macro2::TokenStream::new();
95 let mut cx_teardowns = proc_macro2::TokenStream::new();
96 let mut inner_fn_args = proc_macro2::TokenStream::new();
97 for (ix, arg) in inner_fn.sig.inputs.iter().enumerate() {
98 if let FnArg::Typed(arg) = arg {
99 if let Type::Path(ty) = &*arg.ty {
100 let last_segment = ty.path.segments.last();
101 match last_segment.map(|s| s.ident.to_string()).as_deref() {
102 Some("StdRng") => {
103 inner_fn_args.extend(quote!(rand::SeedableRng::seed_from_u64(seed),));
104 continue;
105 }
106 Some("Arc") => {
107 if let syn::PathArguments::AngleBracketed(args) =
108 &last_segment.unwrap().arguments
109 {
110 if let Some(syn::GenericArgument::Type(syn::Type::Path(ty))) =
111 args.args.last()
112 {
113 let last_segment = ty.path.segments.last();
114 if let Some("Deterministic") =
115 last_segment.map(|s| s.ident.to_string()).as_deref()
116 {
117 inner_fn_args.extend(quote!(deterministic.clone(),));
118 continue;
119 }
120 }
121 }
122 }
123 _ => {}
124 }
125 } else if let Type::Reference(ty) = &*arg.ty {
126 if let Type::Path(ty) = &*ty.elem {
127 let last_segment = ty.path.segments.last();
128 if let Some("TestAppContext") =
129 last_segment.map(|s| s.ident.to_string()).as_deref()
130 {
131 let first_entity_id = ix * 100_000;
132 let cx_varname = format_ident!("cx_{}", ix);
133 cx_vars.extend(quote!(
134 let mut #cx_varname = #namespace::TestAppContext::new(
135 foreground_platform.clone(),
136 cx.platform().clone(),
137 deterministic.build_foreground(#ix),
138 deterministic.build_background(),
139 cx.font_cache().clone(),
140 cx.leak_detector(),
141 #first_entity_id,
142 stringify!(#outer_fn_name).to_string(),
143 );
144 ));
145 cx_teardowns.extend(quote!(
146 #cx_varname.remove_all_windows();
147 deterministic.run_until_parked();
148 #cx_varname.update(|cx| cx.clear_globals());
149 ));
150 inner_fn_args.extend(quote!(&mut #cx_varname,));
151 continue;
152 }
153 }
154 }
155 }
156
157 return TokenStream::from(
158 syn::Error::new_spanned(arg, "invalid argument").into_compile_error(),
159 );
160 }
161
162 parse_quote! {
163 #[test]
164 fn #outer_fn_name() {
165 #inner_fn
166
167 #namespace::test::run_test(
168 #num_iterations as u64,
169 #starting_seed as u64,
170 #max_retries,
171 #detect_nondeterminism,
172 &mut |cx, foreground_platform, deterministic, seed| {
173 // some of the macro contents do not use all variables, silence the warnings
174 let _ = (&cx, &foreground_platform, &deterministic, &seed);
175 #cx_vars
176 cx.foreground().run(#inner_fn_name(#inner_fn_args));
177 #cx_teardowns
178 },
179 #on_failure_fn_name,
180 stringify!(#outer_fn_name).to_string(),
181 );
182 }
183 }
184 } else {
185 // Pass to the test function the number of app contexts that it needs,
186 // based on its parameter list.
187 let mut cx_vars = proc_macro2::TokenStream::new();
188 let mut cx_teardowns = proc_macro2::TokenStream::new();
189 let mut inner_fn_args = proc_macro2::TokenStream::new();
190 for (ix, arg) in inner_fn.sig.inputs.iter().enumerate() {
191 if let FnArg::Typed(arg) = arg {
192 if let Type::Path(ty) = &*arg.ty {
193 let last_segment = ty.path.segments.last();
194
195 if let Some("StdRng") = last_segment.map(|s| s.ident.to_string()).as_deref() {
196 inner_fn_args.extend(quote!(rand::SeedableRng::seed_from_u64(seed),));
197 continue;
198 }
199 } else if let Type::Reference(ty) = &*arg.ty {
200 if let Type::Path(ty) = &*ty.elem {
201 let last_segment = ty.path.segments.last();
202 match last_segment.map(|s| s.ident.to_string()).as_deref() {
203 Some("AppContext") => {
204 inner_fn_args.extend(quote!(cx,));
205 continue;
206 }
207 Some("TestAppContext") => {
208 let first_entity_id = ix * 100_000;
209 let cx_varname = format_ident!("cx_{}", ix);
210 cx_vars.extend(quote!(
211 let mut #cx_varname = #namespace::TestAppContext::new(
212 foreground_platform.clone(),
213 cx.platform().clone(),
214 deterministic.build_foreground(#ix),
215 deterministic.build_background(),
216 cx.font_cache().clone(),
217 cx.leak_detector(),
218 #first_entity_id,
219 stringify!(#outer_fn_name).to_string(),
220 );
221 ));
222 cx_teardowns.extend(quote!(
223 #cx_varname.remove_all_windows();
224 deterministic.run_until_parked();
225 #cx_varname.update(|cx| cx.clear_globals());
226 ));
227 inner_fn_args.extend(quote!(&mut #cx_varname,));
228 continue;
229 }
230 _ => {}
231 }
232 }
233 }
234 }
235
236 return TokenStream::from(
237 syn::Error::new_spanned(arg, "invalid argument").into_compile_error(),
238 );
239 }
240
241 parse_quote! {
242 #[test]
243 fn #outer_fn_name() {
244 #inner_fn
245
246 #namespace::test::run_test(
247 #num_iterations as u64,
248 #starting_seed as u64,
249 #max_retries,
250 #detect_nondeterminism,
251 &mut |cx, foreground_platform, deterministic, seed| {
252 // some of the macro contents do not use all variables, silence the warnings
253 let _ = (&cx, &foreground_platform, &deterministic, &seed);
254 #cx_vars
255 #inner_fn_name(#inner_fn_args);
256 #cx_teardowns
257 },
258 #on_failure_fn_name,
259 stringify!(#outer_fn_name).to_string(),
260 );
261 }
262 }
263 };
264 outer_fn.attrs.extend(inner_fn_attributes);
265
266 TokenStream::from(quote!(#outer_fn))
267}
268
269fn parse_int(literal: &Lit) -> Result<usize, TokenStream> {
270 let result = if let Lit::Int(int) = &literal {
271 int.base10_parse()
272 } else {
273 Err(syn::Error::new(literal.span(), "must be an integer"))
274 };
275
276 result.map_err(|err| TokenStream::from(err.into_compile_error()))
277}
278
279fn parse_bool(literal: &Lit) -> Result<bool, TokenStream> {
280 let result = if let Lit::Bool(result) = &literal {
281 Ok(result.value)
282 } else {
283 Err(syn::Error::new(literal.span(), "must be a boolean"))
284 };
285
286 result.map_err(|err| TokenStream::from(err.into_compile_error()))
287}
288
289#[proc_macro_derive(Element)]
290pub fn element_derive(input: TokenStream) -> TokenStream {
291 let ast = parse_macro_input!(input as DeriveInput);
292 let type_name = ast.ident;
293
294 let placeholder_view_generics: Generics = parse_quote! { <V: 'static> };
295 let placeholder_view_type_name: Ident = parse_quote! { V };
296 let view_type_name: Ident;
297 let impl_generics: syn::ImplGenerics<'_>;
298 let type_generics: Option<syn::TypeGenerics<'_>>;
299 let where_clause: Option<&'_ WhereClause>;
300
301 match ast.generics.params.iter().find_map(|param| {
302 if let GenericParam::Type(type_param) = param {
303 Some(type_param.ident.clone())
304 } else {
305 None
306 }
307 }) {
308 Some(type_name) => {
309 view_type_name = type_name;
310 let generics = ast.generics.split_for_impl();
311 impl_generics = generics.0;
312 type_generics = Some(generics.1);
313 where_clause = generics.2;
314 }
315 _ => {
316 view_type_name = placeholder_view_type_name;
317 let generics = placeholder_view_generics.split_for_impl();
318 impl_generics = generics.0;
319 type_generics = None;
320 where_clause = generics.2;
321 }
322 }
323
324 let gen = quote! {
325 impl #impl_generics Element<#view_type_name> for #type_name #type_generics
326 #where_clause
327 {
328
329 type LayoutState = gpui::elements::AnyElement<V>;
330 type PaintState = ();
331
332 fn layout(
333 &mut self,
334 constraint: gpui::SizeConstraint,
335 view: &mut V,
336 cx: &mut gpui::ViewContext<V>,
337 ) -> (gpui::geometry::vector::Vector2F, gpui::elements::AnyElement<V>) {
338 let mut element = self.render(view, cx).into_any();
339 let size = element.layout(constraint, view, cx);
340 (size, element)
341 }
342
343 fn paint(
344 &mut self,
345 bounds: gpui::geometry::rect::RectF,
346 visible_bounds: gpui::geometry::rect::RectF,
347 element: &mut gpui::elements::AnyElement<V>,
348 view: &mut V,
349 cx: &mut gpui::ViewContext<V>,
350 ) {
351 element.paint(bounds.origin(), visible_bounds, view, cx);
352 }
353
354 fn rect_for_text_range(
355 &self,
356 range_utf16: std::ops::Range<usize>,
357 _: gpui::geometry::rect::RectF,
358 _: gpui::geometry::rect::RectF,
359 element: &gpui::elements::AnyElement<V>,
360 _: &(),
361 view: &V,
362 cx: &gpui::ViewContext<V>,
363 ) -> Option<gpui::geometry::rect::RectF> {
364 element.rect_for_text_range(range_utf16, view, cx)
365 }
366
367 fn debug(
368 &self,
369 _: gpui::geometry::rect::RectF,
370 element: &gpui::elements::AnyElement<V>,
371 _: &(),
372 view: &V,
373 cx: &gpui::ViewContext<V>,
374 ) -> gpui::json::Value {
375 element.debug(view, cx)
376 }
377 }
378 };
379
380 gen.into()
381}