1use proc_macro::TokenStream;
2use proc_macro2::TokenStream as TokenStream2;
3use quote::{format_ident, quote};
4use syn::{
5 parse_macro_input, parse_quote, DeriveInput, Field, FieldsNamed, PredicateType, TraitBound,
6 Type, TypeParamBound, WhereClause, WherePredicate,
7};
8
9#[proc_macro_derive(Refineable, attributes(refineable))]
10pub fn derive_refineable(input: TokenStream) -> TokenStream {
11 let DeriveInput {
12 ident,
13 data,
14 generics,
15 attrs,
16 ..
17 } = parse_macro_input!(input);
18
19 let impl_debug_on_refinement = attrs
20 .iter()
21 .any(|attr| attr.path.is_ident("refineable") && attr.tokens.to_string().contains("debug"));
22
23 let refinement_ident = format_ident!("{}Refinement", ident);
24 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
25
26 let fields = match data {
27 syn::Data::Struct(syn::DataStruct {
28 fields: syn::Fields::Named(FieldsNamed { named, .. }),
29 ..
30 }) => named.into_iter().collect::<Vec<Field>>(),
31 _ => panic!("This derive macro only supports structs with named fields"),
32 };
33
34 let field_names: Vec<_> = fields.iter().map(|f| f.ident.as_ref().unwrap()).collect();
35 let field_visibilities: Vec<_> = fields.iter().map(|f| &f.vis).collect();
36 let wrapped_types: Vec<_> = fields.iter().map(|f| get_wrapper_type(f, &f.ty)).collect();
37
38 // Create trait bound that each wrapped type must implement Clone // & Default
39 let type_param_bounds: Vec<_> = wrapped_types
40 .iter()
41 .map(|ty| {
42 WherePredicate::Type(PredicateType {
43 lifetimes: None,
44 bounded_ty: ty.clone(),
45 colon_token: Default::default(),
46 bounds: {
47 let mut punctuated = syn::punctuated::Punctuated::new();
48 punctuated.push_value(TypeParamBound::Trait(TraitBound {
49 paren_token: None,
50 modifier: syn::TraitBoundModifier::None,
51 lifetimes: None,
52 path: parse_quote!(Clone),
53 }));
54
55 // punctuated.push_punct(syn::token::Add::default());
56 // punctuated.push_value(TypeParamBound::Trait(TraitBound {
57 // paren_token: None,
58 // modifier: syn::TraitBoundModifier::None,
59 // lifetimes: None,
60 // path: parse_quote!(Default),
61 // }));
62 punctuated
63 },
64 })
65 })
66 .collect();
67
68 // Append to where_clause or create a new one if it doesn't exist
69 let where_clause = match where_clause.cloned() {
70 Some(mut where_clause) => {
71 where_clause
72 .predicates
73 .extend(type_param_bounds.into_iter());
74 where_clause.clone()
75 }
76 None => WhereClause {
77 where_token: Default::default(),
78 predicates: type_param_bounds.into_iter().collect(),
79 },
80 };
81
82 let field_assignments: Vec<TokenStream2> = fields
83 .iter()
84 .map(|field| {
85 let name = &field.ident;
86 let is_refineable = is_refineable_field(field);
87 let is_optional = is_optional_field(field);
88
89 if is_refineable {
90 quote! {
91 self.#name.refine(&refinement.#name);
92 }
93 } else if is_optional {
94 quote! {
95 if let Some(ref value) = &refinement.#name {
96 self.#name = Some(value.clone());
97 }
98 }
99 } else {
100 quote! {
101 if let Some(ref value) = &refinement.#name {
102 self.#name = value.clone();
103 }
104 }
105 }
106 })
107 .collect();
108
109 let refinement_field_assignments: Vec<TokenStream2> = fields
110 .iter()
111 .map(|field| {
112 let name = &field.ident;
113 let is_refineable = is_refineable_field(field);
114
115 if is_refineable {
116 quote! {
117 self.#name.refine(&refinement.#name);
118 }
119 } else {
120 quote! {
121 if let Some(ref value) = &refinement.#name {
122 self.#name = Some(value.clone());
123 }
124 }
125 }
126 })
127 .collect();
128
129 let debug_impl = if impl_debug_on_refinement {
130 let refinement_field_debugs: Vec<TokenStream2> = fields
131 .iter()
132 .map(|field| {
133 let name = &field.ident;
134 quote! {
135 if self.#name.is_some() {
136 debug_struct.field(stringify!(#name), &self.#name);
137 } else {
138 all_some = false;
139 }
140 }
141 })
142 .collect();
143
144 quote! {
145 impl #impl_generics std::fmt::Debug for #refinement_ident #ty_generics
146 #where_clause
147 {
148 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
149 let mut debug_struct = f.debug_struct(stringify!(#refinement_ident));
150 let mut all_some = true;
151 #( #refinement_field_debugs )*
152 if all_some {
153 debug_struct.finish()
154 } else {
155 debug_struct.finish_non_exhaustive()
156 }
157 }
158 }
159 }
160 } else {
161 quote! {}
162 };
163
164 let gen = quote! {
165 #[derive(Clone)]
166 pub struct #refinement_ident #impl_generics {
167 #( #field_visibilities #field_names: #wrapped_types ),*
168 }
169
170 impl #impl_generics Refineable for #ident #ty_generics
171 #where_clause
172 {
173 type Refinement = #refinement_ident #ty_generics;
174
175 fn refine(&mut self, refinement: &Self::Refinement) {
176 #( #field_assignments )*
177 }
178 }
179
180 impl #impl_generics Refineable for #refinement_ident #ty_generics
181 #where_clause
182 {
183 type Refinement = #refinement_ident #ty_generics;
184
185 fn refine(&mut self, refinement: &Self::Refinement) {
186 #( #refinement_field_assignments )*
187 }
188 }
189
190 impl #impl_generics ::core::default::Default for #refinement_ident #ty_generics
191 #where_clause
192 {
193 fn default() -> Self {
194 #refinement_ident {
195 #( #field_names: Default::default() ),*
196 }
197 }
198 }
199
200 impl #impl_generics #refinement_ident #ty_generics
201 #where_clause
202 {
203 pub fn is_some(&self) -> bool {
204 #(
205 if self.#field_names.is_some() {
206 return true;
207 }
208 )*
209 false
210 }
211 }
212
213 #debug_impl
214 };
215 gen.into()
216}
217
218fn is_refineable_field(f: &Field) -> bool {
219 f.attrs.iter().any(|attr| attr.path.is_ident("refineable"))
220}
221
222fn is_optional_field(f: &Field) -> bool {
223 if let Type::Path(typepath) = &f.ty {
224 if typepath.qself.is_none() {
225 let segments = &typepath.path.segments;
226 if segments.len() == 1 && segments.iter().any(|s| s.ident == "Option") {
227 return true;
228 }
229 }
230 }
231 false
232}
233
234fn get_wrapper_type(field: &Field, ty: &Type) -> syn::Type {
235 if is_refineable_field(field) {
236 let struct_name = if let Type::Path(tp) = ty {
237 tp.path.segments.last().unwrap().ident.clone()
238 } else {
239 panic!("Expected struct type for a refineable field");
240 };
241 let refinement_struct_name = format_ident!("{}Refinement", struct_name);
242 let generics = if let Type::Path(tp) = ty {
243 &tp.path.segments.last().unwrap().arguments
244 } else {
245 &syn::PathArguments::None
246 };
247 parse_quote!(#refinement_struct_name #generics)
248 } else if is_optional_field(field) {
249 ty.clone()
250 } else {
251 parse_quote!(Option<#ty>)
252 }
253}