derive_refineable.rs

  1use proc_macro::TokenStream;
  2use proc_macro2::TokenStream as TokenStream2;
  3use quote::{format_ident, quote};
  4use syn::{
  5    DeriveInput, Field, FieldsNamed, PredicateType, TraitBound, Type, TypeParamBound, WhereClause,
  6    WherePredicate, parse_macro_input, parse_quote,
  7};
  8
  9#[proc_macro_derive(Refineable, attributes(refineable))]
 10pub fn derive_refineable(input: TokenStream) -> TokenStream {
 11    let DeriveInput {
 12        ident,
 13        data,
 14        generics,
 15        attrs,
 16        ..
 17    } = parse_macro_input!(input);
 18
 19    let refineable_attr = attrs.iter().find(|attr| attr.path().is_ident("refineable"));
 20
 21    let mut impl_debug_on_refinement = false;
 22    let mut refinement_traits_to_derive = vec![];
 23
 24    if let Some(refineable_attr) = refineable_attr {
 25        let _ = refineable_attr.parse_nested_meta(|meta| {
 26            if meta.path.is_ident("Debug") {
 27                impl_debug_on_refinement = true;
 28            } else {
 29                refinement_traits_to_derive.push(meta.path);
 30            }
 31            Ok(())
 32        });
 33    }
 34
 35    let refinement_ident = format_ident!("{}Refinement", ident);
 36    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
 37
 38    let fields = match data {
 39        syn::Data::Struct(syn::DataStruct {
 40            fields: syn::Fields::Named(FieldsNamed { named, .. }),
 41            ..
 42        }) => named.into_iter().collect::<Vec<Field>>(),
 43        _ => panic!("This derive macro only supports structs with named fields"),
 44    };
 45
 46    let field_names: Vec<_> = fields.iter().map(|f| f.ident.as_ref().unwrap()).collect();
 47    let field_visibilities: Vec<_> = fields.iter().map(|f| &f.vis).collect();
 48    let wrapped_types: Vec<_> = fields.iter().map(|f| get_wrapper_type(f, &f.ty)).collect();
 49
 50    // Create trait bound that each wrapped type must implement Clone // & Default
 51    let type_param_bounds: Vec<_> = wrapped_types
 52        .iter()
 53        .map(|ty| {
 54            WherePredicate::Type(PredicateType {
 55                lifetimes: None,
 56                bounded_ty: ty.clone(),
 57                colon_token: Default::default(),
 58                bounds: {
 59                    let mut punctuated = syn::punctuated::Punctuated::new();
 60                    punctuated.push_value(TypeParamBound::Trait(TraitBound {
 61                        paren_token: None,
 62                        modifier: syn::TraitBoundModifier::None,
 63                        lifetimes: None,
 64                        path: parse_quote!(Clone),
 65                    }));
 66
 67                    punctuated
 68                },
 69            })
 70        })
 71        .collect();
 72
 73    // Append to where_clause or create a new one if it doesn't exist
 74    let where_clause = match where_clause.cloned() {
 75        Some(mut where_clause) => {
 76            where_clause.predicates.extend(type_param_bounds);
 77            where_clause.clone()
 78        }
 79        None => WhereClause {
 80            where_token: Default::default(),
 81            predicates: type_param_bounds.into_iter().collect(),
 82        },
 83    };
 84
 85    let refineable_refine_assignments: Vec<TokenStream2> = fields
 86        .iter()
 87        .map(|field| {
 88            let name = &field.ident;
 89            let is_refineable = is_refineable_field(field);
 90            let is_optional = is_optional_field(field);
 91
 92            if is_refineable {
 93                quote! {
 94                    self.#name.refine(&refinement.#name);
 95                }
 96            } else if is_optional {
 97                quote! {
 98                    if let Some(value) = &refinement.#name {
 99                        self.#name = Some(value.clone());
100                    }
101                }
102            } else {
103                quote! {
104                    if let Some(value) = &refinement.#name {
105                        self.#name = value.clone();
106                    }
107                }
108            }
109        })
110        .collect();
111
112    let refineable_refined_assignments: Vec<TokenStream2> = fields
113        .iter()
114        .map(|field| {
115            let name = &field.ident;
116            let is_refineable = is_refineable_field(field);
117            let is_optional = is_optional_field(field);
118
119            if is_refineable {
120                quote! {
121                    self.#name = self.#name.refined(refinement.#name);
122                }
123            } else if is_optional {
124                quote! {
125                    if let Some(value) = refinement.#name {
126                        self.#name = Some(value);
127                    }
128                }
129            } else {
130                quote! {
131                    if let Some(value) = refinement.#name {
132                        self.#name = value;
133                    }
134                }
135            }
136        })
137        .collect();
138
139    let refinement_refine_assignments: Vec<TokenStream2> = fields
140        .iter()
141        .map(|field| {
142            let name = &field.ident;
143            let is_refineable = is_refineable_field(field);
144
145            if is_refineable {
146                quote! {
147                    self.#name.refine(&refinement.#name);
148                }
149            } else {
150                quote! {
151                    if let Some(value) = &refinement.#name {
152                        self.#name = Some(value.clone());
153                    }
154                }
155            }
156        })
157        .collect();
158
159    let refinement_refined_assignments: Vec<TokenStream2> = fields
160        .iter()
161        .map(|field| {
162            let name = &field.ident;
163            let is_refineable = is_refineable_field(field);
164
165            if is_refineable {
166                quote! {
167                    self.#name = self.#name.refined(refinement.#name);
168                }
169            } else {
170                quote! {
171                    if let Some(value) = refinement.#name {
172                        self.#name = Some(value);
173                    }
174                }
175            }
176        })
177        .collect();
178
179    let from_refinement_assignments: Vec<TokenStream2> = fields
180        .iter()
181        .map(|field| {
182            let name = &field.ident;
183            let is_refineable = is_refineable_field(field);
184            let is_optional = is_optional_field(field);
185
186            if is_refineable {
187                quote! {
188                    #name: value.#name.into(),
189                }
190            } else if is_optional {
191                quote! {
192                    #name: value.#name.map(|v| v.into()),
193                }
194            } else {
195                quote! {
196                    #name: value.#name.map(|v| v.into()).unwrap_or_default(),
197                }
198            }
199        })
200        .collect();
201
202    let debug_impl = if impl_debug_on_refinement {
203        let refinement_field_debugs: Vec<TokenStream2> = fields
204            .iter()
205            .map(|field| {
206                let name = &field.ident;
207                quote! {
208                    if self.#name.is_some() {
209                        debug_struct.field(stringify!(#name), &self.#name);
210                    } else {
211                        all_some = false;
212                    }
213                }
214            })
215            .collect();
216
217        quote! {
218            impl #impl_generics std::fmt::Debug for #refinement_ident #ty_generics
219                #where_clause
220            {
221                fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
222                    let mut debug_struct = f.debug_struct(stringify!(#refinement_ident));
223                    let mut all_some = true;
224                    #( #refinement_field_debugs )*
225                    if all_some {
226                        debug_struct.finish()
227                    } else {
228                        debug_struct.finish_non_exhaustive()
229                    }
230                }
231            }
232        }
233    } else {
234        quote! {}
235    };
236
237    let mut derive_stream = quote! {};
238    for trait_to_derive in refinement_traits_to_derive {
239        derive_stream.extend(quote! { #[derive(#trait_to_derive)] })
240    }
241
242    let r#gen = quote! {
243        /// A refinable version of [`#ident`], see that documentation for details.
244        #[derive(Clone)]
245        #derive_stream
246        pub struct #refinement_ident #impl_generics {
247            #(
248                #[allow(missing_docs)]
249                #field_visibilities #field_names: #wrapped_types
250            ),*
251        }
252
253        impl #impl_generics Refineable for #ident #ty_generics
254            #where_clause
255        {
256            type Refinement = #refinement_ident #ty_generics;
257
258            fn refine(&mut self, refinement: &Self::Refinement) {
259                #( #refineable_refine_assignments )*
260            }
261
262            fn refined(mut self, refinement: Self::Refinement) -> Self {
263                #( #refineable_refined_assignments )*
264                self
265            }
266        }
267
268        impl #impl_generics Refineable for #refinement_ident #ty_generics
269            #where_clause
270        {
271            type Refinement = #refinement_ident #ty_generics;
272
273            fn refine(&mut self, refinement: &Self::Refinement) {
274                #( #refinement_refine_assignments )*
275            }
276
277            fn refined(mut self, refinement: Self::Refinement) -> Self {
278                #( #refinement_refined_assignments )*
279                self
280            }
281        }
282
283        impl #impl_generics From<#refinement_ident #ty_generics> for #ident #ty_generics
284            #where_clause
285        {
286            fn from(value: #refinement_ident #ty_generics) -> Self {
287                Self {
288                    #( #from_refinement_assignments )*
289                }
290            }
291        }
292
293        impl #impl_generics ::core::default::Default for #refinement_ident #ty_generics
294            #where_clause
295        {
296            fn default() -> Self {
297                #refinement_ident {
298                    #( #field_names: Default::default() ),*
299                }
300            }
301        }
302
303        impl #impl_generics #refinement_ident #ty_generics
304            #where_clause
305        {
306            /// Returns `true` if all fields are `Some`
307            pub fn is_some(&self) -> bool {
308                #(
309                    if self.#field_names.is_some() {
310                        return true;
311                    }
312                )*
313                false
314            }
315        }
316
317        #debug_impl
318    };
319    r#gen.into()
320}
321
322fn is_refineable_field(f: &Field) -> bool {
323    f.attrs
324        .iter()
325        .any(|attr| attr.path().is_ident("refineable"))
326}
327
328fn is_optional_field(f: &Field) -> bool {
329    if let Type::Path(typepath) = &f.ty {
330        if typepath.qself.is_none() {
331            let segments = &typepath.path.segments;
332            if segments.len() == 1 && segments.iter().any(|s| s.ident == "Option") {
333                return true;
334            }
335        }
336    }
337    false
338}
339
340fn get_wrapper_type(field: &Field, ty: &Type) -> syn::Type {
341    if is_refineable_field(field) {
342        let struct_name = if let Type::Path(tp) = ty {
343            tp.path.segments.last().unwrap().ident.clone()
344        } else {
345            panic!("Expected struct type for a refineable field");
346        };
347        let refinement_struct_name = format_ident!("{}Refinement", struct_name);
348        let generics = if let Type::Path(tp) = ty {
349            &tp.path.segments.last().unwrap().arguments
350        } else {
351            &syn::PathArguments::None
352        };
353        parse_quote!(#refinement_struct_name #generics)
354    } else if is_optional_field(field) {
355        ty.clone()
356    } else {
357        parse_quote!(Option<#ty>)
358    }
359}