mlx_rs/module/
module.rs

1use std::{borrow::Borrow, collections::HashMap, hash::Hash, path::Path, rc::Rc};
2
3use crate::{
4    error::{Exception, IoError},
5    nested::{NestedHashMap, NestedValue},
6    Array,
7};
8
9/// Type alias for owned module parameters.
10pub type ModuleParam = NestedHashMap<Rc<str>, Array>;
11
12/// Type alias for borrowed module parameters.
13pub type ModuleParamRef<'a> = NestedHashMap<Rc<str>, &'a Array>;
14
15/// Type alias for mutably borrowed module parameters.
16pub type ModuleParamMut<'a> = NestedHashMap<Rc<str>, &'a mut Array>;
17
18/// Type alias for flattened module parameters.
19pub type FlattenedModuleParam = HashMap<Rc<str>, Array>;
20
21/// Type alias for borrowed flattened module parameters.
22pub type FlattenedModuleParamRef<'a> = HashMap<Rc<str>, &'a Array>;
23
24/// Type alias for mutably borrowed flattened module parameters.
25pub type FlattenedModuleParamMut<'a> = HashMap<Rc<str>, &'a mut Array>;
26
27/// Trait for a neural network module.
28pub trait Module<Input>: ModuleParameters + std::fmt::Debug {
29    /// Output type of the module.
30    type Output;
31
32    /// Error type for the module.
33    type Error: std::error::Error;
34
35    /// Forward pass of the module.
36    fn forward(&mut self, input: Input) -> Result<Self::Output, Self::Error>;
37
38    /// Set whether the module is in training mode.
39    ///
40    /// Training mode only applies to certain layers. For example, dropout layers applies a random
41    /// mask in training mode, but is the identity in evaluation mode. Implementations of nested
42    /// modules should propagate the training mode to all child modules.
43    fn training_mode(&mut self, mode: bool);
44}
45
46/// Marker trait for a unary neural network module.
47///
48/// This trait should not be implemented directly. Instead, implement [`Module`] with `Args` as a
49/// reference to the input.
50pub trait UnaryModule: for<'a> Module<&'a Array, Output = Array> {}
51
52impl<T> UnaryModule for T where T: for<'a> Module<&'a Array, Output = Array> {}
53
54/// Trait for accessing and updating module parameters.
55pub trait ModuleParameters {
56    /// Get the total number of parameters in the module.
57    ///
58    /// Returns the total number of parameters in the module without counting
59    /// the parameters iterator. `module.parameters().flatten().len()`
60    fn num_parameters(&self) -> usize;
61
62    /// Get references to the module parameters.
63    fn parameters(&self) -> ModuleParamRef<'_>;
64
65    /// Get mutable references to the module parameters.
66    fn parameters_mut(&mut self) -> ModuleParamMut<'_>;
67
68    /// Get references to the trainable parameters. A parameter is trainable if it is NOT frozen.
69    fn trainable_parameters(&self) -> ModuleParamRef<'_>;
70
71    /// Update the module parameters.
72    fn update(&mut self, parameters: ModuleParam) {
73        let flattened_parameters = parameters.flatten();
74        update_parameters(self, flattened_parameters)
75    }
76
77    /// Update the module parameters from a flattened representation.
78    fn update_flattened(&mut self, flattened_parameters: FlattenedModuleParam) {
79        update_parameters(self, flattened_parameters)
80    }
81
82    /// Freeze all parameters in the module.
83    fn freeze_parameters(&mut self, recursive: bool);
84
85    /// Unfreeze all parameters in the module.
86    fn unfreeze_parameters(&mut self, recursive: bool);
87
88    /// Check if all parameters in the module are frozen. Returns `None` if there are no parameters.
89    fn all_frozen(&self) -> Option<bool>;
90
91    /// Check if any parameter in the module is frozen. Returns `None` if there are no parameters.
92    fn any_frozen(&self) -> Option<bool>;
93}
94
95/// Update the module parameters from an iterator of (key, value) tuples.
96pub fn update_parameters<M, I, Q>(module: &mut M, parameters: I)
97where
98    M: ModuleParameters + ?Sized,
99    I: IntoIterator<Item = (Q, Array)>,
100    Q: Hash + Eq,
101    Rc<str>: Borrow<Q>,
102{
103    let mut flattened_self_parameters = module.parameters_mut().flatten();
104
105    for (key, value) in parameters {
106        if let Some(self_value) = flattened_self_parameters.get_mut(&key) {
107            **self_value = value;
108        }
109    }
110}
111
112impl<T> ModuleParameters for &'_ mut T
113where
114    T: ModuleParameters + ?Sized,
115{
116    fn num_parameters(&self) -> usize {
117        (**self).num_parameters()
118    }
119
120    fn parameters(&self) -> ModuleParamRef<'_> {
121        (**self).parameters()
122    }
123
124    fn parameters_mut(&mut self) -> ModuleParamMut<'_> {
125        (**self).parameters_mut()
126    }
127
128    fn trainable_parameters(&self) -> ModuleParamRef<'_> {
129        (**self).trainable_parameters()
130    }
131
132    fn freeze_parameters(&mut self, recursive: bool) {
133        (**self).freeze_parameters(recursive);
134    }
135
136    fn unfreeze_parameters(&mut self, recursive: bool) {
137        (**self).unfreeze_parameters(recursive);
138    }
139
140    fn all_frozen(&self) -> Option<bool> {
141        (**self).all_frozen()
142    }
143
144    fn any_frozen(&self) -> Option<bool> {
145        (**self).any_frozen()
146    }
147}
148
149impl<T> ModuleParameters for Box<T>
150where
151    T: ModuleParameters + ?Sized,
152{
153    fn num_parameters(&self) -> usize {
154        self.as_ref().num_parameters()
155    }
156
157    fn parameters(&self) -> ModuleParamRef<'_> {
158        self.as_ref().parameters()
159    }
160
161    fn parameters_mut(&mut self) -> ModuleParamMut<'_> {
162        self.as_mut().parameters_mut()
163    }
164
165    fn trainable_parameters(&self) -> ModuleParamRef<'_> {
166        self.as_ref().trainable_parameters()
167    }
168
169    fn freeze_parameters(&mut self, recursive: bool) {
170        self.as_mut().freeze_parameters(recursive);
171    }
172
173    fn unfreeze_parameters(&mut self, recursive: bool) {
174        self.as_mut().unfreeze_parameters(recursive);
175    }
176
177    fn all_frozen(&self) -> Option<bool> {
178        self.as_ref().all_frozen()
179    }
180
181    fn any_frozen(&self) -> Option<bool> {
182        self.as_ref().any_frozen()
183    }
184}
185
186impl<T> ModuleParameters for Vec<T>
187where
188    T: ModuleParameters,
189{
190    fn num_parameters(&self) -> usize {
191        self.iter().map(|module| module.num_parameters()).sum()
192    }
193
194    fn parameters(&self) -> ModuleParamRef<'_> {
195        let mut parameters = NestedHashMap::new();
196        self.iter().enumerate().for_each(|(i, module)| {
197            let value = module.parameters();
198            parameters.insert(Rc::from(i.to_string()), NestedValue::Map(value.entries));
199        });
200        parameters
201    }
202
203    fn parameters_mut(&mut self) -> ModuleParamMut<'_> {
204        let mut parameters = NestedHashMap::new();
205        self.iter_mut().enumerate().for_each(|(i, module)| {
206            let value = module.parameters_mut();
207            parameters.insert(Rc::from(i.to_string()), NestedValue::Map(value.entries));
208        });
209        parameters
210    }
211
212    fn trainable_parameters(&self) -> ModuleParamRef<'_> {
213        let mut parameters = NestedHashMap::new();
214        self.iter().enumerate().for_each(|(i, module)| {
215            let value = module.trainable_parameters();
216            parameters.insert(Rc::from(i.to_string()), NestedValue::Map(value.entries));
217        });
218        parameters
219    }
220
221    fn freeze_parameters(&mut self, recursive: bool) {
222        self.iter_mut().for_each(|module| {
223            module.freeze_parameters(recursive);
224        });
225    }
226
227    fn unfreeze_parameters(&mut self, recursive: bool) {
228        self.iter_mut().for_each(|module| {
229            module.unfreeze_parameters(recursive);
230        });
231    }
232
233    fn all_frozen(&self) -> Option<bool> {
234        let mut result = None;
235        for module in self.iter() {
236            match module.all_frozen() {
237                Some(true) => result = Some(true),
238                Some(false) => return Some(false),
239                None => {}
240            }
241        }
242        result
243    }
244
245    fn any_frozen(&self) -> Option<bool> {
246        let mut result = None;
247        for module in self.iter() {
248            match module.any_frozen() {
249                Some(true) => return Some(true),
250                Some(false) => result = Some(false),
251                None => {}
252            }
253        }
254        result
255    }
256}
257
258/// Extension trait for `ModuleParameters`. This is implemented for all types that implement
259/// `ModuleParameters`.
260pub trait ModuleParametersExt: ModuleParameters {
261    /// Evaluate the module parameters.
262    fn eval(&self) -> Result<(), Exception> {
263        crate::transforms::eval_params(self.parameters())
264    }
265
266    /// Load module parameters from a `safetensors` file.
267    fn load_safetensors(&mut self, path: impl AsRef<Path>) -> Result<(), IoError> {
268        let loaded = Array::load_safetensors(path)?;
269
270        // Load the parameters
271        let mut params = self.parameters_mut().flatten();
272        for (key, value) in loaded {
273            if let Some(param) = params.get_mut(&*key) {
274                **param = value;
275            }
276        }
277
278        // Loading is lazy, eval after loading
279        self.eval()?;
280
281        Ok(())
282    }
283
284    /// Save module parameters to a file in `safetensors` format.
285    fn save_safetensors(&self, path: impl AsRef<Path>) -> Result<(), IoError> {
286        let params = self.parameters().flatten();
287        Array::save_safetensors(params, None, path)?;
288        Ok(())
289    }
290}
291
292impl<T: ModuleParameters> ModuleParametersExt for T {}