mlx_internal_macros/
generate_builder.rs

1use darling::FromDeriveInput;
2use proc_macro2::TokenTree;
3use quote::quote;
4use syn::DeriveInput;
5
6use crate::{
7    derive_buildable::StructProperty,
8    shared::{
9        parse_fields_from_derive_input, BuilderStructAnalyzer, BuilderStructProperty, PathOrIdent,
10        Result,
11    },
12};
13
14pub(crate) fn expand_generate_builder(input: &DeriveInput) -> Result<proc_macro2::TokenStream> {
15    // Make sure the struct does NOT have #[derive(Default)]
16    if struct_attr_derive_default(&input.attrs) {
17        return Err("Struct with #[derive(Default)] cannot derive Buildable".into());
18    }
19
20    let struct_prop = StructProperty::from_derive_input(input)?;
21    let builder_struct_prop = BuilderStructProperty::from_derive_input(input)?;
22    let (impl_generics, type_generics, where_clause) = input.generics.split_for_impl();
23
24    let struct_ident = &struct_prop.ident;
25    let builder_struct_ident =
26        syn::Ident::new(&format!("{}Builder", struct_ident), struct_ident.span());
27    let root = match struct_prop.root {
28        Some(path) => path,
29        None => syn::parse_quote!(::mlx_rs),
30    };
31
32    let (mandatory_fields, optional_fields) = parse_fields_from_derive_input(input)?;
33    let is_default_infallible = builder_struct_prop
34        .default_infallible
35        .unwrap_or_else(|| builder_struct_prop.err.is_none());
36
37    let builder_struct_ident = match &struct_prop.builder {
38        Some(path) => PathOrIdent::Path(path.clone()),
39        None => PathOrIdent::Ident(builder_struct_ident.clone()),
40    };
41    let builder_struct_analyzer = BuilderStructAnalyzer {
42        struct_ident,
43        builder_struct_ident: &builder_struct_ident,
44        root: &root,
45        impl_generics: &impl_generics,
46        type_generics: &type_generics,
47        where_clause,
48        mandatory_fields: &mandatory_fields,
49        optional_fields: &optional_fields,
50        build_with: builder_struct_prop.build_with.as_ref(),
51        err: builder_struct_prop.err.as_ref(),
52    };
53    let builder_struct = if struct_prop.builder.is_none() {
54        builder_struct_analyzer.generate_builder_struct()
55    } else {
56        quote! {}
57    };
58    let impl_builder = builder_struct_analyzer.impl_builder();
59    let impl_struct_new = builder_struct_analyzer.impl_struct_new(is_default_infallible);
60
61    Ok(quote! {
62        #builder_struct
63        #impl_builder
64        #impl_struct_new
65    })
66}
67
68fn struct_attr_derive_default(attrs: &[syn::Attribute]) -> bool {
69    attrs
70        .iter()
71        .filter_map(|attr| {
72            if attr.path().is_ident("derive") {
73                attr.meta
74                    .require_list()
75                    .map(|list| list.tokens.clone())
76                    .ok()
77            } else {
78                None
79            }
80        })
81        .any(|tokens| {
82            tokens.into_iter().any(|tree| {
83                if let TokenTree::Ident(ident) = tree {
84                    ident == "Default"
85                } else {
86                    false
87                }
88            })
89        })
90}