1use proc_macro::TokenStream;
2use proc_macro2::TokenStream as TokenStream2;
3use quote::{format_ident, quote};
4use syn::{
5 DeriveInput, Field, FieldsNamed, PredicateType, TraitBound, Type, TypeParamBound, WhereClause,
6 WherePredicate, parse_macro_input, parse_quote,
7};
8
9#[proc_macro_derive(Refineable, attributes(refineable))]
10pub fn derive_refineable(input: TokenStream) -> TokenStream {
11 let DeriveInput {
12 ident,
13 data,
14 generics,
15 attrs,
16 ..
17 } = parse_macro_input!(input);
18
19 let refineable_attr = attrs.iter().find(|attr| attr.path().is_ident("refineable"));
20
21 let mut impl_debug_on_refinement = false;
22 let mut derives_serialize = false;
23 let mut refinement_traits_to_derive = vec![];
24
25 if let Some(refineable_attr) = refineable_attr {
26 let _ = refineable_attr.parse_nested_meta(|meta| {
27 if meta.path.is_ident("Debug") {
28 impl_debug_on_refinement = true;
29 } else {
30 if meta.path.is_ident("Serialize") {
31 derives_serialize = true;
32 }
33 refinement_traits_to_derive.push(meta.path);
34 }
35 Ok(())
36 });
37 }
38
39 let refinement_ident = format_ident!("{}Refinement", ident);
40 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
41
42 let fields = match data {
43 syn::Data::Struct(syn::DataStruct {
44 fields: syn::Fields::Named(FieldsNamed { named, .. }),
45 ..
46 }) => named.into_iter().collect::<Vec<Field>>(),
47 _ => panic!("This derive macro only supports structs with named fields"),
48 };
49
50 let field_names: Vec<_> = fields.iter().map(|f| f.ident.as_ref().unwrap()).collect();
51 let field_visibilities: Vec<_> = fields.iter().map(|f| &f.vis).collect();
52 let wrapped_types: Vec<_> = fields.iter().map(|f| get_wrapper_type(f, &f.ty)).collect();
53
54 let field_attributes: Vec<TokenStream2> = fields
55 .iter()
56 .map(|f| {
57 if derives_serialize {
58 if is_refineable_field(f) {
59 quote! { #[serde(default, skip_serializing_if = "::refineable::IsEmpty::is_empty")] }
60 } else {
61 quote! { #[serde(skip_serializing_if = "::std::option::Option::is_none")] }
62 }
63 } else {
64 quote! {}
65 }
66 })
67 .collect();
68
69 // Create trait bound that each wrapped type must implement Clone // & Default
70 let type_param_bounds: Vec<_> = wrapped_types
71 .iter()
72 .map(|ty| {
73 WherePredicate::Type(PredicateType {
74 lifetimes: None,
75 bounded_ty: ty.clone(),
76 colon_token: Default::default(),
77 bounds: {
78 let mut punctuated = syn::punctuated::Punctuated::new();
79 punctuated.push_value(TypeParamBound::Trait(TraitBound {
80 paren_token: None,
81 modifier: syn::TraitBoundModifier::None,
82 lifetimes: None,
83 path: parse_quote!(Clone),
84 }));
85
86 punctuated
87 },
88 })
89 })
90 .collect();
91
92 // Append to where_clause or create a new one if it doesn't exist
93 let where_clause = match where_clause.cloned() {
94 Some(mut where_clause) => {
95 where_clause.predicates.extend(type_param_bounds);
96 where_clause.clone()
97 }
98 None => WhereClause {
99 where_token: Default::default(),
100 predicates: type_param_bounds.into_iter().collect(),
101 },
102 };
103
104 let refineable_refine_assignments: Vec<TokenStream2> = fields
105 .iter()
106 .map(|field| {
107 let name = &field.ident;
108 let is_refineable = is_refineable_field(field);
109 let is_optional = is_optional_field(field);
110
111 if is_refineable {
112 quote! {
113 self.#name.refine(&refinement.#name);
114 }
115 } else if is_optional {
116 quote! {
117 if let Some(value) = &refinement.#name {
118 self.#name = Some(value.clone());
119 }
120 }
121 } else {
122 quote! {
123 if let Some(value) = &refinement.#name {
124 self.#name = value.clone();
125 }
126 }
127 }
128 })
129 .collect();
130
131 let refineable_refined_assignments: Vec<TokenStream2> = fields
132 .iter()
133 .map(|field| {
134 let name = &field.ident;
135 let is_refineable = is_refineable_field(field);
136 let is_optional = is_optional_field(field);
137
138 if is_refineable {
139 quote! {
140 self.#name = self.#name.refined(refinement.#name);
141 }
142 } else if is_optional {
143 quote! {
144 if let Some(value) = refinement.#name {
145 self.#name = Some(value);
146 }
147 }
148 } else {
149 quote! {
150 if let Some(value) = refinement.#name {
151 self.#name = value;
152 }
153 }
154 }
155 })
156 .collect();
157
158 let refinement_refine_assignments: Vec<TokenStream2> = fields
159 .iter()
160 .map(|field| {
161 let name = &field.ident;
162 let is_refineable = is_refineable_field(field);
163
164 if is_refineable {
165 quote! {
166 self.#name.refine(&refinement.#name);
167 }
168 } else {
169 quote! {
170 if let Some(value) = &refinement.#name {
171 self.#name = Some(value.clone());
172 }
173 }
174 }
175 })
176 .collect();
177
178 let refinement_refined_assignments: Vec<TokenStream2> = fields
179 .iter()
180 .map(|field| {
181 let name = &field.ident;
182 let is_refineable = is_refineable_field(field);
183
184 if is_refineable {
185 quote! {
186 self.#name = self.#name.refined(refinement.#name);
187 }
188 } else {
189 quote! {
190 if let Some(value) = refinement.#name {
191 self.#name = Some(value);
192 }
193 }
194 }
195 })
196 .collect();
197
198 let from_refinement_assignments: Vec<TokenStream2> = fields
199 .iter()
200 .map(|field| {
201 let name = &field.ident;
202 let is_refineable = is_refineable_field(field);
203 let is_optional = is_optional_field(field);
204
205 if is_refineable {
206 quote! {
207 #name: value.#name.into(),
208 }
209 } else if is_optional {
210 quote! {
211 #name: value.#name.map(|v| v.into()),
212 }
213 } else {
214 quote! {
215 #name: value.#name.map(|v| v.into()).unwrap_or_default(),
216 }
217 }
218 })
219 .collect();
220
221 let debug_impl = if impl_debug_on_refinement {
222 let refinement_field_debugs: Vec<TokenStream2> = fields
223 .iter()
224 .map(|field| {
225 let name = &field.ident;
226 quote! {
227 if self.#name.is_some() {
228 debug_struct.field(stringify!(#name), &self.#name);
229 } else {
230 all_some = false;
231 }
232 }
233 })
234 .collect();
235
236 quote! {
237 impl #impl_generics std::fmt::Debug for #refinement_ident #ty_generics
238 #where_clause
239 {
240 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
241 let mut debug_struct = f.debug_struct(stringify!(#refinement_ident));
242 let mut all_some = true;
243 #( #refinement_field_debugs )*
244 if all_some {
245 debug_struct.finish()
246 } else {
247 debug_struct.finish_non_exhaustive()
248 }
249 }
250 }
251 }
252 } else {
253 quote! {}
254 };
255
256 let refinement_is_empty_conditions: Vec<TokenStream2> = fields
257 .iter()
258 .enumerate()
259 .map(|(i, field)| {
260 let name = &field.ident;
261
262 let condition = if is_refineable_field(field) {
263 quote! { self.#name.is_empty() }
264 } else {
265 quote! { self.#name.is_none() }
266 };
267
268 if i < fields.len() - 1 {
269 quote! { #condition && }
270 } else {
271 condition
272 }
273 })
274 .collect();
275
276 let mut derive_stream = quote! {};
277 for trait_to_derive in refinement_traits_to_derive {
278 derive_stream.extend(quote! { #[derive(#trait_to_derive)] })
279 }
280
281 let r#gen = quote! {
282 /// A refinable version of [`#ident`], see that documentation for details.
283 #[derive(Clone)]
284 #derive_stream
285 pub struct #refinement_ident #impl_generics {
286 #(
287 #[allow(missing_docs)]
288 #field_attributes
289 #field_visibilities #field_names: #wrapped_types
290 ),*
291 }
292
293 impl #impl_generics Refineable for #ident #ty_generics
294 #where_clause
295 {
296 type Refinement = #refinement_ident #ty_generics;
297
298 fn refine(&mut self, refinement: &Self::Refinement) {
299 #( #refineable_refine_assignments )*
300 }
301
302 fn refined(mut self, refinement: Self::Refinement) -> Self {
303 #( #refineable_refined_assignments )*
304 self
305 }
306 }
307
308 impl #impl_generics Refineable for #refinement_ident #ty_generics
309 #where_clause
310 {
311 type Refinement = #refinement_ident #ty_generics;
312
313 fn refine(&mut self, refinement: &Self::Refinement) {
314 #( #refinement_refine_assignments )*
315 }
316
317 fn refined(mut self, refinement: Self::Refinement) -> Self {
318 #( #refinement_refined_assignments )*
319 self
320 }
321 }
322
323 impl #impl_generics ::refineable::IsEmpty for #refinement_ident #ty_generics
324 #where_clause
325 {
326 fn is_empty(&self) -> bool {
327 #( #refinement_is_empty_conditions )*
328 }
329 }
330
331 impl #impl_generics From<#refinement_ident #ty_generics> for #ident #ty_generics
332 #where_clause
333 {
334 fn from(value: #refinement_ident #ty_generics) -> Self {
335 Self {
336 #( #from_refinement_assignments )*
337 }
338 }
339 }
340
341 impl #impl_generics ::core::default::Default for #refinement_ident #ty_generics
342 #where_clause
343 {
344 fn default() -> Self {
345 #refinement_ident {
346 #( #field_names: Default::default() ),*
347 }
348 }
349 }
350
351 impl #impl_generics #refinement_ident #ty_generics
352 #where_clause
353 {
354 /// Returns `true` if all fields are `Some`
355 pub fn is_some(&self) -> bool {
356 #(
357 if self.#field_names.is_some() {
358 return true;
359 }
360 )*
361 false
362 }
363 }
364
365 #debug_impl
366 };
367 r#gen.into()
368}
369
370fn is_refineable_field(f: &Field) -> bool {
371 f.attrs
372 .iter()
373 .any(|attr| attr.path().is_ident("refineable"))
374}
375
376fn is_optional_field(f: &Field) -> bool {
377 if let Type::Path(typepath) = &f.ty {
378 if typepath.qself.is_none() {
379 let segments = &typepath.path.segments;
380 if segments.len() == 1 && segments.iter().any(|s| s.ident == "Option") {
381 return true;
382 }
383 }
384 }
385 false
386}
387
388fn get_wrapper_type(field: &Field, ty: &Type) -> syn::Type {
389 if is_refineable_field(field) {
390 let struct_name = if let Type::Path(tp) = ty {
391 tp.path.segments.last().unwrap().ident.clone()
392 } else {
393 panic!("Expected struct type for a refineable field");
394 };
395 let refinement_struct_name = format_ident!("{}Refinement", struct_name);
396 let generics = if let Type::Path(tp) = ty {
397 &tp.path.segments.last().unwrap().arguments
398 } else {
399 &syn::PathArguments::None
400 };
401 parse_quote!(#refinement_struct_name #generics)
402 } else if is_optional_field(field) {
403 ty.clone()
404 } else {
405 parse_quote!(Option<#ty>)
406 }
407}