mlx_macros/
module_parameters.rs

1use darling::FromDeriveInput;
2use syn::{DataStruct, DeriveInput, Generics, Ident};
3
4use crate::util::filter_fields_with_attr;
5
6#[derive(Debug, Clone, FromDeriveInput)]
7#[darling(attributes(module))]
8struct ModuleProperties {
9    root: Option<syn::Path>,
10}
11
12pub(crate) fn expand_module_parameters(
13    input: &DeriveInput,
14) -> Result<proc_macro2::TokenStream, syn::Error> {
15    let prop = ModuleProperties::from_derive_input(input)?;
16    let struct_ident = &input.ident;
17    let generics = &input.generics;
18    match &input.data {
19        syn::Data::Struct(data) => {
20            expand_module_parameters_for_struct(struct_ident, generics, data, prop.root)
21        }
22        _ => Err(syn::Error::new_spanned(
23            input,
24            "ModuleParameters can only be derived for structs",
25        )),
26    }
27}
28
29fn expand_module_parameters_for_struct(
30    ident: &Ident,
31    generics: &Generics,
32    data: &DataStruct,
33    root: Option<syn::Path>,
34) -> Result<proc_macro2::TokenStream, syn::Error> {
35    let fields = filter_fields_with_attr(&data.fields, "param")?;
36
37    Ok(impl_module_parameters_for_struct(
38        ident,
39        generics,
40        fields.filtered,
41        root,
42    ))
43}
44
45fn impl_module_parameters_for_struct(
46    ident: &Ident,
47    generics: &Generics,
48    fields: Vec<&syn::Field>,
49    root: Option<syn::Path>,
50) -> proc_macro2::TokenStream {
51    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
52    let field_names: Vec<_> = fields.iter().map(|field| &field.ident).collect();
53
54    // Returns None if there are no fields
55    let default_all_frozen = match field_names.len() {
56        0 => quote::quote! { None },
57        _ => quote::quote! { Some(true) },
58    };
59
60    // Returns None if there are no fields
61    let default_any_frozen = match field_names.len() {
62        0 => quote::quote! { None },
63        _ => quote::quote! { Some(false) },
64    };
65
66    let (extern_import, root) = match root {
67        Some(root) => (quote::quote! {}, quote::quote! { #root }),
68        None => (
69            quote::quote! { extern crate mlx_rs as _mlx_rs; },
70            quote::quote! { _mlx_rs },
71        ),
72    };
73
74    quote::quote! {
75        const _: () = {
76            #extern_import
77            impl #impl_generics #root::module::ModuleParameters for #ident #ty_generics #where_clause {
78                fn num_parameters(&self) -> usize {
79                    use #root::module::Parameter;
80                    let mut count = 0;
81                    #(
82                        count += self.#field_names.count();
83                    )*
84                    count
85                }
86
87                fn freeze_parameters(&mut self, recursive: bool) {
88                    use #root::module::Parameter;
89                    #(self.#field_names.freeze(recursive);)*
90                }
91
92                fn unfreeze_parameters(&mut self, recursive: bool) {
93                    use #root::module::Parameter;
94                    #(self.#field_names.unfreeze(recursive);)*
95                }
96
97                fn parameters(&self) -> #root::module::ModuleParamRef<'_> {
98                    let mut parameters = #root::nested::NestedHashMap::new();
99                    #(parameters.insert(std::rc::Rc::from(stringify!(#field_names)), #root::module::Parameter::as_nested_value(&self.#field_names));)*
100                    parameters
101                }
102
103                fn parameters_mut(&mut self) -> #root::module::ModuleParamMut<'_> {
104                    let mut parameters = #root::nested::NestedHashMap::new();
105                    #(parameters.insert(std::rc::Rc::from(stringify!(#field_names)), #root::module::Parameter::as_nested_value_mut(&mut self.#field_names));)*
106                    parameters
107                }
108
109                fn trainable_parameters(&self) -> #root::module::ModuleParamRef<'_> {
110                    let mut parameters = #root::nested::NestedHashMap::new();
111                    #(
112                        if let Some(field) = #root::module::Parameter::as_trainable_nested_value(&self.#field_names) {
113                            parameters.insert(std::rc::Rc::from(stringify!(#field_names)), field);
114                        }
115                    )*
116                    parameters
117                }
118
119                fn all_frozen(&self) -> Option<bool> {
120                    use #root::module::Parameter;
121                    #(
122                        if matches!(self.#field_names.is_frozen(), Some(false)) {
123                            return Some(false);
124                        }
125                    )*
126                    #default_all_frozen
127                }
128
129                fn any_frozen(&self) -> Option<bool> {
130                    use #root::module::Parameter;
131                    #(
132                        if matches!(self.#field_names.is_frozen(), Some(true)) {
133                            return Some(true);
134                        }
135                    )*
136                    #default_any_frozen
137                }
138            }
139        };
140    }
141}