derive_refineable.rs

  1use proc_macro::TokenStream;
  2use proc_macro2::TokenStream as TokenStream2;
  3use quote::{format_ident, quote};
  4use syn::{
  5    parse_macro_input, parse_quote, DeriveInput, Field, FieldsNamed, PredicateType, TraitBound,
  6    Type, TypeParamBound, WhereClause, WherePredicate,
  7};
  8
  9#[proc_macro_derive(Refineable, attributes(refineable))]
 10pub fn derive_refineable(input: TokenStream) -> TokenStream {
 11    let DeriveInput {
 12        ident,
 13        data,
 14        generics,
 15        attrs,
 16        ..
 17    } = parse_macro_input!(input);
 18
 19    let refineable_attr = attrs.iter().find(|attr| attr.path.is_ident("refineable"));
 20
 21    let mut impl_debug_on_refinement = false;
 22    let mut refinement_traits_to_derive = vec![];
 23
 24    if let Some(refineable_attr) = refineable_attr {
 25        if let Ok(syn::Meta::List(meta_list)) = refineable_attr.parse_meta() {
 26            for nested in meta_list.nested {
 27                let syn::NestedMeta::Meta(syn::Meta::Path(path)) = nested else {
 28                    continue;
 29                };
 30
 31                if path.is_ident("Debug") {
 32                    impl_debug_on_refinement = true;
 33                } else {
 34                    refinement_traits_to_derive.push(path);
 35                }
 36            }
 37        }
 38    }
 39
 40    let refinement_ident = format_ident!("{}Refinement", ident);
 41    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
 42
 43    let fields = match data {
 44        syn::Data::Struct(syn::DataStruct {
 45            fields: syn::Fields::Named(FieldsNamed { named, .. }),
 46            ..
 47        }) => named.into_iter().collect::<Vec<Field>>(),
 48        _ => panic!("This derive macro only supports structs with named fields"),
 49    };
 50
 51    let field_names: Vec<_> = fields.iter().map(|f| f.ident.as_ref().unwrap()).collect();
 52    let field_visibilities: Vec<_> = fields.iter().map(|f| &f.vis).collect();
 53    let wrapped_types: Vec<_> = fields.iter().map(|f| get_wrapper_type(f, &f.ty)).collect();
 54
 55    // Create trait bound that each wrapped type must implement Clone // & Default
 56    let type_param_bounds: Vec<_> = wrapped_types
 57        .iter()
 58        .map(|ty| {
 59            WherePredicate::Type(PredicateType {
 60                lifetimes: None,
 61                bounded_ty: ty.clone(),
 62                colon_token: Default::default(),
 63                bounds: {
 64                    let mut punctuated = syn::punctuated::Punctuated::new();
 65                    punctuated.push_value(TypeParamBound::Trait(TraitBound {
 66                        paren_token: None,
 67                        modifier: syn::TraitBoundModifier::None,
 68                        lifetimes: None,
 69                        path: parse_quote!(Clone),
 70                    }));
 71
 72                    // punctuated.push_punct(syn::token::Add::default());
 73                    // punctuated.push_value(TypeParamBound::Trait(TraitBound {
 74                    //     paren_token: None,
 75                    //     modifier: syn::TraitBoundModifier::None,
 76                    //     lifetimes: None,
 77                    //     path: parse_quote!(Default),
 78                    // }));
 79                    punctuated
 80                },
 81            })
 82        })
 83        .collect();
 84
 85    // Append to where_clause or create a new one if it doesn't exist
 86    let where_clause = match where_clause.cloned() {
 87        Some(mut where_clause) => {
 88            where_clause
 89                .predicates
 90                .extend(type_param_bounds.into_iter());
 91            where_clause.clone()
 92        }
 93        None => WhereClause {
 94            where_token: Default::default(),
 95            predicates: type_param_bounds.into_iter().collect(),
 96        },
 97    };
 98
 99    // refinable_refine_assignments
100    // refinable_refined_assignments
101    // refinement_refine_assignments
102
103    let refineable_refine_assignments: Vec<TokenStream2> = fields
104        .iter()
105        .map(|field| {
106            let name = &field.ident;
107            let is_refineable = is_refineable_field(field);
108            let is_optional = is_optional_field(field);
109
110            if is_refineable {
111                quote! {
112                    self.#name.refine(&refinement.#name);
113                }
114            } else if is_optional {
115                quote! {
116                    if let Some(ref value) = &refinement.#name {
117                        self.#name = Some(value.clone());
118                    }
119                }
120            } else {
121                quote! {
122                    if let Some(ref value) = &refinement.#name {
123                        self.#name = value.clone();
124                    }
125                }
126            }
127        })
128        .collect();
129
130    let refineable_refined_assignments: Vec<TokenStream2> = fields
131        .iter()
132        .map(|field| {
133            let name = &field.ident;
134            let is_refineable = is_refineable_field(field);
135            let is_optional = is_optional_field(field);
136
137            if is_refineable {
138                quote! {
139                    self.#name = self.#name.refined(refinement.#name);
140                }
141            } else if is_optional {
142                quote! {
143                    if let Some(value) = refinement.#name {
144                        self.#name = Some(value);
145                    }
146                }
147            } else {
148                quote! {
149                    if let Some(value) = refinement.#name {
150                        self.#name = value;
151                    }
152                }
153            }
154        })
155        .collect();
156
157    let refinement_refine_assigments: Vec<TokenStream2> = fields
158        .iter()
159        .map(|field| {
160            let name = &field.ident;
161            let is_refineable = is_refineable_field(field);
162
163            if is_refineable {
164                quote! {
165                    self.#name.refine(&refinement.#name);
166                }
167            } else {
168                quote! {
169                    if let Some(ref value) = &refinement.#name {
170                        self.#name = Some(value.clone());
171                    }
172                }
173            }
174        })
175        .collect();
176
177    let refinement_refined_assigments: Vec<TokenStream2> = fields
178        .iter()
179        .map(|field| {
180            let name = &field.ident;
181            let is_refineable = is_refineable_field(field);
182
183            if is_refineable {
184                quote! {
185                    self.#name = self.#name.refined(refinement.#name);
186                }
187            } else {
188                quote! {
189                    if let Some(value) = refinement.#name {
190                        self.#name = Some(value);
191                    }
192                }
193            }
194        })
195        .collect();
196
197    let from_refinement_assigments: Vec<TokenStream2> = fields
198        .iter()
199        .map(|field| {
200            let name = &field.ident;
201            let is_refineable = is_refineable_field(field);
202            let is_optional = is_optional_field(field);
203
204            if is_refineable {
205                quote! {
206                    #name: value.#name.into(),
207                }
208            } else if is_optional {
209                quote! {
210                    #name: value.#name.map(|v| v.into()),
211                }
212            } else {
213                quote! {
214                    #name: value.#name.map(|v| v.into()).unwrap_or_default(),
215                }
216            }
217        })
218        .collect();
219
220    let debug_impl = if impl_debug_on_refinement {
221        let refinement_field_debugs: Vec<TokenStream2> = fields
222            .iter()
223            .map(|field| {
224                let name = &field.ident;
225                quote! {
226                    if self.#name.is_some() {
227                        debug_struct.field(stringify!(#name), &self.#name);
228                    } else {
229                        all_some = false;
230                    }
231                }
232            })
233            .collect();
234
235        quote! {
236            impl #impl_generics std::fmt::Debug for #refinement_ident #ty_generics
237                #where_clause
238            {
239                fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
240                    let mut debug_struct = f.debug_struct(stringify!(#refinement_ident));
241                    let mut all_some = true;
242                    #( #refinement_field_debugs )*
243                    if all_some {
244                        debug_struct.finish()
245                    } else {
246                        debug_struct.finish_non_exhaustive()
247                    }
248                }
249            }
250        }
251    } else {
252        quote! {}
253    };
254
255    let mut derive_stream = quote! {};
256    for trait_to_derive in refinement_traits_to_derive {
257        derive_stream.extend(quote! { #[derive(#trait_to_derive)] })
258    }
259
260    let gen = quote! {
261        #[derive(Clone)]
262        #derive_stream
263        pub struct #refinement_ident #impl_generics {
264            #( #field_visibilities #field_names: #wrapped_types ),*
265        }
266
267        impl #impl_generics Refineable for #ident #ty_generics
268            #where_clause
269        {
270            type Refinement = #refinement_ident #ty_generics;
271
272            fn refine(&mut self, refinement: &Self::Refinement) {
273                #( #refineable_refine_assignments )*
274            }
275
276            fn refined(mut self, refinement: Self::Refinement) -> Self {
277                #( #refineable_refined_assignments )*
278                self
279            }
280        }
281
282        impl #impl_generics Refineable for #refinement_ident #ty_generics
283            #where_clause
284        {
285            type Refinement = #refinement_ident #ty_generics;
286
287            fn refine(&mut self, refinement: &Self::Refinement) {
288                #( #refinement_refine_assigments )*
289            }
290
291            fn refined(mut self, refinement: Self::Refinement) -> Self {
292                #( #refinement_refined_assigments )*
293                self
294            }
295        }
296
297        impl #impl_generics From<#refinement_ident #ty_generics> for #ident #ty_generics
298            #where_clause
299        {
300            fn from(value: #refinement_ident #ty_generics) -> Self {
301                Self {
302                    #( #from_refinement_assigments )*
303                }
304            }
305        }
306
307        impl #impl_generics ::core::default::Default for #refinement_ident #ty_generics
308            #where_clause
309        {
310            fn default() -> Self {
311                #refinement_ident {
312                    #( #field_names: Default::default() ),*
313                }
314            }
315        }
316
317        impl #impl_generics #refinement_ident #ty_generics
318            #where_clause
319        {
320            pub fn is_some(&self) -> bool {
321                #(
322                    if self.#field_names.is_some() {
323                        return true;
324                    }
325                )*
326                false
327            }
328        }
329
330        #debug_impl
331    };
332    gen.into()
333}
334
335fn is_refineable_field(f: &Field) -> bool {
336    f.attrs.iter().any(|attr| attr.path.is_ident("refineable"))
337}
338
339fn is_optional_field(f: &Field) -> bool {
340    if let Type::Path(typepath) = &f.ty {
341        if typepath.qself.is_none() {
342            let segments = &typepath.path.segments;
343            if segments.len() == 1 && segments.iter().any(|s| s.ident == "Option") {
344                return true;
345            }
346        }
347    }
348    false
349}
350
351fn get_wrapper_type(field: &Field, ty: &Type) -> syn::Type {
352    if is_refineable_field(field) {
353        let struct_name = if let Type::Path(tp) = ty {
354            tp.path.segments.last().unwrap().ident.clone()
355        } else {
356            panic!("Expected struct type for a refineable field");
357        };
358        let refinement_struct_name = format_ident!("{}Refinement", struct_name);
359        let generics = if let Type::Path(tp) = ty {
360            &tp.path.segments.last().unwrap().arguments
361        } else {
362            &syn::PathArguments::None
363        };
364        parse_quote!(#refinement_struct_name #generics)
365    } else if is_optional_field(field) {
366        ty.clone()
367    } else {
368        parse_quote!(Option<#ty>)
369    }
370}