1use proc_macro::TokenStream;
2use proc_macro2::TokenStream as TokenStream2;
3use quote::{format_ident, quote};
4use syn::{
5 parse_macro_input, parse_quote, DeriveInput, Field, FieldsNamed, PredicateType, TraitBound,
6 Type, TypeParamBound, WhereClause, WherePredicate,
7};
8
9#[proc_macro_derive(Refineable, attributes(refineable))]
10pub fn derive_refineable(input: TokenStream) -> TokenStream {
11 let DeriveInput {
12 ident,
13 data,
14 generics,
15 attrs,
16 ..
17 } = parse_macro_input!(input);
18
19 let refineable_attr = attrs.iter().find(|attr| attr.path.is_ident("refineable"));
20
21 let mut impl_debug_on_refinement = false;
22 let mut refinement_traits_to_derive = vec![];
23
24 if let Some(refineable_attr) = refineable_attr {
25 if let Ok(syn::Meta::List(meta_list)) = refineable_attr.parse_meta() {
26 for nested in meta_list.nested {
27 let syn::NestedMeta::Meta(syn::Meta::Path(path)) = nested else {
28 continue;
29 };
30
31 if path.is_ident("Debug") {
32 impl_debug_on_refinement = true;
33 } else {
34 refinement_traits_to_derive.push(path);
35 }
36 }
37 }
38 }
39
40 let refinement_ident = format_ident!("{}Refinement", ident);
41 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
42
43 let fields = match data {
44 syn::Data::Struct(syn::DataStruct {
45 fields: syn::Fields::Named(FieldsNamed { named, .. }),
46 ..
47 }) => named.into_iter().collect::<Vec<Field>>(),
48 _ => panic!("This derive macro only supports structs with named fields"),
49 };
50
51 let field_names: Vec<_> = fields.iter().map(|f| f.ident.as_ref().unwrap()).collect();
52 let field_visibilities: Vec<_> = fields.iter().map(|f| &f.vis).collect();
53 let wrapped_types: Vec<_> = fields.iter().map(|f| get_wrapper_type(f, &f.ty)).collect();
54
55 // Create trait bound that each wrapped type must implement Clone // & Default
56 let type_param_bounds: Vec<_> = wrapped_types
57 .iter()
58 .map(|ty| {
59 WherePredicate::Type(PredicateType {
60 lifetimes: None,
61 bounded_ty: ty.clone(),
62 colon_token: Default::default(),
63 bounds: {
64 let mut punctuated = syn::punctuated::Punctuated::new();
65 punctuated.push_value(TypeParamBound::Trait(TraitBound {
66 paren_token: None,
67 modifier: syn::TraitBoundModifier::None,
68 lifetimes: None,
69 path: parse_quote!(Clone),
70 }));
71
72 punctuated
73 },
74 })
75 })
76 .collect();
77
78 // Append to where_clause or create a new one if it doesn't exist
79 let where_clause = match where_clause.cloned() {
80 Some(mut where_clause) => {
81 where_clause.predicates.extend(type_param_bounds);
82 where_clause.clone()
83 }
84 None => WhereClause {
85 where_token: Default::default(),
86 predicates: type_param_bounds.into_iter().collect(),
87 },
88 };
89
90 let refineable_refine_assignments: Vec<TokenStream2> = fields
91 .iter()
92 .map(|field| {
93 let name = &field.ident;
94 let is_refineable = is_refineable_field(field);
95 let is_optional = is_optional_field(field);
96
97 if is_refineable {
98 quote! {
99 self.#name.refine(&refinement.#name);
100 }
101 } else if is_optional {
102 quote! {
103 if let Some(ref value) = &refinement.#name {
104 self.#name = Some(value.clone());
105 }
106 }
107 } else {
108 quote! {
109 if let Some(ref value) = &refinement.#name {
110 self.#name = value.clone();
111 }
112 }
113 }
114 })
115 .collect();
116
117 let refineable_refined_assignments: Vec<TokenStream2> = fields
118 .iter()
119 .map(|field| {
120 let name = &field.ident;
121 let is_refineable = is_refineable_field(field);
122 let is_optional = is_optional_field(field);
123
124 if is_refineable {
125 quote! {
126 self.#name = self.#name.refined(refinement.#name);
127 }
128 } else if is_optional {
129 quote! {
130 if let Some(value) = refinement.#name {
131 self.#name = Some(value);
132 }
133 }
134 } else {
135 quote! {
136 if let Some(value) = refinement.#name {
137 self.#name = value;
138 }
139 }
140 }
141 })
142 .collect();
143
144 let refinement_refine_assigments: Vec<TokenStream2> = fields
145 .iter()
146 .map(|field| {
147 let name = &field.ident;
148 let is_refineable = is_refineable_field(field);
149
150 if is_refineable {
151 quote! {
152 self.#name.refine(&refinement.#name);
153 }
154 } else {
155 quote! {
156 if let Some(ref value) = &refinement.#name {
157 self.#name = Some(value.clone());
158 }
159 }
160 }
161 })
162 .collect();
163
164 let refinement_refined_assigments: Vec<TokenStream2> = fields
165 .iter()
166 .map(|field| {
167 let name = &field.ident;
168 let is_refineable = is_refineable_field(field);
169
170 if is_refineable {
171 quote! {
172 self.#name = self.#name.refined(refinement.#name);
173 }
174 } else {
175 quote! {
176 if let Some(value) = refinement.#name {
177 self.#name = Some(value);
178 }
179 }
180 }
181 })
182 .collect();
183
184 let from_refinement_assigments: Vec<TokenStream2> = fields
185 .iter()
186 .map(|field| {
187 let name = &field.ident;
188 let is_refineable = is_refineable_field(field);
189 let is_optional = is_optional_field(field);
190
191 if is_refineable {
192 quote! {
193 #name: value.#name.into(),
194 }
195 } else if is_optional {
196 quote! {
197 #name: value.#name.map(|v| v.into()),
198 }
199 } else {
200 quote! {
201 #name: value.#name.map(|v| v.into()).unwrap_or_default(),
202 }
203 }
204 })
205 .collect();
206
207 let debug_impl = if impl_debug_on_refinement {
208 let refinement_field_debugs: Vec<TokenStream2> = fields
209 .iter()
210 .map(|field| {
211 let name = &field.ident;
212 quote! {
213 if self.#name.is_some() {
214 debug_struct.field(stringify!(#name), &self.#name);
215 } else {
216 all_some = false;
217 }
218 }
219 })
220 .collect();
221
222 quote! {
223 impl #impl_generics std::fmt::Debug for #refinement_ident #ty_generics
224 #where_clause
225 {
226 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
227 let mut debug_struct = f.debug_struct(stringify!(#refinement_ident));
228 let mut all_some = true;
229 #( #refinement_field_debugs )*
230 if all_some {
231 debug_struct.finish()
232 } else {
233 debug_struct.finish_non_exhaustive()
234 }
235 }
236 }
237 }
238 } else {
239 quote! {}
240 };
241
242 let mut derive_stream = quote! {};
243 for trait_to_derive in refinement_traits_to_derive {
244 derive_stream.extend(quote! { #[derive(#trait_to_derive)] })
245 }
246
247 let gen = quote! {
248 #[derive(Clone)]
249 #derive_stream
250 pub struct #refinement_ident #impl_generics {
251 #( #field_visibilities #field_names: #wrapped_types ),*
252 }
253
254 impl #impl_generics Refineable for #ident #ty_generics
255 #where_clause
256 {
257 type Refinement = #refinement_ident #ty_generics;
258
259 fn refine(&mut self, refinement: &Self::Refinement) {
260 #( #refineable_refine_assignments )*
261 }
262
263 fn refined(mut self, refinement: Self::Refinement) -> Self {
264 #( #refineable_refined_assignments )*
265 self
266 }
267 }
268
269 impl #impl_generics Refineable for #refinement_ident #ty_generics
270 #where_clause
271 {
272 type Refinement = #refinement_ident #ty_generics;
273
274 fn refine(&mut self, refinement: &Self::Refinement) {
275 #( #refinement_refine_assigments )*
276 }
277
278 fn refined(mut self, refinement: Self::Refinement) -> Self {
279 #( #refinement_refined_assigments )*
280 self
281 }
282 }
283
284 impl #impl_generics From<#refinement_ident #ty_generics> for #ident #ty_generics
285 #where_clause
286 {
287 fn from(value: #refinement_ident #ty_generics) -> Self {
288 Self {
289 #( #from_refinement_assigments )*
290 }
291 }
292 }
293
294 impl #impl_generics ::core::default::Default for #refinement_ident #ty_generics
295 #where_clause
296 {
297 fn default() -> Self {
298 #refinement_ident {
299 #( #field_names: Default::default() ),*
300 }
301 }
302 }
303
304 impl #impl_generics #refinement_ident #ty_generics
305 #where_clause
306 {
307 pub fn is_some(&self) -> bool {
308 #(
309 if self.#field_names.is_some() {
310 return true;
311 }
312 )*
313 false
314 }
315 }
316
317 #debug_impl
318 };
319 gen.into()
320}
321
322fn is_refineable_field(f: &Field) -> bool {
323 f.attrs.iter().any(|attr| attr.path.is_ident("refineable"))
324}
325
326fn is_optional_field(f: &Field) -> bool {
327 if let Type::Path(typepath) = &f.ty {
328 if typepath.qself.is_none() {
329 let segments = &typepath.path.segments;
330 if segments.len() == 1 && segments.iter().any(|s| s.ident == "Option") {
331 return true;
332 }
333 }
334 }
335 false
336}
337
338fn get_wrapper_type(field: &Field, ty: &Type) -> syn::Type {
339 if is_refineable_field(field) {
340 let struct_name = if let Type::Path(tp) = ty {
341 tp.path.segments.last().unwrap().ident.clone()
342 } else {
343 panic!("Expected struct type for a refineable field");
344 };
345 let refinement_struct_name = format_ident!("{}Refinement", struct_name);
346 let generics = if let Type::Path(tp) = ty {
347 &tp.path.segments.last().unwrap().arguments
348 } else {
349 &syn::PathArguments::None
350 };
351 parse_quote!(#refinement_struct_name #generics)
352 } else if is_optional_field(field) {
353 ty.clone()
354 } else {
355 parse_quote!(Option<#ty>)
356 }
357}