mlx_internal_macros/
shared.rs

1use std::fmt::Display;
2
3use darling::{FromDeriveInput, FromField};
4use quote::{quote, ToTokens};
5use syn::{DeriveInput, Ident, ImplGenerics, TypeGenerics, WhereClause};
6
7pub(crate) type Result<T> = std::result::Result<T, Box<dyn std::error::Error>>;
8
9#[derive(Debug, Clone, FromDeriveInput)]
10#[darling(attributes(builder))]
11pub(crate) struct BuilderStructProperty {
12    pub ident: syn::Ident,
13
14    pub build_with: Option<syn::Ident>,
15
16    pub root: Option<syn::Path>,
17
18    pub err: Option<syn::Path>,
19
20    /// Whether building with the default parameters can fail
21    pub default_infallible: Option<bool>,
22}
23
24pub(crate) struct BuilderStructAnalyzer<'a> {
25    pub struct_ident: &'a Ident,
26    pub builder_struct_ident: &'a PathOrIdent,
27    pub root: &'a syn::Path,
28    pub impl_generics: &'a ImplGenerics<'a>,
29    pub type_generics: &'a TypeGenerics<'a>,
30    pub where_clause: Option<&'a WhereClause>,
31    pub mandatory_fields: &'a [MandatoryField],
32    pub optional_fields: &'a [OptionalField],
33    pub build_with: Option<&'a Ident>,
34    pub err: Option<&'a syn::Path>,
35}
36
37impl BuilderStructAnalyzer<'_> {
38    pub fn generate_builder_struct(&self) -> proc_macro2::TokenStream {
39        let struct_ident = self.struct_ident;
40        let builder_ident = self.builder_struct_ident;
41        let type_generics = self.type_generics;
42        let where_clause = self.where_clause;
43
44        let mandatory_field_idents = self.mandatory_fields.iter().map(|field| &field.ident);
45        let mandatory_field_tys = self.mandatory_fields.iter().map(|field| &field.ty);
46
47        let optional_field_idents = self
48            .optional_fields
49            .iter()
50            .map(|field| &field.ident)
51            .collect::<Vec<_>>();
52        let optional_field_tys = self.optional_fields.iter().map(|field| &field.ty);
53        let optional_field_defaults = self.optional_fields.iter().map(|field| &field.default);
54
55        let doc = format!("Builder for `{}`.", struct_ident);
56
57        let mandatory_field_doc = format!("See [`{}`] for more information.", struct_ident);
58        let optional_field_doc =
59            optional_field_idents
60                .iter()
61                .zip(optional_field_defaults)
62                .map(|(ident, default)| {
63                    format!(
64                    "See [`{}::{}`] for more information. Initialized with default value [`{}`].",
65                    struct_ident, ident, default.to_token_stream()
66                )
67                });
68
69        quote! {
70            #[doc = #doc]
71            #[derive(Debug, Clone)]
72            pub struct #builder_ident #type_generics #where_clause {
73                #(
74                    #[doc = #mandatory_field_doc]
75                    #mandatory_field_idents: #mandatory_field_tys,
76                )*
77                #(
78                    #[doc = #optional_field_doc]
79                    #optional_field_idents: #optional_field_tys,
80                )*
81            }
82        }
83    }
84
85    pub fn impl_builder_new(&self) -> proc_macro2::TokenStream {
86        let builder_struct_ident = self.builder_struct_ident;
87        let impl_generics = self.impl_generics;
88        let type_generics = self.type_generics;
89        let where_clause = self.where_clause;
90        let mandatory_field_idents = self
91            .mandatory_fields
92            .iter()
93            .map(|field| &field.ident)
94            .collect::<Vec<_>>();
95        let mandatory_field_types = self.mandatory_fields.iter().map(|field| &field.ty);
96
97        let optional_field_idents = self.optional_fields.iter().map(|field| &field.ident);
98        let optional_field_defaults = self.optional_fields.iter().map(|field| &field.default);
99
100        let doc = format!("Creates a new [`{}`].", builder_struct_ident);
101
102        quote! {
103            impl #impl_generics #builder_struct_ident #type_generics #where_clause {
104                #[doc = #doc]
105                pub fn new(#(#mandatory_field_idents: impl Into<#mandatory_field_types>),*) -> Self {
106                    Self {
107                        #(#mandatory_field_idents: #mandatory_field_idents.into(),)*
108                        #(#optional_field_idents: #optional_field_defaults,)*
109                    }
110                }
111            }
112        }
113    }
114
115    pub fn impl_builder_setters(&self) -> proc_macro2::TokenStream {
116        let builder_struct_ident = self.builder_struct_ident;
117        let impl_generics = self.impl_generics;
118        let type_generics = self.type_generics;
119        let where_clause = self.where_clause;
120        let setters = self.optional_fields.iter().filter_map(|field| {
121            if field.skip_setter {
122                return None;
123            }
124
125            let ident = &field.ident;
126            let ty = &field.ty;
127            let doc = format!("Sets the value of [`{}`].", ident);
128            Some(quote! {
129                #[doc = #doc]
130                pub fn #ident(mut self, #ident: impl Into<#ty>) -> Self {
131                    self.#ident = #ident.into();
132                    self
133                }
134            })
135        });
136
137        quote! {
138            impl #impl_generics #builder_struct_ident #type_generics #where_clause {
139                #(#setters)*
140            }
141        }
142    }
143
144    pub fn impl_builder_trait(&self) -> proc_macro2::TokenStream {
145        let struct_ident = self.struct_ident;
146        let builder_struct_ident = self.builder_struct_ident;
147        let root = self.root;
148        let impl_generics = self.impl_generics;
149        let type_generics = self.type_generics;
150        let where_clause = self.where_clause;
151        let mandatory_field_idents = self.mandatory_fields.iter().map(|field| &field.ident);
152        let optional_field_idents = self.optional_fields.iter().map(|field| &field.ident);
153
154        let err_ty = match self.err {
155            Some(err) => quote! { #err },
156            None => quote! { std::convert::Infallible },
157        };
158
159        let build_body = match self.build_with {
160            Some(f) => quote! {
161                #f(self)
162            },
163            None => quote! {
164                Ok(#struct_ident {
165                    #(#mandatory_field_idents: self.#mandatory_field_idents,)*
166                    #(#optional_field_idents: self.#optional_field_idents,)*
167                })
168            },
169        };
170
171        quote! {
172            impl #impl_generics #root::builder::Builder<#struct_ident #type_generics> for #builder_struct_ident #type_generics #where_clause {
173                type Error = #err_ty;
174
175                fn build(self) -> std::result::Result<#struct_ident #type_generics, Self::Error> {
176                    #build_body
177                }
178            }
179        }
180    }
181
182    pub(crate) fn impl_builder(&self) -> proc_macro2::TokenStream {
183        let builder_new = self.impl_builder_new();
184        let builder_setters = self.impl_builder_setters();
185        let builder_trait = self.impl_builder_trait();
186
187        quote! {
188            #builder_new
189            #builder_setters
190            #builder_trait
191        }
192    }
193
194    pub(crate) fn impl_struct_new(&self, is_default_infallible: bool) -> proc_macro2::TokenStream {
195        let struct_ident = self.struct_ident;
196        let root = self.root;
197        let impl_generics = self.impl_generics;
198        let type_generics = self.type_generics;
199        let where_clause = self.where_clause;
200
201        let mandatory_field_idents = self
202            .mandatory_fields
203            .iter()
204            .map(|field| &field.ident)
205            .collect::<Vec<_>>();
206        let mandatory_field_types = self.mandatory_fields.iter().map(|field| &field.ty);
207
208        let doc = format!("Creates a new instance of `{}`.", struct_ident);
209
210        // TODO: do we want to generate different code for infallible and fallible cases
211        let ret = if is_default_infallible {
212            quote! { -> Self }
213        } else {
214            quote! { -> std::result::Result<Self, <<Self as #root::builder::Buildable>::Builder as #root::builder::Builder<Self>>::Error> }
215        };
216
217        let unwrap_result = if is_default_infallible {
218            quote! { .expect("Build with default parameters should not fail") }
219        } else {
220            quote! {}
221        };
222
223        quote! {
224            impl #impl_generics #struct_ident #type_generics #where_clause {
225                #[doc = #doc]
226                pub fn new(#(#mandatory_field_idents: impl Into<#mandatory_field_types>),*) #ret
227                {
228                    use #root::builder::Builder;
229                    <Self as #root::builder::Buildable>::Builder::new(#(#mandatory_field_idents),*).build()
230                        #unwrap_result
231                }
232            }
233        }
234    }
235}
236
237#[derive(Debug, darling::FromField, PartialEq)]
238#[darling(attributes(builder))]
239pub(crate) struct BuilderFieldProperty {
240    pub ident: Option<syn::Ident>,
241
242    pub ty: syn::Type,
243
244    #[darling(default)]
245    pub optional: bool,
246
247    pub default: Option<syn::Path>,
248
249    pub rename: Option<String>,
250
251    #[darling(default)]
252    pub ignore: bool,
253
254    pub ty_override: Option<syn::Path>,
255
256    #[darling(default)]
257    pub skip_setter: bool,
258}
259
260pub(crate) struct MandatoryField {
261    pub ident: syn::Ident,
262    pub ty: syn::Type,
263}
264
265pub(crate) struct OptionalField {
266    pub ident: syn::Ident,
267    pub ty: syn::Type,
268    pub default: syn::Path,
269    pub skip_setter: bool,
270}
271
272pub(crate) fn parse_fields_from_derive_input(
273    item: &DeriveInput,
274) -> Result<(Vec<MandatoryField>, Vec<OptionalField>)> {
275    match &item.data {
276        syn::Data::Struct(data) => parse_fields_from_datastruct(data),
277        _ => Err("Only structs are supported".into()),
278    }
279}
280
281fn parse_fields_from_datastruct(
282    item: &syn::DataStruct,
283) -> Result<(Vec<MandatoryField>, Vec<OptionalField>)> {
284    parse_fields(&item.fields)
285}
286
287fn parse_fields(fields: &syn::Fields) -> Result<(Vec<MandatoryField>, Vec<OptionalField>)> {
288    let mut mandatory_fields = Vec::new();
289    let mut optional_fields = Vec::new();
290
291    let field_props = fields.iter().map(BuilderFieldProperty::from_field);
292
293    for field_prop in field_props {
294        let field_prop = field_prop?;
295        if field_prop.ignore {
296            continue;
297        }
298
299        let mut ident = match field_prop.ident {
300            Some(ident) => ident,
301            None => return Err("Unnamed fields are not supported".into()),
302        };
303
304        if let Some(rename) = field_prop.rename {
305            ident = syn::Ident::new(&rename, ident.span());
306        }
307
308        let ty = match field_prop.ty_override {
309            Some(ty_override) => syn::Type::Path(syn::TypePath {
310                qself: None,
311                path: ty_override,
312            }),
313            None => field_prop.ty,
314        };
315
316        if field_prop.optional {
317            let default = match field_prop.default {
318                Some(default) => default,
319                None => {
320                    return Err(
321                        format!("Field {} is optional but has no default value", ident).into(),
322                    )
323                }
324            };
325
326            optional_fields.push(OptionalField {
327                ident,
328                ty,
329                default,
330                skip_setter: field_prop.skip_setter,
331            });
332        } else {
333            mandatory_fields.push(MandatoryField { ident, ty });
334        }
335    }
336
337    Ok((mandatory_fields, optional_fields))
338}
339
340#[derive(Debug, Clone)]
341pub(crate) enum PathOrIdent {
342    Path(syn::Path),
343    Ident(syn::Ident),
344}
345
346impl ToTokens for PathOrIdent {
347    fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
348        match self {
349            PathOrIdent::Path(path) => path.to_tokens(tokens),
350            PathOrIdent::Ident(ident) => ident.to_tokens(tokens),
351        }
352    }
353}
354
355impl Display for PathOrIdent {
356    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
357        match self {
358            PathOrIdent::Path(path) => path.to_token_stream().fmt(f),
359            PathOrIdent::Ident(ident) => Display::fmt(ident, f),
360        }
361    }
362}