derive_refineable.rs

  1use proc_macro::TokenStream;
  2use proc_macro2::TokenStream as TokenStream2;
  3use quote::{format_ident, quote};
  4use syn::{
  5    parse_macro_input, parse_quote, DeriveInput, Field, FieldsNamed, PredicateType, TraitBound,
  6    Type, TypeParamBound, WhereClause, WherePredicate,
  7};
  8
  9#[proc_macro_derive(Refineable, attributes(refineable))]
 10pub fn derive_refineable(input: TokenStream) -> TokenStream {
 11    let DeriveInput {
 12        ident,
 13        data,
 14        generics,
 15        attrs,
 16        ..
 17    } = parse_macro_input!(input);
 18
 19    let refineable_attr = attrs.iter().find(|attr| attr.path.is_ident("refineable"));
 20
 21    let mut impl_debug_on_refinement = false;
 22    let mut derive_serialize_on_refinement = false;
 23    let mut derive_deserialize_on_refinement = false;
 24
 25    if let Some(refineable_attr) = refineable_attr {
 26        if let Ok(syn::Meta::List(meta_list)) = refineable_attr.parse_meta() {
 27            for nested in meta_list.nested {
 28                let syn::NestedMeta::Meta(syn::Meta::Path(path)) = nested else {
 29                    continue;
 30                };
 31
 32                if path.is_ident("debug") {
 33                    impl_debug_on_refinement = true;
 34                }
 35
 36                if path.is_ident("serialize") {
 37                    derive_serialize_on_refinement = true;
 38                }
 39
 40                if path.is_ident("deserialize") {
 41                    derive_deserialize_on_refinement = true;
 42                }
 43            }
 44        }
 45    }
 46
 47    let refinement_ident = format_ident!("{}Refinement", ident);
 48    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
 49
 50    let fields = match data {
 51        syn::Data::Struct(syn::DataStruct {
 52            fields: syn::Fields::Named(FieldsNamed { named, .. }),
 53            ..
 54        }) => named.into_iter().collect::<Vec<Field>>(),
 55        _ => panic!("This derive macro only supports structs with named fields"),
 56    };
 57
 58    let field_names: Vec<_> = fields.iter().map(|f| f.ident.as_ref().unwrap()).collect();
 59    let field_visibilities: Vec<_> = fields.iter().map(|f| &f.vis).collect();
 60    let wrapped_types: Vec<_> = fields.iter().map(|f| get_wrapper_type(f, &f.ty)).collect();
 61
 62    // Create trait bound that each wrapped type must implement Clone // & Default
 63    let type_param_bounds: Vec<_> = wrapped_types
 64        .iter()
 65        .map(|ty| {
 66            WherePredicate::Type(PredicateType {
 67                lifetimes: None,
 68                bounded_ty: ty.clone(),
 69                colon_token: Default::default(),
 70                bounds: {
 71                    let mut punctuated = syn::punctuated::Punctuated::new();
 72                    punctuated.push_value(TypeParamBound::Trait(TraitBound {
 73                        paren_token: None,
 74                        modifier: syn::TraitBoundModifier::None,
 75                        lifetimes: None,
 76                        path: parse_quote!(Clone),
 77                    }));
 78
 79                    // punctuated.push_punct(syn::token::Add::default());
 80                    // punctuated.push_value(TypeParamBound::Trait(TraitBound {
 81                    //     paren_token: None,
 82                    //     modifier: syn::TraitBoundModifier::None,
 83                    //     lifetimes: None,
 84                    //     path: parse_quote!(Default),
 85                    // }));
 86                    punctuated
 87                },
 88            })
 89        })
 90        .collect();
 91
 92    // Append to where_clause or create a new one if it doesn't exist
 93    let where_clause = match where_clause.cloned() {
 94        Some(mut where_clause) => {
 95            where_clause
 96                .predicates
 97                .extend(type_param_bounds.into_iter());
 98            where_clause.clone()
 99        }
100        None => WhereClause {
101            where_token: Default::default(),
102            predicates: type_param_bounds.into_iter().collect(),
103        },
104    };
105
106    // refinable_refine_assignments
107    // refinable_refined_assignments
108    // refinement_refine_assignments
109
110    let refineable_refine_assignments: Vec<TokenStream2> = fields
111        .iter()
112        .map(|field| {
113            let name = &field.ident;
114            let is_refineable = is_refineable_field(field);
115            let is_optional = is_optional_field(field);
116
117            if is_refineable {
118                quote! {
119                    self.#name.refine(&refinement.#name);
120                }
121            } else if is_optional {
122                quote! {
123                    if let Some(ref value) = &refinement.#name {
124                        self.#name = Some(value.clone());
125                    }
126                }
127            } else {
128                quote! {
129                    if let Some(ref value) = &refinement.#name {
130                        self.#name = value.clone();
131                    }
132                }
133            }
134        })
135        .collect();
136
137    let refineable_refined_assignments: Vec<TokenStream2> = fields
138        .iter()
139        .map(|field| {
140            let name = &field.ident;
141            let is_refineable = is_refineable_field(field);
142            let is_optional = is_optional_field(field);
143
144            if is_refineable {
145                quote! {
146                    self.#name = self.#name.refined(refinement.#name);
147                }
148            } else if is_optional {
149                quote! {
150                    if let Some(value) = refinement.#name {
151                        self.#name = Some(value);
152                    }
153                }
154            } else {
155                quote! {
156                    if let Some(value) = refinement.#name {
157                        self.#name = value;
158                    }
159                }
160            }
161        })
162        .collect();
163
164    let refinement_refine_assigments: Vec<TokenStream2> = fields
165        .iter()
166        .map(|field| {
167            let name = &field.ident;
168            let is_refineable = is_refineable_field(field);
169
170            if is_refineable {
171                quote! {
172                    self.#name.refine(&refinement.#name);
173                }
174            } else {
175                quote! {
176                    if let Some(ref value) = &refinement.#name {
177                        self.#name = Some(value.clone());
178                    }
179                }
180            }
181        })
182        .collect();
183
184    let refinement_refined_assigments: Vec<TokenStream2> = fields
185        .iter()
186        .map(|field| {
187            let name = &field.ident;
188            let is_refineable = is_refineable_field(field);
189
190            if is_refineable {
191                quote! {
192                    self.#name = self.#name.refined(refinement.#name);
193                }
194            } else {
195                quote! {
196                    if let Some(value) = refinement.#name {
197                        self.#name = Some(value);
198                    }
199                }
200            }
201        })
202        .collect();
203
204    let from_refinement_assigments: Vec<TokenStream2> = fields
205        .iter()
206        .map(|field| {
207            let name = &field.ident;
208            let is_refineable = is_refineable_field(field);
209            let is_optional = is_optional_field(field);
210
211            if is_refineable {
212                quote! {
213                    #name: value.#name.into(),
214                }
215            } else if is_optional {
216                quote! {
217                    #name: value.#name.map(|v| v.into()),
218                }
219            } else {
220                quote! {
221                    #name: value.#name.map(|v| v.into()).unwrap_or_default(),
222                }
223            }
224        })
225        .collect();
226
227    let debug_impl = if impl_debug_on_refinement {
228        let refinement_field_debugs: Vec<TokenStream2> = fields
229            .iter()
230            .map(|field| {
231                let name = &field.ident;
232                quote! {
233                    if self.#name.is_some() {
234                        debug_struct.field(stringify!(#name), &self.#name);
235                    } else {
236                        all_some = false;
237                    }
238                }
239            })
240            .collect();
241
242        quote! {
243            impl #impl_generics std::fmt::Debug for #refinement_ident #ty_generics
244                #where_clause
245            {
246                fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
247                    let mut debug_struct = f.debug_struct(stringify!(#refinement_ident));
248                    let mut all_some = true;
249                    #( #refinement_field_debugs )*
250                    if all_some {
251                        debug_struct.finish()
252                    } else {
253                        debug_struct.finish_non_exhaustive()
254                    }
255                }
256            }
257        }
258    } else {
259        quote! {}
260    };
261
262    let derive_serialize = if derive_serialize_on_refinement {
263        quote! { #[derive(serde::Serialize)]}
264    } else {
265        quote! {}
266    };
267
268    let derive_deserialize = if derive_deserialize_on_refinement {
269        quote! { #[derive(serde::Deserialize)]}
270    } else {
271        quote! {}
272    };
273
274    let gen = quote! {
275        #[derive(Clone)]
276        #derive_serialize
277        #derive_deserialize
278        pub struct #refinement_ident #impl_generics {
279            #( #field_visibilities #field_names: #wrapped_types ),*
280        }
281
282        impl #impl_generics Refineable for #ident #ty_generics
283            #where_clause
284        {
285            type Refinement = #refinement_ident #ty_generics;
286
287            fn refine(&mut self, refinement: &Self::Refinement) {
288                #( #refineable_refine_assignments )*
289            }
290
291            fn refined(mut self, refinement: Self::Refinement) -> Self {
292                #( #refineable_refined_assignments )*
293                self
294            }
295        }
296
297        impl #impl_generics Refineable for #refinement_ident #ty_generics
298            #where_clause
299        {
300            type Refinement = #refinement_ident #ty_generics;
301
302            fn refine(&mut self, refinement: &Self::Refinement) {
303                #( #refinement_refine_assigments )*
304            }
305
306            fn refined(mut self, refinement: Self::Refinement) -> Self {
307                #( #refinement_refined_assigments )*
308                self
309            }
310        }
311
312        impl #impl_generics From<#refinement_ident #ty_generics> for #ident #ty_generics
313            #where_clause
314        {
315            fn from(value: #refinement_ident #ty_generics) -> Self {
316                Self {
317                    #( #from_refinement_assigments )*
318                }
319            }
320        }
321
322        impl #impl_generics ::core::default::Default for #refinement_ident #ty_generics
323            #where_clause
324        {
325            fn default() -> Self {
326                #refinement_ident {
327                    #( #field_names: Default::default() ),*
328                }
329            }
330        }
331
332        impl #impl_generics #refinement_ident #ty_generics
333            #where_clause
334        {
335            pub fn is_some(&self) -> bool {
336                #(
337                    if self.#field_names.is_some() {
338                        return true;
339                    }
340                )*
341                false
342            }
343        }
344
345        #debug_impl
346    };
347    gen.into()
348}
349
350fn is_refineable_field(f: &Field) -> bool {
351    f.attrs.iter().any(|attr| attr.path.is_ident("refineable"))
352}
353
354fn is_optional_field(f: &Field) -> bool {
355    if let Type::Path(typepath) = &f.ty {
356        if typepath.qself.is_none() {
357            let segments = &typepath.path.segments;
358            if segments.len() == 1 && segments.iter().any(|s| s.ident == "Option") {
359                return true;
360            }
361        }
362    }
363    false
364}
365
366fn get_wrapper_type(field: &Field, ty: &Type) -> syn::Type {
367    if is_refineable_field(field) {
368        let struct_name = if let Type::Path(tp) = ty {
369            tp.path.segments.last().unwrap().ident.clone()
370        } else {
371            panic!("Expected struct type for a refineable field");
372        };
373        let refinement_struct_name = format_ident!("{}Refinement", struct_name);
374        let generics = if let Type::Path(tp) = ty {
375            &tp.path.segments.last().unwrap().arguments
376        } else {
377            &syn::PathArguments::None
378        };
379        parse_quote!(#refinement_struct_name #generics)
380    } else if is_optional_field(field) {
381        ty.clone()
382    } else {
383        parse_quote!(Option<#ty>)
384    }
385}