mlx_rs/optimizers/
rmsprop.rs1use std::rc::Rc;
2
3use crate::{
4 array,
5 ops::{sqrt, square},
6 Array,
7};
8use mlx_internal_macros::{generate_builder, Buildable};
9
10use crate::{error::RmsPropBuildError, utils::get_mut_or_insert_with};
11
12use super::*;
13
14generate_builder! {
15 #[derive(Debug, Clone, Buildable)]
20 #[buildable(root = crate)]
21 #[builder(
22 build_with = build_rmdprop,
23 err = RmsPropBuildError,
24 root = crate
25 )]
26 pub struct RmsProp {
27 #[builder(ty_override = f32)]
29 pub lr: Array,
30
31 #[builder(optional, ty_override = f32, default = RmsProp::DEFAULT_ALPHA)]
33 pub alpha: Array,
34
35 #[builder(optional, ty_override = f32, default = RmsProp::DEFAULT_EPSILON)]
38 pub epsilon: Array,
39
40 #[builder(ignore)]
42 pub state: State,
43 }
44}
45
46fn build_rmdprop(builder: RmsPropBuilder) -> Result<RmsProp, RmsPropBuildError> {
47 let lr = builder.lr;
48 let alpha = builder.alpha;
49 let epsilon = builder.epsilon;
50
51 if alpha < 0.0 {
52 return Err(RmsPropBuildError::NegativeAlpha);
53 }
54
55 if epsilon < 0.0 {
56 return Err(RmsPropBuildError::NegativeEpsilon);
57 }
58
59 Ok(RmsProp {
60 lr: array!(lr),
61 alpha: array!(alpha),
62 epsilon: array!(epsilon),
63 state: State::new(),
64 })
65}
66
67impl RmsProp {
68 pub const DEFAULT_ALPHA: f32 = 0.99;
70
71 pub const DEFAULT_EPSILON: f32 = 1e-8;
73}
74
75impl Optimizer for RmsProp {
76 type State = State;
77
78 fn state(&self) -> &Self::State {
79 &self.state
80 }
81
82 fn state_mut(&mut self) -> &mut Self::State {
83 &mut self.state
84 }
85
86 fn update_single(
87 &mut self,
88 key: &Rc<str>,
89 gradient: &Array,
90 parameter: &mut Array,
91 ) -> crate::error::Result<()> {
92 let state = get_mut_or_insert_with(&mut self.state, key, || array!(0.0));
93
94 let lr = &self.lr;
95 let alpha = &self.alpha;
96 let eps = &self.epsilon;
97
98 let one_minus_alpha = array!(1.0).subtract(alpha)?;
99 let first_term = alpha.multiply(&*state)?;
100 let second_term = one_minus_alpha.multiply(square(gradient)?)?;
101 let v = first_term.add(&second_term)?;
102
103 let num = lr.multiply(gradient)?;
104 let den = sqrt(&v)?.add(eps)?;
105 let new_param = parameter.subtract(num.divide(&den)?)?;
106
107 *parameter = new_param;
108 *state = v;
109
110 Ok(())
111 }
112}
113
114impl Updatable for RmsProp {
115 fn updatable_states(&self) -> impl IntoIterator<Item = &Array> {
116 use itertools::Itertools;
117
118 self.state
119 .iter()
120 .sorted_by(|a, b| a.0.cmp(b.0))
121 .map(|(_, v)| v)
122 }
123
124 fn updatable_states_mut(&mut self) -> impl IntoIterator<Item = &mut Array> {
125 use itertools::Itertools;
126
127 self.state
128 .iter_mut()
129 .sorted_by(|a, b| a.0.cmp(b.0))
130 .map(|(_, v)| v)
131 }
132}
133
134impl_updatable_for_mut_optimizer!(RmsProp);