derive_refineable.rs

  1use proc_macro::TokenStream;
  2use proc_macro2::TokenStream as TokenStream2;
  3use quote::{format_ident, quote};
  4use syn::{
  5    parse_macro_input, parse_quote, DeriveInput, Field, FieldsNamed, PredicateType, TraitBound,
  6    Type, TypeParamBound, WhereClause, WherePredicate,
  7};
  8
  9#[proc_macro_derive(Refineable, attributes(refineable))]
 10pub fn derive_refineable(input: TokenStream) -> TokenStream {
 11    let DeriveInput {
 12        ident,
 13        data,
 14        generics,
 15        attrs,
 16        ..
 17    } = parse_macro_input!(input);
 18
 19    let impl_debug_on_refinement = attrs
 20        .iter()
 21        .any(|attr| attr.path.is_ident("refineable") && attr.tokens.to_string().contains("debug"));
 22
 23    let refinement_ident = format_ident!("{}Refinement", ident);
 24    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
 25
 26    let fields = match data {
 27        syn::Data::Struct(syn::DataStruct {
 28            fields: syn::Fields::Named(FieldsNamed { named, .. }),
 29            ..
 30        }) => named.into_iter().collect::<Vec<Field>>(),
 31        _ => panic!("This derive macro only supports structs with named fields"),
 32    };
 33
 34    let field_names: Vec<_> = fields.iter().map(|f| f.ident.as_ref().unwrap()).collect();
 35    let field_visibilities: Vec<_> = fields.iter().map(|f| &f.vis).collect();
 36    let wrapped_types: Vec<_> = fields.iter().map(|f| get_wrapper_type(f, &f.ty)).collect();
 37
 38    // Create trait bound that each wrapped type must implement Clone // & Default
 39    let type_param_bounds: Vec<_> = wrapped_types
 40        .iter()
 41        .map(|ty| {
 42            WherePredicate::Type(PredicateType {
 43                lifetimes: None,
 44                bounded_ty: ty.clone(),
 45                colon_token: Default::default(),
 46                bounds: {
 47                    let mut punctuated = syn::punctuated::Punctuated::new();
 48                    punctuated.push_value(TypeParamBound::Trait(TraitBound {
 49                        paren_token: None,
 50                        modifier: syn::TraitBoundModifier::None,
 51                        lifetimes: None,
 52                        path: parse_quote!(Clone),
 53                    }));
 54
 55                    // punctuated.push_punct(syn::token::Add::default());
 56                    // punctuated.push_value(TypeParamBound::Trait(TraitBound {
 57                    //     paren_token: None,
 58                    //     modifier: syn::TraitBoundModifier::None,
 59                    //     lifetimes: None,
 60                    //     path: parse_quote!(Default),
 61                    // }));
 62                    punctuated
 63                },
 64            })
 65        })
 66        .collect();
 67
 68    // Append to where_clause or create a new one if it doesn't exist
 69    let where_clause = match where_clause.cloned() {
 70        Some(mut where_clause) => {
 71            where_clause
 72                .predicates
 73                .extend(type_param_bounds.into_iter());
 74            where_clause.clone()
 75        }
 76        None => WhereClause {
 77            where_token: Default::default(),
 78            predicates: type_param_bounds.into_iter().collect(),
 79        },
 80    };
 81
 82    // refinable_refine_assignments
 83    // refinable_refined_assignments
 84    // refinement_refine_assignments
 85
 86    let refineable_refine_assignments: Vec<TokenStream2> = fields
 87        .iter()
 88        .map(|field| {
 89            let name = &field.ident;
 90            let is_refineable = is_refineable_field(field);
 91            let is_optional = is_optional_field(field);
 92
 93            if is_refineable {
 94                quote! {
 95                    self.#name.refine(&refinement.#name);
 96                }
 97            } else if is_optional {
 98                quote! {
 99                    if let Some(ref value) = &refinement.#name {
100                        self.#name = Some(value.clone());
101                    }
102                }
103            } else {
104                quote! {
105                    if let Some(ref value) = &refinement.#name {
106                        self.#name = value.clone();
107                    }
108                }
109            }
110        })
111        .collect();
112
113    let refineable_refined_assignments: Vec<TokenStream2> = fields
114        .iter()
115        .map(|field| {
116            let name = &field.ident;
117            let is_refineable = is_refineable_field(field);
118            let is_optional = is_optional_field(field);
119
120            if is_refineable {
121                quote! {
122                    self.#name = self.#name.refined(refinement.#name);
123                }
124            } else if is_optional {
125                quote! {
126                    if let Some(value) = refinement.#name {
127                        self.#name = Some(value);
128                    }
129                }
130            } else {
131                quote! {
132                    if let Some(value) = refinement.#name {
133                        self.#name = value;
134                    }
135                }
136            }
137        })
138        .collect();
139
140    let refinement_refine_assigments: Vec<TokenStream2> = fields
141        .iter()
142        .map(|field| {
143            let name = &field.ident;
144            let is_refineable = is_refineable_field(field);
145
146            if is_refineable {
147                quote! {
148                    self.#name.refine(&refinement.#name);
149                }
150            } else {
151                quote! {
152                    if let Some(ref value) = &refinement.#name {
153                        self.#name = Some(value.clone());
154                    }
155                }
156            }
157        })
158        .collect();
159
160    let refinement_refined_assignments: Vec<TokenStream2> = fields
161        .iter()
162        .map(|field| {
163            let name = &field.ident;
164            let is_refineable = is_refineable_field(field);
165
166            if is_refineable {
167                quote! {
168                    self.#name = self.#name.refined(refinement.#name);
169                }
170            } else {
171                quote! {
172                    if refinement.#name.is_some() {
173                        self.#name = refinement.#name;
174                    }
175                }
176            }
177        })
178        .collect();
179
180    let debug_impl = if impl_debug_on_refinement {
181        let refinement_field_debugs: Vec<TokenStream2> = fields
182            .iter()
183            .map(|field| {
184                let name = &field.ident;
185                quote! {
186                    if self.#name.is_some() {
187                        debug_struct.field(stringify!(#name), &self.#name);
188                    } else {
189                        all_some = false;
190                    }
191                }
192            })
193            .collect();
194
195        quote! {
196            impl #impl_generics std::fmt::Debug for #refinement_ident #ty_generics
197                #where_clause
198            {
199                fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
200                    let mut debug_struct = f.debug_struct(stringify!(#refinement_ident));
201                    let mut all_some = true;
202                    #( #refinement_field_debugs )*
203                    if all_some {
204                        debug_struct.finish()
205                    } else {
206                        debug_struct.finish_non_exhaustive()
207                    }
208                }
209            }
210        }
211    } else {
212        quote! {}
213    };
214
215    let gen = quote! {
216        #[derive(Clone)]
217        pub struct #refinement_ident #impl_generics {
218            #( #field_visibilities #field_names: #wrapped_types ),*
219        }
220
221        impl #impl_generics Refineable for #ident #ty_generics
222            #where_clause
223        {
224            type Refinement = #refinement_ident #ty_generics;
225
226            fn refine(&mut self, refinement: &Self::Refinement) {
227                #( #refineable_refine_assignments )*
228            }
229
230            fn refined(mut self, refinement: Self::Refinement) -> Self {
231                #( #refineable_refined_assignments )*
232                self
233            }
234        }
235
236        impl #impl_generics Refineable for #refinement_ident #ty_generics
237            #where_clause
238        {
239            type Refinement = #refinement_ident #ty_generics;
240
241            fn refine(&mut self, refinement: &Self::Refinement) {
242                #( #refinement_refine_assigments )*
243            }
244
245            fn refined(mut self, refinement: Self::Refinement) -> Self {
246                #( #refinement_refined_assignments )*
247                self
248            }
249        }
250
251        impl #impl_generics ::core::default::Default for #refinement_ident #ty_generics
252            #where_clause
253        {
254            fn default() -> Self {
255                #refinement_ident {
256                    #( #field_names: Default::default() ),*
257                }
258            }
259        }
260
261        impl #impl_generics #refinement_ident #ty_generics
262            #where_clause
263        {
264            pub fn is_some(&self) -> bool {
265                #(
266                    if self.#field_names.is_some() {
267                        return true;
268                    }
269                )*
270                false
271            }
272        }
273
274        #debug_impl
275    };
276
277    gen.into()
278}
279
280fn is_refineable_field(f: &Field) -> bool {
281    f.attrs.iter().any(|attr| attr.path.is_ident("refineable"))
282}
283
284fn is_optional_field(f: &Field) -> bool {
285    if let Type::Path(typepath) = &f.ty {
286        if typepath.qself.is_none() {
287            let segments = &typepath.path.segments;
288            if segments.len() == 1 && segments.iter().any(|s| s.ident == "Option") {
289                return true;
290            }
291        }
292    }
293    false
294}
295
296fn get_wrapper_type(field: &Field, ty: &Type) -> syn::Type {
297    if is_refineable_field(field) {
298        let struct_name = if let Type::Path(tp) = ty {
299            tp.path.segments.last().unwrap().ident.clone()
300        } else {
301            panic!("Expected struct type for a refineable field");
302        };
303        let refinement_struct_name = format_ident!("{}Refinement", struct_name);
304        let generics = if let Type::Path(tp) = ty {
305            &tp.path.segments.last().unwrap().arguments
306        } else {
307            &syn::PathArguments::None
308        };
309        parse_quote!(#refinement_struct_name #generics)
310    } else if is_optional_field(field) {
311        ty.clone()
312    } else {
313        parse_quote!(Option<#ty>)
314    }
315}