1use proc_macro::TokenStream;
2use proc_macro2::TokenStream as TokenStream2;
3use quote::{format_ident, quote};
4use syn::{
5 parse_macro_input, parse_quote, DeriveInput, Field, FieldsNamed, PredicateType, TraitBound,
6 Type, TypeParamBound, WhereClause, WherePredicate,
7};
8
9#[proc_macro_derive(Refineable, attributes(refineable))]
10pub fn derive_refineable(input: TokenStream) -> TokenStream {
11 let DeriveInput {
12 ident,
13 data,
14 generics,
15 attrs,
16 ..
17 } = parse_macro_input!(input);
18
19 let refineable_attr = attrs.iter().find(|attr| attr.path.is_ident("refineable"));
20
21 let mut impl_debug_on_refinement = false;
22 let mut refinement_traits_to_derive = vec![];
23
24 if let Some(refineable_attr) = refineable_attr {
25 if let Ok(syn::Meta::List(meta_list)) = refineable_attr.parse_meta() {
26 for nested in meta_list.nested {
27 let syn::NestedMeta::Meta(syn::Meta::Path(path)) = nested else {
28 continue;
29 };
30
31 if path.is_ident("Debug") {
32 impl_debug_on_refinement = true;
33 } else {
34 refinement_traits_to_derive.push(path);
35 }
36 }
37 }
38 }
39
40 let refinement_ident = format_ident!("{}Refinement", ident);
41 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
42
43 let fields = match data {
44 syn::Data::Struct(syn::DataStruct {
45 fields: syn::Fields::Named(FieldsNamed { named, .. }),
46 ..
47 }) => named.into_iter().collect::<Vec<Field>>(),
48 _ => panic!("This derive macro only supports structs with named fields"),
49 };
50
51 let field_names: Vec<_> = fields.iter().map(|f| f.ident.as_ref().unwrap()).collect();
52 let field_visibilities: Vec<_> = fields.iter().map(|f| &f.vis).collect();
53 let wrapped_types: Vec<_> = fields.iter().map(|f| get_wrapper_type(f, &f.ty)).collect();
54
55 // Create trait bound that each wrapped type must implement Clone // & Default
56 let type_param_bounds: Vec<_> = wrapped_types
57 .iter()
58 .map(|ty| {
59 WherePredicate::Type(PredicateType {
60 lifetimes: None,
61 bounded_ty: ty.clone(),
62 colon_token: Default::default(),
63 bounds: {
64 let mut punctuated = syn::punctuated::Punctuated::new();
65 punctuated.push_value(TypeParamBound::Trait(TraitBound {
66 paren_token: None,
67 modifier: syn::TraitBoundModifier::None,
68 lifetimes: None,
69 path: parse_quote!(Clone),
70 }));
71
72 // punctuated.push_punct(syn::token::Add::default());
73 // punctuated.push_value(TypeParamBound::Trait(TraitBound {
74 // paren_token: None,
75 // modifier: syn::TraitBoundModifier::None,
76 // lifetimes: None,
77 // path: parse_quote!(Default),
78 // }));
79 punctuated
80 },
81 })
82 })
83 .collect();
84
85 // Append to where_clause or create a new one if it doesn't exist
86 let where_clause = match where_clause.cloned() {
87 Some(mut where_clause) => {
88 where_clause
89 .predicates
90 .extend(type_param_bounds.into_iter());
91 where_clause.clone()
92 }
93 None => WhereClause {
94 where_token: Default::default(),
95 predicates: type_param_bounds.into_iter().collect(),
96 },
97 };
98
99 // refinable_refine_assignments
100 // refinable_refined_assignments
101 // refinement_refine_assignments
102
103 let refineable_refine_assignments: Vec<TokenStream2> = fields
104 .iter()
105 .map(|field| {
106 let name = &field.ident;
107 let is_refineable = is_refineable_field(field);
108 let is_optional = is_optional_field(field);
109
110 if is_refineable {
111 quote! {
112 self.#name.refine(&refinement.#name);
113 }
114 } else if is_optional {
115 quote! {
116 if let Some(ref value) = &refinement.#name {
117 self.#name = Some(value.clone());
118 }
119 }
120 } else {
121 quote! {
122 if let Some(ref value) = &refinement.#name {
123 self.#name = value.clone();
124 }
125 }
126 }
127 })
128 .collect();
129
130 let refineable_refined_assignments: Vec<TokenStream2> = fields
131 .iter()
132 .map(|field| {
133 let name = &field.ident;
134 let is_refineable = is_refineable_field(field);
135 let is_optional = is_optional_field(field);
136
137 if is_refineable {
138 quote! {
139 self.#name = self.#name.refined(refinement.#name);
140 }
141 } else if is_optional {
142 quote! {
143 if let Some(value) = refinement.#name {
144 self.#name = Some(value);
145 }
146 }
147 } else {
148 quote! {
149 if let Some(value) = refinement.#name {
150 self.#name = value;
151 }
152 }
153 }
154 })
155 .collect();
156
157 let refinement_refine_assigments: Vec<TokenStream2> = fields
158 .iter()
159 .map(|field| {
160 let name = &field.ident;
161 let is_refineable = is_refineable_field(field);
162
163 if is_refineable {
164 quote! {
165 self.#name.refine(&refinement.#name);
166 }
167 } else {
168 quote! {
169 if let Some(ref value) = &refinement.#name {
170 self.#name = Some(value.clone());
171 }
172 }
173 }
174 })
175 .collect();
176
177 let refinement_refined_assigments: Vec<TokenStream2> = fields
178 .iter()
179 .map(|field| {
180 let name = &field.ident;
181 let is_refineable = is_refineable_field(field);
182
183 if is_refineable {
184 quote! {
185 self.#name = self.#name.refined(refinement.#name);
186 }
187 } else {
188 quote! {
189 if let Some(value) = refinement.#name {
190 self.#name = Some(value);
191 }
192 }
193 }
194 })
195 .collect();
196
197 let from_refinement_assigments: Vec<TokenStream2> = fields
198 .iter()
199 .map(|field| {
200 let name = &field.ident;
201 let is_refineable = is_refineable_field(field);
202 let is_optional = is_optional_field(field);
203
204 if is_refineable {
205 quote! {
206 #name: value.#name.into(),
207 }
208 } else if is_optional {
209 quote! {
210 #name: value.#name.map(|v| v.into()),
211 }
212 } else {
213 quote! {
214 #name: value.#name.map(|v| v.into()).unwrap_or_default(),
215 }
216 }
217 })
218 .collect();
219
220 let debug_impl = if impl_debug_on_refinement {
221 let refinement_field_debugs: Vec<TokenStream2> = fields
222 .iter()
223 .map(|field| {
224 let name = &field.ident;
225 quote! {
226 if self.#name.is_some() {
227 debug_struct.field(stringify!(#name), &self.#name);
228 } else {
229 all_some = false;
230 }
231 }
232 })
233 .collect();
234
235 quote! {
236 impl #impl_generics std::fmt::Debug for #refinement_ident #ty_generics
237 #where_clause
238 {
239 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
240 let mut debug_struct = f.debug_struct(stringify!(#refinement_ident));
241 let mut all_some = true;
242 #( #refinement_field_debugs )*
243 if all_some {
244 debug_struct.finish()
245 } else {
246 debug_struct.finish_non_exhaustive()
247 }
248 }
249 }
250 }
251 } else {
252 quote! {}
253 };
254
255 let mut derive_stream = quote! {};
256 for trait_to_derive in refinement_traits_to_derive {
257 derive_stream.extend(quote! { #[derive(#trait_to_derive)] })
258 }
259
260 let gen = quote! {
261 #[derive(Clone)]
262 #derive_stream
263 pub struct #refinement_ident #impl_generics {
264 #( #field_visibilities #field_names: #wrapped_types ),*
265 }
266
267 impl #impl_generics Refineable for #ident #ty_generics
268 #where_clause
269 {
270 type Refinement = #refinement_ident #ty_generics;
271
272 fn refine(&mut self, refinement: &Self::Refinement) {
273 #( #refineable_refine_assignments )*
274 }
275
276 fn refined(mut self, refinement: Self::Refinement) -> Self {
277 #( #refineable_refined_assignments )*
278 self
279 }
280 }
281
282 impl #impl_generics Refineable for #refinement_ident #ty_generics
283 #where_clause
284 {
285 type Refinement = #refinement_ident #ty_generics;
286
287 fn refine(&mut self, refinement: &Self::Refinement) {
288 #( #refinement_refine_assigments )*
289 }
290
291 fn refined(mut self, refinement: Self::Refinement) -> Self {
292 #( #refinement_refined_assigments )*
293 self
294 }
295 }
296
297 impl #impl_generics From<#refinement_ident #ty_generics> for #ident #ty_generics
298 #where_clause
299 {
300 fn from(value: #refinement_ident #ty_generics) -> Self {
301 Self {
302 #( #from_refinement_assigments )*
303 }
304 }
305 }
306
307 impl #impl_generics ::core::default::Default for #refinement_ident #ty_generics
308 #where_clause
309 {
310 fn default() -> Self {
311 #refinement_ident {
312 #( #field_names: Default::default() ),*
313 }
314 }
315 }
316
317 impl #impl_generics #refinement_ident #ty_generics
318 #where_clause
319 {
320 pub fn is_some(&self) -> bool {
321 #(
322 if self.#field_names.is_some() {
323 return true;
324 }
325 )*
326 false
327 }
328 }
329
330 #debug_impl
331 };
332 gen.into()
333}
334
335fn is_refineable_field(f: &Field) -> bool {
336 f.attrs.iter().any(|attr| attr.path.is_ident("refineable"))
337}
338
339fn is_optional_field(f: &Field) -> bool {
340 if let Type::Path(typepath) = &f.ty {
341 if typepath.qself.is_none() {
342 let segments = &typepath.path.segments;
343 if segments.len() == 1 && segments.iter().any(|s| s.ident == "Option") {
344 return true;
345 }
346 }
347 }
348 false
349}
350
351fn get_wrapper_type(field: &Field, ty: &Type) -> syn::Type {
352 if is_refineable_field(field) {
353 let struct_name = if let Type::Path(tp) = ty {
354 tp.path.segments.last().unwrap().ident.clone()
355 } else {
356 panic!("Expected struct type for a refineable field");
357 };
358 let refinement_struct_name = format_ident!("{}Refinement", struct_name);
359 let generics = if let Type::Path(tp) = ty {
360 &tp.path.segments.last().unwrap().arguments
361 } else {
362 &syn::PathArguments::None
363 };
364 parse_quote!(#refinement_struct_name #generics)
365 } else if is_optional_field(field) {
366 ty.clone()
367 } else {
368 parse_quote!(Option<#ty>)
369 }
370}