1#![deny(missing_docs)]
4
5use std::{
6 borrow::{Borrow, Cow},
7 collections::HashMap,
8 path::Path,
9 rc::Rc,
10};
11
12use crate::{
13 array,
14 error::{IoError, UnflattenError},
15 module::{FlattenedModuleParam, ModuleParameters},
16 utils::Updatable,
17 Array,
18};
19
20mod adadelta;
21mod adafactor;
22mod adagrad;
23mod adam;
24mod adamax;
25mod adamw;
26mod lion;
27mod rmsprop;
28mod sgd;
29
30pub use adadelta::*;
31pub use adafactor::*;
32pub use adagrad::*;
33pub use adam::*;
34pub use adamax::*;
35pub use adamw::*;
36use itertools::Itertools;
37pub use lion::*;
38pub use rmsprop::*;
39pub use sgd::*;
40
41macro_rules! impl_updatable_for_mut_optimizer {
46 ($optimizer:ty) => {
47 impl Updatable for &'_ mut $optimizer {
48 fn updatable_states_len(&self) -> usize {
49 <$optimizer as Updatable>::updatable_states_len(&**self)
50 }
51
52 fn updatable_states(&self) -> impl IntoIterator<Item = &Array> {
53 <$optimizer as Updatable>::updatable_states(&**self)
54 }
55
56 fn updatable_states_mut(&mut self) -> impl IntoIterator<Item = &mut Array> {
57 <$optimizer as Updatable>::updatable_states_mut(&mut **self)
58 }
59 }
60 };
61}
62use impl_updatable_for_mut_optimizer;
63
64pub type State<T = Array> = HashMap<Rc<str>, T>;
66
67pub trait OptimizerState: Sized {
69 type UnflattenError: std::error::Error + Into<IoError>;
71
72 fn flatten(&self) -> impl Iterator<Item = (Rc<str>, &Array)>;
74
75 fn flatten_mut(&mut self) -> impl Iterator<Item = (Rc<str>, &mut Array)>;
77
78 fn unflatten<I, K>(input: I) -> Result<Self, Self::UnflattenError>
80 where
81 I: IntoIterator<Item = (K, Array)>,
82 K: Ord + AsRef<str> + Into<Rc<str>>;
83
84 fn save_safetensors(&self, path: impl AsRef<Path>) -> Result<(), IoError> {
86 let state = self.flatten();
87 Array::save_safetensors(state, None, path)
88 }
89
90 fn load_safetensors(&mut self, path: impl AsRef<Path>) -> Result<(), IoError> {
92 let loaded = Array::load_safetensors(path)?;
93 let unflattened = Self::unflatten(loaded).map_err(Into::into)?;
94
95 *self = unflattened;
96
97 Ok(())
98 }
99}
100
101impl OptimizerState for State {
102 type UnflattenError = std::convert::Infallible;
103
104 fn flatten(&self) -> impl Iterator<Item = (Rc<str>, &Array)> {
105 self.iter().map(|(k, v)| (k.clone(), v))
106 }
107
108 fn flatten_mut(&mut self) -> impl Iterator<Item = (Rc<str>, &mut Array)> {
109 self.iter_mut().map(|(k, v)| (k.clone(), v))
110 }
111
112 fn unflatten<I, K>(input: I) -> Result<Self, Self::UnflattenError>
113 where
114 Self: Sized,
115 I: IntoIterator<Item = (K, Array)>,
116 K: Ord + AsRef<str> + Into<Rc<str>>,
117 {
118 Ok(input.into_iter().map(|(k, v)| (k.into(), v)).collect())
119 }
120}
121
122impl OptimizerState for State<(Array, Array)> {
123 type UnflattenError = UnflattenError;
124
125 fn flatten(&self) -> impl Iterator<Item = (Rc<str>, &Array)> {
126 self.iter().flat_map(|(k, (first, second))| {
127 let first_k: Rc<str> = Rc::from(format!("{}.0", k));
128 let second_k: Rc<str> = Rc::from(format!("{}.1", k));
129
130 [(first_k, first), (second_k, second)]
131 })
132 }
133
134 fn flatten_mut(&mut self) -> impl Iterator<Item = (Rc<str>, &mut Array)> {
135 self.iter_mut().flat_map(|(k, (first, second))| {
136 let first_k: Rc<str> = Rc::from(format!("{}.0", k));
137 let second_k: Rc<str> = Rc::from(format!("{}.1", k));
138
139 [(first_k, first), (second_k, second)]
140 })
141 }
142
143 fn unflatten<I, K>(input: I) -> Result<Self, Self::UnflattenError>
144 where
145 Self: Sized,
146 I: IntoIterator<Item = (K, Array)>,
147 K: Ord + AsRef<str> + Into<Rc<str>>,
148 {
149 let mut state = State::new();
150 let iter = input
151 .into_iter()
152 .sorted_by(|a, b| a.0.as_ref().cmp(b.0.as_ref()))
153 .chunks(2);
154
155 for mut chunk in &iter {
156 let first = chunk.next().ok_or(UnflattenError::ExpectingNextPair)?;
157 let second = chunk.next().ok_or(UnflattenError::ExpectingNextPair)?;
158
159 let first_key = first.0.as_ref();
161 let second_key = second.0.as_ref();
162 if !first_key.ends_with(".0") || !second_key.ends_with(".1") {
163 return Err(UnflattenError::InvalidKey);
164 }
165 if first_key[..first_key.len() - 2] != second_key[..second_key.len() - 2] {
166 return Err(UnflattenError::InvalidKey);
167 }
168
169 let key = &first_key[..first_key.len() - 2];
170 let key: Rc<str> = Rc::from(key);
171 state.insert(key, (first.1, second.1));
172 }
173 Ok(state)
174 }
175}
176
177pub trait Optimizer: Updatable {
179 type State: OptimizerState;
181
182 fn state(&self) -> &Self::State;
184
185 fn state_mut(&mut self) -> &mut Self::State;
187
188 fn update_single(
194 &mut self,
195 key: &Rc<str>,
196 gradient: &Array,
197 parameter: &mut Array,
198 ) -> crate::error::Result<()>;
199
200 fn update<M>(
203 &mut self,
204 model: &mut M,
205 gradients: impl Borrow<FlattenedModuleParam>,
206 ) -> crate::error::Result<()>
207 where
208 M: ModuleParameters,
209 {
210 let mut parameters = model.parameters_mut().flatten();
211
212 for (key, gradient) in gradients.borrow().iter() {
213 if let Some(parameter) = parameters.get_mut(key) {
214 self.update_single(key, gradient, parameter)?;
215 }
216 }
217
218 Ok(())
219 }
220}
221
222pub type MaybeClippedGrads<'a> = HashMap<Rc<str>, Cow<'a, Array>>;
224
225pub fn clip_grad_norm(
231 gradients: &FlattenedModuleParam,
232 max_norm: f32,
233) -> crate::error::Result<(MaybeClippedGrads, f32)> {
234 let total_norm: f32 = gradients
235 .values()
236 .try_fold(array!(0.0), |acc, grad| acc.add(&grad.square()?.sum(None)?))?
237 .sqrt()?
238 .item();
239 let normalizer = array!(max_norm / (total_norm + 1e-6));
240
241 let clipped_gradients: HashMap<_, _> = gradients
242 .iter()
243 .map(|(key, grad)| {
244 let clipped_grad = if total_norm < max_norm {
245 Cow::Borrowed(grad)
246 } else {
247 Cow::Owned(grad * &normalizer)
248 };
249 (key.clone(), clipped_grad)
250 })
251 .collect();
252 Ok((clipped_gradients, total_norm))
253}
254
255#[cfg(test)]
256mod tests {
257 use std::collections::HashMap;
258
259 use crate::{array, module::FlattenedModuleParam, Array};
260
261 use super::clip_grad_norm;
262
263 #[test]
264 fn test_clip_grad_norm() {
265 let mut small_grads: FlattenedModuleParam = HashMap::new();
267 small_grads.insert("first.a".into(), array!([0.1, 0.2]));
268 small_grads.insert("first.b".into(), array!(0.1));
269 small_grads.insert("second".into(), array!(0.3));
270
271 let max_norm = 10.0;
272
273 let (clipped_grads, _) = clip_grad_norm(&small_grads, max_norm).unwrap();
274 for (key, value) in small_grads.iter() {
275 assert_eq!(&*clipped_grads[key], value);
276 }
277
278 let mut large_grads: FlattenedModuleParam = HashMap::new();
280 large_grads.insert("first.a".into(), array!([10.0, 20.0]));
281 large_grads.insert("first.b".into(), array!(10.0));
282 large_grads.insert("second".into(), array!(30.0));
283
284 let max_norm = 1.0;
285
286 let (clipped_grads, total_norm) = clip_grad_norm(&large_grads, max_norm).unwrap();
287 let clipped_values: Vec<_> = clipped_grads.values().map(|v| v.as_ref()).collect();
288 let norm_of_clipped = clipped_values
289 .into_iter()
290 .map(|g| g.square().unwrap().sum(None).unwrap())
291 .sum::<Array>()
292 .sqrt()
293 .unwrap();
294
295 float_eq::assert_float_eq!(norm_of_clipped.item::<f32>(), max_norm, abs <= 1e-6);
296
297 let scale = max_norm / total_norm;
299 let expected_grads: FlattenedModuleParam = large_grads
300 .iter()
301 .map(|(key, value)| (key.clone(), value * scale))
302 .collect();
303 for (key, value) in expected_grads.iter() {
304 assert_eq!(&*clipped_grads[key], value);
305 }
306 }
307}