mlx_rs/optimizers/
mod.rs

1//! Trait and implementations for optimizers.
2
3#![deny(missing_docs)]
4
5use std::{
6    borrow::{Borrow, Cow},
7    collections::HashMap,
8    path::Path,
9    rc::Rc,
10};
11
12use crate::{
13    array,
14    error::{IoError, UnflattenError},
15    module::{FlattenedModuleParam, ModuleParameters},
16    utils::Updatable,
17    Array,
18};
19
20mod adadelta;
21mod adafactor;
22mod adagrad;
23mod adam;
24mod adamax;
25mod adamw;
26mod lion;
27mod rmsprop;
28mod sgd;
29
30pub use adadelta::*;
31pub use adafactor::*;
32pub use adagrad::*;
33pub use adam::*;
34pub use adamax::*;
35pub use adamw::*;
36use itertools::Itertools;
37pub use lion::*;
38pub use rmsprop::*;
39pub use sgd::*;
40
41// Unfortunate workaround to implement Updatable for mutable references of
42// optimizers This is needed because of the orphan rule and lack of negative
43// trait bound, otherwise we would need to implement Updatable for every
44// `Module`
45macro_rules! impl_updatable_for_mut_optimizer {
46    ($optimizer:ty) => {
47        impl Updatable for &'_ mut $optimizer {
48            fn updatable_states_len(&self) -> usize {
49                <$optimizer as Updatable>::updatable_states_len(&**self)
50            }
51
52            fn updatable_states(&self) -> impl IntoIterator<Item = &Array> {
53                <$optimizer as Updatable>::updatable_states(&**self)
54            }
55
56            fn updatable_states_mut(&mut self) -> impl IntoIterator<Item = &mut Array> {
57                <$optimizer as Updatable>::updatable_states_mut(&mut **self)
58            }
59        }
60    };
61}
62use impl_updatable_for_mut_optimizer;
63
64/// Type alias for common optimizer state.
65pub type State<T = Array> = HashMap<Rc<str>, T>;
66
67/// Trait for optimizer states.
68pub trait OptimizerState: Sized {
69    /// Error type for unflatten.
70    type UnflattenError: std::error::Error + Into<IoError>;
71
72    /// Flatten the optimizer state.
73    fn flatten(&self) -> impl Iterator<Item = (Rc<str>, &Array)>;
74
75    /// Flatten the mutable optimizer state.
76    fn flatten_mut(&mut self) -> impl Iterator<Item = (Rc<str>, &mut Array)>;
77
78    /// Unflatten an iterator of key-value pairs into the optimizer state.
79    fn unflatten<I, K>(input: I) -> Result<Self, Self::UnflattenError>
80    where
81        I: IntoIterator<Item = (K, Array)>,
82        K: Ord + AsRef<str> + Into<Rc<str>>;
83
84    /// Save the optimizer state to a safetensors file.
85    fn save_safetensors(&self, path: impl AsRef<Path>) -> Result<(), IoError> {
86        let state = self.flatten();
87        Array::save_safetensors(state, None, path)
88    }
89
90    /// Load the optimizer state from a safetensors file.
91    fn load_safetensors(&mut self, path: impl AsRef<Path>) -> Result<(), IoError> {
92        let loaded = Array::load_safetensors(path)?;
93        let unflattened = Self::unflatten(loaded).map_err(Into::into)?;
94
95        *self = unflattened;
96
97        Ok(())
98    }
99}
100
101impl OptimizerState for State {
102    type UnflattenError = std::convert::Infallible;
103
104    fn flatten(&self) -> impl Iterator<Item = (Rc<str>, &Array)> {
105        self.iter().map(|(k, v)| (k.clone(), v))
106    }
107
108    fn flatten_mut(&mut self) -> impl Iterator<Item = (Rc<str>, &mut Array)> {
109        self.iter_mut().map(|(k, v)| (k.clone(), v))
110    }
111
112    fn unflatten<I, K>(input: I) -> Result<Self, Self::UnflattenError>
113    where
114        Self: Sized,
115        I: IntoIterator<Item = (K, Array)>,
116        K: Ord + AsRef<str> + Into<Rc<str>>,
117    {
118        Ok(input.into_iter().map(|(k, v)| (k.into(), v)).collect())
119    }
120}
121
122impl OptimizerState for State<(Array, Array)> {
123    type UnflattenError = UnflattenError;
124
125    fn flatten(&self) -> impl Iterator<Item = (Rc<str>, &Array)> {
126        self.iter().flat_map(|(k, (first, second))| {
127            let first_k: Rc<str> = Rc::from(format!("{}.0", k));
128            let second_k: Rc<str> = Rc::from(format!("{}.1", k));
129
130            [(first_k, first), (second_k, second)]
131        })
132    }
133
134    fn flatten_mut(&mut self) -> impl Iterator<Item = (Rc<str>, &mut Array)> {
135        self.iter_mut().flat_map(|(k, (first, second))| {
136            let first_k: Rc<str> = Rc::from(format!("{}.0", k));
137            let second_k: Rc<str> = Rc::from(format!("{}.1", k));
138
139            [(first_k, first), (second_k, second)]
140        })
141    }
142
143    fn unflatten<I, K>(input: I) -> Result<Self, Self::UnflattenError>
144    where
145        Self: Sized,
146        I: IntoIterator<Item = (K, Array)>,
147        K: Ord + AsRef<str> + Into<Rc<str>>,
148    {
149        let mut state = State::new();
150        let iter = input
151            .into_iter()
152            .sorted_by(|a, b| a.0.as_ref().cmp(b.0.as_ref()))
153            .chunks(2);
154
155        for mut chunk in &iter {
156            let first = chunk.next().ok_or(UnflattenError::ExpectingNextPair)?;
157            let second = chunk.next().ok_or(UnflattenError::ExpectingNextPair)?;
158
159            // Check if the keys match up to the last dot and the suffix is 0 and 1 (should be already sorted)
160            let first_key = first.0.as_ref();
161            let second_key = second.0.as_ref();
162            if !first_key.ends_with(".0") || !second_key.ends_with(".1") {
163                return Err(UnflattenError::InvalidKey);
164            }
165            if first_key[..first_key.len() - 2] != second_key[..second_key.len() - 2] {
166                return Err(UnflattenError::InvalidKey);
167            }
168
169            let key = &first_key[..first_key.len() - 2];
170            let key: Rc<str> = Rc::from(key);
171            state.insert(key, (first.1, second.1));
172        }
173        Ok(state)
174    }
175}
176
177/// Trait for optimizers.
178pub trait Optimizer: Updatable {
179    /// State of the optimizer.
180    type State: OptimizerState;
181
182    /// Get the state of the optimizer.
183    fn state(&self) -> &Self::State;
184
185    /// Get the mutable state of the optimizer.
186    fn state_mut(&mut self) -> &mut Self::State;
187
188    /// Update a single parameter with the given gradient.
189    ///
190    /// The implementation should look up the state for the parameter using the key and update the
191    /// state and the parameter accordingly. The key is provided instead of the state because it
192    /// would otherwise create a mutable borrow conflict with the rest of the optimizer fields.
193    fn update_single(
194        &mut self,
195        key: &Rc<str>,
196        gradient: &Array,
197        parameter: &mut Array,
198    ) -> crate::error::Result<()>;
199
200    /// Apply the gradients to the parameters of the model and update the model with the new
201    /// parameters.
202    fn update<M>(
203        &mut self,
204        model: &mut M,
205        gradients: impl Borrow<FlattenedModuleParam>,
206    ) -> crate::error::Result<()>
207    where
208        M: ModuleParameters,
209    {
210        let mut parameters = model.parameters_mut().flatten();
211
212        for (key, gradient) in gradients.borrow().iter() {
213            if let Some(parameter) = parameters.get_mut(key) {
214                self.update_single(key, gradient, parameter)?;
215            }
216        }
217
218        Ok(())
219    }
220}
221
222/// Type alias for clipped gradients that is returned by `clip_grad_norm`.
223pub type MaybeClippedGrads<'a> = HashMap<Rc<str>, Cow<'a, Array>>;
224
225/// Clips the global norm of the gradients
226///
227/// This function ensures that the global norm of the gradients does not exceed
228/// `max_norm`. It scales down the gradients proportionally if their norm is
229/// greater than `max_norm`.
230pub fn clip_grad_norm(
231    gradients: &FlattenedModuleParam,
232    max_norm: f32,
233) -> crate::error::Result<(MaybeClippedGrads, f32)> {
234    let total_norm: f32 = gradients
235        .values()
236        .try_fold(array!(0.0), |acc, grad| acc.add(&grad.square()?.sum(None)?))?
237        .sqrt()?
238        .item();
239    let normalizer = array!(max_norm / (total_norm + 1e-6));
240
241    let clipped_gradients: HashMap<_, _> = gradients
242        .iter()
243        .map(|(key, grad)| {
244            let clipped_grad = if total_norm < max_norm {
245                Cow::Borrowed(grad)
246            } else {
247                Cow::Owned(grad * &normalizer)
248            };
249            (key.clone(), clipped_grad)
250        })
251        .collect();
252    Ok((clipped_gradients, total_norm))
253}
254
255#[cfg(test)]
256mod tests {
257    use std::collections::HashMap;
258
259    use crate::{array, module::FlattenedModuleParam, Array};
260
261    use super::clip_grad_norm;
262
263    #[test]
264    fn test_clip_grad_norm() {
265        // Test with small gradients that do not require clipping
266        let mut small_grads: FlattenedModuleParam = HashMap::new();
267        small_grads.insert("first.a".into(), array!([0.1, 0.2]));
268        small_grads.insert("first.b".into(), array!(0.1));
269        small_grads.insert("second".into(), array!(0.3));
270
271        let max_norm = 10.0;
272
273        let (clipped_grads, _) = clip_grad_norm(&small_grads, max_norm).unwrap();
274        for (key, value) in small_grads.iter() {
275            assert_eq!(&*clipped_grads[key], value);
276        }
277
278        // Test with large gradients that require clipping
279        let mut large_grads: FlattenedModuleParam = HashMap::new();
280        large_grads.insert("first.a".into(), array!([10.0, 20.0]));
281        large_grads.insert("first.b".into(), array!(10.0));
282        large_grads.insert("second".into(), array!(30.0));
283
284        let max_norm = 1.0;
285
286        let (clipped_grads, total_norm) = clip_grad_norm(&large_grads, max_norm).unwrap();
287        let clipped_values: Vec<_> = clipped_grads.values().map(|v| v.as_ref()).collect();
288        let norm_of_clipped = clipped_values
289            .into_iter()
290            .map(|g| g.square().unwrap().sum(None).unwrap())
291            .sum::<Array>()
292            .sqrt()
293            .unwrap();
294
295        float_eq::assert_float_eq!(norm_of_clipped.item::<f32>(), max_norm, abs <= 1e-6);
296
297        // Ensures that the scaling was done correctly
298        let scale = max_norm / total_norm;
299        let expected_grads: FlattenedModuleParam = large_grads
300            .iter()
301            .map(|(key, value)| (key.clone(), value * scale))
302            .collect();
303        for (key, value) in expected_grads.iter() {
304            assert_eq!(&*clipped_grads[key], value);
305        }
306    }
307}