1use proc_macro::TokenStream;
2use proc_macro2::TokenStream as TokenStream2;
3use quote::{format_ident, quote};
4use syn::{
5 DeriveInput, Field, FieldsNamed, PredicateType, TraitBound, Type, TypeParamBound, WhereClause,
6 WherePredicate, parse_macro_input, parse_quote,
7};
8
9#[proc_macro_derive(Refineable, attributes(refineable))]
10pub fn derive_refineable(input: TokenStream) -> TokenStream {
11 let DeriveInput {
12 ident,
13 data,
14 generics,
15 attrs,
16 ..
17 } = parse_macro_input!(input);
18
19 let refineable_attr = attrs.iter().find(|attr| attr.path.is_ident("refineable"));
20
21 let mut impl_debug_on_refinement = false;
22 let mut refinement_traits_to_derive = vec![];
23
24 if let Some(refineable_attr) = refineable_attr {
25 if let Ok(syn::Meta::List(meta_list)) = refineable_attr.parse_meta() {
26 for nested in meta_list.nested {
27 let syn::NestedMeta::Meta(syn::Meta::Path(path)) = nested else {
28 continue;
29 };
30
31 if path.is_ident("Debug") {
32 impl_debug_on_refinement = true;
33 } else {
34 refinement_traits_to_derive.push(path);
35 }
36 }
37 }
38 }
39
40 let refinement_ident = format_ident!("{}Refinement", ident);
41 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
42
43 let fields = match data {
44 syn::Data::Struct(syn::DataStruct {
45 fields: syn::Fields::Named(FieldsNamed { named, .. }),
46 ..
47 }) => named.into_iter().collect::<Vec<Field>>(),
48 _ => panic!("This derive macro only supports structs with named fields"),
49 };
50
51 let field_names: Vec<_> = fields.iter().map(|f| f.ident.as_ref().unwrap()).collect();
52 let field_visibilities: Vec<_> = fields.iter().map(|f| &f.vis).collect();
53 let wrapped_types: Vec<_> = fields.iter().map(|f| get_wrapper_type(f, &f.ty)).collect();
54
55 // Create trait bound that each wrapped type must implement Clone // & Default
56 let type_param_bounds: Vec<_> = wrapped_types
57 .iter()
58 .map(|ty| {
59 WherePredicate::Type(PredicateType {
60 lifetimes: None,
61 bounded_ty: ty.clone(),
62 colon_token: Default::default(),
63 bounds: {
64 let mut punctuated = syn::punctuated::Punctuated::new();
65 punctuated.push_value(TypeParamBound::Trait(TraitBound {
66 paren_token: None,
67 modifier: syn::TraitBoundModifier::None,
68 lifetimes: None,
69 path: parse_quote!(Clone),
70 }));
71
72 punctuated
73 },
74 })
75 })
76 .collect();
77
78 // Append to where_clause or create a new one if it doesn't exist
79 let where_clause = match where_clause.cloned() {
80 Some(mut where_clause) => {
81 where_clause.predicates.extend(type_param_bounds);
82 where_clause.clone()
83 }
84 None => WhereClause {
85 where_token: Default::default(),
86 predicates: type_param_bounds.into_iter().collect(),
87 },
88 };
89
90 let refineable_refine_assignments: Vec<TokenStream2> = fields
91 .iter()
92 .map(|field| {
93 let name = &field.ident;
94 let is_refineable = is_refineable_field(field);
95 let is_optional = is_optional_field(field);
96
97 if is_refineable {
98 quote! {
99 self.#name.refine(&refinement.#name);
100 }
101 } else if is_optional {
102 quote! {
103 if let Some(value) = &refinement.#name {
104 self.#name = Some(value.clone());
105 }
106 }
107 } else {
108 quote! {
109 if let Some(value) = &refinement.#name {
110 self.#name = value.clone();
111 }
112 }
113 }
114 })
115 .collect();
116
117 let refineable_refined_assignments: Vec<TokenStream2> = fields
118 .iter()
119 .map(|field| {
120 let name = &field.ident;
121 let is_refineable = is_refineable_field(field);
122 let is_optional = is_optional_field(field);
123
124 if is_refineable {
125 quote! {
126 self.#name = self.#name.refined(refinement.#name);
127 }
128 } else if is_optional {
129 quote! {
130 if let Some(value) = refinement.#name {
131 self.#name = Some(value);
132 }
133 }
134 } else {
135 quote! {
136 if let Some(value) = refinement.#name {
137 self.#name = value;
138 }
139 }
140 }
141 })
142 .collect();
143
144 let refinement_refine_assignments: Vec<TokenStream2> = fields
145 .iter()
146 .map(|field| {
147 let name = &field.ident;
148 let is_refineable = is_refineable_field(field);
149
150 if is_refineable {
151 quote! {
152 self.#name.refine(&refinement.#name);
153 }
154 } else {
155 quote! {
156 if let Some(value) = &refinement.#name {
157 self.#name = Some(value.clone());
158 }
159 }
160 }
161 })
162 .collect();
163
164 let refinement_refined_assignments: Vec<TokenStream2> = fields
165 .iter()
166 .map(|field| {
167 let name = &field.ident;
168 let is_refineable = is_refineable_field(field);
169
170 if is_refineable {
171 quote! {
172 self.#name = self.#name.refined(refinement.#name);
173 }
174 } else {
175 quote! {
176 if let Some(value) = refinement.#name {
177 self.#name = Some(value);
178 }
179 }
180 }
181 })
182 .collect();
183
184 let from_refinement_assignments: Vec<TokenStream2> = fields
185 .iter()
186 .map(|field| {
187 let name = &field.ident;
188 let is_refineable = is_refineable_field(field);
189 let is_optional = is_optional_field(field);
190
191 if is_refineable {
192 quote! {
193 #name: value.#name.into(),
194 }
195 } else if is_optional {
196 quote! {
197 #name: value.#name.map(|v| v.into()),
198 }
199 } else {
200 quote! {
201 #name: value.#name.map(|v| v.into()).unwrap_or_default(),
202 }
203 }
204 })
205 .collect();
206
207 let debug_impl = if impl_debug_on_refinement {
208 let refinement_field_debugs: Vec<TokenStream2> = fields
209 .iter()
210 .map(|field| {
211 let name = &field.ident;
212 quote! {
213 if self.#name.is_some() {
214 debug_struct.field(stringify!(#name), &self.#name);
215 } else {
216 all_some = false;
217 }
218 }
219 })
220 .collect();
221
222 quote! {
223 impl #impl_generics std::fmt::Debug for #refinement_ident #ty_generics
224 #where_clause
225 {
226 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
227 let mut debug_struct = f.debug_struct(stringify!(#refinement_ident));
228 let mut all_some = true;
229 #( #refinement_field_debugs )*
230 if all_some {
231 debug_struct.finish()
232 } else {
233 debug_struct.finish_non_exhaustive()
234 }
235 }
236 }
237 }
238 } else {
239 quote! {}
240 };
241
242 let mut derive_stream = quote! {};
243 for trait_to_derive in refinement_traits_to_derive {
244 derive_stream.extend(quote! { #[derive(#trait_to_derive)] })
245 }
246
247 let r#gen = quote! {
248 /// A refinable version of [`#ident`], see that documentation for details.
249 #[derive(Clone)]
250 #derive_stream
251 pub struct #refinement_ident #impl_generics {
252 #(
253 #[allow(missing_docs)]
254 #field_visibilities #field_names: #wrapped_types
255 ),*
256 }
257
258 impl #impl_generics Refineable for #ident #ty_generics
259 #where_clause
260 {
261 type Refinement = #refinement_ident #ty_generics;
262
263 fn refine(&mut self, refinement: &Self::Refinement) {
264 #( #refineable_refine_assignments )*
265 }
266
267 fn refined(mut self, refinement: Self::Refinement) -> Self {
268 #( #refineable_refined_assignments )*
269 self
270 }
271 }
272
273 impl #impl_generics Refineable for #refinement_ident #ty_generics
274 #where_clause
275 {
276 type Refinement = #refinement_ident #ty_generics;
277
278 fn refine(&mut self, refinement: &Self::Refinement) {
279 #( #refinement_refine_assignments )*
280 }
281
282 fn refined(mut self, refinement: Self::Refinement) -> Self {
283 #( #refinement_refined_assignments )*
284 self
285 }
286 }
287
288 impl #impl_generics From<#refinement_ident #ty_generics> for #ident #ty_generics
289 #where_clause
290 {
291 fn from(value: #refinement_ident #ty_generics) -> Self {
292 Self {
293 #( #from_refinement_assignments )*
294 }
295 }
296 }
297
298 impl #impl_generics ::core::default::Default for #refinement_ident #ty_generics
299 #where_clause
300 {
301 fn default() -> Self {
302 #refinement_ident {
303 #( #field_names: Default::default() ),*
304 }
305 }
306 }
307
308 impl #impl_generics #refinement_ident #ty_generics
309 #where_clause
310 {
311 /// Returns `true` if all fields are `Some`
312 pub fn is_some(&self) -> bool {
313 #(
314 if self.#field_names.is_some() {
315 return true;
316 }
317 )*
318 false
319 }
320 }
321
322 #debug_impl
323 };
324 r#gen.into()
325}
326
327fn is_refineable_field(f: &Field) -> bool {
328 f.attrs.iter().any(|attr| attr.path.is_ident("refineable"))
329}
330
331fn is_optional_field(f: &Field) -> bool {
332 if let Type::Path(typepath) = &f.ty {
333 if typepath.qself.is_none() {
334 let segments = &typepath.path.segments;
335 if segments.len() == 1 && segments.iter().any(|s| s.ident == "Option") {
336 return true;
337 }
338 }
339 }
340 false
341}
342
343fn get_wrapper_type(field: &Field, ty: &Type) -> syn::Type {
344 if is_refineable_field(field) {
345 let struct_name = if let Type::Path(tp) = ty {
346 tp.path.segments.last().unwrap().ident.clone()
347 } else {
348 panic!("Expected struct type for a refineable field");
349 };
350 let refinement_struct_name = format_ident!("{}Refinement", struct_name);
351 let generics = if let Type::Path(tp) = ty {
352 &tp.path.segments.last().unwrap().arguments
353 } else {
354 &syn::PathArguments::None
355 };
356 parse_quote!(#refinement_struct_name #generics)
357 } else if is_optional_field(field) {
358 ty.clone()
359 } else {
360 parse_quote!(Option<#ty>)
361 }
362}