mlx_rs/optimizers/
adagrad.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
use std::{convert::Infallible, rc::Rc};

use crate::{array, ops::square, utils::Updatable, Array};
use mlx_internal_macros::{generate_builder, Buildable};

use crate::utils::get_mut_or_insert_with;

use super::*;

generate_builder! {
    /// The Adagrad optimizer.
    ///
    /// Please refer to the original paper for more details:
    ///
    /// [1]: Duchi, J., Hazan, E. and Singer, Y., 2011. Adaptive subgradient methods for online
    ///     learning and stochastic optimization. JMLR 2011.
    #[derive(Debug, Clone, Buildable)]
    #[buildable(root = crate)]
    #[builder(
        build_with = build_adagrad,
        root = crate
    )]
    pub struct AdaGrad {
        /// Learning rate
        #[builder(ty_override = f32)]
        pub lr: Array,

        /// The epsilon added to the denominator to improve numerical stability. Default to
        /// [`AdaGrad::DEFAULT_EPS`].
        #[builder(optional, ty_override = f32, default = AdaGrad::DEFAULT_EPS)]
        pub eps: Array,

        /// Inner state
        #[builder(ignore)]
        pub state: State,
    }
}

/// Builds a new [`AdaGrad`].
fn build_adagrad(builder: AdaGradBuilder) -> Result<AdaGrad, Infallible> {
    let eps = array!(builder.eps);

    Ok(AdaGrad {
        lr: array!(builder.lr),
        eps,
        state: State::new(),
    })
}

impl AdaGrad {
    /// Default value for `eps`.
    pub const DEFAULT_EPS: f32 = 1e-8;
}

impl Optimizer for AdaGrad {
    type State = State;

    fn state(&self) -> &Self::State {
        &self.state
    }

    fn state_mut(&mut self) -> &mut Self::State {
        &mut self.state
    }

    fn update_single(
        &mut self,
        key: &Rc<str>,
        gradient: &Array,
        parameter: &mut Array,
    ) -> crate::error::Result<()> {
        let state = get_mut_or_insert_with(&mut self.state, key, || array!(0.0));

        let v = state.add(square(gradient)?)?;

        let num = self.lr.multiply(gradient)?;
        let den = v.sqrt()?.add(&self.eps)?;
        let new_param = parameter.subtract(num.divide(&den)?)?;

        *state = v;
        *parameter = new_param;

        Ok(())
    }
}

impl Updatable for AdaGrad {
    fn updatable_states(&self) -> impl IntoIterator<Item = &Array> {
        use itertools::Itertools;

        self.state
            .iter()
            .sorted_by(|a, b| a.0.cmp(b.0))
            .map(|(_, v)| v)
    }

    fn updatable_states_mut(&mut self) -> impl IntoIterator<Item = &mut Array> {
        use itertools::Itertools;

        self.state
            .iter_mut()
            .sorted_by(|a, b| a.0.cmp(b.0))
            .map(|(_, v)| v)
    }
}

impl_updatable_for_mut_optimizer!(AdaGrad);