1use proc_macro::TokenStream;
2use proc_macro2::Ident;
3use quote::{format_ident, quote};
4use std::mem;
5use syn::{
6 self, Expr, ExprLit, FnArg, ItemFn, Lit, Meta, MetaList, PathSegment, Token, Type,
7 parse::{Parse, ParseStream},
8 parse_quote,
9 punctuated::Punctuated,
10 spanned::Spanned,
11};
12
13struct Args {
14 seeds: Vec<u64>,
15 max_retries: usize,
16 max_iterations: usize,
17 on_failure_fn_name: proc_macro2::TokenStream,
18}
19
20impl Parse for Args {
21 fn parse(input: ParseStream) -> Result<Self, syn::Error> {
22 let mut seeds = Vec::<u64>::new();
23 let mut max_retries = 0;
24 let mut max_iterations = 1;
25 let mut on_failure_fn_name = quote!(None);
26
27 let metas = Punctuated::<Meta, Token![,]>::parse_terminated(input)?;
28
29 for meta in metas {
30 let ident = {
31 let meta_path = match &meta {
32 Meta::NameValue(meta) => &meta.path,
33 Meta::List(list) => &list.path,
34 Meta::Path(path) => {
35 return Err(syn::Error::new(path.span(), "invalid path argument"));
36 }
37 };
38 let Some(ident) = meta_path.get_ident() else {
39 return Err(syn::Error::new(meta_path.span(), "unexpected path"));
40 };
41 ident.to_string()
42 };
43
44 match (&meta, ident.as_str()) {
45 (Meta::NameValue(meta), "retries") => {
46 max_retries = parse_usize_from_expr(&meta.value)?
47 }
48 (Meta::NameValue(meta), "iterations") => {
49 max_iterations = parse_usize_from_expr(&meta.value)?
50 }
51 (Meta::NameValue(meta), "on_failure") => {
52 let Expr::Lit(ExprLit {
53 lit: Lit::Str(name),
54 ..
55 }) = &meta.value
56 else {
57 return Err(syn::Error::new(
58 meta.value.span(),
59 "on_failure argument must be a string",
60 ));
61 };
62 let segments = name
63 .value()
64 .split("::")
65 .map(|part| PathSegment::from(Ident::new(part, name.span())))
66 .collect();
67 let path = syn::Path {
68 leading_colon: None,
69 segments,
70 };
71 on_failure_fn_name = quote!(Some(#path));
72 }
73 (Meta::NameValue(meta), "seed") => {
74 seeds = vec![parse_usize_from_expr(&meta.value)? as u64]
75 }
76 (Meta::List(list), "seeds") => seeds = parse_u64_array(list)?,
77 (Meta::Path(_), _) => {
78 return Err(syn::Error::new(meta.span(), "invalid path argument"));
79 }
80 (_, _) => {
81 return Err(syn::Error::new(meta.span(), "invalid argument name"));
82 }
83 }
84 }
85
86 Ok(Args {
87 seeds,
88 max_retries,
89 max_iterations,
90 on_failure_fn_name,
91 })
92 }
93}
94
95pub fn test(args: TokenStream, function: TokenStream) -> TokenStream {
96 let args = syn::parse_macro_input!(args as Args);
97 let mut inner_fn = match syn::parse::<ItemFn>(function) {
98 Ok(f) => f,
99 Err(err) => return error_to_stream(err),
100 };
101
102 let inner_fn_attributes = mem::take(&mut inner_fn.attrs);
103 let inner_fn_name = format_ident!("_{}", inner_fn.sig.ident);
104 let outer_fn_name = mem::replace(&mut inner_fn.sig.ident, inner_fn_name.clone());
105
106 let result = generate_test_function(
107 args,
108 inner_fn,
109 inner_fn_attributes,
110 inner_fn_name,
111 outer_fn_name,
112 );
113 match result {
114 Ok(tokens) => tokens,
115 Err(tokens) => tokens,
116 }
117}
118
119fn generate_test_function(
120 args: Args,
121 inner_fn: ItemFn,
122 inner_fn_attributes: Vec<syn::Attribute>,
123 inner_fn_name: Ident,
124 outer_fn_name: Ident,
125) -> Result<TokenStream, TokenStream> {
126 let seeds = &args.seeds;
127 let max_retries = args.max_retries;
128 let num_iterations = args.max_iterations;
129 let on_failure_fn_name = &args.on_failure_fn_name;
130 let seeds = quote!( #(#seeds),* );
131
132 let mut outer_fn: ItemFn = if inner_fn.sig.asyncness.is_some() {
133 // Pass to the test function the number of app contexts that it needs,
134 // based on its parameter list.
135 let mut cx_vars = proc_macro2::TokenStream::new();
136 let mut cx_teardowns = proc_macro2::TokenStream::new();
137 let mut inner_fn_args = proc_macro2::TokenStream::new();
138 for (ix, arg) in inner_fn.sig.inputs.iter().enumerate() {
139 if let FnArg::Typed(arg) = arg {
140 if let Type::Path(ty) = &*arg.ty {
141 let last_segment = ty.path.segments.last();
142 match last_segment.map(|s| s.ident.to_string()).as_deref() {
143 Some("StdRng") => {
144 inner_fn_args.extend(quote!(rand::SeedableRng::seed_from_u64(_seed),));
145 continue;
146 }
147 Some("BackgroundExecutor") => {
148 inner_fn_args.extend(quote!(gpui::BackgroundExecutor::new(
149 std::sync::Arc::new(dispatcher.clone()),
150 ),));
151 continue;
152 }
153 _ => {}
154 }
155 } else if let Type::Reference(ty) = &*arg.ty
156 && let Type::Path(ty) = &*ty.elem
157 {
158 let last_segment = ty.path.segments.last();
159 if let Some("TestAppContext") =
160 last_segment.map(|s| s.ident.to_string()).as_deref()
161 {
162 let cx_varname = format_ident!("cx_{}", ix);
163 cx_vars.extend(quote!(
164 let mut #cx_varname = gpui::TestAppContext::build(
165 dispatcher.clone(),
166 Some(stringify!(#outer_fn_name)),
167 );
168 ));
169 cx_teardowns.extend(quote!(
170 dispatcher.run_until_parked();
171 #cx_varname.executor().forbid_parking();
172 #cx_varname.quit();
173 dispatcher.run_until_parked();
174 ));
175 inner_fn_args.extend(quote!(&mut #cx_varname,));
176 continue;
177 }
178 }
179 }
180
181 return Err(error_with_message("invalid function signature", arg));
182 }
183
184 parse_quote! {
185 #[test]
186 fn #outer_fn_name() {
187 #inner_fn
188
189 gpui::run_test(
190 #num_iterations,
191 &[#seeds],
192 #max_retries,
193 &mut |dispatcher, _seed| {
194 let executor = gpui::BackgroundExecutor::new(std::sync::Arc::new(dispatcher.clone()));
195 #cx_vars
196 executor.block_test(#inner_fn_name(#inner_fn_args));
197 #cx_teardowns
198 },
199 #on_failure_fn_name
200 );
201 }
202 }
203 } else {
204 // Pass to the test function the number of app contexts that it needs,
205 // based on its parameter list.
206 let mut cx_vars = proc_macro2::TokenStream::new();
207 let mut cx_teardowns = proc_macro2::TokenStream::new();
208 let mut inner_fn_args = proc_macro2::TokenStream::new();
209 for (ix, arg) in inner_fn.sig.inputs.iter().enumerate() {
210 if let FnArg::Typed(arg) = arg {
211 if let Type::Path(ty) = &*arg.ty {
212 let last_segment = ty.path.segments.last();
213
214 if let Some("StdRng") = last_segment.map(|s| s.ident.to_string()).as_deref() {
215 inner_fn_args.extend(quote!(rand::SeedableRng::seed_from_u64(_seed),));
216 continue;
217 }
218 } else if let Type::Reference(ty) = &*arg.ty
219 && let Type::Path(ty) = &*ty.elem
220 {
221 let last_segment = ty.path.segments.last();
222 match last_segment.map(|s| s.ident.to_string()).as_deref() {
223 Some("App") => {
224 let cx_varname = format_ident!("cx_{}", ix);
225 let cx_varname_lock = format_ident!("cx_{}_lock", ix);
226 cx_vars.extend(quote!(
227 let mut #cx_varname = gpui::TestAppContext::build(
228 dispatcher.clone(),
229 Some(stringify!(#outer_fn_name))
230 );
231 let mut #cx_varname_lock = #cx_varname.app.borrow_mut();
232 ));
233 inner_fn_args.extend(quote!(&mut #cx_varname_lock,));
234 cx_teardowns.extend(quote!(
235 drop(#cx_varname_lock);
236 dispatcher.run_until_parked();
237 #cx_varname.update(|cx| { cx.background_executor().forbid_parking(); cx.quit(); });
238 dispatcher.run_until_parked();
239 ));
240 continue;
241 }
242 Some("TestAppContext") => {
243 let cx_varname = format_ident!("cx_{}", ix);
244 cx_vars.extend(quote!(
245 let mut #cx_varname = gpui::TestAppContext::build(
246 dispatcher.clone(),
247 Some(stringify!(#outer_fn_name))
248 );
249 ));
250 cx_teardowns.extend(quote!(
251 dispatcher.run_until_parked();
252 #cx_varname.executor().forbid_parking();
253 #cx_varname.quit();
254 dispatcher.run_until_parked();
255 ));
256 inner_fn_args.extend(quote!(&mut #cx_varname,));
257 continue;
258 }
259 _ => {}
260 }
261 }
262 }
263
264 return Err(error_with_message("invalid function signature", arg));
265 }
266
267 parse_quote! {
268 #[test]
269 fn #outer_fn_name() {
270 #inner_fn
271
272 gpui::run_test(
273 #num_iterations,
274 &[#seeds],
275 #max_retries,
276 &mut |dispatcher, _seed| {
277 #cx_vars
278 #inner_fn_name(#inner_fn_args);
279 #cx_teardowns
280 },
281 #on_failure_fn_name,
282 );
283 }
284 }
285 };
286 outer_fn.attrs.extend(inner_fn_attributes);
287
288 Ok(TokenStream::from(quote!(#outer_fn)))
289}
290
291fn parse_usize_from_expr(expr: &Expr) -> Result<usize, syn::Error> {
292 let Expr::Lit(ExprLit {
293 lit: Lit::Int(int), ..
294 }) = expr
295 else {
296 return Err(syn::Error::new(expr.span(), "expected an integer"));
297 };
298 int.base10_parse()
299 .map_err(|_| syn::Error::new(int.span(), "failed to parse integer"))
300}
301
302fn parse_u64_array(meta_list: &MetaList) -> Result<Vec<u64>, syn::Error> {
303 let mut result = Vec::new();
304 let tokens = &meta_list.tokens;
305 let parser = |input: ParseStream| {
306 let exprs = Punctuated::<Expr, Token![,]>::parse_terminated(input)?;
307 for expr in exprs {
308 if let Expr::Lit(ExprLit {
309 lit: Lit::Int(int), ..
310 }) = expr
311 {
312 let value: usize = int.base10_parse()?;
313 result.push(value as u64);
314 } else {
315 return Err(syn::Error::new(expr.span(), "expected an integer"));
316 }
317 }
318 Ok(())
319 };
320 syn::parse::Parser::parse2(parser, tokens.clone())?;
321 Ok(result)
322}
323
324fn error_with_message(message: &str, spanned: impl Spanned) -> TokenStream {
325 error_to_stream(syn::Error::new(spanned.span(), message))
326}
327
328fn error_to_stream(err: syn::Error) -> TokenStream {
329 TokenStream::from(err.into_compile_error())
330}