mlx_rs/module/
module.rs

1use std::{borrow::Borrow, collections::HashMap, hash::Hash, path::Path, rc::Rc};
2
3use crate::{
4    error::{Exception, IoError},
5    nested::{NestedHashMap, NestedValue},
6    Array,
7};
8
9/// Type alias for owned module parameters.
10pub type ModuleParam = NestedHashMap<Rc<str>, Array>;
11
12/// Type alias for borrowed module parameters.
13pub type ModuleParamRef<'a> = NestedHashMap<Rc<str>, &'a Array>;
14
15/// Type alias for mutably borrowed module parameters.
16pub type ModuleParamMut<'a> = NestedHashMap<Rc<str>, &'a mut Array>;
17
18/// Type alias for flattened module parameters.
19pub type FlattenedModuleParam = HashMap<Rc<str>, Array>;
20
21/// Type alias for borrowed flattened module parameters.
22pub type FlattenedModuleParamRef<'a> = HashMap<Rc<str>, &'a Array>;
23
24/// Type alias for mutably borrowed flattened module parameters.
25pub type FlattenedModuleParamMut<'a> = HashMap<Rc<str>, &'a mut Array>;
26
27/// Trait for a neural network module.
28pub trait Module<Input>: ModuleParameters + std::fmt::Debug {
29    /// Output type of the module.
30    type Output;
31
32    /// Error type for the module.
33    type Error: std::error::Error;
34
35    /// Forward pass of the module.
36    fn forward(&mut self, input: Input) -> Result<Self::Output, Self::Error>;
37
38    /// Set whether the module is in training mode.
39    ///
40    /// Training mode only applies to certain layers. For example, dropout layers applies a random
41    /// mask in training mode, but is the identity in evaluation mode. Implementations of nested
42    /// modules should propagate the training mode to all child modules.
43    fn training_mode(&mut self, mode: bool);
44}
45
46/// Marker trait for a unary neural network module.
47///
48/// This trait should not be implemented directly. Instead, implement [`Module`] with `Args` as a
49/// reference to the input.
50pub trait UnaryModule: for<'a> Module<&'a Array, Output = Array> {}
51
52impl<T> UnaryModule for T where T: for<'a> Module<&'a Array, Output = Array> {}
53
54/// Trait for accessing and updating module parameters.
55pub trait ModuleParameters {
56    /// Get references to the module parameters.
57    fn parameters(&self) -> ModuleParamRef<'_>;
58
59    /// Get mutable references to the module parameters.
60    fn parameters_mut(&mut self) -> ModuleParamMut<'_>;
61
62    /// Get references to the trainable parameters. A parameter is trainable if it is NOT frozen.
63    fn trainable_parameters(&self) -> ModuleParamRef<'_>;
64
65    /// Update the module parameters.
66    fn update(&mut self, parameters: ModuleParam) {
67        let flattened_parameters = parameters.flatten();
68        update_parameters(self, flattened_parameters)
69    }
70
71    /// Update the module parameters from a flattened representation.
72    fn update_flattened(&mut self, flattened_parameters: FlattenedModuleParam) {
73        update_parameters(self, flattened_parameters)
74    }
75
76    /// Freeze all parameters in the module.
77    fn freeze_parameters(&mut self, recursive: bool);
78
79    /// Unfreeze all parameters in the module.
80    fn unfreeze_parameters(&mut self, recursive: bool);
81
82    /// Check if all parameters in the module are frozen. Returns `None` if there are no parameters.
83    fn all_frozen(&self) -> Option<bool>;
84
85    /// Check if any parameter in the module is frozen. Returns `None` if there are no parameters.
86    fn any_frozen(&self) -> Option<bool>;
87}
88
89/// Update the module parameters from an iterator of (key, value) tuples.
90pub fn update_parameters<M, I, Q>(module: &mut M, parameters: I)
91where
92    M: ModuleParameters + ?Sized,
93    I: IntoIterator<Item = (Q, Array)>,
94    Q: Hash + Eq,
95    Rc<str>: Borrow<Q>,
96{
97    let mut flattened_self_parameters = module.parameters_mut().flatten();
98
99    for (key, value) in parameters {
100        if let Some(self_value) = flattened_self_parameters.get_mut(&key) {
101            **self_value = value;
102        }
103    }
104}
105
106impl<T> ModuleParameters for &'_ mut T
107where
108    T: ModuleParameters + ?Sized,
109{
110    fn parameters(&self) -> ModuleParamRef<'_> {
111        (**self).parameters()
112    }
113
114    fn parameters_mut(&mut self) -> ModuleParamMut<'_> {
115        (**self).parameters_mut()
116    }
117
118    fn trainable_parameters(&self) -> ModuleParamRef<'_> {
119        (**self).trainable_parameters()
120    }
121
122    fn freeze_parameters(&mut self, recursive: bool) {
123        (**self).freeze_parameters(recursive);
124    }
125
126    fn unfreeze_parameters(&mut self, recursive: bool) {
127        (**self).unfreeze_parameters(recursive);
128    }
129
130    fn all_frozen(&self) -> Option<bool> {
131        (**self).all_frozen()
132    }
133
134    fn any_frozen(&self) -> Option<bool> {
135        (**self).any_frozen()
136    }
137}
138
139impl<T> ModuleParameters for Box<T>
140where
141    T: ModuleParameters + ?Sized,
142{
143    fn parameters(&self) -> ModuleParamRef<'_> {
144        self.as_ref().parameters()
145    }
146
147    fn parameters_mut(&mut self) -> ModuleParamMut<'_> {
148        self.as_mut().parameters_mut()
149    }
150
151    fn trainable_parameters(&self) -> ModuleParamRef<'_> {
152        self.as_ref().trainable_parameters()
153    }
154
155    fn freeze_parameters(&mut self, recursive: bool) {
156        self.as_mut().freeze_parameters(recursive);
157    }
158
159    fn unfreeze_parameters(&mut self, recursive: bool) {
160        self.as_mut().unfreeze_parameters(recursive);
161    }
162
163    fn all_frozen(&self) -> Option<bool> {
164        self.as_ref().all_frozen()
165    }
166
167    fn any_frozen(&self) -> Option<bool> {
168        self.as_ref().any_frozen()
169    }
170}
171
172impl<T> ModuleParameters for Vec<T>
173where
174    T: ModuleParameters,
175{
176    fn parameters(&self) -> ModuleParamRef<'_> {
177        let mut parameters = NestedHashMap::new();
178        self.iter().enumerate().for_each(|(i, module)| {
179            let value = module.parameters();
180            parameters.insert(Rc::from(i.to_string()), NestedValue::Map(value.entries));
181        });
182        parameters
183    }
184
185    fn parameters_mut(&mut self) -> ModuleParamMut<'_> {
186        let mut parameters = NestedHashMap::new();
187        self.iter_mut().enumerate().for_each(|(i, module)| {
188            let value = module.parameters_mut();
189            parameters.insert(Rc::from(i.to_string()), NestedValue::Map(value.entries));
190        });
191        parameters
192    }
193
194    fn trainable_parameters(&self) -> ModuleParamRef<'_> {
195        let mut parameters = NestedHashMap::new();
196        self.iter().enumerate().for_each(|(i, module)| {
197            let value = module.trainable_parameters();
198            parameters.insert(Rc::from(i.to_string()), NestedValue::Map(value.entries));
199        });
200        parameters
201    }
202
203    fn freeze_parameters(&mut self, recursive: bool) {
204        self.iter_mut().for_each(|module| {
205            module.freeze_parameters(recursive);
206        });
207    }
208
209    fn unfreeze_parameters(&mut self, recursive: bool) {
210        self.iter_mut().for_each(|module| {
211            module.unfreeze_parameters(recursive);
212        });
213    }
214
215    fn all_frozen(&self) -> Option<bool> {
216        let mut result = None;
217        for module in self.iter() {
218            match module.all_frozen() {
219                Some(true) => result = Some(true),
220                Some(false) => return Some(false),
221                None => {}
222            }
223        }
224        result
225    }
226
227    fn any_frozen(&self) -> Option<bool> {
228        let mut result = None;
229        for module in self.iter() {
230            match module.any_frozen() {
231                Some(true) => return Some(true),
232                Some(false) => result = Some(false),
233                None => {}
234            }
235        }
236        result
237    }
238}
239
240/// Extension trait for `ModuleParameters`. This is implemented for all types that implement
241/// `ModuleParameters`.
242pub trait ModuleParametersExt: ModuleParameters {
243    /// Evaluate the module parameters.
244    fn eval(&self) -> Result<(), Exception> {
245        crate::transforms::eval_params(self.parameters())
246    }
247
248    /// Load module parameters from a `safetensors` file.
249    fn load_safetensors(&mut self, path: impl AsRef<Path>) -> Result<(), IoError> {
250        let loaded = Array::load_safetensors(path)?;
251
252        // Load the parameters
253        let mut params = self.parameters_mut().flatten();
254        for (key, value) in loaded {
255            if let Some(param) = params.get_mut(&*key) {
256                **param = value;
257            }
258        }
259
260        // Loading is lazy, eval after loading
261        self.eval()?;
262
263        Ok(())
264    }
265
266    /// Save module parameters to a file in `safetensors` format.
267    fn save_safetensors(&self, path: impl AsRef<Path>) -> Result<(), IoError> {
268        let params = self.parameters().flatten();
269        Array::save_safetensors(params, None, path)?;
270        Ok(())
271    }
272}
273
274impl<T: ModuleParameters> ModuleParametersExt for T {}