1use proc_macro::TokenStream;
2use quote::{format_ident, quote};
3use std::mem;
4use syn::{
5 parse_macro_input, parse_quote, spanned::Spanned as _, AttributeArgs, ItemFn, Lit, Meta,
6 NestedMeta,
7};
8
9#[proc_macro_attribute]
10pub fn test(args: TokenStream, function: TokenStream) -> TokenStream {
11 let mut namespace = format_ident!("gpui");
12
13 let args = syn::parse_macro_input!(args as AttributeArgs);
14 let mut max_retries = 0;
15 let mut num_iterations = 1;
16 let mut starting_seed = std::env::var("SEED")
17 .map(|i| i.parse().expect("invalid `SEED`"))
18 .ok();
19
20 for arg in args {
21 match arg {
22 NestedMeta::Meta(Meta::Path(name))
23 if name.get_ident().map_or(false, |n| n == "self") =>
24 {
25 namespace = format_ident!("crate");
26 }
27 NestedMeta::Meta(Meta::NameValue(meta)) => {
28 let key_name = meta.path.get_ident().map(|i| i.to_string());
29 let result = (|| {
30 match key_name.as_ref().map(String::as_str) {
31 Some("retries") => max_retries = parse_int(&meta.lit)?,
32 Some("iterations") => {
33 if let Ok(iters) = std::env::var("ITERATIONS") {
34 num_iterations = iters.parse().expect("invalid `ITERATIONS`");
35 } else {
36 num_iterations = parse_int(&meta.lit)?;
37 }
38 }
39 Some("seed") => {
40 if starting_seed.is_none() {
41 starting_seed = Some(parse_int(&meta.lit)?);
42 }
43 }
44 _ => {
45 return Err(TokenStream::from(
46 syn::Error::new(meta.path.span(), "invalid argument")
47 .into_compile_error(),
48 ))
49 }
50 }
51 Ok(())
52 })();
53
54 if let Err(tokens) = result {
55 return tokens;
56 }
57 }
58 other => {
59 return TokenStream::from(
60 syn::Error::new_spanned(other, "invalid argument").into_compile_error(),
61 )
62 }
63 }
64 }
65 let starting_seed = starting_seed.unwrap_or(0);
66
67 let mut inner_fn = parse_macro_input!(function as ItemFn);
68 let inner_fn_attributes = mem::take(&mut inner_fn.attrs);
69 let inner_fn_name = format_ident!("_{}", inner_fn.sig.ident);
70 let outer_fn_name = mem::replace(&mut inner_fn.sig.ident, inner_fn_name.clone());
71
72 // Pass to the test function the number of app contexts that it needs,
73 // based on its parameter list.
74 let inner_fn_args = (0..inner_fn.sig.inputs.len())
75 .map(|i| {
76 let first_entity_id = i * 100_000;
77 quote!(#namespace::TestAppContext::new(foreground.clone(), background.clone(), #first_entity_id),)
78 })
79 .collect::<proc_macro2::TokenStream>();
80
81 let mut outer_fn: ItemFn = if inner_fn.sig.asyncness.is_some() {
82 parse_quote! {
83 #[test]
84 fn #outer_fn_name() {
85 #inner_fn
86
87 let mut retries = 0;
88 let mut i = 0;
89 loop {
90 let seed = #starting_seed + i;
91 let result = std::panic::catch_unwind(|| {
92 let (foreground, background) = #namespace::executor::deterministic(seed as u64);
93 foreground.run(#inner_fn_name(#inner_fn_args));
94 });
95
96 match result {
97 Ok(result) => {
98 retries = 0;
99 i += 1;
100 if i == #num_iterations {
101 return result
102 }
103 }
104 Err(error) => {
105 if retries < #max_retries {
106 retries += 1;
107 println!("retrying: attempt {}", retries);
108 } else {
109 if #num_iterations > 1 {
110 eprintln!("failing seed: {}", seed);
111 }
112 std::panic::resume_unwind(error);
113 }
114 }
115 }
116 }
117 }
118 }
119 } else {
120 parse_quote! {
121 #[test]
122 fn #outer_fn_name() {
123 #inner_fn
124
125 if #max_retries > 0 {
126 let mut retries = 0;
127 loop {
128 let result = std::panic::catch_unwind(|| {
129 #namespace::App::test(|cx| {
130 #inner_fn_name(cx);
131 });
132 });
133
134 match result {
135 Ok(result) => return result,
136 Err(error) => {
137 if retries < #max_retries {
138 retries += 1;
139 println!("retrying: attempt {}", retries);
140 } else {
141 std::panic::resume_unwind(error);
142 }
143 }
144 }
145 }
146 } else {
147 #namespace::App::test(|cx| {
148 #inner_fn_name(cx);
149 });
150 }
151 }
152 }
153 };
154 outer_fn.attrs.extend(inner_fn_attributes);
155
156 TokenStream::from(quote!(#outer_fn))
157}
158
159fn parse_int(literal: &Lit) -> Result<usize, TokenStream> {
160 let result = if let Lit::Int(int) = &literal {
161 int.base10_parse()
162 } else {
163 Err(syn::Error::new(literal.span(), "must be an integer"))
164 };
165
166 result.map_err(|err| TokenStream::from(err.into_compile_error()))
167}