mlx_rs/optimizers/
adadelta.rs

1use std::rc::Rc;
2
3use crate::{
4    array,
5    ops::sqrt,
6    utils::{get_mut_or_insert_with, Updatable},
7    Array,
8};
9use mlx_internal_macros::{generate_builder, Buildable};
10
11use crate::error::AdaDeltaBuildError;
12
13use super::*;
14
15generate_builder! {
16    /// The AdaDelta optimizer with a learning rate
17    ///
18    /// Please refer to the original paper for more details:
19    ///
20    /// [1]: Zeiler, M.D., 2012. ADADELTA: an adaptive learning rate method. arXiv preprint arXiv:1212.5701.
21    #[derive(Debug, Clone, Buildable)]
22    #[buildable(root = crate)]
23    #[builder(
24        build_with = build_adadelta,
25        err = AdaDeltaBuildError,
26        root = crate
27    )]
28    pub struct AdaDelta {
29        /// The learning rate
30        #[builder(ty_override = f32)]
31        pub lr: Array,
32
33        /// The coefficient used for computing a running average of squared gradients. Default to
34        /// [`AdaDelta::DEFAULT_RHO`].
35        #[builder(optional, ty_override = f32, default = AdaDelta::DEFAULT_RHO)]
36        pub rho: Array,
37
38        /// The epsilon added to the denominator to improve numerical stability. Default to
39        /// [`AdaDelta::DEFAULT_EPS`].
40        #[builder(optional, ty_override = f32, default = AdaDelta::DEFAULT_EPS)]
41        pub eps: Array,
42
43        /// Inner state
44        #[builder(ignore)]
45        pub state: State<(Array, Array)>,
46    }
47}
48
49/// Builds a new [`AdaDelta`] optimizer
50fn build_adadelta(builder: AdaDeltaBuilder) -> Result<AdaDelta, AdaDeltaBuildError> {
51    let rho = builder.rho;
52    let eps = builder.eps;
53
54    if rho < 0.0 {
55        return Err(AdaDeltaBuildError::NegativeRho);
56    }
57
58    if eps < 0.0 {
59        return Err(AdaDeltaBuildError::NegativeEps);
60    }
61
62    Ok(AdaDelta {
63        lr: array!(builder.lr),
64        rho: array!(rho),
65        eps: array!(eps),
66        state: State::new(),
67    })
68}
69
70impl AdaDelta {
71    /// Default value for `rho`
72    pub const DEFAULT_RHO: f32 = 0.99;
73
74    /// Default value for `eps`
75    pub const DEFAULT_EPS: f32 = 1e-6;
76}
77
78impl Optimizer for AdaDelta {
79    type State = State<(Array, Array)>;
80
81    fn state(&self) -> &Self::State {
82        &self.state
83    }
84
85    fn state_mut(&mut self) -> &mut Self::State {
86        &mut self.state
87    }
88
89    fn update_single(
90        &mut self,
91        key: &Rc<str>,
92        gradient: &Array,
93        parameter: &mut Array,
94    ) -> crate::error::Result<()> {
95        let (v, u) = get_mut_or_insert_with(&mut self.state, key, || (array!(0.0), array!(0.0)));
96
97        let one_minus_rho = array!(1.0).subtract(&self.rho)?;
98        let first_term = self.rho.multiply(&v)?;
99        let second_term = one_minus_rho.multiply(gradient.square()?)?;
100        let v_new = first_term.add(&second_term)?;
101
102        let num = sqrt(&u.add(&self.eps)?)?;
103        let den = sqrt(&v_new.add(&self.eps)?)?;
104        let d = num.divide(&den)?.multiply(gradient)?;
105        let first_term = self.rho.multiply(&u)?;
106        let second_term = one_minus_rho.multiply(d.square()?)?;
107        let u_new = first_term.add(&second_term)?;
108
109        let param_new = parameter.subtract(self.lr.multiply(d)?)?;
110
111        *parameter = param_new;
112
113        *v = v_new;
114        *u = u_new;
115
116        Ok(())
117    }
118}
119
120impl Updatable for AdaDelta {
121    fn updatable_states(&self) -> impl IntoIterator<Item = &Array> {
122        use itertools::Itertools;
123
124        self.state
125            .iter()
126            .sorted_by(|a, b| a.0.cmp(b.0))
127            .flat_map(|(_, (v, u))| vec![v, u])
128    }
129
130    fn updatable_states_mut(&mut self) -> impl IntoIterator<Item = &mut Array> {
131        use itertools::Itertools;
132
133        self.state
134            .iter_mut()
135            .sorted_by(|a, b| a.0.cmp(b.0))
136            .flat_map(|(_, (v, u))| vec![v, u])
137    }
138}
139
140impl_updatable_for_mut_optimizer!(AdaDelta);