mlx_rs/optimizers/
rmsprop.rs

1use std::rc::Rc;
2
3use crate::{
4    array,
5    ops::{sqrt, square},
6    Array,
7};
8use mlx_internal_macros::{generate_builder, Buildable};
9
10use crate::{error::RmsPropBuildError, utils::get_mut_or_insert_with};
11
12use super::*;
13
14generate_builder! {
15    /// The RMSprop optimizer [1].
16    ///
17    /// [1]: Tieleman, T. and Hinton, G. 2012. Lecture 6.5-rmsprop, coursera: Neural networks for
18    ///     machine learning
19    #[derive(Debug, Clone, Buildable)]
20    #[buildable(root = crate)]
21    #[builder(
22        build_with = build_rmdprop,
23        err = RmsPropBuildError,
24        root = crate
25    )]
26    pub struct RmsProp {
27        /// Learning rate
28        #[builder(ty_override = f32)]
29        pub lr: Array,
30
31        /// The smoothing constant. Default to [`RmsProp::DEFAULT_ALPHA`] if not specified.
32        #[builder(optional, ty_override = f32, default = RmsProp::DEFAULT_ALPHA)]
33        pub alpha: Array,
34
35        /// The epsilon added to the denominator to improve numerical stability. Default to
36        /// [`RmsProp::DEFAULT_EPSILON`] if not specified.
37        #[builder(optional, ty_override = f32, default = RmsProp::DEFAULT_EPSILON)]
38        pub epsilon: Array,
39
40        /// Inner state
41        #[builder(ignore)]
42        pub state: State,
43    }
44}
45
46fn build_rmdprop(builder: RmsPropBuilder) -> Result<RmsProp, RmsPropBuildError> {
47    let lr = builder.lr;
48    let alpha = builder.alpha;
49    let epsilon = builder.epsilon;
50
51    if alpha < 0.0 {
52        return Err(RmsPropBuildError::NegativeAlpha);
53    }
54
55    if epsilon < 0.0 {
56        return Err(RmsPropBuildError::NegativeEpsilon);
57    }
58
59    Ok(RmsProp {
60        lr: array!(lr),
61        alpha: array!(alpha),
62        epsilon: array!(epsilon),
63        state: State::new(),
64    })
65}
66
67impl RmsProp {
68    /// Default alpha if not specified.
69    pub const DEFAULT_ALPHA: f32 = 0.99;
70
71    /// Default epsilon if not specified.
72    pub const DEFAULT_EPSILON: f32 = 1e-8;
73}
74
75impl Optimizer for RmsProp {
76    type State = State;
77
78    fn state(&self) -> &Self::State {
79        &self.state
80    }
81
82    fn state_mut(&mut self) -> &mut Self::State {
83        &mut self.state
84    }
85
86    fn update_single(
87        &mut self,
88        key: &Rc<str>,
89        gradient: &Array,
90        parameter: &mut Array,
91    ) -> crate::error::Result<()> {
92        let state = get_mut_or_insert_with(&mut self.state, key, || array!(0.0));
93
94        let lr = &self.lr;
95        let alpha = &self.alpha;
96        let eps = &self.epsilon;
97
98        let one_minus_alpha = array!(1.0).subtract(alpha)?;
99        let first_term = alpha.multiply(&*state)?;
100        let second_term = one_minus_alpha.multiply(square(gradient)?)?;
101        let v = first_term.add(&second_term)?;
102
103        let num = lr.multiply(gradient)?;
104        let den = sqrt(&v)?.add(eps)?;
105        let new_param = parameter.subtract(num.divide(&den)?)?;
106
107        *parameter = new_param;
108        *state = v;
109
110        Ok(())
111    }
112}
113
114impl Updatable for RmsProp {
115    fn updatable_states(&self) -> impl IntoIterator<Item = &Array> {
116        use itertools::Itertools;
117
118        self.state
119            .iter()
120            .sorted_by(|a, b| a.0.cmp(b.0))
121            .map(|(_, v)| v)
122    }
123
124    fn updatable_states_mut(&mut self) -> impl IntoIterator<Item = &mut Array> {
125        use itertools::Itertools;
126
127        self.state
128            .iter_mut()
129            .sorted_by(|a, b| a.0.cmp(b.0))
130            .map(|(_, v)| v)
131    }
132}
133
134impl_updatable_for_mut_optimizer!(RmsProp);