mlx_rs/optimizers/
adam.rs

1use std::convert::Infallible;
2
3use mlx_internal_macros::{generate_builder, Buildable};
4
5use crate::{array, utils::get_mut_or_insert_with};
6
7use super::*;
8
9/// `(f32, f32O)`. Type alias for betas in the Adam/AdamW/Adamax optimizer builders due to
10/// limitation in the `generate_builder` macro
11pub type Betas = (f32, f32); // The macro right now can't handle raw tuple types
12
13generate_builder! {
14    /// The Adam optimizer.
15    ///
16    /// Please refer to the original paper for more details:
17    ///
18    /// [1]: Kingma, D.P. and Ba, J., 2015. Adam: A method for stochastic optimization. ICLR 2015.
19    #[derive(Debug, Clone, Buildable)]
20    #[buildable(root = crate)]
21    #[builder(
22        build_with = build_adam,
23        root = crate
24    )]
25    pub struct Adam {
26        /// The learning rate
27        #[builder(ty_override = f32)]
28        pub lr: Array,
29
30        /// The coefficients used for computing running averages of the gradient and its square
31        ///
32        /// Default to [`Adam::DEFAULT_BETAS`]
33        #[builder(optional, ty_override = Betas, default = Adam::DEFAULT_BETAS)]
34        pub betas: (Array, Array),
35
36        /// The epsilon added to the denominator to improve numerical stability
37        ///
38        /// Default to [`Adam::DEFAULT_EPS`]
39        #[builder(optional, ty_override = f32, default = Adam::DEFAULT_EPS)]
40        pub eps: Array,
41
42        /// Inner state
43        #[builder(ignore)]
44        pub state: State<(Array, Array)>,
45    }
46}
47
48/// Builds a new [`Adam`].
49fn build_adam(builder: AdamBuilder) -> Result<Adam, Infallible> {
50    let lr = array!(builder.lr);
51    let betas = builder.betas;
52    let eps = array!(builder.eps);
53
54    Ok(Adam {
55        lr,
56        betas: (array!(betas.0), array!(betas.1)),
57        eps,
58        state: State::new(),
59    })
60}
61
62impl Adam {
63    /// Default values for `betas`
64    pub const DEFAULT_BETAS: (f32, f32) = (0.9, 0.999);
65
66    /// Default value for `eps`
67    pub const DEFAULT_EPS: f32 = 1e-8;
68}
69
70impl Optimizer for Adam {
71    type State = State<(Array, Array)>;
72
73    fn state(&self) -> &Self::State {
74        &self.state
75    }
76
77    fn state_mut(&mut self) -> &mut Self::State {
78        &mut self.state
79    }
80
81    fn update_single(
82        &mut self,
83        key: &Rc<str>,
84        gradient: &Array,
85        parameter: &mut Array,
86    ) -> crate::error::Result<()> {
87        let betas = &self.betas;
88        let state = get_mut_or_insert_with(&mut self.state, key, || (array!(0.0), array!(0.0)));
89
90        let (new_parameter, new_state) =
91            adam_apply_single(&self.lr, betas, &self.eps, gradient, parameter, state)?;
92
93        *state = new_state;
94        *parameter = new_parameter;
95
96        Ok(())
97    }
98}
99
100// Returns (new_parameter, (new_m, new_v))
101pub(super) fn adam_apply_single(
102    lr: &Array,
103    betas: &(Array, Array),
104    eps: &Array,
105    gradient: &Array,
106    parameter: &Array,
107    state: &(Array, Array),
108) -> crate::error::Result<(Array, (Array, Array))> {
109    let (b1, b2) = betas;
110    let (m, v) = state;
111
112    let one_minus_b1 = array!(1.0).subtract(b1)?;
113    let one_minus_b2 = array!(1.0).subtract(b2)?;
114
115    let new_m = b1.multiply(m)?.add(&one_minus_b1.multiply(gradient)?)?;
116    let new_v = b2
117        .multiply(v)?
118        .add(&one_minus_b2.multiply(gradient.square()?)?)?;
119
120    let new_parameter =
121        parameter.subtract(&lr.multiply(&new_m.divide(&new_v.sqrt()?.add(eps)?)?)?)?;
122
123    Ok((new_parameter, (new_m, new_v)))
124}
125
126impl Updatable for Adam {
127    fn updatable_states(&self) -> impl IntoIterator<Item = &Array> {
128        use itertools::Itertools;
129
130        self.state
131            .iter()
132            .sorted_by(|a, b| a.0.cmp(b.0))
133            .flat_map(|(_, (v, u))| vec![v, u])
134    }
135
136    fn updatable_states_mut(&mut self) -> impl IntoIterator<Item = &mut Array> {
137        use itertools::Itertools;
138
139        self.state
140            .iter_mut()
141            .sorted_by(|a, b| a.0.cmp(b.0))
142            .flat_map(|(_, (v, u))| vec![v, u])
143    }
144}
145
146impl_updatable_for_mut_optimizer!(Adam);