1use proc_macro::TokenStream;
2use proc_macro2::Ident;
3use quote::{format_ident, quote};
4use std::mem;
5use syn::{
6 parse_macro_input, parse_quote, spanned::Spanned as _, AttributeArgs, FnArg, ItemFn, Lit, Meta,
7 NestedMeta, Type,
8};
9
10pub fn test(args: TokenStream, function: TokenStream) -> TokenStream {
11 let args = syn::parse_macro_input!(args as AttributeArgs);
12 let mut max_retries = 0;
13 let mut num_iterations = 1;
14 let mut on_failure_fn_name = quote!(None);
15
16 for arg in args {
17 match arg {
18 NestedMeta::Meta(Meta::NameValue(meta)) => {
19 let key_name = meta.path.get_ident().map(|i| i.to_string());
20 let result = (|| {
21 match key_name.as_deref() {
22 Some("retries") => max_retries = parse_int(&meta.lit)?,
23 Some("iterations") => num_iterations = parse_int(&meta.lit)?,
24 Some("on_failure") => {
25 if let Lit::Str(name) = meta.lit {
26 let mut path = syn::Path {
27 leading_colon: None,
28 segments: Default::default(),
29 };
30 for part in name.value().split("::") {
31 path.segments.push(Ident::new(part, name.span()).into());
32 }
33 on_failure_fn_name = quote!(Some(#path));
34 } else {
35 return Err(TokenStream::from(
36 syn::Error::new(
37 meta.lit.span(),
38 "on_failure argument must be a string",
39 )
40 .into_compile_error(),
41 ));
42 }
43 }
44 _ => {
45 return Err(TokenStream::from(
46 syn::Error::new(meta.path.span(), "invalid argument")
47 .into_compile_error(),
48 ))
49 }
50 }
51 Ok(())
52 })();
53
54 if let Err(tokens) = result {
55 return tokens;
56 }
57 }
58 other => {
59 return TokenStream::from(
60 syn::Error::new_spanned(other, "invalid argument").into_compile_error(),
61 )
62 }
63 }
64 }
65
66 let mut inner_fn = parse_macro_input!(function as ItemFn);
67 if max_retries > 0 && num_iterations > 1 {
68 return TokenStream::from(
69 syn::Error::new_spanned(inner_fn, "retries and randomized iterations can't be mixed")
70 .into_compile_error(),
71 );
72 }
73 let inner_fn_attributes = mem::take(&mut inner_fn.attrs);
74 let inner_fn_name = format_ident!("_{}", inner_fn.sig.ident);
75 let outer_fn_name = mem::replace(&mut inner_fn.sig.ident, inner_fn_name.clone());
76
77 let mut outer_fn: ItemFn = if inner_fn.sig.asyncness.is_some() {
78 // Pass to the test function the number of app contexts that it needs,
79 // based on its parameter list.
80 let mut cx_vars = proc_macro2::TokenStream::new();
81 let mut cx_teardowns = proc_macro2::TokenStream::new();
82 let mut inner_fn_args = proc_macro2::TokenStream::new();
83 for (ix, arg) in inner_fn.sig.inputs.iter().enumerate() {
84 if let FnArg::Typed(arg) = arg {
85 if let Type::Path(ty) = &*arg.ty {
86 let last_segment = ty.path.segments.last();
87 match last_segment.map(|s| s.ident.to_string()).as_deref() {
88 Some("StdRng") => {
89 inner_fn_args.extend(quote!(rand::SeedableRng::seed_from_u64(_seed),));
90 continue;
91 }
92 Some("BackgroundExecutor") => {
93 inner_fn_args.extend(quote!(gpui::BackgroundExecutor::new(
94 std::sync::Arc::new(dispatcher.clone()),
95 ),));
96 continue;
97 }
98 _ => {}
99 }
100 } else if let Type::Reference(ty) = &*arg.ty {
101 if let Type::Path(ty) = &*ty.elem {
102 let last_segment = ty.path.segments.last();
103 if let Some("TestAppContext") =
104 last_segment.map(|s| s.ident.to_string()).as_deref()
105 {
106 let cx_varname = format_ident!("cx_{}", ix);
107 cx_vars.extend(quote!(
108 let mut #cx_varname = gpui::TestAppContext::new(
109 dispatcher.clone(),
110 Some(stringify!(#outer_fn_name)),
111 );
112 ));
113 cx_teardowns.extend(quote!(
114 dispatcher.run_until_parked();
115 #cx_varname.quit();
116 dispatcher.run_until_parked();
117 ));
118 inner_fn_args.extend(quote!(&mut #cx_varname,));
119 continue;
120 }
121 }
122 }
123 }
124
125 return TokenStream::from(
126 syn::Error::new_spanned(arg, "invalid argument").into_compile_error(),
127 );
128 }
129
130 parse_quote! {
131 #[test]
132 fn #outer_fn_name() {
133 #inner_fn
134
135 gpui::run_test(
136 #num_iterations as u64,
137 #max_retries,
138 &mut |dispatcher, _seed| {
139 let executor = gpui::BackgroundExecutor::new(std::sync::Arc::new(dispatcher.clone()));
140 #cx_vars
141 executor.block_test(#inner_fn_name(#inner_fn_args));
142 #cx_teardowns
143 },
144 #on_failure_fn_name
145 );
146 }
147 }
148 } else {
149 // Pass to the test function the number of app contexts that it needs,
150 // based on its parameter list.
151 let mut cx_vars = proc_macro2::TokenStream::new();
152 let mut cx_teardowns = proc_macro2::TokenStream::new();
153 let mut inner_fn_args = proc_macro2::TokenStream::new();
154 for (ix, arg) in inner_fn.sig.inputs.iter().enumerate() {
155 if let FnArg::Typed(arg) = arg {
156 if let Type::Path(ty) = &*arg.ty {
157 let last_segment = ty.path.segments.last();
158
159 if let Some("StdRng") = last_segment.map(|s| s.ident.to_string()).as_deref() {
160 inner_fn_args.extend(quote!(rand::SeedableRng::seed_from_u64(_seed),));
161 continue;
162 }
163 } else if let Type::Reference(ty) = &*arg.ty {
164 if let Type::Path(ty) = &*ty.elem {
165 let last_segment = ty.path.segments.last();
166 match last_segment.map(|s| s.ident.to_string()).as_deref() {
167 Some("App") => {
168 let cx_varname = format_ident!("cx_{}", ix);
169 let cx_varname_lock = format_ident!("cx_{}_lock", ix);
170 cx_vars.extend(quote!(
171 let mut #cx_varname = gpui::TestAppContext::new(
172 dispatcher.clone(),
173 Some(stringify!(#outer_fn_name))
174 );
175 let mut #cx_varname_lock = #cx_varname.app.borrow_mut();
176 ));
177 inner_fn_args.extend(quote!(&mut #cx_varname_lock,));
178 cx_teardowns.extend(quote!(
179 drop(#cx_varname_lock);
180 dispatcher.run_until_parked();
181 #cx_varname.update(|cx| { cx.quit() });
182 dispatcher.run_until_parked();
183 ));
184 continue;
185 }
186 Some("TestAppContext") => {
187 let cx_varname = format_ident!("cx_{}", ix);
188 cx_vars.extend(quote!(
189 let mut #cx_varname = gpui::TestAppContext::new(
190 dispatcher.clone(),
191 Some(stringify!(#outer_fn_name))
192 );
193 ));
194 cx_teardowns.extend(quote!(
195 dispatcher.run_until_parked();
196 #cx_varname.quit();
197 dispatcher.run_until_parked();
198 ));
199 inner_fn_args.extend(quote!(&mut #cx_varname,));
200 continue;
201 }
202 _ => {}
203 }
204 }
205 }
206 }
207
208 return TokenStream::from(
209 syn::Error::new_spanned(arg, "invalid argument").into_compile_error(),
210 );
211 }
212
213 parse_quote! {
214 #[test]
215 fn #outer_fn_name() {
216 #inner_fn
217
218 gpui::run_test(
219 #num_iterations as u64,
220 #max_retries,
221 &mut |dispatcher, _seed| {
222 #cx_vars
223 #inner_fn_name(#inner_fn_args);
224 #cx_teardowns
225 },
226 #on_failure_fn_name,
227 );
228 }
229 }
230 };
231 outer_fn.attrs.extend(inner_fn_attributes);
232
233 TokenStream::from(quote!(#outer_fn))
234}
235
236fn parse_int(literal: &Lit) -> Result<usize, TokenStream> {
237 let result = if let Lit::Int(int) = &literal {
238 int.base10_parse()
239 } else {
240 Err(syn::Error::new(literal.span(), "must be an integer"))
241 };
242
243 result.map_err(|err| TokenStream::from(err.into_compile_error()))
244}