mlx_macros/
module_parameters.rs1use darling::FromDeriveInput;
2use syn::{DataStruct, DeriveInput, Generics, Ident};
3
4use crate::util::filter_fields_with_attr;
5
6#[derive(Debug, Clone, FromDeriveInput)]
7#[darling(attributes(module))]
8struct ModuleProperties {
9 root: Option<syn::Path>,
10}
11
12pub(crate) fn expand_module_parameters(
13 input: &DeriveInput,
14) -> Result<proc_macro2::TokenStream, syn::Error> {
15 let prop = ModuleProperties::from_derive_input(input)?;
16 let struct_ident = &input.ident;
17 let generics = &input.generics;
18 match &input.data {
19 syn::Data::Struct(data) => {
20 expand_module_parameters_for_struct(struct_ident, generics, data, prop.root)
21 }
22 _ => Err(syn::Error::new_spanned(
23 input,
24 "ModuleParameters can only be derived for structs",
25 )),
26 }
27}
28
29fn expand_module_parameters_for_struct(
30 ident: &Ident,
31 generics: &Generics,
32 data: &DataStruct,
33 root: Option<syn::Path>,
34) -> Result<proc_macro2::TokenStream, syn::Error> {
35 let fields = filter_fields_with_attr(&data.fields, "param")?;
36
37 Ok(impl_module_parameters_for_struct(
38 ident,
39 generics,
40 fields.filtered,
41 root,
42 ))
43}
44
45fn impl_module_parameters_for_struct(
46 ident: &Ident,
47 generics: &Generics,
48 fields: Vec<&syn::Field>,
49 root: Option<syn::Path>,
50) -> proc_macro2::TokenStream {
51 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
52 let field_names: Vec<_> = fields.iter().map(|field| &field.ident).collect();
53
54 let default_all_frozen = match field_names.len() {
56 0 => quote::quote! { None },
57 _ => quote::quote! { Some(true) },
58 };
59
60 let default_any_frozen = match field_names.len() {
62 0 => quote::quote! { None },
63 _ => quote::quote! { Some(false) },
64 };
65
66 let (extern_import, root) = match root {
67 Some(root) => (quote::quote! {}, quote::quote! { #root }),
68 None => (
69 quote::quote! { extern crate mlx_rs as _mlx_rs; },
70 quote::quote! { _mlx_rs },
71 ),
72 };
73
74 quote::quote! {
75 const _: () = {
76 #extern_import
77 impl #impl_generics #root::module::ModuleParameters for #ident #ty_generics #where_clause {
78 fn freeze_parameters(&mut self, recursive: bool) {
79 use #root::module::Parameter;
80 #(self.#field_names.freeze(recursive);)*
81 }
82
83 fn unfreeze_parameters(&mut self, recursive: bool) {
84 use #root::module::Parameter;
85 #(self.#field_names.unfreeze(recursive);)*
86 }
87
88 fn parameters(&self) -> #root::module::ModuleParamRef<'_> {
89 let mut parameters = #root::nested::NestedHashMap::new();
90 #(parameters.insert(std::rc::Rc::from(stringify!(#field_names)), #root::module::Parameter::as_nested_value(&self.#field_names));)*
91 parameters
92 }
93
94 fn parameters_mut(&mut self) -> #root::module::ModuleParamMut<'_> {
95 let mut parameters = #root::nested::NestedHashMap::new();
96 #(parameters.insert(std::rc::Rc::from(stringify!(#field_names)), #root::module::Parameter::as_nested_value_mut(&mut self.#field_names));)*
97 parameters
98 }
99
100 fn trainable_parameters(&self) -> #root::module::ModuleParamRef<'_> {
101 let mut parameters = #root::nested::NestedHashMap::new();
102 #(
103 if let Some(field) = #root::module::Parameter::as_trainable_nested_value(&self.#field_names) {
104 parameters.insert(std::rc::Rc::from(stringify!(#field_names)), field);
105 }
106 )*
107 parameters
108 }
109
110 fn all_frozen(&self) -> Option<bool> {
111 use #root::module::Parameter;
112 #(
113 if matches!(self.#field_names.is_frozen(), Some(false)) {
114 return Some(false);
115 }
116 )*
117 #default_all_frozen
118 }
119
120 fn any_frozen(&self) -> Option<bool> {
121 use #root::module::Parameter;
122 #(
123 if matches!(self.#field_names.is_frozen(), Some(true)) {
124 return Some(true);
125 }
126 )*
127 #default_any_frozen
128 }
129 }
130 };
131 }
132}