derive_refineable.rs

  1use proc_macro::TokenStream;
  2use proc_macro2::TokenStream as TokenStream2;
  3use quote::{format_ident, quote};
  4use syn::{
  5    DeriveInput, Field, FieldsNamed, PredicateType, TraitBound, Type, TypeParamBound, WhereClause,
  6    WherePredicate, parse_macro_input, parse_quote,
  7};
  8
  9#[proc_macro_derive(Refineable, attributes(refineable))]
 10pub fn derive_refineable(input: TokenStream) -> TokenStream {
 11    let DeriveInput {
 12        ident,
 13        data,
 14        generics,
 15        attrs,
 16        ..
 17    } = parse_macro_input!(input);
 18
 19    let refineable_attr = attrs.iter().find(|attr| attr.path.is_ident("refineable"));
 20
 21    let mut impl_debug_on_refinement = false;
 22    let mut refinement_traits_to_derive = vec![];
 23
 24    if let Some(refineable_attr) = refineable_attr {
 25        if let Ok(syn::Meta::List(meta_list)) = refineable_attr.parse_meta() {
 26            for nested in meta_list.nested {
 27                let syn::NestedMeta::Meta(syn::Meta::Path(path)) = nested else {
 28                    continue;
 29                };
 30
 31                if path.is_ident("Debug") {
 32                    impl_debug_on_refinement = true;
 33                } else {
 34                    refinement_traits_to_derive.push(path);
 35                }
 36            }
 37        }
 38    }
 39
 40    let refinement_ident = format_ident!("{}Refinement", ident);
 41    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
 42
 43    let fields = match data {
 44        syn::Data::Struct(syn::DataStruct {
 45            fields: syn::Fields::Named(FieldsNamed { named, .. }),
 46            ..
 47        }) => named.into_iter().collect::<Vec<Field>>(),
 48        _ => panic!("This derive macro only supports structs with named fields"),
 49    };
 50
 51    let field_names: Vec<_> = fields.iter().map(|f| f.ident.as_ref().unwrap()).collect();
 52    let field_visibilities: Vec<_> = fields.iter().map(|f| &f.vis).collect();
 53    let wrapped_types: Vec<_> = fields.iter().map(|f| get_wrapper_type(f, &f.ty)).collect();
 54
 55    // Create trait bound that each wrapped type must implement Clone // & Default
 56    let type_param_bounds: Vec<_> = wrapped_types
 57        .iter()
 58        .map(|ty| {
 59            WherePredicate::Type(PredicateType {
 60                lifetimes: None,
 61                bounded_ty: ty.clone(),
 62                colon_token: Default::default(),
 63                bounds: {
 64                    let mut punctuated = syn::punctuated::Punctuated::new();
 65                    punctuated.push_value(TypeParamBound::Trait(TraitBound {
 66                        paren_token: None,
 67                        modifier: syn::TraitBoundModifier::None,
 68                        lifetimes: None,
 69                        path: parse_quote!(Clone),
 70                    }));
 71
 72                    punctuated
 73                },
 74            })
 75        })
 76        .collect();
 77
 78    // Append to where_clause or create a new one if it doesn't exist
 79    let where_clause = match where_clause.cloned() {
 80        Some(mut where_clause) => {
 81            where_clause.predicates.extend(type_param_bounds);
 82            where_clause.clone()
 83        }
 84        None => WhereClause {
 85            where_token: Default::default(),
 86            predicates: type_param_bounds.into_iter().collect(),
 87        },
 88    };
 89
 90    let refineable_refine_assignments: Vec<TokenStream2> = fields
 91        .iter()
 92        .map(|field| {
 93            let name = &field.ident;
 94            let is_refineable = is_refineable_field(field);
 95            let is_optional = is_optional_field(field);
 96
 97            if is_refineable {
 98                quote! {
 99                    self.#name.refine(&refinement.#name);
100                }
101            } else if is_optional {
102                quote! {
103                    if let Some(value) = &refinement.#name {
104                        self.#name = Some(value.clone());
105                    }
106                }
107            } else {
108                quote! {
109                    if let Some(value) = &refinement.#name {
110                        self.#name = value.clone();
111                    }
112                }
113            }
114        })
115        .collect();
116
117    let refineable_refined_assignments: Vec<TokenStream2> = fields
118        .iter()
119        .map(|field| {
120            let name = &field.ident;
121            let is_refineable = is_refineable_field(field);
122            let is_optional = is_optional_field(field);
123
124            if is_refineable {
125                quote! {
126                    self.#name = self.#name.refined(refinement.#name);
127                }
128            } else if is_optional {
129                quote! {
130                    if let Some(value) = refinement.#name {
131                        self.#name = Some(value);
132                    }
133                }
134            } else {
135                quote! {
136                    if let Some(value) = refinement.#name {
137                        self.#name = value;
138                    }
139                }
140            }
141        })
142        .collect();
143
144    let refinement_refine_assignments: Vec<TokenStream2> = fields
145        .iter()
146        .map(|field| {
147            let name = &field.ident;
148            let is_refineable = is_refineable_field(field);
149
150            if is_refineable {
151                quote! {
152                    self.#name.refine(&refinement.#name);
153                }
154            } else {
155                quote! {
156                    if let Some(value) = &refinement.#name {
157                        self.#name = Some(value.clone());
158                    }
159                }
160            }
161        })
162        .collect();
163
164    let refinement_refined_assignments: Vec<TokenStream2> = fields
165        .iter()
166        .map(|field| {
167            let name = &field.ident;
168            let is_refineable = is_refineable_field(field);
169
170            if is_refineable {
171                quote! {
172                    self.#name = self.#name.refined(refinement.#name);
173                }
174            } else {
175                quote! {
176                    if let Some(value) = refinement.#name {
177                        self.#name = Some(value);
178                    }
179                }
180            }
181        })
182        .collect();
183
184    let from_refinement_assignments: Vec<TokenStream2> = fields
185        .iter()
186        .map(|field| {
187            let name = &field.ident;
188            let is_refineable = is_refineable_field(field);
189            let is_optional = is_optional_field(field);
190
191            if is_refineable {
192                quote! {
193                    #name: value.#name.into(),
194                }
195            } else if is_optional {
196                quote! {
197                    #name: value.#name.map(|v| v.into()),
198                }
199            } else {
200                quote! {
201                    #name: value.#name.map(|v| v.into()).unwrap_or_default(),
202                }
203            }
204        })
205        .collect();
206
207    let debug_impl = if impl_debug_on_refinement {
208        let refinement_field_debugs: Vec<TokenStream2> = fields
209            .iter()
210            .map(|field| {
211                let name = &field.ident;
212                quote! {
213                    if self.#name.is_some() {
214                        debug_struct.field(stringify!(#name), &self.#name);
215                    } else {
216                        all_some = false;
217                    }
218                }
219            })
220            .collect();
221
222        quote! {
223            impl #impl_generics std::fmt::Debug for #refinement_ident #ty_generics
224                #where_clause
225            {
226                fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
227                    let mut debug_struct = f.debug_struct(stringify!(#refinement_ident));
228                    let mut all_some = true;
229                    #( #refinement_field_debugs )*
230                    if all_some {
231                        debug_struct.finish()
232                    } else {
233                        debug_struct.finish_non_exhaustive()
234                    }
235                }
236            }
237        }
238    } else {
239        quote! {}
240    };
241
242    let mut derive_stream = quote! {};
243    for trait_to_derive in refinement_traits_to_derive {
244        derive_stream.extend(quote! { #[derive(#trait_to_derive)] })
245    }
246
247    let r#gen = quote! {
248        /// A refinable version of [`#ident`], see that documentation for details.
249        #[derive(Clone)]
250        #derive_stream
251        pub struct #refinement_ident #impl_generics {
252            #(
253                #[allow(missing_docs)]
254                #field_visibilities #field_names: #wrapped_types
255            ),*
256        }
257
258        impl #impl_generics Refineable for #ident #ty_generics
259            #where_clause
260        {
261            type Refinement = #refinement_ident #ty_generics;
262
263            fn refine(&mut self, refinement: &Self::Refinement) {
264                #( #refineable_refine_assignments )*
265            }
266
267            fn refined(mut self, refinement: Self::Refinement) -> Self {
268                #( #refineable_refined_assignments )*
269                self
270            }
271        }
272
273        impl #impl_generics Refineable for #refinement_ident #ty_generics
274            #where_clause
275        {
276            type Refinement = #refinement_ident #ty_generics;
277
278            fn refine(&mut self, refinement: &Self::Refinement) {
279                #( #refinement_refine_assignments )*
280            }
281
282            fn refined(mut self, refinement: Self::Refinement) -> Self {
283                #( #refinement_refined_assignments )*
284                self
285            }
286        }
287
288        impl #impl_generics From<#refinement_ident #ty_generics> for #ident #ty_generics
289            #where_clause
290        {
291            fn from(value: #refinement_ident #ty_generics) -> Self {
292                Self {
293                    #( #from_refinement_assignments )*
294                }
295            }
296        }
297
298        impl #impl_generics ::core::default::Default for #refinement_ident #ty_generics
299            #where_clause
300        {
301            fn default() -> Self {
302                #refinement_ident {
303                    #( #field_names: Default::default() ),*
304                }
305            }
306        }
307
308        impl #impl_generics #refinement_ident #ty_generics
309            #where_clause
310        {
311            /// Returns `true` if all fields are `Some`
312            pub fn is_some(&self) -> bool {
313                #(
314                    if self.#field_names.is_some() {
315                        return true;
316                    }
317                )*
318                false
319            }
320        }
321
322        #debug_impl
323    };
324    r#gen.into()
325}
326
327fn is_refineable_field(f: &Field) -> bool {
328    f.attrs.iter().any(|attr| attr.path.is_ident("refineable"))
329}
330
331fn is_optional_field(f: &Field) -> bool {
332    if let Type::Path(typepath) = &f.ty {
333        if typepath.qself.is_none() {
334            let segments = &typepath.path.segments;
335            if segments.len() == 1 && segments.iter().any(|s| s.ident == "Option") {
336                return true;
337            }
338        }
339    }
340    false
341}
342
343fn get_wrapper_type(field: &Field, ty: &Type) -> syn::Type {
344    if is_refineable_field(field) {
345        let struct_name = if let Type::Path(tp) = ty {
346            tp.path.segments.last().unwrap().ident.clone()
347        } else {
348            panic!("Expected struct type for a refineable field");
349        };
350        let refinement_struct_name = format_ident!("{}Refinement", struct_name);
351        let generics = if let Type::Path(tp) = ty {
352            &tp.path.segments.last().unwrap().arguments
353        } else {
354            &syn::PathArguments::None
355        };
356        parse_quote!(#refinement_struct_name #generics)
357    } else if is_optional_field(field) {
358        ty.clone()
359    } else {
360        parse_quote!(Option<#ty>)
361    }
362}