derive_refineable.rs

  1use proc_macro::TokenStream;
  2use proc_macro2::TokenStream as TokenStream2;
  3use quote::{format_ident, quote};
  4use syn::{
  5    parse_macro_input, parse_quote, DeriveInput, Field, FieldsNamed, PredicateType, TraitBound,
  6    Type, TypeParamBound, WhereClause, WherePredicate,
  7};
  8
  9#[proc_macro_derive(Refineable, attributes(refineable))]
 10pub fn derive_refineable(input: TokenStream) -> TokenStream {
 11    let DeriveInput {
 12        ident,
 13        data,
 14        generics,
 15        attrs,
 16        ..
 17    } = parse_macro_input!(input);
 18
 19    let impl_debug_on_refinement = attrs
 20        .iter()
 21        .any(|attr| attr.path.is_ident("refineable") && attr.tokens.to_string().contains("debug"));
 22
 23    let refinement_ident = format_ident!("{}Refinement", ident);
 24    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
 25
 26    let fields = match data {
 27        syn::Data::Struct(syn::DataStruct {
 28            fields: syn::Fields::Named(FieldsNamed { named, .. }),
 29            ..
 30        }) => named.into_iter().collect::<Vec<Field>>(),
 31        _ => panic!("This derive macro only supports structs with named fields"),
 32    };
 33
 34    let field_names: Vec<_> = fields.iter().map(|f| f.ident.as_ref().unwrap()).collect();
 35    let field_visibilities: Vec<_> = fields.iter().map(|f| &f.vis).collect();
 36    let wrapped_types: Vec<_> = fields.iter().map(|f| get_wrapper_type(f, &f.ty)).collect();
 37
 38    // Create trait bound that each wrapped type must implement Clone & Default
 39    let type_param_bounds: Vec<_> = wrapped_types
 40        .iter()
 41        .map(|ty| {
 42            WherePredicate::Type(PredicateType {
 43                lifetimes: None,
 44                bounded_ty: ty.clone(),
 45                colon_token: Default::default(),
 46                bounds: {
 47                    let mut punctuated = syn::punctuated::Punctuated::new();
 48                    punctuated.push_value(TypeParamBound::Trait(TraitBound {
 49                        paren_token: None,
 50                        modifier: syn::TraitBoundModifier::None,
 51                        lifetimes: None,
 52                        path: parse_quote!(Clone),
 53                    }));
 54                    punctuated.push_punct(syn::token::Add::default());
 55                    punctuated.push_value(TypeParamBound::Trait(TraitBound {
 56                        paren_token: None,
 57                        modifier: syn::TraitBoundModifier::None,
 58                        lifetimes: None,
 59                        path: parse_quote!(Default),
 60                    }));
 61                    punctuated
 62                },
 63            })
 64        })
 65        .collect();
 66
 67    // Append to where_clause or create a new one if it doesn't exist
 68    let where_clause = match where_clause.cloned() {
 69        Some(mut where_clause) => {
 70            where_clause
 71                .predicates
 72                .extend(type_param_bounds.into_iter());
 73            where_clause.clone()
 74        }
 75        None => WhereClause {
 76            where_token: Default::default(),
 77            predicates: type_param_bounds.into_iter().collect(),
 78        },
 79    };
 80
 81    let field_assignments: Vec<TokenStream2> = fields
 82        .iter()
 83        .map(|field| {
 84            let name = &field.ident;
 85            let is_refineable = is_refineable_field(field);
 86            let is_optional = is_optional_field(field);
 87
 88            if is_refineable {
 89                quote! {
 90                    self.#name.refine(&refinement.#name);
 91                }
 92            } else if is_optional {
 93                quote! {
 94                    if let Some(ref value) = &refinement.#name {
 95                        self.#name = Some(value.clone());
 96                    }
 97                }
 98            } else {
 99                quote! {
100                    if let Some(ref value) = &refinement.#name {
101                        self.#name = value.clone();
102                    }
103                }
104            }
105        })
106        .collect();
107
108    let refinement_field_assignments: Vec<TokenStream2> = fields
109        .iter()
110        .map(|field| {
111            let name = &field.ident;
112            let is_refineable = is_refineable_field(field);
113
114            if is_refineable {
115                quote! {
116                    self.#name.refine(&refinement.#name);
117                }
118            } else {
119                quote! {
120                    if let Some(ref value) = &refinement.#name {
121                        self.#name = Some(value.clone());
122                    }
123                }
124            }
125        })
126        .collect();
127
128    let debug_impl = if impl_debug_on_refinement {
129        let refinement_field_debugs: Vec<TokenStream2> = fields
130            .iter()
131            .map(|field| {
132                let name = &field.ident;
133                quote! {
134                    if self.#name.is_some() {
135                        debug_struct.field(stringify!(#name), &self.#name);
136                    } else {
137                        all_some = false;
138                    }
139                }
140            })
141            .collect();
142
143        quote! {
144            impl #impl_generics std::fmt::Debug for #refinement_ident #ty_generics
145                #where_clause
146            {
147                fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
148                    let mut debug_struct = f.debug_struct(stringify!(#refinement_ident));
149                    let mut all_some = true;
150                    #( #refinement_field_debugs )*
151                    if all_some {
152                        debug_struct.finish()
153                    } else {
154                        debug_struct.finish_non_exhaustive()
155                    }
156                }
157            }
158        }
159    } else {
160        quote! {}
161    };
162
163    let gen = quote! {
164        #[derive(Default, Clone)]
165        pub struct #refinement_ident #impl_generics {
166            #( #field_visibilities #field_names: #wrapped_types ),*
167        }
168
169        impl #impl_generics Refineable for #ident #ty_generics
170            #where_clause
171        {
172            type Refinement = #refinement_ident #ty_generics;
173
174            fn refine(&mut self, refinement: &Self::Refinement) {
175                #( #field_assignments )*
176            }
177        }
178
179        impl #impl_generics Refineable for #refinement_ident #ty_generics
180            #where_clause
181        {
182            type Refinement = #refinement_ident #ty_generics;
183
184            fn refine(&mut self, refinement: &Self::Refinement) {
185                #( #refinement_field_assignments )*
186            }
187        }
188
189        impl #impl_generics #refinement_ident #ty_generics
190            #where_clause
191        {
192            pub fn is_some(&self) -> bool {
193                #(
194                    if self.#field_names.is_some() {
195                        return true;
196                    }
197                )*
198                false
199            }
200        }
201
202        #debug_impl
203    };
204    gen.into()
205}
206
207fn is_refineable_field(f: &Field) -> bool {
208    f.attrs.iter().any(|attr| attr.path.is_ident("refineable"))
209}
210
211fn is_optional_field(f: &Field) -> bool {
212    if let Type::Path(typepath) = &f.ty {
213        if typepath.qself.is_none() {
214            let segments = &typepath.path.segments;
215            if segments.len() == 1 && segments.iter().any(|s| s.ident == "Option") {
216                return true;
217            }
218        }
219    }
220    false
221}
222
223fn get_wrapper_type(field: &Field, ty: &Type) -> syn::Type {
224    if is_refineable_field(field) {
225        let struct_name = if let Type::Path(tp) = ty {
226            tp.path.segments.last().unwrap().ident.clone()
227        } else {
228            panic!("Expected struct type for a refineable field");
229        };
230        let refinement_struct_name = format_ident!("{}Refinement", struct_name);
231        let generics = if let Type::Path(tp) = ty {
232            &tp.path.segments.last().unwrap().arguments
233        } else {
234            &syn::PathArguments::None
235        };
236        parse_quote!(#refinement_struct_name #generics)
237    } else if is_optional_field(field) {
238        ty.clone()
239    } else {
240        parse_quote!(Option<#ty>)
241    }
242}