1use proc_macro::TokenStream;
2use proc_macro2::Ident;
3use quote::{format_ident, quote};
4use std::mem;
5use syn::{
6 parse_macro_input, parse_quote, spanned::Spanned as _, AttributeArgs, FnArg, ItemFn, Lit, Meta,
7 NestedMeta, Type,
8};
9
10pub fn test(args: TokenStream, function: TokenStream) -> TokenStream {
11 let args = syn::parse_macro_input!(args as AttributeArgs);
12 let mut max_retries = 0;
13 let mut num_iterations = 1;
14 let mut on_failure_fn_name = quote!(None);
15
16 for arg in args {
17 match arg {
18 NestedMeta::Meta(Meta::NameValue(meta)) => {
19 let key_name = meta.path.get_ident().map(|i| i.to_string());
20 let result = (|| {
21 match key_name.as_deref() {
22 Some("retries") => max_retries = parse_int(&meta.lit)?,
23 Some("iterations") => num_iterations = parse_int(&meta.lit)?,
24 Some("on_failure") => {
25 if let Lit::Str(name) = meta.lit {
26 let mut path = syn::Path {
27 leading_colon: None,
28 segments: Default::default(),
29 };
30 for part in name.value().split("::") {
31 path.segments.push(Ident::new(part, name.span()).into());
32 }
33 on_failure_fn_name = quote!(Some(#path));
34 } else {
35 return Err(TokenStream::from(
36 syn::Error::new(
37 meta.lit.span(),
38 "on_failure argument must be a string",
39 )
40 .into_compile_error(),
41 ));
42 }
43 }
44 _ => {
45 return Err(TokenStream::from(
46 syn::Error::new(meta.path.span(), "invalid argument")
47 .into_compile_error(),
48 ))
49 }
50 }
51 Ok(())
52 })();
53
54 if let Err(tokens) = result {
55 return tokens;
56 }
57 }
58 other => {
59 return TokenStream::from(
60 syn::Error::new_spanned(other, "invalid argument").into_compile_error(),
61 )
62 }
63 }
64 }
65
66 let mut inner_fn = parse_macro_input!(function as ItemFn);
67 if max_retries > 0 && num_iterations > 1 {
68 return TokenStream::from(
69 syn::Error::new_spanned(inner_fn, "retries and randomized iterations can't be mixed")
70 .into_compile_error(),
71 );
72 }
73 let inner_fn_attributes = mem::take(&mut inner_fn.attrs);
74 let inner_fn_name = format_ident!("_{}", inner_fn.sig.ident);
75 let outer_fn_name = mem::replace(&mut inner_fn.sig.ident, inner_fn_name.clone());
76
77 let mut outer_fn: ItemFn = if inner_fn.sig.asyncness.is_some() {
78 // Pass to the test function the number of app contexts that it needs,
79 // based on its parameter list.
80 let mut cx_vars = proc_macro2::TokenStream::new();
81 let mut cx_teardowns = proc_macro2::TokenStream::new();
82 let mut inner_fn_args = proc_macro2::TokenStream::new();
83 for (ix, arg) in inner_fn.sig.inputs.iter().enumerate() {
84 if let FnArg::Typed(arg) = arg {
85 if let Type::Path(ty) = &*arg.ty {
86 let last_segment = ty.path.segments.last();
87 match last_segment.map(|s| s.ident.to_string()).as_deref() {
88 Some("StdRng") => {
89 inner_fn_args.extend(quote!(rand::SeedableRng::seed_from_u64(_seed),));
90 continue;
91 }
92 Some("BackgroundExecutor") => {
93 inner_fn_args.extend(quote!(gpui::BackgroundExecutor::new(
94 std::sync::Arc::new(dispatcher.clone()),
95 ),));
96 continue;
97 }
98 _ => {}
99 }
100 } else if let Type::Reference(ty) = &*arg.ty {
101 if let Type::Path(ty) = &*ty.elem {
102 let last_segment = ty.path.segments.last();
103 if let Some("TestAppContext") =
104 last_segment.map(|s| s.ident.to_string()).as_deref()
105 {
106 let cx_varname = format_ident!("cx_{}", ix);
107 cx_vars.extend(quote!(
108 let mut #cx_varname = gpui::TestAppContext::new(
109 dispatcher.clone()
110 );
111 ));
112 cx_teardowns.extend(quote!(
113 #cx_varname.quit();
114 dispatcher.run_until_parked();
115 ));
116 inner_fn_args.extend(quote!(&mut #cx_varname,));
117 continue;
118 }
119 }
120 }
121 }
122
123 return TokenStream::from(
124 syn::Error::new_spanned(arg, "invalid argument").into_compile_error(),
125 );
126 }
127
128 parse_quote! {
129 #[test]
130 fn #outer_fn_name() {
131 #inner_fn
132
133 gpui::run_test(
134 #num_iterations as u64,
135 #max_retries,
136 &mut |dispatcher, _seed| {
137 let executor = gpui::BackgroundExecutor::new(std::sync::Arc::new(dispatcher.clone()));
138 #cx_vars
139 executor.block_test(#inner_fn_name(#inner_fn_args));
140 #cx_teardowns
141 },
142 #on_failure_fn_name,
143 stringify!(#outer_fn_name).to_string(),
144 );
145 }
146 }
147 } else {
148 // Pass to the test function the number of app contexts that it needs,
149 // based on its parameter list.
150 let mut cx_vars = proc_macro2::TokenStream::new();
151 let mut cx_teardowns = proc_macro2::TokenStream::new();
152 let mut inner_fn_args = proc_macro2::TokenStream::new();
153 for (ix, arg) in inner_fn.sig.inputs.iter().enumerate() {
154 if let FnArg::Typed(arg) = arg {
155 if let Type::Path(ty) = &*arg.ty {
156 let last_segment = ty.path.segments.last();
157
158 if let Some("StdRng") = last_segment.map(|s| s.ident.to_string()).as_deref() {
159 inner_fn_args.extend(quote!(rand::SeedableRng::seed_from_u64(_seed),));
160 continue;
161 }
162 } else if let Type::Reference(ty) = &*arg.ty {
163 if let Type::Path(ty) = &*ty.elem {
164 let last_segment = ty.path.segments.last();
165 match last_segment.map(|s| s.ident.to_string()).as_deref() {
166 Some("AppContext") => {
167 let cx_varname = format_ident!("cx_{}", ix);
168 let cx_varname_lock = format_ident!("cx_{}_lock", ix);
169 cx_vars.extend(quote!(
170 let mut #cx_varname = gpui::TestAppContext::new(
171 dispatcher.clone()
172 );
173 let mut #cx_varname_lock = #cx_varname.app.borrow_mut();
174 ));
175 inner_fn_args.extend(quote!(&mut #cx_varname_lock,));
176 cx_teardowns.extend(quote!(
177 #cx_varname_lock.quit();
178 drop(#cx_varname_lock);
179 dispatcher.run_until_parked();
180 ));
181 continue;
182 }
183 Some("TestAppContext") => {
184 let cx_varname = format_ident!("cx_{}", ix);
185 cx_vars.extend(quote!(
186 let mut #cx_varname = gpui::TestAppContext::new(
187 dispatcher.clone()
188 );
189 ));
190 cx_teardowns.extend(quote!(
191 #cx_varname.quit();
192 dispatcher.run_until_parked();
193 ));
194 inner_fn_args.extend(quote!(&mut #cx_varname,));
195 continue;
196 }
197 _ => {}
198 }
199 }
200 }
201 }
202
203 return TokenStream::from(
204 syn::Error::new_spanned(arg, "invalid argument").into_compile_error(),
205 );
206 }
207
208 parse_quote! {
209 #[test]
210 fn #outer_fn_name() {
211 #inner_fn
212
213 gpui::run_test(
214 #num_iterations as u64,
215 #max_retries,
216 &mut |dispatcher, _seed| {
217 #cx_vars
218 #inner_fn_name(#inner_fn_args);
219 #cx_teardowns
220 },
221 #on_failure_fn_name,
222 stringify!(#outer_fn_name).to_string(),
223 );
224 }
225 }
226 };
227 outer_fn.attrs.extend(inner_fn_attributes);
228
229 TokenStream::from(quote!(#outer_fn))
230}
231
232fn parse_int(literal: &Lit) -> Result<usize, TokenStream> {
233 let result = if let Lit::Int(int) = &literal {
234 int.base10_parse()
235 } else {
236 Err(syn::Error::new(literal.span(), "must be an integer"))
237 };
238
239 result.map_err(|err| TokenStream::from(err.into_compile_error()))
240}