derive_refineable.rs

  1use proc_macro::TokenStream;
  2use proc_macro2::TokenStream as TokenStream2;
  3use quote::{format_ident, quote};
  4use syn::{
  5    DeriveInput, Field, FieldsNamed, PredicateType, TraitBound, Type, TypeParamBound, WhereClause,
  6    WherePredicate, parse_macro_input, parse_quote,
  7};
  8
  9#[proc_macro_derive(Refineable, attributes(refineable))]
 10pub fn derive_refineable(input: TokenStream) -> TokenStream {
 11    let DeriveInput {
 12        ident,
 13        data,
 14        generics,
 15        attrs,
 16        ..
 17    } = parse_macro_input!(input);
 18
 19    let refineable_attr = attrs.iter().find(|attr| attr.path().is_ident("refineable"));
 20
 21    let mut impl_debug_on_refinement = false;
 22    let mut derives_serialize = false;
 23    let mut refinement_traits_to_derive = vec![];
 24
 25    if let Some(refineable_attr) = refineable_attr {
 26        let _ = refineable_attr.parse_nested_meta(|meta| {
 27            if meta.path.is_ident("Debug") {
 28                impl_debug_on_refinement = true;
 29            } else {
 30                if meta.path.is_ident("Serialize") {
 31                    derives_serialize = true;
 32                }
 33                refinement_traits_to_derive.push(meta.path);
 34            }
 35            Ok(())
 36        });
 37    }
 38
 39    let refinement_ident = format_ident!("{}Refinement", ident);
 40    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
 41
 42    let fields = match data {
 43        syn::Data::Struct(syn::DataStruct {
 44            fields: syn::Fields::Named(FieldsNamed { named, .. }),
 45            ..
 46        }) => named.into_iter().collect::<Vec<Field>>(),
 47        _ => panic!("This derive macro only supports structs with named fields"),
 48    };
 49
 50    let field_names: Vec<_> = fields.iter().map(|f| f.ident.as_ref().unwrap()).collect();
 51    let field_visibilities: Vec<_> = fields.iter().map(|f| &f.vis).collect();
 52    let wrapped_types: Vec<_> = fields.iter().map(|f| get_wrapper_type(f, &f.ty)).collect();
 53
 54    let field_attributes: Vec<TokenStream2> = fields
 55        .iter()
 56        .map(|f| {
 57            if derives_serialize {
 58                if is_refineable_field(f) {
 59                    quote! { #[serde(default, skip_serializing_if = "::refineable::IsEmpty::is_empty")] }
 60                } else {
 61                    quote! { #[serde(skip_serializing_if = "::std::option::Option::is_none")] }
 62                }
 63            } else {
 64                quote! {}
 65            }
 66        })
 67        .collect();
 68
 69    // Create trait bound that each wrapped type must implement Clone
 70    let type_param_bounds: Vec<_> = wrapped_types
 71        .iter()
 72        .map(|ty| {
 73            WherePredicate::Type(PredicateType {
 74                lifetimes: None,
 75                bounded_ty: ty.clone(),
 76                colon_token: Default::default(),
 77                bounds: {
 78                    let mut punctuated = syn::punctuated::Punctuated::new();
 79                    punctuated.push_value(TypeParamBound::Trait(TraitBound {
 80                        paren_token: None,
 81                        modifier: syn::TraitBoundModifier::None,
 82                        lifetimes: None,
 83                        path: parse_quote!(Clone),
 84                    }));
 85
 86                    punctuated
 87                },
 88            })
 89        })
 90        .collect();
 91
 92    // Append to where_clause or create a new one if it doesn't exist
 93    let where_clause = match where_clause.cloned() {
 94        Some(mut where_clause) => {
 95            where_clause.predicates.extend(type_param_bounds);
 96            where_clause.clone()
 97        }
 98        None => WhereClause {
 99            where_token: Default::default(),
100            predicates: type_param_bounds.into_iter().collect(),
101        },
102    };
103
104    let refineable_refine_assignments: Vec<TokenStream2> = fields
105        .iter()
106        .map(|field| {
107            let name = &field.ident;
108            let is_refineable = is_refineable_field(field);
109            let is_optional = is_optional_field(field);
110
111            if is_refineable {
112                quote! {
113                    self.#name.refine(&refinement.#name);
114                }
115            } else if is_optional {
116                quote! {
117                    if let Some(value) = &refinement.#name {
118                        self.#name = Some(value.clone());
119                    }
120                }
121            } else {
122                quote! {
123                    if let Some(value) = &refinement.#name {
124                        self.#name = value.clone();
125                    }
126                }
127            }
128        })
129        .collect();
130
131    let refineable_refined_assignments: Vec<TokenStream2> = fields
132        .iter()
133        .map(|field| {
134            let name = &field.ident;
135            let is_refineable = is_refineable_field(field);
136            let is_optional = is_optional_field(field);
137
138            if is_refineable {
139                quote! {
140                    self.#name = self.#name.refined(refinement.#name);
141                }
142            } else if is_optional {
143                quote! {
144                    if let Some(value) = refinement.#name {
145                        self.#name = Some(value);
146                    }
147                }
148            } else {
149                quote! {
150                    if let Some(value) = refinement.#name {
151                        self.#name = value;
152                    }
153                }
154            }
155        })
156        .collect();
157
158    let refinement_refine_assignments: Vec<TokenStream2> = fields
159        .iter()
160        .map(|field| {
161            let name = &field.ident;
162            let is_refineable = is_refineable_field(field);
163
164            if is_refineable {
165                quote! {
166                    self.#name.refine(&refinement.#name);
167                }
168            } else {
169                quote! {
170                    if let Some(value) = &refinement.#name {
171                        self.#name = Some(value.clone());
172                    }
173                }
174            }
175        })
176        .collect();
177
178    let refinement_refined_assignments: Vec<TokenStream2> = fields
179        .iter()
180        .map(|field| {
181            let name = &field.ident;
182            let is_refineable = is_refineable_field(field);
183
184            if is_refineable {
185                quote! {
186                    self.#name = self.#name.refined(refinement.#name);
187                }
188            } else {
189                quote! {
190                    if let Some(value) = refinement.#name {
191                        self.#name = Some(value);
192                    }
193                }
194            }
195        })
196        .collect();
197
198    let from_refinement_assignments: Vec<TokenStream2> = fields
199        .iter()
200        .map(|field| {
201            let name = &field.ident;
202            let is_refineable = is_refineable_field(field);
203            let is_optional = is_optional_field(field);
204
205            if is_refineable {
206                quote! {
207                    #name: value.#name.into(),
208                }
209            } else if is_optional {
210                quote! {
211                    #name: value.#name.map(|v| v.into()),
212                }
213            } else {
214                quote! {
215                    #name: value.#name.map(|v| v.into()).unwrap_or_default(),
216                }
217            }
218        })
219        .collect();
220
221    let debug_impl = if impl_debug_on_refinement {
222        let refinement_field_debugs: Vec<TokenStream2> = fields
223            .iter()
224            .map(|field| {
225                let name = &field.ident;
226                quote! {
227                    if self.#name.is_some() {
228                        debug_struct.field(stringify!(#name), &self.#name);
229                    } else {
230                        all_some = false;
231                    }
232                }
233            })
234            .collect();
235
236        quote! {
237            impl #impl_generics std::fmt::Debug for #refinement_ident #ty_generics
238                #where_clause
239            {
240                fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
241                    let mut debug_struct = f.debug_struct(stringify!(#refinement_ident));
242                    let mut all_some = true;
243                    #( #refinement_field_debugs )*
244                    if all_some {
245                        debug_struct.finish()
246                    } else {
247                        debug_struct.finish_non_exhaustive()
248                    }
249                }
250            }
251        }
252    } else {
253        quote! {}
254    };
255
256    let refinement_is_empty_conditions: Vec<TokenStream2> = fields
257        .iter()
258        .enumerate()
259        .map(|(i, field)| {
260            let name = &field.ident;
261
262            let condition = if is_refineable_field(field) {
263                quote! { self.#name.is_empty() }
264            } else {
265                quote! { self.#name.is_none() }
266            };
267
268            if i < fields.len() - 1 {
269                quote! { #condition && }
270            } else {
271                condition
272            }
273        })
274        .collect();
275
276    let refineable_is_superset_conditions: Vec<TokenStream2> = fields
277        .iter()
278        .map(|field| {
279            let name = &field.ident;
280            let is_refineable = is_refineable_field(field);
281            let is_optional = is_optional_field(field);
282
283            if is_refineable {
284                quote! {
285                    if !self.#name.is_superset_of(&refinement.#name) {
286                        return false;
287                    }
288                }
289            } else if is_optional {
290                quote! {
291                    if refinement.#name.is_some() && &self.#name != &refinement.#name {
292                        return false;
293                    }
294                }
295            } else {
296                quote! {
297                    if let Some(refinement_value) = &refinement.#name {
298                        if &self.#name != refinement_value {
299                            return false;
300                        }
301                    }
302                }
303            }
304        })
305        .collect();
306
307    let refinement_is_superset_conditions: Vec<TokenStream2> = fields
308        .iter()
309        .map(|field| {
310            let name = &field.ident;
311            let is_refineable = is_refineable_field(field);
312
313            if is_refineable {
314                quote! {
315                    if !self.#name.is_superset_of(&refinement.#name) {
316                        return false;
317                    }
318                }
319            } else {
320                quote! {
321                    if refinement.#name.is_some() && &self.#name != &refinement.#name {
322                        return false;
323                    }
324                }
325            }
326        })
327        .collect();
328
329    let refineable_subtract_assignments: Vec<TokenStream2> = fields
330        .iter()
331        .map(|field| {
332            let name = &field.ident;
333            let is_refineable = is_refineable_field(field);
334            let is_optional = is_optional_field(field);
335
336            if is_refineable {
337                quote! {
338                    #name: self.#name.subtract(&refinement.#name),
339                }
340            } else if is_optional {
341                quote! {
342                    #name: if &self.#name == &refinement.#name {
343                        None
344                    } else {
345                        self.#name.clone()
346                    },
347                }
348            } else {
349                quote! {
350                    #name: if let Some(refinement_value) = &refinement.#name {
351                        if &self.#name == refinement_value {
352                            None
353                        } else {
354                            Some(self.#name.clone())
355                        }
356                    } else {
357                        Some(self.#name.clone())
358                    },
359                }
360            }
361        })
362        .collect();
363
364    let refinement_subtract_assignments: Vec<TokenStream2> = fields
365        .iter()
366        .map(|field| {
367            let name = &field.ident;
368            let is_refineable = is_refineable_field(field);
369
370            if is_refineable {
371                quote! {
372                    #name: self.#name.subtract(&refinement.#name),
373                }
374            } else {
375                quote! {
376                    #name: if &self.#name == &refinement.#name {
377                        None
378                    } else {
379                        self.#name.clone()
380                    },
381                }
382            }
383        })
384        .collect();
385
386    let mut derive_stream = quote! {};
387    for trait_to_derive in refinement_traits_to_derive {
388        derive_stream.extend(quote! { #[derive(#trait_to_derive)] })
389    }
390
391    let r#gen = quote! {
392        /// A refinable version of [`#ident`], see that documentation for details.
393        #[derive(Clone)]
394        #derive_stream
395        pub struct #refinement_ident #impl_generics {
396            #(
397                #[allow(missing_docs)]
398                #field_attributes
399                #field_visibilities #field_names: #wrapped_types
400            ),*
401        }
402
403        impl #impl_generics Refineable for #ident #ty_generics
404            #where_clause
405        {
406            type Refinement = #refinement_ident #ty_generics;
407
408            fn refine(&mut self, refinement: &Self::Refinement) {
409                #( #refineable_refine_assignments )*
410            }
411
412            fn refined(mut self, refinement: Self::Refinement) -> Self {
413                #( #refineable_refined_assignments )*
414                self
415            }
416
417            fn is_superset_of(&self, refinement: &Self::Refinement) -> bool
418            {
419                #( #refineable_is_superset_conditions )*
420                true
421            }
422
423            fn subtract(&self, refinement: &Self::Refinement) -> Self::Refinement
424            {
425                #refinement_ident {
426                    #( #refineable_subtract_assignments )*
427                }
428            }
429        }
430
431        impl #impl_generics Refineable for #refinement_ident #ty_generics
432            #where_clause
433        {
434            type Refinement = #refinement_ident #ty_generics;
435
436            fn refine(&mut self, refinement: &Self::Refinement) {
437                #( #refinement_refine_assignments )*
438            }
439
440            fn refined(mut self, refinement: Self::Refinement) -> Self {
441                #( #refinement_refined_assignments )*
442                self
443            }
444
445            fn is_superset_of(&self, refinement: &Self::Refinement) -> bool
446            {
447                #( #refinement_is_superset_conditions )*
448                true
449            }
450
451            fn subtract(&self, refinement: &Self::Refinement) -> Self::Refinement
452            {
453                #refinement_ident {
454                    #( #refinement_subtract_assignments )*
455                }
456            }
457        }
458
459        impl #impl_generics ::refineable::IsEmpty for #refinement_ident #ty_generics
460            #where_clause
461        {
462            fn is_empty(&self) -> bool {
463                #( #refinement_is_empty_conditions )*
464            }
465        }
466
467        impl #impl_generics From<#refinement_ident #ty_generics> for #ident #ty_generics
468            #where_clause
469        {
470            fn from(value: #refinement_ident #ty_generics) -> Self {
471                Self {
472                    #( #from_refinement_assignments )*
473                }
474            }
475        }
476
477        impl #impl_generics ::core::default::Default for #refinement_ident #ty_generics
478            #where_clause
479        {
480            fn default() -> Self {
481                #refinement_ident {
482                    #( #field_names: Default::default() ),*
483                }
484            }
485        }
486
487        impl #impl_generics #refinement_ident #ty_generics
488            #where_clause
489        {
490            /// Returns `true` if all fields are `Some`
491            pub fn is_some(&self) -> bool {
492                #(
493                    if self.#field_names.is_some() {
494                        return true;
495                    }
496                )*
497                false
498            }
499        }
500
501        #debug_impl
502    };
503    r#gen.into()
504}
505
506fn is_refineable_field(f: &Field) -> bool {
507    f.attrs
508        .iter()
509        .any(|attr| attr.path().is_ident("refineable"))
510}
511
512fn is_optional_field(f: &Field) -> bool {
513    if let Type::Path(typepath) = &f.ty
514        && typepath.qself.is_none()
515    {
516        let segments = &typepath.path.segments;
517        if segments.len() == 1 && segments.iter().any(|s| s.ident == "Option") {
518            return true;
519        }
520    }
521    false
522}
523
524fn get_wrapper_type(field: &Field, ty: &Type) -> syn::Type {
525    if is_refineable_field(field) {
526        let struct_name = if let Type::Path(tp) = ty {
527            tp.path.segments.last().unwrap().ident.clone()
528        } else {
529            panic!("Expected struct type for a refineable field");
530        };
531        let refinement_struct_name = format_ident!("{}Refinement", struct_name);
532        let generics = if let Type::Path(tp) = ty {
533            &tp.path.segments.last().unwrap().arguments
534        } else {
535            &syn::PathArguments::None
536        };
537        parse_quote!(#refinement_struct_name #generics)
538    } else if is_optional_field(field) {
539        ty.clone()
540    } else {
541        parse_quote!(Option<#ty>)
542    }
543}