mlx_rs/optimizers/
adamw.rs

1use std::convert::Infallible;
2
3use mlx_internal_macros::{generate_builder, Buildable};
4
5use crate::{
6    array,
7    utils::{get_mut_or_insert_with, Updatable},
8    Array,
9};
10
11use super::*;
12
13generate_builder! {
14    /// The AdamW optimizer [1].
15    ///
16    /// Following the above convention, in contrast with [1], we do not use bias
17    /// correction in the first and second moments for AdamW. We update the weights
18    /// with a `weightDecay` lambda value:
19    ///
20    /// [1]: Loshchilov, I. and Hutter, F., 2019. Decoupled weight decay regularization. ICLR 2019.
21    #[derive(Debug, Clone, Buildable)]
22    #[buildable(root = crate)]
23    #[builder(
24        build_with = build_adamw,
25        root = crate
26    )]
27    pub struct AdamW {
28        /// The learning rate.
29        #[builder(ty_override = f32)]
30        pub lr: Array,
31
32        /// The coefficients used for computing running averages of the gradient and its square.
33        ///
34        /// Default to [`AdamW::DEFAULT_BETAS`].
35        #[builder(optional, ty_override = Betas, default = AdamW::DEFAULT_BETAS)]
36        pub betas: (Array, Array),
37
38        /// The epsilon added to the denominator to improve numerical stability.
39        ///
40        /// Default to [`AdamW::DEFAULT_EPS`].
41        #[builder(optional, ty_override = f32, default = AdamW::DEFAULT_EPS)]
42        pub eps: Array,
43
44        /// The weight decay
45        ///
46        /// Default to [`AdamW::DEFAULT_WEIGHT_DECAY`].
47        #[builder(optional, ty_override = f32, default = AdamW::DEFAULT_WEIGHT_DECAY)]
48        pub weight_decay: Array,
49
50        /// Inner state.
51        #[builder(ignore)]
52        pub state: State<(Array, Array)>,
53    }
54}
55
56/// Builds a new [`AdamW`] optimizer.
57fn build_adamw(builder: AdamWBuilder) -> Result<AdamW, Infallible> {
58    let lr = builder.lr;
59    let betas = builder.betas;
60    let eps = builder.eps;
61    let weight_decay = builder.weight_decay;
62
63    Ok(AdamW {
64        lr: array!(lr),
65        betas: (array!(betas.0), array!(betas.1)),
66        eps: array!(eps),
67        weight_decay: array!(weight_decay),
68        state: State::new(),
69    })
70}
71
72impl AdamW {
73    /// Default value for `betas`.
74    pub const DEFAULT_BETAS: (f32, f32) = super::Adam::DEFAULT_BETAS;
75
76    /// Default value for `eps`.
77    pub const DEFAULT_EPS: f32 = super::Adam::DEFAULT_EPS;
78
79    /// Default value for `weight_decay`.
80    pub const DEFAULT_WEIGHT_DECAY: f32 = 0.01;
81}
82
83impl Optimizer for AdamW {
84    type State = State<(Array, Array)>;
85
86    fn state(&self) -> &Self::State {
87        &self.state
88    }
89
90    fn state_mut(&mut self) -> &mut Self::State {
91        &mut self.state
92    }
93
94    fn update_single(
95        &mut self,
96        key: &std::rc::Rc<str>,
97        gradient: &Array,
98        parameter: &mut Array,
99    ) -> Result<(), crate::error::Exception> {
100        let betas = &self.betas;
101        let state = get_mut_or_insert_with(&mut self.state, key, || (array!(0.0), array!(0.0)));
102
103        // SAFETY: These are all single-element arrays and won't panic.
104        let one_minus_lr_wd = array!(1.0) - (&self.lr * &self.weight_decay);
105        let decayed_parameter = &*parameter * &one_minus_lr_wd;
106
107        let (new_parameter, new_states) = super::adam_apply_single(
108            &self.lr,
109            betas,
110            &self.eps,
111            gradient,
112            &decayed_parameter,
113            state,
114        )?;
115
116        *state = new_states;
117        *parameter = new_parameter;
118
119        Ok(())
120    }
121}
122
123impl Updatable for AdamW {
124    fn updatable_states(&self) -> impl IntoIterator<Item = &Array> {
125        use itertools::Itertools;
126
127        self.state
128            .iter()
129            .sorted_by(|a, b| a.0.cmp(b.0))
130            .flat_map(|(_, (v, u))| vec![v, u])
131    }
132
133    fn updatable_states_mut(&mut self) -> impl IntoIterator<Item = &mut Array> {
134        use itertools::Itertools;
135
136        self.state
137            .iter_mut()
138            .sorted_by(|a, b| a.0.cmp(b.0))
139            .flat_map(|(_, (v, u))| vec![v, u])
140    }
141}
142
143impl_updatable_for_mut_optimizer!(AdamW);