1use std::{borrow::Borrow, collections::HashMap, hash::Hash, path::Path, rc::Rc};
2
3use crate::{
4 error::{Exception, IoError},
5 nested::{NestedHashMap, NestedValue},
6 Array,
7};
8
9pub type ModuleParam = NestedHashMap<Rc<str>, Array>;
11
12pub type ModuleParamRef<'a> = NestedHashMap<Rc<str>, &'a Array>;
14
15pub type ModuleParamMut<'a> = NestedHashMap<Rc<str>, &'a mut Array>;
17
18pub type FlattenedModuleParam = HashMap<Rc<str>, Array>;
20
21pub type FlattenedModuleParamRef<'a> = HashMap<Rc<str>, &'a Array>;
23
24pub type FlattenedModuleParamMut<'a> = HashMap<Rc<str>, &'a mut Array>;
26
27pub trait Module<Input>: ModuleParameters + std::fmt::Debug {
29 type Output;
31
32 type Error: std::error::Error;
34
35 fn forward(&mut self, input: Input) -> Result<Self::Output, Self::Error>;
37
38 fn training_mode(&mut self, mode: bool);
44}
45
46pub trait UnaryModule: for<'a> Module<&'a Array, Output = Array> {}
51
52impl<T> UnaryModule for T where T: for<'a> Module<&'a Array, Output = Array> {}
53
54pub trait ModuleParameters {
56 fn num_parameters(&self) -> usize;
61
62 fn parameters(&self) -> ModuleParamRef<'_>;
64
65 fn parameters_mut(&mut self) -> ModuleParamMut<'_>;
67
68 fn trainable_parameters(&self) -> ModuleParamRef<'_>;
70
71 fn update(&mut self, parameters: ModuleParam) {
73 let flattened_parameters = parameters.flatten();
74 update_parameters(self, flattened_parameters)
75 }
76
77 fn update_flattened(&mut self, flattened_parameters: FlattenedModuleParam) {
79 update_parameters(self, flattened_parameters)
80 }
81
82 fn freeze_parameters(&mut self, recursive: bool);
84
85 fn unfreeze_parameters(&mut self, recursive: bool);
87
88 fn all_frozen(&self) -> Option<bool>;
90
91 fn any_frozen(&self) -> Option<bool>;
93}
94
95pub fn update_parameters<M, I, Q>(module: &mut M, parameters: I)
97where
98 M: ModuleParameters + ?Sized,
99 I: IntoIterator<Item = (Q, Array)>,
100 Q: Hash + Eq,
101 Rc<str>: Borrow<Q>,
102{
103 let mut flattened_self_parameters = module.parameters_mut().flatten();
104
105 for (key, value) in parameters {
106 if let Some(self_value) = flattened_self_parameters.get_mut(&key) {
107 **self_value = value;
108 }
109 }
110}
111
112impl<T> ModuleParameters for &'_ mut T
113where
114 T: ModuleParameters + ?Sized,
115{
116 fn num_parameters(&self) -> usize {
117 (**self).num_parameters()
118 }
119
120 fn parameters(&self) -> ModuleParamRef<'_> {
121 (**self).parameters()
122 }
123
124 fn parameters_mut(&mut self) -> ModuleParamMut<'_> {
125 (**self).parameters_mut()
126 }
127
128 fn trainable_parameters(&self) -> ModuleParamRef<'_> {
129 (**self).trainable_parameters()
130 }
131
132 fn freeze_parameters(&mut self, recursive: bool) {
133 (**self).freeze_parameters(recursive);
134 }
135
136 fn unfreeze_parameters(&mut self, recursive: bool) {
137 (**self).unfreeze_parameters(recursive);
138 }
139
140 fn all_frozen(&self) -> Option<bool> {
141 (**self).all_frozen()
142 }
143
144 fn any_frozen(&self) -> Option<bool> {
145 (**self).any_frozen()
146 }
147}
148
149impl<T> ModuleParameters for Box<T>
150where
151 T: ModuleParameters + ?Sized,
152{
153 fn num_parameters(&self) -> usize {
154 self.as_ref().num_parameters()
155 }
156
157 fn parameters(&self) -> ModuleParamRef<'_> {
158 self.as_ref().parameters()
159 }
160
161 fn parameters_mut(&mut self) -> ModuleParamMut<'_> {
162 self.as_mut().parameters_mut()
163 }
164
165 fn trainable_parameters(&self) -> ModuleParamRef<'_> {
166 self.as_ref().trainable_parameters()
167 }
168
169 fn freeze_parameters(&mut self, recursive: bool) {
170 self.as_mut().freeze_parameters(recursive);
171 }
172
173 fn unfreeze_parameters(&mut self, recursive: bool) {
174 self.as_mut().unfreeze_parameters(recursive);
175 }
176
177 fn all_frozen(&self) -> Option<bool> {
178 self.as_ref().all_frozen()
179 }
180
181 fn any_frozen(&self) -> Option<bool> {
182 self.as_ref().any_frozen()
183 }
184}
185
186impl<T> ModuleParameters for Vec<T>
187where
188 T: ModuleParameters,
189{
190 fn num_parameters(&self) -> usize {
191 self.iter().map(|module| module.num_parameters()).sum()
192 }
193
194 fn parameters(&self) -> ModuleParamRef<'_> {
195 let mut parameters = NestedHashMap::new();
196 self.iter().enumerate().for_each(|(i, module)| {
197 let value = module.parameters();
198 parameters.insert(Rc::from(i.to_string()), NestedValue::Map(value.entries));
199 });
200 parameters
201 }
202
203 fn parameters_mut(&mut self) -> ModuleParamMut<'_> {
204 let mut parameters = NestedHashMap::new();
205 self.iter_mut().enumerate().for_each(|(i, module)| {
206 let value = module.parameters_mut();
207 parameters.insert(Rc::from(i.to_string()), NestedValue::Map(value.entries));
208 });
209 parameters
210 }
211
212 fn trainable_parameters(&self) -> ModuleParamRef<'_> {
213 let mut parameters = NestedHashMap::new();
214 self.iter().enumerate().for_each(|(i, module)| {
215 let value = module.trainable_parameters();
216 parameters.insert(Rc::from(i.to_string()), NestedValue::Map(value.entries));
217 });
218 parameters
219 }
220
221 fn freeze_parameters(&mut self, recursive: bool) {
222 self.iter_mut().for_each(|module| {
223 module.freeze_parameters(recursive);
224 });
225 }
226
227 fn unfreeze_parameters(&mut self, recursive: bool) {
228 self.iter_mut().for_each(|module| {
229 module.unfreeze_parameters(recursive);
230 });
231 }
232
233 fn all_frozen(&self) -> Option<bool> {
234 let mut result = None;
235 for module in self.iter() {
236 match module.all_frozen() {
237 Some(true) => result = Some(true),
238 Some(false) => return Some(false),
239 None => {}
240 }
241 }
242 result
243 }
244
245 fn any_frozen(&self) -> Option<bool> {
246 let mut result = None;
247 for module in self.iter() {
248 match module.any_frozen() {
249 Some(true) => return Some(true),
250 Some(false) => result = Some(false),
251 None => {}
252 }
253 }
254 result
255 }
256}
257
258pub trait ModuleParametersExt: ModuleParameters {
261 fn eval(&self) -> Result<(), Exception> {
263 crate::transforms::eval_params(self.parameters())
264 }
265
266 fn load_safetensors(&mut self, path: impl AsRef<Path>) -> Result<(), IoError> {
268 let loaded = Array::load_safetensors(path)?;
269
270 let mut params = self.parameters_mut().flatten();
272 for (key, value) in loaded {
273 if let Some(param) = params.get_mut(&*key) {
274 **param = value;
275 }
276 }
277
278 self.eval()?;
280
281 Ok(())
282 }
283
284 fn save_safetensors(&self, path: impl AsRef<Path>) -> Result<(), IoError> {
286 let params = self.parameters().flatten();
287 Array::save_safetensors(params, None, path)?;
288 Ok(())
289 }
290}
291
292impl<T: ModuleParameters> ModuleParametersExt for T {}