1use proc_macro::TokenStream;
2use proc_macro2::TokenStream as TokenStream2;
3use quote::{format_ident, quote};
4use syn::{
5 parse_macro_input, parse_quote, DeriveInput, Field, FieldsNamed, PredicateType, TraitBound,
6 Type, TypeParamBound, WhereClause, WherePredicate,
7};
8
9#[proc_macro_derive(Refineable, attributes(refineable))]
10pub fn derive_refineable(input: TokenStream) -> TokenStream {
11 let DeriveInput {
12 ident,
13 data,
14 generics,
15 attrs,
16 ..
17 } = parse_macro_input!(input);
18
19 let impl_debug_on_refinement = attrs
20 .iter()
21 .any(|attr| attr.path.is_ident("refineable") && attr.tokens.to_string().contains("debug"));
22
23 let refinement_ident = format_ident!("{}Refinement", ident);
24 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
25
26 let fields = match data {
27 syn::Data::Struct(syn::DataStruct {
28 fields: syn::Fields::Named(FieldsNamed { named, .. }),
29 ..
30 }) => named.into_iter().collect::<Vec<Field>>(),
31 _ => panic!("This derive macro only supports structs with named fields"),
32 };
33
34 let field_names: Vec<_> = fields.iter().map(|f| f.ident.as_ref().unwrap()).collect();
35 let field_visibilities: Vec<_> = fields.iter().map(|f| &f.vis).collect();
36 let wrapped_types: Vec<_> = fields.iter().map(|f| get_wrapper_type(f, &f.ty)).collect();
37
38 // Create trait bound that each wrapped type must implement Clone // & Default
39 let type_param_bounds: Vec<_> = wrapped_types
40 .iter()
41 .map(|ty| {
42 WherePredicate::Type(PredicateType {
43 lifetimes: None,
44 bounded_ty: ty.clone(),
45 colon_token: Default::default(),
46 bounds: {
47 let mut punctuated = syn::punctuated::Punctuated::new();
48 punctuated.push_value(TypeParamBound::Trait(TraitBound {
49 paren_token: None,
50 modifier: syn::TraitBoundModifier::None,
51 lifetimes: None,
52 path: parse_quote!(Clone),
53 }));
54
55 // punctuated.push_punct(syn::token::Add::default());
56 // punctuated.push_value(TypeParamBound::Trait(TraitBound {
57 // paren_token: None,
58 // modifier: syn::TraitBoundModifier::None,
59 // lifetimes: None,
60 // path: parse_quote!(Default),
61 // }));
62 punctuated
63 },
64 })
65 })
66 .collect();
67
68 // Append to where_clause or create a new one if it doesn't exist
69 let where_clause = match where_clause.cloned() {
70 Some(mut where_clause) => {
71 where_clause
72 .predicates
73 .extend(type_param_bounds.into_iter());
74 where_clause.clone()
75 }
76 None => WhereClause {
77 where_token: Default::default(),
78 predicates: type_param_bounds.into_iter().collect(),
79 },
80 };
81
82 // refinable_refine_assignments
83 // refinable_refined_assignments
84 // refinement_refine_assignments
85
86 let refineable_refine_assignments: Vec<TokenStream2> = fields
87 .iter()
88 .map(|field| {
89 let name = &field.ident;
90 let is_refineable = is_refineable_field(field);
91 let is_optional = is_optional_field(field);
92
93 if is_refineable {
94 quote! {
95 self.#name.refine(&refinement.#name);
96 }
97 } else if is_optional {
98 quote! {
99 if let Some(ref value) = &refinement.#name {
100 self.#name = Some(value.clone());
101 }
102 }
103 } else {
104 quote! {
105 if let Some(ref value) = &refinement.#name {
106 self.#name = value.clone();
107 }
108 }
109 }
110 })
111 .collect();
112
113 let refineable_refined_assignments: Vec<TokenStream2> = fields
114 .iter()
115 .map(|field| {
116 let name = &field.ident;
117 let is_refineable = is_refineable_field(field);
118 let is_optional = is_optional_field(field);
119
120 if is_refineable {
121 quote! {
122 self.#name = self.#name.refined(refinement.#name);
123 }
124 } else if is_optional {
125 quote! {
126 if let Some(value) = refinement.#name {
127 self.#name = Some(value);
128 }
129 }
130 } else {
131 quote! {
132 if let Some(value) = refinement.#name {
133 self.#name = value;
134 }
135 }
136 }
137 })
138 .collect();
139
140 let refinement_refine_assigments: Vec<TokenStream2> = fields
141 .iter()
142 .map(|field| {
143 let name = &field.ident;
144 let is_refineable = is_refineable_field(field);
145
146 if is_refineable {
147 quote! {
148 self.#name.refine(&refinement.#name);
149 }
150 } else {
151 quote! {
152 if let Some(ref value) = &refinement.#name {
153 self.#name = Some(value.clone());
154 }
155 }
156 }
157 })
158 .collect();
159
160 let refinement_refined_assigments: Vec<TokenStream2> = fields
161 .iter()
162 .map(|field| {
163 let name = &field.ident;
164 let is_refineable = is_refineable_field(field);
165
166 if is_refineable {
167 quote! {
168 self.#name = self.#name.refined(refinement.#name);
169 }
170 } else {
171 quote! {
172 if let Some(value) = refinement.#name {
173 self.#name = Some(value);
174 }
175 }
176 }
177 })
178 .collect();
179
180 let from_refinement_assigments: Vec<TokenStream2> = fields
181 .iter()
182 .map(|field| {
183 let name = &field.ident;
184 let is_refineable = is_refineable_field(field);
185 let is_optional = is_optional_field(field);
186
187 if is_refineable {
188 quote! {
189 #name: value.#name.into(),
190 }
191 } else if is_optional {
192 quote! {
193 #name: value.#name.map(|v| v.into()),
194 }
195 } else {
196 quote! {
197 #name: value.#name.map(|v| v.into()).unwrap_or_default(),
198 }
199 }
200 })
201 .collect();
202
203 let debug_impl = if impl_debug_on_refinement {
204 let refinement_field_debugs: Vec<TokenStream2> = fields
205 .iter()
206 .map(|field| {
207 let name = &field.ident;
208 quote! {
209 if self.#name.is_some() {
210 debug_struct.field(stringify!(#name), &self.#name);
211 } else {
212 all_some = false;
213 }
214 }
215 })
216 .collect();
217
218 quote! {
219 impl #impl_generics std::fmt::Debug for #refinement_ident #ty_generics
220 #where_clause
221 {
222 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
223 let mut debug_struct = f.debug_struct(stringify!(#refinement_ident));
224 let mut all_some = true;
225 #( #refinement_field_debugs )*
226 if all_some {
227 debug_struct.finish()
228 } else {
229 debug_struct.finish_non_exhaustive()
230 }
231 }
232 }
233 }
234 } else {
235 quote! {}
236 };
237
238 let gen = quote! {
239 #[derive(Clone)]
240 pub struct #refinement_ident #impl_generics {
241 #( #field_visibilities #field_names: #wrapped_types ),*
242 }
243
244 impl #impl_generics Refineable for #ident #ty_generics
245 #where_clause
246 {
247 type Refinement = #refinement_ident #ty_generics;
248
249 fn refine(&mut self, refinement: &Self::Refinement) {
250 #( #refineable_refine_assignments )*
251 }
252
253 fn refined(mut self, refinement: Self::Refinement) -> Self {
254 #( #refineable_refined_assignments )*
255 self
256 }
257 }
258
259 impl #impl_generics Refineable for #refinement_ident #ty_generics
260 #where_clause
261 {
262 type Refinement = #refinement_ident #ty_generics;
263
264 fn refine(&mut self, refinement: &Self::Refinement) {
265 #( #refinement_refine_assigments )*
266 }
267
268 fn refined(mut self, refinement: Self::Refinement) -> Self {
269 #( #refinement_refined_assigments )*
270 self
271 }
272 }
273
274 impl #impl_generics From<#refinement_ident #ty_generics> for #ident #ty_generics
275 #where_clause
276 {
277 fn from(value: #refinement_ident #ty_generics) -> Self {
278 Self {
279 #( #from_refinement_assigments )*
280 }
281 }
282 }
283
284 impl #impl_generics ::core::default::Default for #refinement_ident #ty_generics
285 #where_clause
286 {
287 fn default() -> Self {
288 #refinement_ident {
289 #( #field_names: Default::default() ),*
290 }
291 }
292 }
293
294 impl #impl_generics #refinement_ident #ty_generics
295 #where_clause
296 {
297 pub fn is_some(&self) -> bool {
298 #(
299 if self.#field_names.is_some() {
300 return true;
301 }
302 )*
303 false
304 }
305 }
306
307 #debug_impl
308 };
309 gen.into()
310}
311
312fn is_refineable_field(f: &Field) -> bool {
313 f.attrs.iter().any(|attr| attr.path.is_ident("refineable"))
314}
315
316fn is_optional_field(f: &Field) -> bool {
317 if let Type::Path(typepath) = &f.ty {
318 if typepath.qself.is_none() {
319 let segments = &typepath.path.segments;
320 if segments.len() == 1 && segments.iter().any(|s| s.ident == "Option") {
321 return true;
322 }
323 }
324 }
325 false
326}
327
328fn get_wrapper_type(field: &Field, ty: &Type) -> syn::Type {
329 if is_refineable_field(field) {
330 let struct_name = if let Type::Path(tp) = ty {
331 tp.path.segments.last().unwrap().ident.clone()
332 } else {
333 panic!("Expected struct type for a refineable field");
334 };
335 let refinement_struct_name = format_ident!("{}Refinement", struct_name);
336 let generics = if let Type::Path(tp) = ty {
337 &tp.path.segments.last().unwrap().arguments
338 } else {
339 &syn::PathArguments::None
340 };
341 parse_quote!(#refinement_struct_name #generics)
342 } else if is_optional_field(field) {
343 ty.clone()
344 } else {
345 parse_quote!(Option<#ty>)
346 }
347}