mlx_rs/module/
module.rs
1use std::{borrow::Borrow, collections::HashMap, hash::Hash, path::Path, rc::Rc};
2
3use crate::{
4 error::{Exception, IoError},
5 nested::{NestedHashMap, NestedValue},
6 Array,
7};
8
9pub type ModuleParam = NestedHashMap<Rc<str>, Array>;
11
12pub type ModuleParamRef<'a> = NestedHashMap<Rc<str>, &'a Array>;
14
15pub type ModuleParamMut<'a> = NestedHashMap<Rc<str>, &'a mut Array>;
17
18pub type FlattenedModuleParam = HashMap<Rc<str>, Array>;
20
21pub type FlattenedModuleParamRef<'a> = HashMap<Rc<str>, &'a Array>;
23
24pub type FlattenedModuleParamMut<'a> = HashMap<Rc<str>, &'a mut Array>;
26
27pub trait Module<Input>: ModuleParameters + std::fmt::Debug {
29 type Output;
31
32 type Error: std::error::Error;
34
35 fn forward(&mut self, input: Input) -> Result<Self::Output, Self::Error>;
37
38 fn training_mode(&mut self, mode: bool);
44}
45
46pub trait UnaryModule: for<'a> Module<&'a Array, Output = Array> {}
51
52impl<T> UnaryModule for T where T: for<'a> Module<&'a Array, Output = Array> {}
53
54pub trait ModuleParameters {
56 fn parameters(&self) -> ModuleParamRef<'_>;
58
59 fn parameters_mut(&mut self) -> ModuleParamMut<'_>;
61
62 fn trainable_parameters(&self) -> ModuleParamRef<'_>;
64
65 fn update(&mut self, parameters: ModuleParam) {
67 let flattened_parameters = parameters.flatten();
68 update_parameters(self, flattened_parameters)
69 }
70
71 fn update_flattened(&mut self, flattened_parameters: FlattenedModuleParam) {
73 update_parameters(self, flattened_parameters)
74 }
75
76 fn freeze_parameters(&mut self, recursive: bool);
78
79 fn unfreeze_parameters(&mut self, recursive: bool);
81
82 fn all_frozen(&self) -> Option<bool>;
84
85 fn any_frozen(&self) -> Option<bool>;
87}
88
89pub fn update_parameters<M, I, Q>(module: &mut M, parameters: I)
91where
92 M: ModuleParameters + ?Sized,
93 I: IntoIterator<Item = (Q, Array)>,
94 Q: Hash + Eq,
95 Rc<str>: Borrow<Q>,
96{
97 let mut flattened_self_parameters = module.parameters_mut().flatten();
98
99 for (key, value) in parameters {
100 if let Some(self_value) = flattened_self_parameters.get_mut(&key) {
101 **self_value = value;
102 }
103 }
104}
105
106impl<T> ModuleParameters for &'_ mut T
107where
108 T: ModuleParameters + ?Sized,
109{
110 fn parameters(&self) -> ModuleParamRef<'_> {
111 (**self).parameters()
112 }
113
114 fn parameters_mut(&mut self) -> ModuleParamMut<'_> {
115 (**self).parameters_mut()
116 }
117
118 fn trainable_parameters(&self) -> ModuleParamRef<'_> {
119 (**self).trainable_parameters()
120 }
121
122 fn freeze_parameters(&mut self, recursive: bool) {
123 (**self).freeze_parameters(recursive);
124 }
125
126 fn unfreeze_parameters(&mut self, recursive: bool) {
127 (**self).unfreeze_parameters(recursive);
128 }
129
130 fn all_frozen(&self) -> Option<bool> {
131 (**self).all_frozen()
132 }
133
134 fn any_frozen(&self) -> Option<bool> {
135 (**self).any_frozen()
136 }
137}
138
139impl<T> ModuleParameters for Box<T>
140where
141 T: ModuleParameters + ?Sized,
142{
143 fn parameters(&self) -> ModuleParamRef<'_> {
144 self.as_ref().parameters()
145 }
146
147 fn parameters_mut(&mut self) -> ModuleParamMut<'_> {
148 self.as_mut().parameters_mut()
149 }
150
151 fn trainable_parameters(&self) -> ModuleParamRef<'_> {
152 self.as_ref().trainable_parameters()
153 }
154
155 fn freeze_parameters(&mut self, recursive: bool) {
156 self.as_mut().freeze_parameters(recursive);
157 }
158
159 fn unfreeze_parameters(&mut self, recursive: bool) {
160 self.as_mut().unfreeze_parameters(recursive);
161 }
162
163 fn all_frozen(&self) -> Option<bool> {
164 self.as_ref().all_frozen()
165 }
166
167 fn any_frozen(&self) -> Option<bool> {
168 self.as_ref().any_frozen()
169 }
170}
171
172impl<T> ModuleParameters for Vec<T>
173where
174 T: ModuleParameters,
175{
176 fn parameters(&self) -> ModuleParamRef<'_> {
177 let mut parameters = NestedHashMap::new();
178 self.iter().enumerate().for_each(|(i, module)| {
179 let value = module.parameters();
180 parameters.insert(Rc::from(i.to_string()), NestedValue::Map(value.entries));
181 });
182 parameters
183 }
184
185 fn parameters_mut(&mut self) -> ModuleParamMut<'_> {
186 let mut parameters = NestedHashMap::new();
187 self.iter_mut().enumerate().for_each(|(i, module)| {
188 let value = module.parameters_mut();
189 parameters.insert(Rc::from(i.to_string()), NestedValue::Map(value.entries));
190 });
191 parameters
192 }
193
194 fn trainable_parameters(&self) -> ModuleParamRef<'_> {
195 let mut parameters = NestedHashMap::new();
196 self.iter().enumerate().for_each(|(i, module)| {
197 let value = module.trainable_parameters();
198 parameters.insert(Rc::from(i.to_string()), NestedValue::Map(value.entries));
199 });
200 parameters
201 }
202
203 fn freeze_parameters(&mut self, recursive: bool) {
204 self.iter_mut().for_each(|module| {
205 module.freeze_parameters(recursive);
206 });
207 }
208
209 fn unfreeze_parameters(&mut self, recursive: bool) {
210 self.iter_mut().for_each(|module| {
211 module.unfreeze_parameters(recursive);
212 });
213 }
214
215 fn all_frozen(&self) -> Option<bool> {
216 let mut result = None;
217 for module in self.iter() {
218 match module.all_frozen() {
219 Some(true) => result = Some(true),
220 Some(false) => return Some(false),
221 None => {}
222 }
223 }
224 result
225 }
226
227 fn any_frozen(&self) -> Option<bool> {
228 let mut result = None;
229 for module in self.iter() {
230 match module.any_frozen() {
231 Some(true) => return Some(true),
232 Some(false) => result = Some(false),
233 None => {}
234 }
235 }
236 result
237 }
238}
239
240pub trait ModuleParametersExt: ModuleParameters {
243 fn eval(&self) -> Result<(), Exception> {
245 crate::transforms::eval_params(self.parameters())
246 }
247
248 fn load_safetensors(&mut self, path: impl AsRef<Path>) -> Result<(), IoError> {
250 let loaded = Array::load_safetensors(path)?;
251
252 let mut params = self.parameters_mut().flatten();
254 for (key, value) in loaded {
255 if let Some(param) = params.get_mut(&*key) {
256 **param = value;
257 }
258 }
259
260 self.eval()?;
262
263 Ok(())
264 }
265
266 fn save_safetensors(&self, path: impl AsRef<Path>) -> Result<(), IoError> {
268 let params = self.parameters().flatten();
269 Array::save_safetensors(params, None, path)?;
270 Ok(())
271 }
272}
273
274impl<T: ModuleParameters> ModuleParametersExt for T {}