mlx_rs/optimizers/
adafactor.rs

1use std::{borrow::Cow, collections::HashMap, rc::Rc};
2
3use mlx_internal_macros::{generate_builder, Buildable};
4
5use crate::{
6    array,
7    error::AdafactorBuildError,
8    ops::{
9        matmul, maximum, mean, mean_axes, minimum, rsqrt, sqrt, square, zeros_dtype, zeros_like,
10    },
11    utils::Updatable,
12    Array,
13};
14
15use super::*;
16
17fn rms(inputs: &Array) -> crate::error::Result<Array> {
18    sqrt(&mean(&square(inputs)?, None)?)
19}
20
21fn approvate_exp_moving_avg(
22    exp_avg_sq_row: &Array,
23    exp_avg_sq_col: &Array,
24) -> crate::error::Result<Array> {
25    let rfactor = rsqrt(&exp_avg_sq_row.divide(&mean_axes(exp_avg_sq_row, &[-1], true)?)?)?;
26    let cfactor = rsqrt(exp_avg_sq_col)?;
27    matmul(&rfactor.expand_dims(-1)?, &cfactor.expand_dims(0)?)
28}
29
30/// Type alias for the epsilon values used in Adafactor builder
31pub type AdafactorEps = (f32, f32);
32
33/// State of the Adafactor optimizer.
34#[derive(Debug, Clone)]
35pub struct AdafactorState {
36    pub(crate) step: Array,
37    pub(crate) exp_avg_sq_row: Option<Array>,
38    pub(crate) exp_avg_sq_col: Option<Array>,
39    pub(crate) exp_avg_sq: Option<Array>,
40    pub(crate) exp_avg: Option<Array>,
41}
42
43impl OptimizerState for State<AdafactorState> {
44    type UnflattenError = UnflattenError;
45
46    fn flatten(&self) -> impl Iterator<Item = (Rc<str>, &Array)> {
47        self.iter().flat_map(|(k, v)| {
48            let mut iter = vec![(Rc::from(format!("{}.step", k)), &v.step)];
49
50            if let Some(exp_avg_sq_row) = &v.exp_avg_sq_row {
51                iter.push((Rc::from(format!("{}.exp_avg_sq_row", k)), exp_avg_sq_row));
52            }
53
54            if let Some(exp_avg_sq_col) = &v.exp_avg_sq_col {
55                iter.push((Rc::from(format!("{}.exp_avg_sq_col", k)), exp_avg_sq_col));
56            }
57
58            if let Some(exp_avg_sq) = &v.exp_avg_sq {
59                iter.push((Rc::from(format!("{}.exp_avg_sq", k)), exp_avg_sq));
60            }
61
62            if let Some(exp_avg) = &v.exp_avg {
63                iter.push((Rc::from(format!("{}.exp_avg", k)), exp_avg));
64            }
65
66            iter
67        })
68    }
69
70    fn flatten_mut(&mut self) -> impl Iterator<Item = (Rc<str>, &mut Array)> {
71        self.iter_mut().flat_map(|(k, v)| {
72            let mut iter = vec![(Rc::from(format!("{}.step", k)), &mut v.step)];
73
74            if let Some(exp_avg_sq_row) = &mut v.exp_avg_sq_row {
75                iter.push((Rc::from(format!("{}.exp_avg_sq_row", k)), exp_avg_sq_row));
76            }
77
78            if let Some(exp_avg_sq_col) = &mut v.exp_avg_sq_col {
79                iter.push((Rc::from(format!("{}.exp_avg_sq_col", k)), exp_avg_sq_col));
80            }
81
82            if let Some(exp_avg_sq) = &mut v.exp_avg_sq {
83                iter.push((Rc::from(format!("{}.exp_avg_sq", k)), exp_avg_sq));
84            }
85
86            if let Some(exp_avg) = &mut v.exp_avg {
87                iter.push((Rc::from(format!("{}.exp_avg", k)), exp_avg));
88            }
89
90            iter
91        })
92    }
93
94    fn unflatten<I, K>(input: I) -> Result<Self, Self::UnflattenError>
95    where
96        Self: Sized,
97        I: IntoIterator<Item = (K, Array)>,
98        K: Ord + AsRef<str> + Into<Rc<str>>,
99    {
100        let mut state = State::new();
101        let iter = input
102            .into_iter()
103            .sorted_by(|a, b| a.0.as_ref().cmp(b.0.as_ref()));
104
105        for (k, v) in iter {
106            let key = k.into();
107            let mut parts = key.rsplit('.');
108            let suffix = parts.next().ok_or(UnflattenError::InvalidKey)?;
109            let prefix = parts.next().ok_or(UnflattenError::InvalidKey)?;
110
111            let prefix = Rc::from(prefix);
112            let state = state.entry(prefix).or_insert_with(|| AdafactorState {
113                step: array!(AdafactorState::DEFAULT_STEP),
114                exp_avg_sq_row: None,
115                exp_avg_sq_col: None,
116                exp_avg_sq: None,
117                exp_avg: None,
118            });
119
120            match suffix {
121                "step" => state.step = v,
122                "exp_avg_sq_row" => state.exp_avg_sq_row = Some(v),
123                "exp_avg_sq_col" => state.exp_avg_sq_col = Some(v),
124                "exp_avg_sq" => state.exp_avg_sq = Some(v),
125                "exp_avg" => state.exp_avg = Some(v),
126                _ => return Err(UnflattenError::InvalidKey),
127            }
128        }
129
130        Ok(state)
131    }
132}
133
134impl AdafactorState {
135    /// Default value for `step`
136    pub const DEFAULT_STEP: i32 = 0;
137
138    fn new(parameter: &Array, beta1_is_some: bool) -> crate::error::Result<Self> {
139        let step = array!(Self::DEFAULT_STEP);
140        let mut exp_avg_sq_row = None;
141        let mut exp_avg_sq_col = None;
142        let mut exp_avg_sq = None;
143        let mut exp_avg = None;
144
145        if parameter.ndim() >= 2 {
146            let shape = parameter.shape();
147            let dtype = parameter.dtype();
148
149            let row_shape = &shape[..shape.len() - 1];
150            exp_avg_sq_row = Some(zeros_dtype(row_shape, dtype)?);
151
152            let mut col_shape = shape[..shape.len() - 2].to_vec();
153            col_shape.push(*shape.last().unwrap());
154            exp_avg_sq_col = Some(zeros_dtype(&col_shape, dtype)?);
155        } else {
156            exp_avg_sq = Some(zeros_like(parameter)?);
157        };
158
159        if beta1_is_some {
160            exp_avg = Some(zeros_like(parameter)?);
161        }
162
163        Ok(Self {
164            step,
165            exp_avg_sq_row,
166            exp_avg_sq_col,
167            exp_avg_sq,
168            exp_avg,
169        })
170    }
171}
172
173/// `Option<Array>`. Type alias for the learning rate used in Adafactor builder due to limitation in
174/// the `generate_builder` macro
175pub type AdafactorBuilderLr = Option<f32>;
176
177/// Type alias for the learning rate used in Adafactor
178pub type AdafactorLr = Option<Array>;
179
180/// `Option<f32>` Type alias for the beta1 used in Adafactor builder due to limitation in the
181/// `generate_builder` macro
182pub type AdafactorBuilderBeta1 = Option<f32>;
183
184/// Type alias for the beta1 used in Adafactor
185pub type AdafactorBeta1 = Option<Array>;
186
187generate_builder! {
188    /// The Adafactor optimizer.
189    ///
190    /// Our Adafactor implementation follows the original paper: `Adafactor:
191    /// Adaptive Learning Rates with Sublinear Memory Cost
192    /// <https://arxiv.org/abs/1804.04235>
193    #[derive(Debug, Clone, Buildable)]
194    #[buildable(root = crate)]
195    #[builder(
196        build_with = build_adafactor,
197        err = AdafactorBuildError,
198        root = crate
199    )]
200    pub struct Adafactor {
201        /// The learning rate.
202        #[builder(optional, default = Adafactor::DEFAULT_LR)]
203        pub lr: Option<f32>,
204
205        /// The first term is added to the square of the gradients to improve numerical stability.
206        /// Default to [`Adafactor::DEFAULT_EPS`].
207        #[builder(optional, ty_override = AdafactorEps, default = Adafactor::DEFAULT_EPS)]
208        pub eps: (Array, Array),
209
210        /// Clips the unscaled update. Default to [`Adafactor::DEFAULT_CLIP_THRESHOLD`].
211        #[builder(optional, ty_override = f32, default = Adafactor::DEFAULT_CLIP_THRESHOLD)]
212        pub clip_threshold: Array,
213
214        /// Coefficient for the running average of the squared gradient. Default to
215        /// [`Adafactor::DEFAULT_DECAY_RATE`].
216        #[builder(optional, ty_override = f32, default = Adafactor::DEFAULT_DECAY_RATE)]
217        pub decay_rate: Array,
218
219        /// If set then the first moment will be used.
220        #[builder(optional, ty_override = AdafactorBuilderBeta1, default = Adafactor::DEFAULT_BETA1)]
221        pub beta1: AdafactorBeta1,
222
223        /// The weight decay. Default to [`Adafactor::DEFAULT_WEIGHT_DECAY`].
224        #[builder(optional, default = Adafactor::DEFAULT_WEIGHT_DECAY)]
225        pub weight_decay: f32,
226
227        /// If `true` the `learningRate` will be scaled by `max(eps.0, RMS(parameter))`. Default to
228        /// [`Adafactor::DEFAULT_SCALE_PARAMETER`].
229        #[builder(optional, default = Adafactor::DEFAULT_SCALE_PARAMETER)]
230        pub scale_parameter: bool,
231
232        /// If `true` the `learningRate` will be ignored and the relative step size will be
233        /// computed. Default to [`Adafactor::DEFAULT_RELATIVE_STEP`].
234        #[builder(optional, ty_override = bool, default = Adafactor::DEFAULT_RELATIVE_STEP)]
235        pub relative_step: bool,
236
237        /// If `true` the relative step size will be calculated by the current step. Default to
238        /// [`Adafactor::DEFAULT_WARMUP_INIT`].
239        #[builder(optional, default = Adafactor::DEFAULT_WARMUP_INIT)]
240        pub warmup_init: bool,
241
242        /// Inner state.
243        #[builder(ignore)]
244        pub state: State<AdafactorState>,
245    }
246}
247
248/// Builds a new [`Adafactor`] optimizer.
249fn build_adafactor(builder: AdafactorBuilder) -> Result<Adafactor, AdafactorBuildError> {
250    let eps = builder.eps;
251    let clip_threshold = builder.clip_threshold;
252    let decay_rate = builder.decay_rate;
253    let weight_decay = builder.weight_decay;
254    let scale_parameter = builder.scale_parameter;
255    let relative_step = builder.relative_step;
256    let warmup_init = builder.warmup_init;
257
258    if builder.lr.is_none() && !relative_step {
259        return Err(AdafactorBuildError::LrIsNoneAndRelativeStepIsFalse);
260    }
261
262    Ok(Adafactor {
263        lr: builder.lr,
264        eps: (array!(eps.0), array!(eps.1)),
265        clip_threshold: array!(clip_threshold),
266        decay_rate: array!(decay_rate),
267        beta1: builder.beta1.map(Array::from),
268        weight_decay,
269        scale_parameter,
270        relative_step,
271        warmup_init,
272        state: State::new(),
273    })
274}
275
276impl Adafactor {
277    /// Default value for `lr`
278    pub const DEFAULT_LR: Option<f32> = None;
279
280    /// Default values for `eps`
281    pub const DEFAULT_EPS: (f32, f32) = (1e-30, 1e-3);
282
283    /// Default value for `clip_threshold`
284    pub const DEFAULT_CLIP_THRESHOLD: f32 = 1.0;
285
286    /// Default value for `decay_rate`
287    pub const DEFAULT_DECAY_RATE: f32 = -0.8;
288
289    /// Default value for `weight_decay`
290    pub const DEFAULT_WEIGHT_DECAY: f32 = 0.0;
291
292    /// Default value for `scale_parameter`
293    pub const DEFAULT_SCALE_PARAMETER: bool = true;
294
295    /// Default value for `relative_step`
296    pub const DEFAULT_RELATIVE_STEP: bool = true;
297
298    /// Default value for `warmup_init`
299    pub const DEFAULT_WARMUP_INIT: bool = false;
300
301    /// Default value for `beta1`
302    pub const DEFAULT_BETA1: Option<f32> = None;
303}
304
305fn get_mut_or_insert_with<'a, T, E>(
306    map: &'a mut HashMap<Rc<str>, T>,
307    key: &Rc<str>,
308    f: impl FnOnce() -> Result<T, E>,
309) -> Result<&'a mut T, E> {
310    if !map.contains_key(key) {
311        map.insert(key.clone(), f()?);
312    }
313
314    Ok(map.get_mut(key).unwrap())
315}
316
317fn compute_lr(
318    relative_step: bool,
319    warmup_init: bool,
320    lr: Option<f32>,
321    scale_parameter: bool,
322    eps: &(Array, Array),
323    step: &Array,
324    parameter_rms: &Array,
325) -> crate::error::Result<Array> {
326    let relative_step_size = if relative_step {
327        let min_step = if warmup_init {
328            // SAFETY: `step` is a single-element array and won't panic.
329            array!(1e-6) * step
330        } else {
331            array!(1e-2)
332        };
333        // SAFETY: `step` is a single-element array and won't panic.
334        minimum(min_step, array!(1.0) / sqrt(step)?)?
335    } else {
336        // SAFETY: This is already checked in the `build` stage.
337        array!(lr.expect("The learning rate should be set if the relative step is not enabled"))
338    };
339
340    let mut parameter_scale = array!(1.0);
341    if scale_parameter {
342        parameter_scale = maximum(&eps.1, parameter_rms)?;
343    }
344
345    parameter_scale.multiply(relative_step_size)
346}
347
348impl Optimizer for Adafactor {
349    type State = State<AdafactorState>;
350
351    fn state(&self) -> &Self::State {
352        &self.state
353    }
354
355    fn state_mut(&mut self) -> &mut Self::State {
356        &mut self.state
357    }
358
359    fn update_single(
360        &mut self,
361        key: &std::rc::Rc<str>,
362        gradient: &Array,
363        parameter: &mut Array,
364    ) -> crate::error::Result<()> {
365        let beta1_is_some = self.beta1.is_some();
366        let state = get_mut_or_insert_with(&mut self.state, key, || {
367            AdafactorState::new(parameter, beta1_is_some)
368        })?;
369
370        state.step = state.step.add(array!(1))?;
371
372        let gradient_shape = gradient.shape();
373        let factored = gradient_shape.len() >= 2;
374        let step = &state.step;
375
376        let parameter_rms = rms(parameter)?;
377        let lr = compute_lr(
378            self.relative_step,
379            self.warmup_init,
380            self.lr,
381            self.scale_parameter,
382            &self.eps,
383            step,
384            &parameter_rms,
385        )?;
386        let beta2 = array!(1.0).subtract(&step.power(&self.decay_rate)?)?;
387
388        let mut update: Cow<Array> = Cow::Owned(gradient.square()?.add(&self.eps.0)?);
389
390        let one_minus_beta2 = array!(1.0).subtract(&beta2)?;
391        if factored {
392            // SAFETY: These fields are created in the `new` when ndim >= 2 and won't panic.
393            let exp_avg_sq_row = state.exp_avg_sq_row.as_mut().unwrap();
394            let exp_avg_sq_col = state.exp_avg_sq_col.as_mut().unwrap();
395
396            *exp_avg_sq_row = beta2
397                .multiply(&*exp_avg_sq_row)?
398                .add(&one_minus_beta2.multiply(&update.mean_axes(&[-1], None)?)?)?;
399            *exp_avg_sq_col = beta2
400                .multiply(&*exp_avg_sq_col)?
401                .add(&one_minus_beta2.multiply(&update.mean_axes(&[-2], None)?)?)?;
402
403            update = Cow::Owned(approvate_exp_moving_avg(
404                &*exp_avg_sq_row,
405                &*exp_avg_sq_col,
406            )?);
407            update = Cow::Owned(update.multiply(gradient)?);
408        } else {
409            // SAFETY: This field is created in the `new` when ndim < 2 and won't panic.
410            let exp_avg_sq = state.exp_avg_sq.as_mut().unwrap();
411
412            *exp_avg_sq = beta2
413                .multiply(&*exp_avg_sq)?
414                .add(&one_minus_beta2.multiply(&update)?)?;
415            update = Cow::Owned(rsqrt(&*exp_avg_sq)?.multiply(gradient)?);
416        }
417
418        let update_rms = rms(&update)?;
419        let max = maximum(array!(1.0), update_rms.divide(&self.clip_threshold)?)?;
420        update = Cow::Owned(update.divide(max)?);
421        update = Cow::Owned(lr.multiply(update)?);
422
423        if let Some(beta1) = &self.beta1 {
424            // SAFETY: This field is created in the `new` when beta1 is set and won't panic.
425            let exp_avg = state.exp_avg.as_mut().unwrap();
426            let one_minus_beta1 = array!(1.0).subtract(beta1)?;
427            *exp_avg = beta1
428                .multiply(&*exp_avg)?
429                .add(&one_minus_beta1.multiply(&update)?)?;
430            update = Cow::Borrowed(&*exp_avg);
431        }
432
433        if self.weight_decay != 0.0 {
434            let rhs = parameter.multiply(array!(-self.weight_decay).multiply(lr)?)?;
435            *parameter = parameter.add(rhs)?;
436        }
437
438        *parameter = parameter.subtract(&update)?;
439
440        Ok(())
441    }
442}
443
444impl Updatable for Adafactor {
445    fn updatable_states_len(&self) -> usize {
446        self.updatable_states().into_iter().count()
447    }
448
449    fn updatable_states(&self) -> impl IntoIterator<Item = &Array> {
450        use itertools::Itertools;
451
452        self.state
453            .iter()
454            .sorted_by(|a, b| a.0.cmp(b.0))
455            .flat_map(|(_, v)| {
456                // [expAvgSqRow, expAvgSqCol, expAvgSq, expAvg]
457                [
458                    &v.exp_avg_sq_row,
459                    &v.exp_avg_sq_col,
460                    &v.exp_avg_sq,
461                    &v.exp_avg,
462                ]
463                .into_iter()
464                .filter_map(|v| v.as_ref())
465                .collect::<Vec<_>>()
466            })
467    }
468
469    fn updatable_states_mut(&mut self) -> impl IntoIterator<Item = &mut Array> {
470        use itertools::Itertools;
471
472        self.state
473            .iter_mut()
474            .sorted_by(|a, b| a.0.cmp(b.0))
475            .flat_map(|(_, v)| {
476                // [expAvgSqRow, expAvgSqCol, expAvgSq, expAvg]
477                [
478                    &mut v.exp_avg_sq_row,
479                    &mut v.exp_avg_sq_col,
480                    &mut v.exp_avg_sq,
481                    &mut v.exp_avg,
482                ]
483                .into_iter()
484                .filter_map(|v| v.as_mut())
485                .collect::<Vec<_>>()
486            })
487    }
488}
489
490impl_updatable_for_mut_optimizer!(Adafactor);