derive_refineable.rs

  1use proc_macro::TokenStream;
  2use proc_macro2::TokenStream as TokenStream2;
  3use quote::{format_ident, quote};
  4use syn::{
  5    parse_macro_input, parse_quote, DeriveInput, Field, FieldsNamed, PredicateType, TraitBound,
  6    Type, TypeParamBound, WhereClause, WherePredicate,
  7};
  8
  9#[proc_macro_derive(Refineable, attributes(refineable))]
 10pub fn derive_refineable(input: TokenStream) -> TokenStream {
 11    let DeriveInput {
 12        ident,
 13        data,
 14        generics,
 15        attrs,
 16        ..
 17    } = parse_macro_input!(input);
 18
 19    let refineable_attr = attrs.iter().find(|attr| attr.path.is_ident("refineable"));
 20
 21    let mut impl_debug_on_refinement = false;
 22    let mut refinement_traits_to_derive = vec![];
 23
 24    if let Some(refineable_attr) = refineable_attr {
 25        if let Ok(syn::Meta::List(meta_list)) = refineable_attr.parse_meta() {
 26            for nested in meta_list.nested {
 27                let syn::NestedMeta::Meta(syn::Meta::Path(path)) = nested else {
 28                    continue;
 29                };
 30
 31                if path.is_ident("Debug") {
 32                    impl_debug_on_refinement = true;
 33                } else {
 34                    refinement_traits_to_derive.push(path);
 35                }
 36            }
 37        }
 38    }
 39
 40    let refinement_ident = format_ident!("{}Refinement", ident);
 41    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
 42
 43    let fields = match data {
 44        syn::Data::Struct(syn::DataStruct {
 45            fields: syn::Fields::Named(FieldsNamed { named, .. }),
 46            ..
 47        }) => named.into_iter().collect::<Vec<Field>>(),
 48        _ => panic!("This derive macro only supports structs with named fields"),
 49    };
 50
 51    let field_names: Vec<_> = fields.iter().map(|f| f.ident.as_ref().unwrap()).collect();
 52    let field_visibilities: Vec<_> = fields.iter().map(|f| &f.vis).collect();
 53    let wrapped_types: Vec<_> = fields.iter().map(|f| get_wrapper_type(f, &f.ty)).collect();
 54
 55    // Create trait bound that each wrapped type must implement Clone // & Default
 56    let type_param_bounds: Vec<_> = wrapped_types
 57        .iter()
 58        .map(|ty| {
 59            WherePredicate::Type(PredicateType {
 60                lifetimes: None,
 61                bounded_ty: ty.clone(),
 62                colon_token: Default::default(),
 63                bounds: {
 64                    let mut punctuated = syn::punctuated::Punctuated::new();
 65                    punctuated.push_value(TypeParamBound::Trait(TraitBound {
 66                        paren_token: None,
 67                        modifier: syn::TraitBoundModifier::None,
 68                        lifetimes: None,
 69                        path: parse_quote!(Clone),
 70                    }));
 71
 72                    // punctuated.push_punct(syn::token::Add::default());
 73                    // punctuated.push_value(TypeParamBound::Trait(TraitBound {
 74                    //     paren_token: None,
 75                    //     modifier: syn::TraitBoundModifier::None,
 76                    //     lifetimes: None,
 77                    //     path: parse_quote!(Default),
 78                    // }));
 79                    punctuated
 80                },
 81            })
 82        })
 83        .collect();
 84
 85    // Append to where_clause or create a new one if it doesn't exist
 86    let where_clause = match where_clause.cloned() {
 87        Some(mut where_clause) => {
 88            where_clause.predicates.extend(type_param_bounds);
 89            where_clause.clone()
 90        }
 91        None => WhereClause {
 92            where_token: Default::default(),
 93            predicates: type_param_bounds.into_iter().collect(),
 94        },
 95    };
 96
 97    // refinable_refine_assignments
 98    // refinable_refined_assignments
 99    // refinement_refine_assignments
100
101    let refineable_refine_assignments: Vec<TokenStream2> = fields
102        .iter()
103        .map(|field| {
104            let name = &field.ident;
105            let is_refineable = is_refineable_field(field);
106            let is_optional = is_optional_field(field);
107
108            if is_refineable {
109                quote! {
110                    self.#name.refine(&refinement.#name);
111                }
112            } else if is_optional {
113                quote! {
114                    if let Some(ref value) = &refinement.#name {
115                        self.#name = Some(value.clone());
116                    }
117                }
118            } else {
119                quote! {
120                    if let Some(ref value) = &refinement.#name {
121                        self.#name = value.clone();
122                    }
123                }
124            }
125        })
126        .collect();
127
128    let refineable_refined_assignments: Vec<TokenStream2> = fields
129        .iter()
130        .map(|field| {
131            let name = &field.ident;
132            let is_refineable = is_refineable_field(field);
133            let is_optional = is_optional_field(field);
134
135            if is_refineable {
136                quote! {
137                    self.#name = self.#name.refined(refinement.#name);
138                }
139            } else if is_optional {
140                quote! {
141                    if let Some(value) = refinement.#name {
142                        self.#name = Some(value);
143                    }
144                }
145            } else {
146                quote! {
147                    if let Some(value) = refinement.#name {
148                        self.#name = value;
149                    }
150                }
151            }
152        })
153        .collect();
154
155    let refinement_refine_assigments: Vec<TokenStream2> = fields
156        .iter()
157        .map(|field| {
158            let name = &field.ident;
159            let is_refineable = is_refineable_field(field);
160
161            if is_refineable {
162                quote! {
163                    self.#name.refine(&refinement.#name);
164                }
165            } else {
166                quote! {
167                    if let Some(ref value) = &refinement.#name {
168                        self.#name = Some(value.clone());
169                    }
170                }
171            }
172        })
173        .collect();
174
175    let refinement_refined_assigments: Vec<TokenStream2> = fields
176        .iter()
177        .map(|field| {
178            let name = &field.ident;
179            let is_refineable = is_refineable_field(field);
180
181            if is_refineable {
182                quote! {
183                    self.#name = self.#name.refined(refinement.#name);
184                }
185            } else {
186                quote! {
187                    if let Some(value) = refinement.#name {
188                        self.#name = Some(value);
189                    }
190                }
191            }
192        })
193        .collect();
194
195    let from_refinement_assigments: Vec<TokenStream2> = fields
196        .iter()
197        .map(|field| {
198            let name = &field.ident;
199            let is_refineable = is_refineable_field(field);
200            let is_optional = is_optional_field(field);
201
202            if is_refineable {
203                quote! {
204                    #name: value.#name.into(),
205                }
206            } else if is_optional {
207                quote! {
208                    #name: value.#name.map(|v| v.into()),
209                }
210            } else {
211                quote! {
212                    #name: value.#name.map(|v| v.into()).unwrap_or_default(),
213                }
214            }
215        })
216        .collect();
217
218    let debug_impl = if impl_debug_on_refinement {
219        let refinement_field_debugs: Vec<TokenStream2> = fields
220            .iter()
221            .map(|field| {
222                let name = &field.ident;
223                quote! {
224                    if self.#name.is_some() {
225                        debug_struct.field(stringify!(#name), &self.#name);
226                    } else {
227                        all_some = false;
228                    }
229                }
230            })
231            .collect();
232
233        quote! {
234            impl #impl_generics std::fmt::Debug for #refinement_ident #ty_generics
235                #where_clause
236            {
237                fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
238                    let mut debug_struct = f.debug_struct(stringify!(#refinement_ident));
239                    let mut all_some = true;
240                    #( #refinement_field_debugs )*
241                    if all_some {
242                        debug_struct.finish()
243                    } else {
244                        debug_struct.finish_non_exhaustive()
245                    }
246                }
247            }
248        }
249    } else {
250        quote! {}
251    };
252
253    let mut derive_stream = quote! {};
254    for trait_to_derive in refinement_traits_to_derive {
255        derive_stream.extend(quote! { #[derive(#trait_to_derive)] })
256    }
257
258    let gen = quote! {
259        #[derive(Clone)]
260        #derive_stream
261        pub struct #refinement_ident #impl_generics {
262            #( #field_visibilities #field_names: #wrapped_types ),*
263        }
264
265        impl #impl_generics Refineable for #ident #ty_generics
266            #where_clause
267        {
268            type Refinement = #refinement_ident #ty_generics;
269
270            fn refine(&mut self, refinement: &Self::Refinement) {
271                #( #refineable_refine_assignments )*
272            }
273
274            fn refined(mut self, refinement: Self::Refinement) -> Self {
275                #( #refineable_refined_assignments )*
276                self
277            }
278        }
279
280        impl #impl_generics Refineable for #refinement_ident #ty_generics
281            #where_clause
282        {
283            type Refinement = #refinement_ident #ty_generics;
284
285            fn refine(&mut self, refinement: &Self::Refinement) {
286                #( #refinement_refine_assigments )*
287            }
288
289            fn refined(mut self, refinement: Self::Refinement) -> Self {
290                #( #refinement_refined_assigments )*
291                self
292            }
293        }
294
295        impl #impl_generics From<#refinement_ident #ty_generics> for #ident #ty_generics
296            #where_clause
297        {
298            fn from(value: #refinement_ident #ty_generics) -> Self {
299                Self {
300                    #( #from_refinement_assigments )*
301                }
302            }
303        }
304
305        impl #impl_generics ::core::default::Default for #refinement_ident #ty_generics
306            #where_clause
307        {
308            fn default() -> Self {
309                #refinement_ident {
310                    #( #field_names: Default::default() ),*
311                }
312            }
313        }
314
315        impl #impl_generics #refinement_ident #ty_generics
316            #where_clause
317        {
318            pub fn is_some(&self) -> bool {
319                #(
320                    if self.#field_names.is_some() {
321                        return true;
322                    }
323                )*
324                false
325            }
326        }
327
328        #debug_impl
329    };
330    gen.into()
331}
332
333fn is_refineable_field(f: &Field) -> bool {
334    f.attrs.iter().any(|attr| attr.path.is_ident("refineable"))
335}
336
337fn is_optional_field(f: &Field) -> bool {
338    if let Type::Path(typepath) = &f.ty {
339        if typepath.qself.is_none() {
340            let segments = &typepath.path.segments;
341            if segments.len() == 1 && segments.iter().any(|s| s.ident == "Option") {
342                return true;
343            }
344        }
345    }
346    false
347}
348
349fn get_wrapper_type(field: &Field, ty: &Type) -> syn::Type {
350    if is_refineable_field(field) {
351        let struct_name = if let Type::Path(tp) = ty {
352            tp.path.segments.last().unwrap().ident.clone()
353        } else {
354            panic!("Expected struct type for a refineable field");
355        };
356        let refinement_struct_name = format_ident!("{}Refinement", struct_name);
357        let generics = if let Type::Path(tp) = ty {
358            &tp.path.segments.last().unwrap().arguments
359        } else {
360            &syn::PathArguments::None
361        };
362        parse_quote!(#refinement_struct_name #generics)
363    } else if is_optional_field(field) {
364        ty.clone()
365    } else {
366        parse_quote!(Option<#ty>)
367    }
368}