mlx_macros/
module_parameters.rs

1use darling::FromDeriveInput;
2use syn::{DataStruct, DeriveInput, Generics, Ident};
3
4use crate::util::filter_fields_with_attr;
5
6#[derive(Debug, Clone, FromDeriveInput)]
7#[darling(attributes(module))]
8struct ModuleProperties {
9    root: Option<syn::Path>,
10}
11
12pub(crate) fn expand_module_parameters(
13    input: &DeriveInput,
14) -> Result<proc_macro2::TokenStream, syn::Error> {
15    let prop = ModuleProperties::from_derive_input(input)?;
16    let struct_ident = &input.ident;
17    let generics = &input.generics;
18    match &input.data {
19        syn::Data::Struct(data) => {
20            expand_module_parameters_for_struct(struct_ident, generics, data, prop.root)
21        }
22        _ => Err(syn::Error::new_spanned(
23            input,
24            "ModuleParameters can only be derived for structs",
25        )),
26    }
27}
28
29fn expand_module_parameters_for_struct(
30    ident: &Ident,
31    generics: &Generics,
32    data: &DataStruct,
33    root: Option<syn::Path>,
34) -> Result<proc_macro2::TokenStream, syn::Error> {
35    let fields = filter_fields_with_attr(&data.fields, "param")?;
36
37    Ok(impl_module_parameters_for_struct(
38        ident,
39        generics,
40        fields.filtered,
41        root,
42    ))
43}
44
45fn impl_module_parameters_for_struct(
46    ident: &Ident,
47    generics: &Generics,
48    fields: Vec<&syn::Field>,
49    root: Option<syn::Path>,
50) -> proc_macro2::TokenStream {
51    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
52    let field_names: Vec<_> = fields.iter().map(|field| &field.ident).collect();
53
54    // Returns None if there are no fields
55    let default_all_frozen = match field_names.len() {
56        0 => quote::quote! { None },
57        _ => quote::quote! { Some(true) },
58    };
59
60    // Returns None if there are no fields
61    let default_any_frozen = match field_names.len() {
62        0 => quote::quote! { None },
63        _ => quote::quote! { Some(false) },
64    };
65
66    let (extern_import, root) = match root {
67        Some(root) => (quote::quote! {}, quote::quote! { #root }),
68        None => (
69            quote::quote! { extern crate mlx_rs as _mlx_rs; },
70            quote::quote! { _mlx_rs },
71        ),
72    };
73
74    quote::quote! {
75        const _: () = {
76            #extern_import
77            impl #impl_generics #root::module::ModuleParameters for #ident #ty_generics #where_clause {
78                fn freeze_parameters(&mut self, recursive: bool) {
79                    use #root::module::Parameter;
80                    #(self.#field_names.freeze(recursive);)*
81                }
82
83                fn unfreeze_parameters(&mut self, recursive: bool) {
84                    use #root::module::Parameter;
85                    #(self.#field_names.unfreeze(recursive);)*
86                }
87
88                fn parameters(&self) -> #root::module::ModuleParamRef<'_> {
89                    let mut parameters = #root::nested::NestedHashMap::new();
90                    #(parameters.insert(std::rc::Rc::from(stringify!(#field_names)), #root::module::Parameter::as_nested_value(&self.#field_names));)*
91                    parameters
92                }
93
94                fn parameters_mut(&mut self) -> #root::module::ModuleParamMut<'_> {
95                    let mut parameters = #root::nested::NestedHashMap::new();
96                    #(parameters.insert(std::rc::Rc::from(stringify!(#field_names)), #root::module::Parameter::as_nested_value_mut(&mut self.#field_names));)*
97                    parameters
98                }
99
100                fn trainable_parameters(&self) -> #root::module::ModuleParamRef<'_> {
101                    let mut parameters = #root::nested::NestedHashMap::new();
102                    #(
103                        if let Some(field) = #root::module::Parameter::as_trainable_nested_value(&self.#field_names) {
104                            parameters.insert(std::rc::Rc::from(stringify!(#field_names)), field);
105                        }
106                    )*
107                    parameters
108                }
109
110                fn all_frozen(&self) -> Option<bool> {
111                    use #root::module::Parameter;
112                    #(
113                        if matches!(self.#field_names.is_frozen(), Some(false)) {
114                            return Some(false);
115                        }
116                    )*
117                    #default_all_frozen
118                }
119
120                fn any_frozen(&self) -> Option<bool> {
121                    use #root::module::Parameter;
122                    #(
123                        if matches!(self.#field_names.is_frozen(), Some(true)) {
124                            return Some(true);
125                        }
126                    )*
127                    #default_any_frozen
128                }
129            }
130        };
131    }
132}