mlx_rs/optimizers/
mod.rs

1//! Trait and implementations for optimizers.
2
3#![deny(missing_docs)]
4
5use std::{
6    borrow::{Borrow, Cow},
7    collections::HashMap,
8    path::Path,
9    rc::Rc,
10};
11
12use crate::{
13    array,
14    error::{IoError, UnflattenError},
15    module::{FlattenedModuleParam, ModuleParameters},
16    utils::Updatable,
17    Array,
18};
19
20mod adadelta;
21mod adafactor;
22mod adagrad;
23mod adam;
24mod adamax;
25mod adamw;
26mod lion;
27mod rmsprop;
28mod sgd;
29
30pub use adadelta::*;
31pub use adafactor::*;
32pub use adagrad::*;
33pub use adam::*;
34pub use adamax::*;
35pub use adamw::*;
36use itertools::Itertools;
37pub use lion::*;
38pub use rmsprop::*;
39pub use sgd::*;
40
41// Unfortunate workaround to implement Updatable for mutable references of
42// optimizers This is needed because of the orphan rule and lack of negative
43// trait bound, otherwise we would need to implement Updatable for every
44// `Module`
45macro_rules! impl_updatable_for_mut_optimizer {
46    ($optimizer:ty) => {
47        impl Updatable for &'_ mut $optimizer {
48            fn updatable_states(&self) -> impl IntoIterator<Item = &Array> {
49                <$optimizer as Updatable>::updatable_states(&**self)
50            }
51
52            fn updatable_states_mut(&mut self) -> impl IntoIterator<Item = &mut Array> {
53                <$optimizer as Updatable>::updatable_states_mut(&mut **self)
54            }
55        }
56    };
57}
58use impl_updatable_for_mut_optimizer;
59
60/// Type alias for common optimizer state.
61pub type State<T = Array> = HashMap<Rc<str>, T>;
62
63/// Trait for optimizer states.
64pub trait OptimizerState: Sized {
65    /// Error type for unflatten.
66    type UnflattenError: std::error::Error + Into<IoError>;
67
68    /// Flatten the optimizer state.
69    fn flatten(&self) -> impl Iterator<Item = (Rc<str>, &Array)>;
70
71    /// Flatten the mutable optimizer state.
72    fn flatten_mut(&mut self) -> impl Iterator<Item = (Rc<str>, &mut Array)>;
73
74    /// Unflatten an iterator of key-value pairs into the optimizer state.
75    fn unflatten<I, K>(input: I) -> Result<Self, Self::UnflattenError>
76    where
77        I: IntoIterator<Item = (K, Array)>,
78        K: Ord + AsRef<str> + Into<Rc<str>>;
79
80    /// Save the optimizer state to a safetensors file.
81    fn save_safetensors(&self, path: impl AsRef<Path>) -> Result<(), IoError> {
82        let state = self.flatten();
83        Array::save_safetensors(state, None, path)
84    }
85
86    /// Load the optimizer state from a safetensors file.
87    fn load_safetensors(&mut self, path: impl AsRef<Path>) -> Result<(), IoError> {
88        let loaded = Array::load_safetensors(path)?;
89        let unflattened = Self::unflatten(loaded).map_err(Into::into)?;
90
91        *self = unflattened;
92
93        Ok(())
94    }
95}
96
97impl OptimizerState for State {
98    type UnflattenError = std::convert::Infallible;
99
100    fn flatten(&self) -> impl Iterator<Item = (Rc<str>, &Array)> {
101        self.iter().map(|(k, v)| (k.clone(), v))
102    }
103
104    fn flatten_mut(&mut self) -> impl Iterator<Item = (Rc<str>, &mut Array)> {
105        self.iter_mut().map(|(k, v)| (k.clone(), v))
106    }
107
108    fn unflatten<I, K>(input: I) -> Result<Self, Self::UnflattenError>
109    where
110        Self: Sized,
111        I: IntoIterator<Item = (K, Array)>,
112        K: Ord + AsRef<str> + Into<Rc<str>>,
113    {
114        Ok(input.into_iter().map(|(k, v)| (k.into(), v)).collect())
115    }
116}
117
118impl OptimizerState for State<(Array, Array)> {
119    type UnflattenError = UnflattenError;
120
121    fn flatten(&self) -> impl Iterator<Item = (Rc<str>, &Array)> {
122        self.iter().flat_map(|(k, (first, second))| {
123            let first_k: Rc<str> = Rc::from(format!("{}.0", k));
124            let second_k: Rc<str> = Rc::from(format!("{}.1", k));
125
126            [(first_k, first), (second_k, second)]
127        })
128    }
129
130    fn flatten_mut(&mut self) -> impl Iterator<Item = (Rc<str>, &mut Array)> {
131        self.iter_mut().flat_map(|(k, (first, second))| {
132            let first_k: Rc<str> = Rc::from(format!("{}.0", k));
133            let second_k: Rc<str> = Rc::from(format!("{}.1", k));
134
135            [(first_k, first), (second_k, second)]
136        })
137    }
138
139    fn unflatten<I, K>(input: I) -> Result<Self, Self::UnflattenError>
140    where
141        Self: Sized,
142        I: IntoIterator<Item = (K, Array)>,
143        K: Ord + AsRef<str> + Into<Rc<str>>,
144    {
145        let mut state = State::new();
146        let iter = input
147            .into_iter()
148            .sorted_by(|a, b| a.0.as_ref().cmp(b.0.as_ref()))
149            .chunks(2);
150
151        for mut chunk in &iter {
152            let first = chunk.next().ok_or(UnflattenError::ExpectingNextPair)?;
153            let second = chunk.next().ok_or(UnflattenError::ExpectingNextPair)?;
154
155            // Check if the keys match up to the last dot and the suffix is 0 and 1 (should be already sorted)
156            let first_key = first.0.as_ref();
157            let second_key = second.0.as_ref();
158            if !first_key.ends_with(".0") || !second_key.ends_with(".1") {
159                return Err(UnflattenError::InvalidKey);
160            }
161            if first_key[..first_key.len() - 2] != second_key[..second_key.len() - 2] {
162                return Err(UnflattenError::InvalidKey);
163            }
164
165            let key = &first_key[..first_key.len() - 2];
166            let key: Rc<str> = Rc::from(key);
167            state.insert(key, (first.1, second.1));
168        }
169        Ok(state)
170    }
171}
172
173/// Trait for optimizers.
174pub trait Optimizer: Updatable {
175    /// State of the optimizer.
176    type State: OptimizerState;
177
178    /// Get the state of the optimizer.
179    fn state(&self) -> &Self::State;
180
181    /// Get the mutable state of the optimizer.
182    fn state_mut(&mut self) -> &mut Self::State;
183
184    /// Update a single parameter with the given gradient.
185    ///
186    /// The implementation should look up the state for the parameter using the key and update the
187    /// state and the parameter accordingly. The key is provided instead of the state because it
188    /// would otherwise create a mutable borrow conflict with the rest of the optimizer fields.
189    fn update_single(
190        &mut self,
191        key: &Rc<str>,
192        gradient: &Array,
193        parameter: &mut Array,
194    ) -> crate::error::Result<()>;
195
196    /// Apply the gradients to the parameters of the model and update the model with the new
197    /// parameters.
198    fn update<M>(
199        &mut self,
200        model: &mut M,
201        gradients: impl Borrow<FlattenedModuleParam>,
202    ) -> crate::error::Result<()>
203    where
204        M: ModuleParameters,
205    {
206        let mut parameters = model.parameters_mut().flatten();
207
208        for (key, gradient) in gradients.borrow().iter() {
209            if let Some(parameter) = parameters.get_mut(key) {
210                self.update_single(key, gradient, parameter)?;
211            }
212        }
213
214        Ok(())
215    }
216}
217
218/// Type alias for clipped gradients that is returned by `clip_grad_norm`.
219pub type MaybeClippedGrads<'a> = HashMap<Rc<str>, Cow<'a, Array>>;
220
221/// Clips the global norm of the gradients
222///
223/// This function ensures that the global norm of the gradients does not exceed
224/// `max_norm`. It scales down the gradients proportionally if their norm is
225/// greater than `max_norm`.
226pub fn clip_grad_norm(
227    gradients: &FlattenedModuleParam,
228    max_norm: f32,
229) -> crate::error::Result<(MaybeClippedGrads, f32)> {
230    let total_norm: f32 = gradients
231        .values()
232        .try_fold(array!(0.0), |acc, grad| {
233            acc.add(&grad.square()?.sum(None, None)?)
234        })?
235        .sqrt()?
236        .item();
237    let normalizer = array!(max_norm / (total_norm + 1e-6));
238
239    let clipped_gradients: HashMap<_, _> = gradients
240        .iter()
241        .map(|(key, grad)| {
242            let clipped_grad = if total_norm < max_norm {
243                Cow::Borrowed(grad)
244            } else {
245                Cow::Owned(grad * &normalizer)
246            };
247            (key.clone(), clipped_grad)
248        })
249        .collect();
250    Ok((clipped_gradients, total_norm))
251}
252
253#[cfg(test)]
254mod tests {
255    use std::collections::HashMap;
256
257    use crate::{array, module::FlattenedModuleParam, Array};
258
259    use super::clip_grad_norm;
260
261    #[test]
262    fn test_clip_grad_norm() {
263        // Test with small gradients that do not require clipping
264        let mut small_grads: FlattenedModuleParam = HashMap::new();
265        small_grads.insert("first.a".into(), array!([0.1, 0.2]));
266        small_grads.insert("first.b".into(), array!(0.1));
267        small_grads.insert("second".into(), array!(0.3));
268
269        let max_norm = 10.0;
270
271        let (clipped_grads, _) = clip_grad_norm(&small_grads, max_norm).unwrap();
272        for (key, value) in small_grads.iter() {
273            assert_eq!(&*clipped_grads[key], value);
274        }
275
276        // Test with large gradients that require clipping
277        let mut large_grads: FlattenedModuleParam = HashMap::new();
278        large_grads.insert("first.a".into(), array!([10.0, 20.0]));
279        large_grads.insert("first.b".into(), array!(10.0));
280        large_grads.insert("second".into(), array!(30.0));
281
282        let max_norm = 1.0;
283
284        let (clipped_grads, total_norm) = clip_grad_norm(&large_grads, max_norm).unwrap();
285        let clipped_values: Vec<_> = clipped_grads.values().map(|v| v.as_ref()).collect();
286        let norm_of_clipped = clipped_values
287            .into_iter()
288            .map(|g| g.square().unwrap().sum(None, None).unwrap())
289            .sum::<Array>()
290            .sqrt()
291            .unwrap();
292
293        float_eq::assert_float_eq!(norm_of_clipped.item::<f32>(), max_norm, abs <= 1e-6);
294
295        // Ensures that the scaling was done correctly
296        let scale = max_norm / total_norm;
297        let expected_grads: FlattenedModuleParam = large_grads
298            .iter()
299            .map(|(key, value)| (key.clone(), value * scale))
300            .collect();
301        for (key, value) in expected_grads.iter() {
302            assert_eq!(&*clipped_grads[key], value);
303        }
304    }
305}