derive_refineable.rs

  1use proc_macro::TokenStream;
  2use proc_macro2::TokenStream as TokenStream2;
  3use quote::{format_ident, quote};
  4use syn::{
  5    parse_macro_input, parse_quote, DeriveInput, Field, FieldsNamed, PredicateType, TraitBound,
  6    Type, TypeParamBound, WhereClause, WherePredicate,
  7};
  8
  9#[proc_macro_derive(Refineable, attributes(refineable))]
 10pub fn derive_refineable(input: TokenStream) -> TokenStream {
 11    let DeriveInput {
 12        ident,
 13        data,
 14        generics,
 15        attrs,
 16        ..
 17    } = parse_macro_input!(input);
 18
 19    let impl_debug_on_refinement = attrs
 20        .iter()
 21        .any(|attr| attr.path.is_ident("refineable") && attr.tokens.to_string().contains("debug"));
 22
 23    let refinement_ident = format_ident!("{}Refinement", ident);
 24    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
 25
 26    let fields = match data {
 27        syn::Data::Struct(syn::DataStruct {
 28            fields: syn::Fields::Named(FieldsNamed { named, .. }),
 29            ..
 30        }) => named.into_iter().collect::<Vec<Field>>(),
 31        _ => panic!("This derive macro only supports structs with named fields"),
 32    };
 33
 34    let field_names: Vec<_> = fields.iter().map(|f| f.ident.as_ref().unwrap()).collect();
 35    let field_visibilities: Vec<_> = fields.iter().map(|f| &f.vis).collect();
 36    let wrapped_types: Vec<_> = fields.iter().map(|f| get_wrapper_type(f, &f.ty)).collect();
 37
 38    // Create trait bound that each wrapped type must implement Clone // & Default
 39    let type_param_bounds: Vec<_> = wrapped_types
 40        .iter()
 41        .map(|ty| {
 42            WherePredicate::Type(PredicateType {
 43                lifetimes: None,
 44                bounded_ty: ty.clone(),
 45                colon_token: Default::default(),
 46                bounds: {
 47                    let mut punctuated = syn::punctuated::Punctuated::new();
 48                    punctuated.push_value(TypeParamBound::Trait(TraitBound {
 49                        paren_token: None,
 50                        modifier: syn::TraitBoundModifier::None,
 51                        lifetimes: None,
 52                        path: parse_quote!(Clone),
 53                    }));
 54
 55                    // punctuated.push_punct(syn::token::Add::default());
 56                    // punctuated.push_value(TypeParamBound::Trait(TraitBound {
 57                    //     paren_token: None,
 58                    //     modifier: syn::TraitBoundModifier::None,
 59                    //     lifetimes: None,
 60                    //     path: parse_quote!(Default),
 61                    // }));
 62                    punctuated
 63                },
 64            })
 65        })
 66        .collect();
 67
 68    // Append to where_clause or create a new one if it doesn't exist
 69    let where_clause = match where_clause.cloned() {
 70        Some(mut where_clause) => {
 71            where_clause
 72                .predicates
 73                .extend(type_param_bounds.into_iter());
 74            where_clause.clone()
 75        }
 76        None => WhereClause {
 77            where_token: Default::default(),
 78            predicates: type_param_bounds.into_iter().collect(),
 79        },
 80    };
 81
 82    // refinable_refine_assignments
 83    // refinable_refined_assignments
 84    // refinement_refine_assignments
 85
 86    let refineable_refine_assignments: Vec<TokenStream2> = fields
 87        .iter()
 88        .map(|field| {
 89            let name = &field.ident;
 90            let is_refineable = is_refineable_field(field);
 91            let is_optional = is_optional_field(field);
 92
 93            if is_refineable {
 94                quote! {
 95                    self.#name.refine(&refinement.#name);
 96                }
 97            } else if is_optional {
 98                quote! {
 99                    if let Some(ref value) = &refinement.#name {
100                        self.#name = Some(value.clone());
101                    }
102                }
103            } else {
104                quote! {
105                    if let Some(ref value) = &refinement.#name {
106                        self.#name = value.clone();
107                    }
108                }
109            }
110        })
111        .collect();
112
113    let refineable_refined_assignments: Vec<TokenStream2> = fields
114        .iter()
115        .map(|field| {
116            let name = &field.ident;
117            let is_refineable = is_refineable_field(field);
118            let is_optional = is_optional_field(field);
119
120            if is_refineable {
121                quote! {
122                    self.#name = self.#name.refined(refinement.#name);
123                }
124            } else if is_optional {
125                quote! {
126                    if let Some(value) = refinement.#name {
127                        self.#name = Some(value);
128                    }
129                }
130            } else {
131                quote! {
132                    if let Some(value) = refinement.#name {
133                        self.#name = value;
134                    }
135                }
136            }
137        })
138        .collect();
139
140    let refinement_refine_assigments: Vec<TokenStream2> = fields
141        .iter()
142        .map(|field| {
143            let name = &field.ident;
144            let is_refineable = is_refineable_field(field);
145
146            if is_refineable {
147                quote! {
148                    self.#name.refine(&refinement.#name);
149                }
150            } else {
151                quote! {
152                    if let Some(ref value) = &refinement.#name {
153                        self.#name = Some(value.clone());
154                    }
155                }
156            }
157        })
158        .collect();
159
160    let refinement_refined_assigments: Vec<TokenStream2> = fields
161        .iter()
162        .map(|field| {
163            let name = &field.ident;
164            let is_refineable = is_refineable_field(field);
165
166            if is_refineable {
167                quote! {
168                    self.#name = self.#name.refined(refinement.#name);
169                }
170            } else {
171                quote! {
172                    if let Some(value) = refinement.#name {
173                        self.#name = Some(value);
174                    }
175                }
176            }
177        })
178        .collect();
179
180    let from_refinement_assigments: Vec<TokenStream2> = fields
181        .iter()
182        .map(|field| {
183            let name = &field.ident;
184            let is_refineable = is_refineable_field(field);
185            let is_optional = is_optional_field(field);
186
187            if is_refineable {
188                quote! {
189                    #name: value.#name.into(),
190                }
191            } else if is_optional {
192                quote! {
193                    #name: value.#name.map(|v| v.into()),
194                }
195            } else {
196                quote! {
197                    #name: value.#name.map(|v| v.into()).unwrap_or_default(),
198                }
199            }
200        })
201        .collect();
202
203    let debug_impl = if impl_debug_on_refinement {
204        let refinement_field_debugs: Vec<TokenStream2> = fields
205            .iter()
206            .map(|field| {
207                let name = &field.ident;
208                quote! {
209                    if self.#name.is_some() {
210                        debug_struct.field(stringify!(#name), &self.#name);
211                    } else {
212                        all_some = false;
213                    }
214                }
215            })
216            .collect();
217
218        quote! {
219            impl #impl_generics std::fmt::Debug for #refinement_ident #ty_generics
220                #where_clause
221            {
222                fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
223                    let mut debug_struct = f.debug_struct(stringify!(#refinement_ident));
224                    let mut all_some = true;
225                    #( #refinement_field_debugs )*
226                    if all_some {
227                        debug_struct.finish()
228                    } else {
229                        debug_struct.finish_non_exhaustive()
230                    }
231                }
232            }
233        }
234    } else {
235        quote! {}
236    };
237
238    let gen = quote! {
239        #[derive(Clone)]
240        pub struct #refinement_ident #impl_generics {
241            #( #field_visibilities #field_names: #wrapped_types ),*
242        }
243
244        impl #impl_generics Refineable for #ident #ty_generics
245            #where_clause
246        {
247            type Refinement = #refinement_ident #ty_generics;
248
249            fn refine(&mut self, refinement: &Self::Refinement) {
250                #( #refineable_refine_assignments )*
251            }
252
253            fn refined(mut self, refinement: Self::Refinement) -> Self {
254                #( #refineable_refined_assignments )*
255                self
256            }
257        }
258
259        impl #impl_generics Refineable for #refinement_ident #ty_generics
260            #where_clause
261        {
262            type Refinement = #refinement_ident #ty_generics;
263
264            fn refine(&mut self, refinement: &Self::Refinement) {
265                #( #refinement_refine_assigments )*
266            }
267
268            fn refined(mut self, refinement: Self::Refinement) -> Self {
269                #( #refinement_refined_assigments )*
270                self
271            }
272        }
273
274        impl #impl_generics From<#refinement_ident #ty_generics> for #ident #ty_generics
275            #where_clause
276        {
277            fn from(value: #refinement_ident #ty_generics) -> Self {
278                Self {
279                    #( #from_refinement_assigments )*
280                }
281            }
282        }
283
284        impl #impl_generics ::core::default::Default for #refinement_ident #ty_generics
285            #where_clause
286        {
287            fn default() -> Self {
288                #refinement_ident {
289                    #( #field_names: Default::default() ),*
290                }
291            }
292        }
293
294        impl #impl_generics #refinement_ident #ty_generics
295            #where_clause
296        {
297            pub fn is_some(&self) -> bool {
298                #(
299                    if self.#field_names.is_some() {
300                        return true;
301                    }
302                )*
303                false
304            }
305        }
306
307        #debug_impl
308    };
309    gen.into()
310}
311
312fn is_refineable_field(f: &Field) -> bool {
313    f.attrs.iter().any(|attr| attr.path.is_ident("refineable"))
314}
315
316fn is_optional_field(f: &Field) -> bool {
317    if let Type::Path(typepath) = &f.ty {
318        if typepath.qself.is_none() {
319            let segments = &typepath.path.segments;
320            if segments.len() == 1 && segments.iter().any(|s| s.ident == "Option") {
321                return true;
322            }
323        }
324    }
325    false
326}
327
328fn get_wrapper_type(field: &Field, ty: &Type) -> syn::Type {
329    if is_refineable_field(field) {
330        let struct_name = if let Type::Path(tp) = ty {
331            tp.path.segments.last().unwrap().ident.clone()
332        } else {
333            panic!("Expected struct type for a refineable field");
334        };
335        let refinement_struct_name = format_ident!("{}Refinement", struct_name);
336        let generics = if let Type::Path(tp) = ty {
337            &tp.path.segments.last().unwrap().arguments
338        } else {
339            &syn::PathArguments::None
340        };
341        parse_quote!(#refinement_struct_name #generics)
342    } else if is_optional_field(field) {
343        ty.clone()
344    } else {
345        parse_quote!(Option<#ty>)
346    }
347}