1use proc_macro::TokenStream;
2use quote::{format_ident, quote};
3use std::mem;
4use syn::{
5 parse_macro_input, parse_quote, AttributeArgs, ItemFn, Lit, Meta, MetaNameValue, NestedMeta,
6};
7
8#[proc_macro_attribute]
9pub fn test(args: TokenStream, function: TokenStream) -> TokenStream {
10 let mut namespace = format_ident!("gpui");
11
12 let args = syn::parse_macro_input!(args as AttributeArgs);
13 let mut max_retries = 0;
14 for arg in args {
15 match arg {
16 NestedMeta::Meta(Meta::Path(name))
17 if name.get_ident().map_or(false, |n| n == "self") =>
18 {
19 namespace = format_ident!("crate");
20 }
21 NestedMeta::Meta(Meta::NameValue(meta)) => {
22 if let Some(result) = parse_retries(&meta) {
23 match result {
24 Ok(retries) => max_retries = retries,
25 Err(error) => return TokenStream::from(error.into_compile_error()),
26 }
27 }
28 }
29 other => {
30 return TokenStream::from(
31 syn::Error::new_spanned(other, "invalid argument").into_compile_error(),
32 )
33 }
34 }
35 }
36
37 let mut inner_fn = parse_macro_input!(function as ItemFn);
38 let inner_fn_attributes = mem::take(&mut inner_fn.attrs);
39 let inner_fn_name = format_ident!("_{}", inner_fn.sig.ident);
40 let outer_fn_name = mem::replace(&mut inner_fn.sig.ident, inner_fn_name.clone());
41
42 // Pass to the test function the number of app contexts that it needs,
43 // based on its parameter list.
44 let inner_fn_args = (0..inner_fn.sig.inputs.len())
45 .map(|i| {
46 let first_entity_id = i * 100_000;
47 quote!(#namespace::TestAppContext::new(foreground.clone(), #first_entity_id),)
48 })
49 .collect::<proc_macro2::TokenStream>();
50
51 let mut outer_fn: ItemFn = if inner_fn.sig.asyncness.is_some() {
52 parse_quote! {
53 #[test]
54 fn #outer_fn_name() {
55 #inner_fn
56
57 if #max_retries > 0 {
58 let mut retries = 0;
59 loop {
60 let result = std::panic::catch_unwind(|| {
61 let foreground = ::std::rc::Rc::new(#namespace::executor::Foreground::test());
62 #namespace::block_on(foreground.run(#inner_fn_name(#inner_fn_args)));
63 });
64
65 match result {
66 Ok(result) => return result,
67 Err(error) => {
68 if retries < #max_retries {
69 retries += 1;
70 println!("retrying: attempt {}", retries);
71 } else {
72 std::panic::resume_unwind(error);
73 }
74 }
75 }
76 }
77 } else {
78 let foreground = ::std::rc::Rc::new(#namespace::executor::Foreground::test());
79 #namespace::block_on(foreground.run(#inner_fn_name(#inner_fn_args)));
80 }
81 }
82 }
83 } else {
84 parse_quote! {
85 #[test]
86 fn #outer_fn_name() {
87 #inner_fn
88
89 if #max_retries > 0 {
90 let mut retries = 0;
91 loop {
92 let result = std::panic::catch_unwind(|| {
93 #namespace::App::test(|cx| {
94 #inner_fn_name(cx);
95 });
96 });
97
98 match result {
99 Ok(result) => return result,
100 Err(error) => {
101 if retries < #max_retries {
102 retries += 1;
103 println!("retrying: attempt {}", retries);
104 } else {
105 std::panic::resume_unwind(error);
106 }
107 }
108 }
109 }
110 } else {
111 #namespace::App::test(|cx| {
112 #inner_fn_name(cx);
113 });
114 }
115 }
116 }
117 };
118 outer_fn.attrs.extend(inner_fn_attributes);
119
120 TokenStream::from(quote!(#outer_fn))
121}
122
123fn parse_retries(meta: &MetaNameValue) -> Option<syn::Result<usize>> {
124 let ident = meta.path.get_ident();
125 if ident.map_or(false, |n| n == "retries") {
126 if let Lit::Int(int) = &meta.lit {
127 Some(int.base10_parse())
128 } else {
129 Some(Err(syn::Error::new(
130 meta.lit.span(),
131 "retries mut be an integer",
132 )))
133 }
134 } else {
135 None
136 }
137}