mlx_rs/optimizers/
adafactor.rs

1use std::{borrow::Cow, collections::HashMap, rc::Rc};
2
3use mlx_internal_macros::{generate_builder, Buildable};
4
5use crate::{
6    array,
7    error::AdafactorBuildError,
8    ops::{matmul, maximum, mean, minimum, rsqrt, sqrt, square, zeros_dtype, zeros_like},
9    utils::Updatable,
10    Array,
11};
12
13use super::*;
14
15fn rms(inputs: &Array) -> crate::error::Result<Array> {
16    sqrt(&mean(&square(inputs)?, None, None)?)
17}
18
19fn approvate_exp_moving_avg(
20    exp_avg_sq_row: &Array,
21    exp_avg_sq_col: &Array,
22) -> crate::error::Result<Array> {
23    let rfactor = rsqrt(&exp_avg_sq_row.divide(&mean(exp_avg_sq_row, &[-1], true)?)?)?;
24    let cfactor = rsqrt(exp_avg_sq_col)?;
25    matmul(&rfactor.expand_dims(&[-1])?, &cfactor.expand_dims(&[0])?)
26}
27
28/// Type alias for the epsilon values used in Adafactor builder
29pub type AdafactorEps = (f32, f32);
30
31/// State of the Adafactor optimizer.
32#[derive(Debug, Clone)]
33pub struct AdafactorState {
34    pub(crate) step: Array,
35    pub(crate) exp_avg_sq_row: Option<Array>,
36    pub(crate) exp_avg_sq_col: Option<Array>,
37    pub(crate) exp_avg_sq: Option<Array>,
38    pub(crate) exp_avg: Option<Array>,
39}
40
41impl OptimizerState for State<AdafactorState> {
42    type UnflattenError = UnflattenError;
43
44    fn flatten(&self) -> impl Iterator<Item = (Rc<str>, &Array)> {
45        self.iter().flat_map(|(k, v)| {
46            let mut iter = vec![(Rc::from(format!("{}.step", k)), &v.step)];
47
48            if let Some(exp_avg_sq_row) = &v.exp_avg_sq_row {
49                iter.push((Rc::from(format!("{}.exp_avg_sq_row", k)), exp_avg_sq_row));
50            }
51
52            if let Some(exp_avg_sq_col) = &v.exp_avg_sq_col {
53                iter.push((Rc::from(format!("{}.exp_avg_sq_col", k)), exp_avg_sq_col));
54            }
55
56            if let Some(exp_avg_sq) = &v.exp_avg_sq {
57                iter.push((Rc::from(format!("{}.exp_avg_sq", k)), exp_avg_sq));
58            }
59
60            if let Some(exp_avg) = &v.exp_avg {
61                iter.push((Rc::from(format!("{}.exp_avg", k)), exp_avg));
62            }
63
64            iter
65        })
66    }
67
68    fn flatten_mut(&mut self) -> impl Iterator<Item = (Rc<str>, &mut Array)> {
69        self.iter_mut().flat_map(|(k, v)| {
70            let mut iter = vec![(Rc::from(format!("{}.step", k)), &mut v.step)];
71
72            if let Some(exp_avg_sq_row) = &mut v.exp_avg_sq_row {
73                iter.push((Rc::from(format!("{}.exp_avg_sq_row", k)), exp_avg_sq_row));
74            }
75
76            if let Some(exp_avg_sq_col) = &mut v.exp_avg_sq_col {
77                iter.push((Rc::from(format!("{}.exp_avg_sq_col", k)), exp_avg_sq_col));
78            }
79
80            if let Some(exp_avg_sq) = &mut v.exp_avg_sq {
81                iter.push((Rc::from(format!("{}.exp_avg_sq", k)), exp_avg_sq));
82            }
83
84            if let Some(exp_avg) = &mut v.exp_avg {
85                iter.push((Rc::from(format!("{}.exp_avg", k)), exp_avg));
86            }
87
88            iter
89        })
90    }
91
92    fn unflatten<I, K>(input: I) -> Result<Self, Self::UnflattenError>
93    where
94        Self: Sized,
95        I: IntoIterator<Item = (K, Array)>,
96        K: Ord + AsRef<str> + Into<Rc<str>>,
97    {
98        let mut state = State::new();
99        let iter = input
100            .into_iter()
101            .sorted_by(|a, b| a.0.as_ref().cmp(b.0.as_ref()));
102
103        for (k, v) in iter {
104            let key = k.into();
105            let mut parts = key.rsplit('.');
106            let suffix = parts.next().ok_or(UnflattenError::InvalidKey)?;
107            let prefix = parts.next().ok_or(UnflattenError::InvalidKey)?;
108
109            let prefix = Rc::from(prefix);
110            let state = state.entry(prefix).or_insert_with(|| AdafactorState {
111                step: array!(AdafactorState::DEFAULT_STEP),
112                exp_avg_sq_row: None,
113                exp_avg_sq_col: None,
114                exp_avg_sq: None,
115                exp_avg: None,
116            });
117
118            match suffix {
119                "step" => state.step = v,
120                "exp_avg_sq_row" => state.exp_avg_sq_row = Some(v),
121                "exp_avg_sq_col" => state.exp_avg_sq_col = Some(v),
122                "exp_avg_sq" => state.exp_avg_sq = Some(v),
123                "exp_avg" => state.exp_avg = Some(v),
124                _ => return Err(UnflattenError::InvalidKey),
125            }
126        }
127
128        Ok(state)
129    }
130}
131
132impl AdafactorState {
133    /// Default value for `step`
134    pub const DEFAULT_STEP: i32 = 0;
135
136    fn new(parameter: &Array, beta1_is_some: bool) -> crate::error::Result<Self> {
137        let step = array!(Self::DEFAULT_STEP);
138        let mut exp_avg_sq_row = None;
139        let mut exp_avg_sq_col = None;
140        let mut exp_avg_sq = None;
141        let mut exp_avg = None;
142
143        if parameter.ndim() >= 2 {
144            let shape = parameter.shape();
145            let dtype = parameter.dtype();
146
147            let row_shape = &shape[..shape.len() - 1];
148            exp_avg_sq_row = Some(zeros_dtype(row_shape, dtype)?);
149
150            let mut col_shape = shape[..shape.len() - 2].to_vec();
151            col_shape.push(*shape.last().unwrap());
152            exp_avg_sq_col = Some(zeros_dtype(&col_shape, dtype)?);
153        } else {
154            exp_avg_sq = Some(zeros_like(parameter)?);
155        };
156
157        if beta1_is_some {
158            exp_avg = Some(zeros_like(parameter)?);
159        }
160
161        Ok(Self {
162            step,
163            exp_avg_sq_row,
164            exp_avg_sq_col,
165            exp_avg_sq,
166            exp_avg,
167        })
168    }
169}
170
171/// `Option<Array>`. Type alias for the learning rate used in Adafactor builder due to limitation in
172/// the `generate_builder` macro
173pub type AdafactorBuilderLr = Option<f32>;
174
175/// Type alias for the learning rate used in Adafactor
176pub type AdafactorLr = Option<Array>;
177
178/// `Option<f32>` Type alias for the beta1 used in Adafactor builder due to limitation in the
179/// `generate_builder` macro
180pub type AdafactorBuilderBeta1 = Option<f32>;
181
182/// Type alias for the beta1 used in Adafactor
183pub type AdafactorBeta1 = Option<Array>;
184
185generate_builder! {
186    /// The Adafactor optimizer.
187    ///
188    /// Our Adafactor implementation follows the original paper: `Adafactor:
189    /// Adaptive Learning Rates with Sublinear Memory Cost
190    /// <https://arxiv.org/abs/1804.04235>
191    #[derive(Debug, Clone, Buildable)]
192    #[buildable(root = crate)]
193    #[builder(
194        build_with = build_adafactor,
195        err = AdafactorBuildError,
196        root = crate
197    )]
198    pub struct Adafactor {
199        /// The learning rate.
200        #[builder(optional, default = Adafactor::DEFAULT_LR)]
201        pub lr: Option<f32>,
202
203        /// The first term is added to the square of the gradients to improve numerical stability.
204        /// Default to [`Adafactor::DEFAULT_EPS`].
205        #[builder(optional, ty_override = AdafactorEps, default = Adafactor::DEFAULT_EPS)]
206        pub eps: (Array, Array),
207
208        /// Clips the unscaled update. Default to [`Adafactor::DEFAULT_CLIP_THRESHOLD`].
209        #[builder(optional, ty_override = f32, default = Adafactor::DEFAULT_CLIP_THRESHOLD)]
210        pub clip_threshold: Array,
211
212        /// Coefficient for the running average of the squared gradient. Default to
213        /// [`Adafactor::DEFAULT_DECAY_RATE`].
214        #[builder(optional, ty_override = f32, default = Adafactor::DEFAULT_DECAY_RATE)]
215        pub decay_rate: Array,
216
217        /// If set then the first moment will be used.
218        #[builder(optional, ty_override = AdafactorBuilderBeta1, default = Adafactor::DEFAULT_BETA1)]
219        pub beta1: AdafactorBeta1,
220
221        /// The weight decay. Default to [`Adafactor::DEFAULT_WEIGHT_DECAY`].
222        #[builder(optional, default = Adafactor::DEFAULT_WEIGHT_DECAY)]
223        pub weight_decay: f32,
224
225        /// If `true` the `learningRate` will be scaled by `max(eps.0, RMS(parameter))`. Default to
226        /// [`Adafactor::DEFAULT_SCALE_PARAMETER`].
227        #[builder(optional, default = Adafactor::DEFAULT_SCALE_PARAMETER)]
228        pub scale_parameter: bool,
229
230        /// If `true` the `learningRate` will be ignored and the relative step size will be
231        /// computed. Default to [`Adafactor::DEFAULT_RELATIVE_STEP`].
232        #[builder(optional, ty_override = bool, default = Adafactor::DEFAULT_RELATIVE_STEP)]
233        pub relative_step: bool,
234
235        /// If `true` the relative step size will be calculated by the current step. Default to
236        /// [`Adafactor::DEFAULT_WARMUP_INIT`].
237        #[builder(optional, default = Adafactor::DEFAULT_WARMUP_INIT)]
238        pub warmup_init: bool,
239
240        /// Inner state.
241        #[builder(ignore)]
242        pub state: State<AdafactorState>,
243    }
244}
245
246/// Builds a new [`Adafactor`] optimizer.
247fn build_adafactor(builder: AdafactorBuilder) -> Result<Adafactor, AdafactorBuildError> {
248    let eps = builder.eps;
249    let clip_threshold = builder.clip_threshold;
250    let decay_rate = builder.decay_rate;
251    let weight_decay = builder.weight_decay;
252    let scale_parameter = builder.scale_parameter;
253    let relative_step = builder.relative_step;
254    let warmup_init = builder.warmup_init;
255
256    if builder.lr.is_none() && !relative_step {
257        return Err(AdafactorBuildError::LrIsNoneAndRelativeStepIsFalse);
258    }
259
260    Ok(Adafactor {
261        lr: builder.lr,
262        eps: (array!(eps.0), array!(eps.1)),
263        clip_threshold: array!(clip_threshold),
264        decay_rate: array!(decay_rate),
265        beta1: builder.beta1.map(Array::from),
266        weight_decay,
267        scale_parameter,
268        relative_step,
269        warmup_init,
270        state: State::new(),
271    })
272}
273
274impl Adafactor {
275    /// Default value for `lr`
276    pub const DEFAULT_LR: Option<f32> = None;
277
278    /// Default values for `eps`
279    pub const DEFAULT_EPS: (f32, f32) = (1e-30, 1e-3);
280
281    /// Default value for `clip_threshold`
282    pub const DEFAULT_CLIP_THRESHOLD: f32 = 1.0;
283
284    /// Default value for `decay_rate`
285    pub const DEFAULT_DECAY_RATE: f32 = -0.8;
286
287    /// Default value for `weight_decay`
288    pub const DEFAULT_WEIGHT_DECAY: f32 = 0.0;
289
290    /// Default value for `scale_parameter`
291    pub const DEFAULT_SCALE_PARAMETER: bool = true;
292
293    /// Default value for `relative_step`
294    pub const DEFAULT_RELATIVE_STEP: bool = true;
295
296    /// Default value for `warmup_init`
297    pub const DEFAULT_WARMUP_INIT: bool = false;
298
299    /// Default value for `beta1`
300    pub const DEFAULT_BETA1: Option<f32> = None;
301}
302
303fn get_mut_or_insert_with<'a, T, E>(
304    map: &'a mut HashMap<Rc<str>, T>,
305    key: &Rc<str>,
306    f: impl FnOnce() -> Result<T, E>,
307) -> Result<&'a mut T, E> {
308    if !map.contains_key(key) {
309        map.insert(key.clone(), f()?);
310    }
311
312    Ok(map.get_mut(key).unwrap())
313}
314
315fn compute_lr(
316    relative_step: bool,
317    warmup_init: bool,
318    lr: Option<f32>,
319    scale_parameter: bool,
320    eps: &(Array, Array),
321    step: &Array,
322    parameter_rms: &Array,
323) -> crate::error::Result<Array> {
324    let relative_step_size = if relative_step {
325        let min_step = if warmup_init {
326            // SAFETY: `step` is a single-element array and won't panic.
327            array!(1e-6) * step
328        } else {
329            array!(1e-2)
330        };
331        // SAFETY: `step` is a single-element array and won't panic.
332        minimum(min_step, array!(1.0) / sqrt(step)?)?
333    } else {
334        // SAFETY: This is already checked in the `build` stage.
335        array!(lr.expect("The learning rate should be set if the relative step is not enabled"))
336    };
337
338    let mut parameter_scale = array!(1.0);
339    if scale_parameter {
340        parameter_scale = maximum(&eps.1, parameter_rms)?;
341    }
342
343    parameter_scale.multiply(relative_step_size)
344}
345
346impl Optimizer for Adafactor {
347    type State = State<AdafactorState>;
348
349    fn state(&self) -> &Self::State {
350        &self.state
351    }
352
353    fn state_mut(&mut self) -> &mut Self::State {
354        &mut self.state
355    }
356
357    fn update_single(
358        &mut self,
359        key: &std::rc::Rc<str>,
360        gradient: &Array,
361        parameter: &mut Array,
362    ) -> crate::error::Result<()> {
363        let beta1_is_some = self.beta1.is_some();
364        let state = get_mut_or_insert_with(&mut self.state, key, || {
365            AdafactorState::new(parameter, beta1_is_some)
366        })?;
367
368        state.step = state.step.add(array!(1))?;
369
370        let gradient_shape = gradient.shape();
371        let factored = gradient_shape.len() >= 2;
372        let step = &state.step;
373
374        let parameter_rms = rms(parameter)?;
375        let lr = compute_lr(
376            self.relative_step,
377            self.warmup_init,
378            self.lr,
379            self.scale_parameter,
380            &self.eps,
381            step,
382            &parameter_rms,
383        )?;
384        let beta2 = array!(1.0).subtract(&step.power(&self.decay_rate)?)?;
385
386        let mut update: Cow<Array> = Cow::Owned(gradient.square()?.add(&self.eps.0)?);
387
388        let one_minus_beta2 = array!(1.0).subtract(&beta2)?;
389        if factored {
390            // SAFETY: These fields are created in the `new` when ndim >= 2 and won't panic.
391            let exp_avg_sq_row = state.exp_avg_sq_row.as_mut().unwrap();
392            let exp_avg_sq_col = state.exp_avg_sq_col.as_mut().unwrap();
393
394            *exp_avg_sq_row = beta2
395                .multiply(&*exp_avg_sq_row)?
396                .add(&one_minus_beta2.multiply(&update.mean(&[-1], None)?)?)?;
397            *exp_avg_sq_col = beta2
398                .multiply(&*exp_avg_sq_col)?
399                .add(&one_minus_beta2.multiply(&update.mean(&[-2], None)?)?)?;
400
401            update = Cow::Owned(approvate_exp_moving_avg(
402                &*exp_avg_sq_row,
403                &*exp_avg_sq_col,
404            )?);
405            update = Cow::Owned(update.multiply(gradient)?);
406        } else {
407            // SAFETY: This field is created in the `new` when ndim < 2 and won't panic.
408            let exp_avg_sq = state.exp_avg_sq.as_mut().unwrap();
409
410            *exp_avg_sq = beta2
411                .multiply(&*exp_avg_sq)?
412                .add(&one_minus_beta2.multiply(&update)?)?;
413            update = Cow::Owned(rsqrt(&*exp_avg_sq)?.multiply(gradient)?);
414        }
415
416        let update_rms = rms(&update)?;
417        let max = maximum(array!(1.0), update_rms.divide(&self.clip_threshold)?)?;
418        update = Cow::Owned(update.divide(max)?);
419        update = Cow::Owned(lr.multiply(update)?);
420
421        if let Some(beta1) = &self.beta1 {
422            // SAFETY: This field is created in the `new` when beta1 is set and won't panic.
423            let exp_avg = state.exp_avg.as_mut().unwrap();
424            let one_minus_beta1 = array!(1.0).subtract(beta1)?;
425            *exp_avg = beta1
426                .multiply(&*exp_avg)?
427                .add(&one_minus_beta1.multiply(&update)?)?;
428            update = Cow::Borrowed(&*exp_avg);
429        }
430
431        if self.weight_decay != 0.0 {
432            let rhs = parameter.multiply(array!(-self.weight_decay).multiply(lr)?)?;
433            *parameter = parameter.add(rhs)?;
434        }
435
436        *parameter = parameter.subtract(&update)?;
437
438        Ok(())
439    }
440}
441
442impl Updatable for Adafactor {
443    fn updatable_states(&self) -> impl IntoIterator<Item = &Array> {
444        use itertools::Itertools;
445
446        self.state
447            .iter()
448            .sorted_by(|a, b| a.0.cmp(b.0))
449            .flat_map(|(_, v)| {
450                // [expAvgSqRow, expAvgSqCol, expAvgSq, expAvg]
451                [
452                    &v.exp_avg_sq_row,
453                    &v.exp_avg_sq_col,
454                    &v.exp_avg_sq,
455                    &v.exp_avg,
456                ]
457                .into_iter()
458                .filter_map(|v| v.as_ref())
459                .collect::<Vec<_>>()
460            })
461    }
462
463    fn updatable_states_mut(&mut self) -> impl IntoIterator<Item = &mut Array> {
464        use itertools::Itertools;
465
466        self.state
467            .iter_mut()
468            .sorted_by(|a, b| a.0.cmp(b.0))
469            .flat_map(|(_, v)| {
470                // [expAvgSqRow, expAvgSqCol, expAvgSq, expAvg]
471                [
472                    &mut v.exp_avg_sq_row,
473                    &mut v.exp_avg_sq_col,
474                    &mut v.exp_avg_sq,
475                    &mut v.exp_avg,
476                ]
477                .into_iter()
478                .filter_map(|v| v.as_mut())
479                .collect::<Vec<_>>()
480            })
481    }
482}
483
484impl_updatable_for_mut_optimizer!(Adafactor);