mlx_macros/
module_parameters.rs1use darling::FromDeriveInput;
2use syn::{DataStruct, DeriveInput, Generics, Ident};
3
4use crate::util::filter_fields_with_attr;
5
6#[derive(Debug, Clone, FromDeriveInput)]
7#[darling(attributes(module))]
8struct ModuleProperties {
9 root: Option<syn::Path>,
10}
11
12pub(crate) fn expand_module_parameters(
13 input: &DeriveInput,
14) -> Result<proc_macro2::TokenStream, syn::Error> {
15 let prop = ModuleProperties::from_derive_input(input)?;
16 let struct_ident = &input.ident;
17 let generics = &input.generics;
18 match &input.data {
19 syn::Data::Struct(data) => {
20 expand_module_parameters_for_struct(struct_ident, generics, data, prop.root)
21 }
22 _ => Err(syn::Error::new_spanned(
23 input,
24 "ModuleParameters can only be derived for structs",
25 )),
26 }
27}
28
29fn expand_module_parameters_for_struct(
30 ident: &Ident,
31 generics: &Generics,
32 data: &DataStruct,
33 root: Option<syn::Path>,
34) -> Result<proc_macro2::TokenStream, syn::Error> {
35 let fields = filter_fields_with_attr(&data.fields, "param")?;
36
37 Ok(impl_module_parameters_for_struct(
38 ident,
39 generics,
40 fields.filtered,
41 root,
42 ))
43}
44
45fn impl_module_parameters_for_struct(
46 ident: &Ident,
47 generics: &Generics,
48 fields: Vec<&syn::Field>,
49 root: Option<syn::Path>,
50) -> proc_macro2::TokenStream {
51 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
52 let field_names: Vec<_> = fields.iter().map(|field| &field.ident).collect();
53
54 let default_all_frozen = match field_names.len() {
56 0 => quote::quote! { None },
57 _ => quote::quote! { Some(true) },
58 };
59
60 let default_any_frozen = match field_names.len() {
62 0 => quote::quote! { None },
63 _ => quote::quote! { Some(false) },
64 };
65
66 let (extern_import, root) = match root {
67 Some(root) => (quote::quote! {}, quote::quote! { #root }),
68 None => (
69 quote::quote! { extern crate mlx_rs as _mlx_rs; },
70 quote::quote! { _mlx_rs },
71 ),
72 };
73
74 quote::quote! {
75 const _: () = {
76 #extern_import
77 impl #impl_generics #root::module::ModuleParameters for #ident #ty_generics #where_clause {
78 fn num_parameters(&self) -> usize {
79 use #root::module::Parameter;
80 let mut count = 0;
81 #(
82 count += self.#field_names.count();
83 )*
84 count
85 }
86
87 fn freeze_parameters(&mut self, recursive: bool) {
88 use #root::module::Parameter;
89 #(self.#field_names.freeze(recursive);)*
90 }
91
92 fn unfreeze_parameters(&mut self, recursive: bool) {
93 use #root::module::Parameter;
94 #(self.#field_names.unfreeze(recursive);)*
95 }
96
97 fn parameters(&self) -> #root::module::ModuleParamRef<'_> {
98 let mut parameters = #root::nested::NestedHashMap::new();
99 #(parameters.insert(std::rc::Rc::from(stringify!(#field_names)), #root::module::Parameter::as_nested_value(&self.#field_names));)*
100 parameters
101 }
102
103 fn parameters_mut(&mut self) -> #root::module::ModuleParamMut<'_> {
104 let mut parameters = #root::nested::NestedHashMap::new();
105 #(parameters.insert(std::rc::Rc::from(stringify!(#field_names)), #root::module::Parameter::as_nested_value_mut(&mut self.#field_names));)*
106 parameters
107 }
108
109 fn trainable_parameters(&self) -> #root::module::ModuleParamRef<'_> {
110 let mut parameters = #root::nested::NestedHashMap::new();
111 #(
112 if let Some(field) = #root::module::Parameter::as_trainable_nested_value(&self.#field_names) {
113 parameters.insert(std::rc::Rc::from(stringify!(#field_names)), field);
114 }
115 )*
116 parameters
117 }
118
119 fn all_frozen(&self) -> Option<bool> {
120 use #root::module::Parameter;
121 #(
122 if matches!(self.#field_names.is_frozen(), Some(false)) {
123 return Some(false);
124 }
125 )*
126 #default_all_frozen
127 }
128
129 fn any_frozen(&self) -> Option<bool> {
130 use #root::module::Parameter;
131 #(
132 if matches!(self.#field_names.is_frozen(), Some(true)) {
133 return Some(true);
134 }
135 )*
136 #default_any_frozen
137 }
138 }
139 };
140 }
141}