mlx_rs/optimizers/
adagrad.rs

1use std::{convert::Infallible, rc::Rc};
2
3use crate::{array, ops::square, utils::Updatable, Array};
4use mlx_internal_macros::{generate_builder, Buildable};
5
6use crate::utils::get_mut_or_insert_with;
7
8use super::*;
9
10generate_builder! {
11    /// The Adagrad optimizer.
12    ///
13    /// Please refer to the original paper for more details:
14    ///
15    /// [1]: Duchi, J., Hazan, E. and Singer, Y., 2011. Adaptive subgradient methods for online
16    ///     learning and stochastic optimization. JMLR 2011.
17    #[derive(Debug, Clone, Buildable)]
18    #[buildable(root = crate)]
19    #[builder(
20        build_with = build_adagrad,
21        root = crate
22    )]
23    pub struct AdaGrad {
24        /// Learning rate
25        #[builder(ty_override = f32)]
26        pub lr: Array,
27
28        /// The epsilon added to the denominator to improve numerical stability. Default to
29        /// [`AdaGrad::DEFAULT_EPS`].
30        #[builder(optional, ty_override = f32, default = AdaGrad::DEFAULT_EPS)]
31        pub eps: Array,
32
33        /// Inner state
34        #[builder(ignore)]
35        pub state: State,
36    }
37}
38
39/// Builds a new [`AdaGrad`].
40fn build_adagrad(builder: AdaGradBuilder) -> Result<AdaGrad, Infallible> {
41    let eps = array!(builder.eps);
42
43    Ok(AdaGrad {
44        lr: array!(builder.lr),
45        eps,
46        state: State::new(),
47    })
48}
49
50impl AdaGrad {
51    /// Default value for `eps`.
52    pub const DEFAULT_EPS: f32 = 1e-8;
53}
54
55impl Optimizer for AdaGrad {
56    type State = State;
57
58    fn state(&self) -> &Self::State {
59        &self.state
60    }
61
62    fn state_mut(&mut self) -> &mut Self::State {
63        &mut self.state
64    }
65
66    fn update_single(
67        &mut self,
68        key: &Rc<str>,
69        gradient: &Array,
70        parameter: &mut Array,
71    ) -> crate::error::Result<()> {
72        let state = get_mut_or_insert_with(&mut self.state, key, || array!(0.0));
73
74        let v = state.add(square(gradient)?)?;
75
76        let num = self.lr.multiply(gradient)?;
77        let den = v.sqrt()?.add(&self.eps)?;
78        let new_param = parameter.subtract(num.divide(&den)?)?;
79
80        *state = v;
81        *parameter = new_param;
82
83        Ok(())
84    }
85}
86
87impl Updatable for AdaGrad {
88    fn updatable_states(&self) -> impl IntoIterator<Item = &Array> {
89        use itertools::Itertools;
90
91        self.state
92            .iter()
93            .sorted_by(|a, b| a.0.cmp(b.0))
94            .map(|(_, v)| v)
95    }
96
97    fn updatable_states_mut(&mut self) -> impl IntoIterator<Item = &mut Array> {
98        use itertools::Itertools;
99
100        self.state
101            .iter_mut()
102            .sorted_by(|a, b| a.0.cmp(b.0))
103            .map(|(_, v)| v)
104    }
105}
106
107impl_updatable_for_mut_optimizer!(AdaGrad);