1use proc_macro::TokenStream;
2use quote::{format_ident, quote};
3use std::mem;
4use syn::{
5 parse_macro_input, parse_quote, spanned::Spanned as _, AttributeArgs, ItemFn, Lit, Meta,
6 NestedMeta,
7};
8
9#[proc_macro_attribute]
10pub fn test(args: TokenStream, function: TokenStream) -> TokenStream {
11 let mut namespace = format_ident!("gpui");
12
13 let args = syn::parse_macro_input!(args as AttributeArgs);
14 let mut max_retries = 0;
15 let mut num_iterations = 1;
16 let mut starting_seed = 0;
17 for arg in args {
18 match arg {
19 NestedMeta::Meta(Meta::Path(name))
20 if name.get_ident().map_or(false, |n| n == "self") =>
21 {
22 namespace = format_ident!("crate");
23 }
24 NestedMeta::Meta(Meta::NameValue(meta)) => {
25 let key_name = meta.path.get_ident().map(|i| i.to_string());
26 let variable = match key_name.as_ref().map(String::as_str) {
27 Some("retries") => &mut max_retries,
28 Some("iterations") => &mut num_iterations,
29 Some("seed") => &mut starting_seed,
30 _ => {
31 return TokenStream::from(
32 syn::Error::new(meta.path.span(), "invalid argument")
33 .into_compile_error(),
34 )
35 }
36 };
37
38 *variable = match parse_int(&meta.lit) {
39 Ok(value) => value,
40 Err(error) => {
41 return TokenStream::from(error.into_compile_error());
42 }
43 };
44 }
45 other => {
46 return TokenStream::from(
47 syn::Error::new_spanned(other, "invalid argument").into_compile_error(),
48 )
49 }
50 }
51 }
52
53 let mut inner_fn = parse_macro_input!(function as ItemFn);
54 let inner_fn_attributes = mem::take(&mut inner_fn.attrs);
55 let inner_fn_name = format_ident!("_{}", inner_fn.sig.ident);
56 let outer_fn_name = mem::replace(&mut inner_fn.sig.ident, inner_fn_name.clone());
57
58 // Pass to the test function the number of app contexts that it needs,
59 // based on its parameter list.
60 let inner_fn_args = (0..inner_fn.sig.inputs.len())
61 .map(|i| {
62 let first_entity_id = i * 100_000;
63 quote!(#namespace::TestAppContext::new(foreground.clone(), background.clone(), #first_entity_id),)
64 })
65 .collect::<proc_macro2::TokenStream>();
66
67 let mut outer_fn: ItemFn = if inner_fn.sig.asyncness.is_some() {
68 parse_quote! {
69 #[test]
70 fn #outer_fn_name() {
71 #inner_fn
72
73 let mut retries = 0;
74 let mut i = 0;
75 loop {
76 let seed = #starting_seed + i;
77 let result = std::panic::catch_unwind(|| {
78 let (foreground, background) = #namespace::executor::deterministic(seed as u64);
79 foreground.run(#inner_fn_name(#inner_fn_args));
80 });
81
82 match result {
83 Ok(result) => {
84 retries = 0;
85 i += 1;
86 if i == #num_iterations {
87 return result
88 }
89 }
90 Err(error) => {
91 if retries < #max_retries {
92 retries += 1;
93 println!("retrying: attempt {}", retries);
94 } else {
95 if #num_iterations > 1 {
96 eprintln!("failing seed: {}", seed);
97 }
98 std::panic::resume_unwind(error);
99 }
100 }
101 }
102 }
103 }
104 }
105 } else {
106 parse_quote! {
107 #[test]
108 fn #outer_fn_name() {
109 #inner_fn
110
111 if #max_retries > 0 {
112 let mut retries = 0;
113 loop {
114 let result = std::panic::catch_unwind(|| {
115 #namespace::App::test(|cx| {
116 #inner_fn_name(cx);
117 });
118 });
119
120 match result {
121 Ok(result) => return result,
122 Err(error) => {
123 if retries < #max_retries {
124 retries += 1;
125 println!("retrying: attempt {}", retries);
126 } else {
127 std::panic::resume_unwind(error);
128 }
129 }
130 }
131 }
132 } else {
133 #namespace::App::test(|cx| {
134 #inner_fn_name(cx);
135 });
136 }
137 }
138 }
139 };
140 outer_fn.attrs.extend(inner_fn_attributes);
141
142 TokenStream::from(quote!(#outer_fn))
143}
144
145fn parse_int(literal: &Lit) -> syn::Result<usize> {
146 if let Lit::Int(int) = &literal {
147 int.base10_parse()
148 } else {
149 Err(syn::Error::new(literal.span(), "must be an integer"))
150 }
151}