1use proc_macro::TokenStream;
2use proc_macro2::TokenStream as TokenStream2;
3use quote::{format_ident, quote};
4use syn::{
5 parse_macro_input, parse_quote, DeriveInput, Field, FieldsNamed, PredicateType, TraitBound,
6 Type, TypeParamBound, WhereClause, WherePredicate,
7};
8
9#[proc_macro_derive(Refineable, attributes(refineable))]
10pub fn derive_refineable(input: TokenStream) -> TokenStream {
11 let DeriveInput {
12 ident,
13 data,
14 generics,
15 attrs,
16 ..
17 } = parse_macro_input!(input);
18
19 let impl_debug_on_refinement = attrs
20 .iter()
21 .any(|attr| attr.path.is_ident("refineable") && attr.tokens.to_string().contains("debug"));
22
23 let refinement_ident = format_ident!("{}Refinement", ident);
24 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
25
26 let fields = match data {
27 syn::Data::Struct(syn::DataStruct {
28 fields: syn::Fields::Named(FieldsNamed { named, .. }),
29 ..
30 }) => named.into_iter().collect::<Vec<Field>>(),
31 _ => panic!("This derive macro only supports structs with named fields"),
32 };
33
34 let field_names: Vec<_> = fields.iter().map(|f| f.ident.as_ref().unwrap()).collect();
35 let field_visibilities: Vec<_> = fields.iter().map(|f| &f.vis).collect();
36 let wrapped_types: Vec<_> = fields.iter().map(|f| get_wrapper_type(f, &f.ty)).collect();
37
38 // Create trait bound that each wrapped type must implement Clone & Default
39 let type_param_bounds: Vec<_> = wrapped_types
40 .iter()
41 .map(|ty| {
42 WherePredicate::Type(PredicateType {
43 lifetimes: None,
44 bounded_ty: ty.clone(),
45 colon_token: Default::default(),
46 bounds: {
47 let mut punctuated = syn::punctuated::Punctuated::new();
48 punctuated.push_value(TypeParamBound::Trait(TraitBound {
49 paren_token: None,
50 modifier: syn::TraitBoundModifier::None,
51 lifetimes: None,
52 path: parse_quote!(Clone),
53 }));
54 punctuated.push_punct(syn::token::Add::default());
55 punctuated.push_value(TypeParamBound::Trait(TraitBound {
56 paren_token: None,
57 modifier: syn::TraitBoundModifier::None,
58 lifetimes: None,
59 path: parse_quote!(Default),
60 }));
61 punctuated
62 },
63 })
64 })
65 .collect();
66
67 // Append to where_clause or create a new one if it doesn't exist
68 let where_clause = match where_clause.cloned() {
69 Some(mut where_clause) => {
70 where_clause
71 .predicates
72 .extend(type_param_bounds.into_iter());
73 where_clause.clone()
74 }
75 None => WhereClause {
76 where_token: Default::default(),
77 predicates: type_param_bounds.into_iter().collect(),
78 },
79 };
80
81 let field_assignments: Vec<TokenStream2> = fields
82 .iter()
83 .map(|field| {
84 let name = &field.ident;
85 let is_refineable = is_refineable_field(field);
86 let is_optional = is_optional_field(field);
87
88 if is_refineable {
89 quote! {
90 self.#name.refine(&refinement.#name);
91 }
92 } else if is_optional {
93 quote! {
94 if let Some(ref value) = &refinement.#name {
95 self.#name = Some(value.clone());
96 }
97 }
98 } else {
99 quote! {
100 if let Some(ref value) = &refinement.#name {
101 self.#name = value.clone();
102 }
103 }
104 }
105 })
106 .collect();
107
108 let refinement_field_assignments: Vec<TokenStream2> = fields
109 .iter()
110 .map(|field| {
111 let name = &field.ident;
112 let is_refineable = is_refineable_field(field);
113
114 if is_refineable {
115 quote! {
116 self.#name.refine(&refinement.#name);
117 }
118 } else {
119 quote! {
120 if let Some(ref value) = &refinement.#name {
121 self.#name = Some(value.clone());
122 }
123 }
124 }
125 })
126 .collect();
127
128 let debug_impl = if impl_debug_on_refinement {
129 let refinement_field_debugs: Vec<TokenStream2> = fields
130 .iter()
131 .map(|field| {
132 let name = &field.ident;
133 quote! {
134 if self.#name.is_some() {
135 debug_struct.field(stringify!(#name), &self.#name);
136 } else {
137 all_some = false;
138 }
139 }
140 })
141 .collect();
142
143 quote! {
144 impl #impl_generics std::fmt::Debug for #refinement_ident #ty_generics
145 #where_clause
146 {
147 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
148 let mut debug_struct = f.debug_struct(stringify!(#refinement_ident));
149 let mut all_some = true;
150 #( #refinement_field_debugs )*
151 if all_some {
152 debug_struct.finish()
153 } else {
154 debug_struct.finish_non_exhaustive()
155 }
156 }
157 }
158 }
159 } else {
160 quote! {}
161 };
162
163 let gen = quote! {
164 #[derive(Default, Clone)]
165 pub struct #refinement_ident #impl_generics {
166 #( #field_visibilities #field_names: #wrapped_types ),*
167 }
168
169 impl #impl_generics Refineable for #ident #ty_generics
170 #where_clause
171 {
172 type Refinement = #refinement_ident #ty_generics;
173
174 fn refine(&mut self, refinement: &Self::Refinement) {
175 #( #field_assignments )*
176 }
177 }
178
179 impl #impl_generics Refineable for #refinement_ident #ty_generics
180 #where_clause
181 {
182 type Refinement = #refinement_ident #ty_generics;
183
184 fn refine(&mut self, refinement: &Self::Refinement) {
185 #( #refinement_field_assignments )*
186 }
187 }
188
189 impl #impl_generics #refinement_ident #ty_generics
190 #where_clause
191 {
192 pub fn is_some(&self) -> bool {
193 #(
194 if self.#field_names.is_some() {
195 return true;
196 }
197 )*
198 false
199 }
200 }
201
202 #debug_impl
203 };
204 gen.into()
205}
206
207fn is_refineable_field(f: &Field) -> bool {
208 f.attrs.iter().any(|attr| attr.path.is_ident("refineable"))
209}
210
211fn is_optional_field(f: &Field) -> bool {
212 if let Type::Path(typepath) = &f.ty {
213 if typepath.qself.is_none() {
214 let segments = &typepath.path.segments;
215 if segments.len() == 1 && segments.iter().any(|s| s.ident == "Option") {
216 return true;
217 }
218 }
219 }
220 false
221}
222
223fn get_wrapper_type(field: &Field, ty: &Type) -> syn::Type {
224 if is_refineable_field(field) {
225 let struct_name = if let Type::Path(tp) = ty {
226 tp.path.segments.last().unwrap().ident.clone()
227 } else {
228 panic!("Expected struct type for a refineable field");
229 };
230 let refinement_struct_name = format_ident!("{}Refinement", struct_name);
231 let generics = if let Type::Path(tp) = ty {
232 &tp.path.segments.last().unwrap().arguments
233 } else {
234 &syn::PathArguments::None
235 };
236 parse_quote!(#refinement_struct_name #generics)
237 } else if is_optional_field(field) {
238 ty.clone()
239 } else {
240 parse_quote!(Option<#ty>)
241 }
242}