1use proc_macro::TokenStream;
2use proc_macro2::Ident;
3use quote::{format_ident, quote};
4use std::mem;
5use syn::{
6 parse_macro_input, parse_quote, spanned::Spanned as _, AttributeArgs, DeriveInput, FnArg,
7 GenericParam, Generics, ItemFn, Lit, Meta, NestedMeta, Type, WhereClause,
8};
9
10#[proc_macro_attribute]
11pub fn test(args: TokenStream, function: TokenStream) -> TokenStream {
12 let mut namespace = format_ident!("gpui");
13
14 let args = syn::parse_macro_input!(args as AttributeArgs);
15 let mut max_retries = 0;
16 let mut num_iterations = 1;
17 let mut starting_seed = 0;
18 let mut detect_nondeterminism = false;
19 let mut on_failure_fn_name = quote!(None);
20
21 for arg in args {
22 match arg {
23 NestedMeta::Meta(Meta::Path(name))
24 if name.get_ident().map_or(false, |n| n == "self") =>
25 {
26 namespace = format_ident!("crate");
27 }
28 NestedMeta::Meta(Meta::NameValue(meta)) => {
29 let key_name = meta.path.get_ident().map(|i| i.to_string());
30 let result = (|| {
31 match key_name.as_deref() {
32 Some("detect_nondeterminism") => {
33 detect_nondeterminism = parse_bool(&meta.lit)?
34 }
35 Some("retries") => max_retries = parse_int(&meta.lit)?,
36 Some("iterations") => num_iterations = parse_int(&meta.lit)?,
37 Some("seed") => starting_seed = parse_int(&meta.lit)?,
38 Some("on_failure") => {
39 if let Lit::Str(name) = meta.lit {
40 let ident = Ident::new(&name.value(), name.span());
41 on_failure_fn_name = quote!(Some(#ident));
42 } else {
43 return Err(TokenStream::from(
44 syn::Error::new(
45 meta.lit.span(),
46 "on_failure argument must be a string",
47 )
48 .into_compile_error(),
49 ));
50 }
51 }
52 _ => {
53 return Err(TokenStream::from(
54 syn::Error::new(meta.path.span(), "invalid argument")
55 .into_compile_error(),
56 ))
57 }
58 }
59 Ok(())
60 })();
61
62 if let Err(tokens) = result {
63 return tokens;
64 }
65 }
66 other => {
67 return TokenStream::from(
68 syn::Error::new_spanned(other, "invalid argument").into_compile_error(),
69 )
70 }
71 }
72 }
73
74 let mut inner_fn = parse_macro_input!(function as ItemFn);
75 if max_retries > 0 && num_iterations > 1 {
76 return TokenStream::from(
77 syn::Error::new_spanned(inner_fn, "retries and randomized iterations can't be mixed")
78 .into_compile_error(),
79 );
80 }
81 let inner_fn_attributes = mem::take(&mut inner_fn.attrs);
82 let inner_fn_name = format_ident!("_{}", inner_fn.sig.ident);
83 let outer_fn_name = mem::replace(&mut inner_fn.sig.ident, inner_fn_name.clone());
84
85 let mut outer_fn: ItemFn = if inner_fn.sig.asyncness.is_some() {
86 // Pass to the test function the number of app contexts that it needs,
87 // based on its parameter list.
88 let mut cx_vars = proc_macro2::TokenStream::new();
89 let mut cx_teardowns = proc_macro2::TokenStream::new();
90 let mut inner_fn_args = proc_macro2::TokenStream::new();
91 for (ix, arg) in inner_fn.sig.inputs.iter().enumerate() {
92 if let FnArg::Typed(arg) = arg {
93 if let Type::Path(ty) = &*arg.ty {
94 let last_segment = ty.path.segments.last();
95 match last_segment.map(|s| s.ident.to_string()).as_deref() {
96 Some("StdRng") => {
97 inner_fn_args.extend(quote!(rand::SeedableRng::seed_from_u64(seed),));
98 continue;
99 }
100 Some("Arc") => {
101 if let syn::PathArguments::AngleBracketed(args) =
102 &last_segment.unwrap().arguments
103 {
104 if let Some(syn::GenericArgument::Type(syn::Type::Path(ty))) =
105 args.args.last()
106 {
107 let last_segment = ty.path.segments.last();
108 if let Some("Deterministic") =
109 last_segment.map(|s| s.ident.to_string()).as_deref()
110 {
111 inner_fn_args.extend(quote!(deterministic.clone(),));
112 continue;
113 }
114 }
115 }
116 }
117 _ => {}
118 }
119 } else if let Type::Reference(ty) = &*arg.ty {
120 if let Type::Path(ty) = &*ty.elem {
121 let last_segment = ty.path.segments.last();
122 if let Some("TestAppContext") =
123 last_segment.map(|s| s.ident.to_string()).as_deref()
124 {
125 let first_entity_id = ix * 100_000;
126 let cx_varname = format_ident!("cx_{}", ix);
127 cx_vars.extend(quote!(
128 let mut #cx_varname = #namespace::TestAppContext::new(
129 foreground_platform.clone(),
130 cx.platform().clone(),
131 deterministic.build_foreground(#ix),
132 deterministic.build_background(),
133 cx.font_cache().clone(),
134 cx.leak_detector(),
135 #first_entity_id,
136 stringify!(#outer_fn_name).to_string(),
137 );
138 ));
139 cx_teardowns.extend(quote!(
140 #cx_varname.remove_all_windows();
141 deterministic.run_until_parked();
142 #cx_varname.update(|cx| cx.clear_globals());
143 ));
144 inner_fn_args.extend(quote!(&mut #cx_varname,));
145 continue;
146 }
147 }
148 }
149 }
150
151 return TokenStream::from(
152 syn::Error::new_spanned(arg, "invalid argument").into_compile_error(),
153 );
154 }
155
156 parse_quote! {
157 #[test]
158 fn #outer_fn_name() {
159 #inner_fn
160
161 #namespace::test::run_test(
162 #num_iterations as u64,
163 #starting_seed as u64,
164 #max_retries,
165 #detect_nondeterminism,
166 &mut |cx, foreground_platform, deterministic, seed| {
167 #cx_vars
168 cx.foreground().run(#inner_fn_name(#inner_fn_args));
169 #cx_teardowns
170 },
171 #on_failure_fn_name,
172 stringify!(#outer_fn_name).to_string(),
173 );
174 }
175 }
176 } else {
177 // Pass to the test function the number of app contexts that it needs,
178 // based on its parameter list.
179 let mut cx_vars = proc_macro2::TokenStream::new();
180 let mut cx_teardowns = proc_macro2::TokenStream::new();
181 let mut inner_fn_args = proc_macro2::TokenStream::new();
182 for (ix, arg) in inner_fn.sig.inputs.iter().enumerate() {
183 if let FnArg::Typed(arg) = arg {
184 if let Type::Path(ty) = &*arg.ty {
185 let last_segment = ty.path.segments.last();
186
187 if let Some("StdRng") = last_segment.map(|s| s.ident.to_string()).as_deref() {
188 inner_fn_args.extend(quote!(rand::SeedableRng::seed_from_u64(seed),));
189 continue;
190 }
191 } else if let Type::Reference(ty) = &*arg.ty {
192 if let Type::Path(ty) = &*ty.elem {
193 let last_segment = ty.path.segments.last();
194 match last_segment.map(|s| s.ident.to_string()).as_deref() {
195 Some("AppContext") => {
196 inner_fn_args.extend(quote!(cx,));
197 continue;
198 }
199 Some("TestAppContext") => {
200 let first_entity_id = ix * 100_000;
201 let cx_varname = format_ident!("cx_{}", ix);
202 cx_vars.extend(quote!(
203 let mut #cx_varname = #namespace::TestAppContext::new(
204 foreground_platform.clone(),
205 cx.platform().clone(),
206 deterministic.build_foreground(#ix),
207 deterministic.build_background(),
208 cx.font_cache().clone(),
209 cx.leak_detector(),
210 #first_entity_id,
211 stringify!(#outer_fn_name).to_string(),
212 );
213 ));
214 cx_teardowns.extend(quote!(
215 #cx_varname.remove_all_windows();
216 deterministic.run_until_parked();
217 #cx_varname.update(|cx| cx.clear_globals());
218 ));
219 inner_fn_args.extend(quote!(&mut #cx_varname,));
220 continue;
221 }
222 _ => {}
223 }
224 }
225 }
226 }
227
228 return TokenStream::from(
229 syn::Error::new_spanned(arg, "invalid argument").into_compile_error(),
230 );
231 }
232
233 parse_quote! {
234 #[test]
235 fn #outer_fn_name() {
236 #inner_fn
237
238 #namespace::test::run_test(
239 #num_iterations as u64,
240 #starting_seed as u64,
241 #max_retries,
242 #detect_nondeterminism,
243 &mut |cx, foreground_platform, deterministic, seed| {
244 #cx_vars
245 #inner_fn_name(#inner_fn_args);
246 #cx_teardowns
247 },
248 #on_failure_fn_name,
249 stringify!(#outer_fn_name).to_string(),
250 );
251 }
252 }
253 };
254 outer_fn.attrs.extend(inner_fn_attributes);
255
256 TokenStream::from(quote!(#outer_fn))
257}
258
259fn parse_int(literal: &Lit) -> Result<usize, TokenStream> {
260 let result = if let Lit::Int(int) = &literal {
261 int.base10_parse()
262 } else {
263 Err(syn::Error::new(literal.span(), "must be an integer"))
264 };
265
266 result.map_err(|err| TokenStream::from(err.into_compile_error()))
267}
268
269fn parse_bool(literal: &Lit) -> Result<bool, TokenStream> {
270 let result = if let Lit::Bool(result) = &literal {
271 Ok(result.value)
272 } else {
273 Err(syn::Error::new(literal.span(), "must be a boolean"))
274 };
275
276 result.map_err(|err| TokenStream::from(err.into_compile_error()))
277}
278
279#[proc_macro_derive(Element)]
280pub fn element_derive(input: TokenStream) -> TokenStream {
281 let ast = parse_macro_input!(input as DeriveInput);
282 let type_name = ast.ident;
283
284 let placeholder_view_generics: Generics = parse_quote! { <V: 'static> };
285 let placeholder_view_type_name: Ident = parse_quote! { V };
286 let view_type_name: Ident;
287 let impl_generics: syn::ImplGenerics<'_>;
288 let type_generics: Option<syn::TypeGenerics<'_>>;
289 let where_clause: Option<&'_ WhereClause>;
290
291 match ast.generics.params.iter().find_map(|param| {
292 if let GenericParam::Type(type_param) = param {
293 Some(type_param.ident.clone())
294 } else {
295 None
296 }
297 }) {
298 Some(type_name) => {
299 view_type_name = type_name;
300 let generics = ast.generics.split_for_impl();
301 impl_generics = generics.0;
302 type_generics = Some(generics.1);
303 where_clause = generics.2;
304 }
305 _ => {
306 view_type_name = placeholder_view_type_name;
307 let generics = placeholder_view_generics.split_for_impl();
308 impl_generics = generics.0;
309 type_generics = None;
310 where_clause = generics.2;
311 }
312 }
313
314 let gen = quote! {
315 impl #impl_generics Element<#view_type_name> for #type_name #type_generics
316 #where_clause
317 {
318
319 type LayoutState = gpui::elements::AnyElement<V>;
320 type PaintState = ();
321
322 fn layout(
323 &mut self,
324 constraint: gpui::SizeConstraint,
325 view: &mut V,
326 cx: &mut gpui::LayoutContext<V>,
327 ) -> (gpui::geometry::vector::Vector2F, gpui::elements::AnyElement<V>) {
328 let mut element = self.render(view, cx).into_any();
329 let size = element.layout(constraint, view, cx);
330 (size, element)
331 }
332
333 fn paint(
334 &mut self,
335 scene: &mut gpui::SceneBuilder,
336 bounds: gpui::geometry::rect::RectF,
337 visible_bounds: gpui::geometry::rect::RectF,
338 element: &mut gpui::elements::AnyElement<V>,
339 view: &mut V,
340 cx: &mut gpui::PaintContext<V>,
341 ) {
342 element.paint(scene, bounds.origin(), visible_bounds, view, cx);
343 }
344
345 fn rect_for_text_range(
346 &self,
347 range_utf16: std::ops::Range<usize>,
348 _: gpui::geometry::rect::RectF,
349 _: gpui::geometry::rect::RectF,
350 element: &gpui::elements::AnyElement<V>,
351 _: &(),
352 view: &V,
353 cx: &gpui::ViewContext<V>,
354 ) -> Option<gpui::geometry::rect::RectF> {
355 element.rect_for_text_range(range_utf16, view, cx)
356 }
357
358 fn debug(
359 &self,
360 _: gpui::geometry::rect::RectF,
361 element: &gpui::elements::AnyElement<V>,
362 _: &(),
363 view: &V,
364 cx: &gpui::ViewContext<V>,
365 ) -> gpui::json::Value {
366 element.debug(view, cx)
367 }
368 }
369 };
370
371 gen.into()
372}