1#![deny(missing_docs)]
4
5use std::{
6 borrow::{Borrow, Cow},
7 collections::HashMap,
8 path::Path,
9 rc::Rc,
10};
11
12use crate::{
13 array,
14 error::{IoError, UnflattenError},
15 module::{FlattenedModuleParam, ModuleParameters},
16 utils::Updatable,
17 Array,
18};
19
20mod adadelta;
21mod adafactor;
22mod adagrad;
23mod adam;
24mod adamax;
25mod adamw;
26mod lion;
27mod rmsprop;
28mod sgd;
29
30pub use adadelta::*;
31pub use adafactor::*;
32pub use adagrad::*;
33pub use adam::*;
34pub use adamax::*;
35pub use adamw::*;
36use itertools::Itertools;
37pub use lion::*;
38pub use rmsprop::*;
39pub use sgd::*;
40
41macro_rules! impl_updatable_for_mut_optimizer {
46 ($optimizer:ty) => {
47 impl Updatable for &'_ mut $optimizer {
48 fn updatable_states(&self) -> impl IntoIterator<Item = &Array> {
49 <$optimizer as Updatable>::updatable_states(&**self)
50 }
51
52 fn updatable_states_mut(&mut self) -> impl IntoIterator<Item = &mut Array> {
53 <$optimizer as Updatable>::updatable_states_mut(&mut **self)
54 }
55 }
56 };
57}
58use impl_updatable_for_mut_optimizer;
59
60pub type State<T = Array> = HashMap<Rc<str>, T>;
62
63pub trait OptimizerState: Sized {
65 type UnflattenError: std::error::Error + Into<IoError>;
67
68 fn flatten(&self) -> impl Iterator<Item = (Rc<str>, &Array)>;
70
71 fn flatten_mut(&mut self) -> impl Iterator<Item = (Rc<str>, &mut Array)>;
73
74 fn unflatten<I, K>(input: I) -> Result<Self, Self::UnflattenError>
76 where
77 I: IntoIterator<Item = (K, Array)>,
78 K: Ord + AsRef<str> + Into<Rc<str>>;
79
80 fn save_safetensors(&self, path: impl AsRef<Path>) -> Result<(), IoError> {
82 let state = self.flatten();
83 Array::save_safetensors(state, None, path)
84 }
85
86 fn load_safetensors(&mut self, path: impl AsRef<Path>) -> Result<(), IoError> {
88 let loaded = Array::load_safetensors(path)?;
89 let unflattened = Self::unflatten(loaded).map_err(Into::into)?;
90
91 *self = unflattened;
92
93 Ok(())
94 }
95}
96
97impl OptimizerState for State {
98 type UnflattenError = std::convert::Infallible;
99
100 fn flatten(&self) -> impl Iterator<Item = (Rc<str>, &Array)> {
101 self.iter().map(|(k, v)| (k.clone(), v))
102 }
103
104 fn flatten_mut(&mut self) -> impl Iterator<Item = (Rc<str>, &mut Array)> {
105 self.iter_mut().map(|(k, v)| (k.clone(), v))
106 }
107
108 fn unflatten<I, K>(input: I) -> Result<Self, Self::UnflattenError>
109 where
110 Self: Sized,
111 I: IntoIterator<Item = (K, Array)>,
112 K: Ord + AsRef<str> + Into<Rc<str>>,
113 {
114 Ok(input.into_iter().map(|(k, v)| (k.into(), v)).collect())
115 }
116}
117
118impl OptimizerState for State<(Array, Array)> {
119 type UnflattenError = UnflattenError;
120
121 fn flatten(&self) -> impl Iterator<Item = (Rc<str>, &Array)> {
122 self.iter().flat_map(|(k, (first, second))| {
123 let first_k: Rc<str> = Rc::from(format!("{}.0", k));
124 let second_k: Rc<str> = Rc::from(format!("{}.1", k));
125
126 [(first_k, first), (second_k, second)]
127 })
128 }
129
130 fn flatten_mut(&mut self) -> impl Iterator<Item = (Rc<str>, &mut Array)> {
131 self.iter_mut().flat_map(|(k, (first, second))| {
132 let first_k: Rc<str> = Rc::from(format!("{}.0", k));
133 let second_k: Rc<str> = Rc::from(format!("{}.1", k));
134
135 [(first_k, first), (second_k, second)]
136 })
137 }
138
139 fn unflatten<I, K>(input: I) -> Result<Self, Self::UnflattenError>
140 where
141 Self: Sized,
142 I: IntoIterator<Item = (K, Array)>,
143 K: Ord + AsRef<str> + Into<Rc<str>>,
144 {
145 let mut state = State::new();
146 let iter = input
147 .into_iter()
148 .sorted_by(|a, b| a.0.as_ref().cmp(b.0.as_ref()))
149 .chunks(2);
150
151 for mut chunk in &iter {
152 let first = chunk.next().ok_or(UnflattenError::ExpectingNextPair)?;
153 let second = chunk.next().ok_or(UnflattenError::ExpectingNextPair)?;
154
155 let first_key = first.0.as_ref();
157 let second_key = second.0.as_ref();
158 if !first_key.ends_with(".0") || !second_key.ends_with(".1") {
159 return Err(UnflattenError::InvalidKey);
160 }
161 if first_key[..first_key.len() - 2] != second_key[..second_key.len() - 2] {
162 return Err(UnflattenError::InvalidKey);
163 }
164
165 let key = &first_key[..first_key.len() - 2];
166 let key: Rc<str> = Rc::from(key);
167 state.insert(key, (first.1, second.1));
168 }
169 Ok(state)
170 }
171}
172
173pub trait Optimizer: Updatable {
175 type State: OptimizerState;
177
178 fn state(&self) -> &Self::State;
180
181 fn state_mut(&mut self) -> &mut Self::State;
183
184 fn update_single(
190 &mut self,
191 key: &Rc<str>,
192 gradient: &Array,
193 parameter: &mut Array,
194 ) -> crate::error::Result<()>;
195
196 fn update<M>(
199 &mut self,
200 model: &mut M,
201 gradients: impl Borrow<FlattenedModuleParam>,
202 ) -> crate::error::Result<()>
203 where
204 M: ModuleParameters,
205 {
206 let mut parameters = model.parameters_mut().flatten();
207
208 for (key, gradient) in gradients.borrow().iter() {
209 if let Some(parameter) = parameters.get_mut(key) {
210 self.update_single(key, gradient, parameter)?;
211 }
212 }
213
214 Ok(())
215 }
216}
217
218pub type MaybeClippedGrads<'a> = HashMap<Rc<str>, Cow<'a, Array>>;
220
221pub fn clip_grad_norm(
227 gradients: &FlattenedModuleParam,
228 max_norm: f32,
229) -> crate::error::Result<(MaybeClippedGrads, f32)> {
230 let total_norm: f32 = gradients
231 .values()
232 .try_fold(array!(0.0), |acc, grad| {
233 acc.add(&grad.square()?.sum(None, None)?)
234 })?
235 .sqrt()?
236 .item();
237 let normalizer = array!(max_norm / (total_norm + 1e-6));
238
239 let clipped_gradients: HashMap<_, _> = gradients
240 .iter()
241 .map(|(key, grad)| {
242 let clipped_grad = if total_norm < max_norm {
243 Cow::Borrowed(grad)
244 } else {
245 Cow::Owned(grad * &normalizer)
246 };
247 (key.clone(), clipped_grad)
248 })
249 .collect();
250 Ok((clipped_gradients, total_norm))
251}
252
253#[cfg(test)]
254mod tests {
255 use std::collections::HashMap;
256
257 use crate::{array, module::FlattenedModuleParam, Array};
258
259 use super::clip_grad_norm;
260
261 #[test]
262 fn test_clip_grad_norm() {
263 let mut small_grads: FlattenedModuleParam = HashMap::new();
265 small_grads.insert("first.a".into(), array!([0.1, 0.2]));
266 small_grads.insert("first.b".into(), array!(0.1));
267 small_grads.insert("second".into(), array!(0.3));
268
269 let max_norm = 10.0;
270
271 let (clipped_grads, _) = clip_grad_norm(&small_grads, max_norm).unwrap();
272 for (key, value) in small_grads.iter() {
273 assert_eq!(&*clipped_grads[key], value);
274 }
275
276 let mut large_grads: FlattenedModuleParam = HashMap::new();
278 large_grads.insert("first.a".into(), array!([10.0, 20.0]));
279 large_grads.insert("first.b".into(), array!(10.0));
280 large_grads.insert("second".into(), array!(30.0));
281
282 let max_norm = 1.0;
283
284 let (clipped_grads, total_norm) = clip_grad_norm(&large_grads, max_norm).unwrap();
285 let clipped_values: Vec<_> = clipped_grads.values().map(|v| v.as_ref()).collect();
286 let norm_of_clipped = clipped_values
287 .into_iter()
288 .map(|g| g.square().unwrap().sum(None, None).unwrap())
289 .sum::<Array>()
290 .sqrt()
291 .unwrap();
292
293 float_eq::assert_float_eq!(norm_of_clipped.item::<f32>(), max_norm, abs <= 1e-6);
294
295 let scale = max_norm / total_norm;
297 let expected_grads: FlattenedModuleParam = large_grads
298 .iter()
299 .map(|(key, value)| (key.clone(), value * scale))
300 .collect();
301 for (key, value) in expected_grads.iter() {
302 assert_eq!(&*clipped_grads[key], value);
303 }
304 }
305}