mlx_rs/optimizers/
adadelta.rs1use std::rc::Rc;
2
3use crate::{
4 array,
5 ops::sqrt,
6 utils::{get_mut_or_insert_with, Updatable},
7 Array,
8};
9use mlx_internal_macros::{generate_builder, Buildable};
10
11use crate::error::AdaDeltaBuildError;
12
13use super::*;
14
15generate_builder! {
16 #[derive(Debug, Clone, Buildable)]
22 #[buildable(root = crate)]
23 #[builder(
24 build_with = build_adadelta,
25 err = AdaDeltaBuildError,
26 root = crate
27 )]
28 pub struct AdaDelta {
29 #[builder(ty_override = f32)]
31 pub lr: Array,
32
33 #[builder(optional, ty_override = f32, default = AdaDelta::DEFAULT_RHO)]
36 pub rho: Array,
37
38 #[builder(optional, ty_override = f32, default = AdaDelta::DEFAULT_EPS)]
41 pub eps: Array,
42
43 #[builder(ignore)]
45 pub state: State<(Array, Array)>,
46 }
47}
48
49fn build_adadelta(builder: AdaDeltaBuilder) -> Result<AdaDelta, AdaDeltaBuildError> {
51 let rho = builder.rho;
52 let eps = builder.eps;
53
54 if rho < 0.0 {
55 return Err(AdaDeltaBuildError::NegativeRho);
56 }
57
58 if eps < 0.0 {
59 return Err(AdaDeltaBuildError::NegativeEps);
60 }
61
62 Ok(AdaDelta {
63 lr: array!(builder.lr),
64 rho: array!(rho),
65 eps: array!(eps),
66 state: State::new(),
67 })
68}
69
70impl AdaDelta {
71 pub const DEFAULT_RHO: f32 = 0.99;
73
74 pub const DEFAULT_EPS: f32 = 1e-6;
76}
77
78impl Optimizer for AdaDelta {
79 type State = State<(Array, Array)>;
80
81 fn state(&self) -> &Self::State {
82 &self.state
83 }
84
85 fn state_mut(&mut self) -> &mut Self::State {
86 &mut self.state
87 }
88
89 fn update_single(
90 &mut self,
91 key: &Rc<str>,
92 gradient: &Array,
93 parameter: &mut Array,
94 ) -> crate::error::Result<()> {
95 let (v, u) = get_mut_or_insert_with(&mut self.state, key, || (array!(0.0), array!(0.0)));
96
97 let one_minus_rho = array!(1.0).subtract(&self.rho)?;
98 let first_term = self.rho.multiply(&v)?;
99 let second_term = one_minus_rho.multiply(gradient.square()?)?;
100 let v_new = first_term.add(&second_term)?;
101
102 let num = sqrt(&u.add(&self.eps)?)?;
103 let den = sqrt(&v_new.add(&self.eps)?)?;
104 let d = num.divide(&den)?.multiply(gradient)?;
105 let first_term = self.rho.multiply(&u)?;
106 let second_term = one_minus_rho.multiply(d.square()?)?;
107 let u_new = first_term.add(&second_term)?;
108
109 let param_new = parameter.subtract(self.lr.multiply(d)?)?;
110
111 *parameter = param_new;
112
113 *v = v_new;
114 *u = u_new;
115
116 Ok(())
117 }
118}
119
120impl Updatable for AdaDelta {
121 fn updatable_states(&self) -> impl IntoIterator<Item = &Array> {
122 use itertools::Itertools;
123
124 self.state
125 .iter()
126 .sorted_by(|a, b| a.0.cmp(b.0))
127 .flat_map(|(_, (v, u))| vec![v, u])
128 }
129
130 fn updatable_states_mut(&mut self) -> impl IntoIterator<Item = &mut Array> {
131 use itertools::Itertools;
132
133 self.state
134 .iter_mut()
135 .sorted_by(|a, b| a.0.cmp(b.0))
136 .flat_map(|(_, (v, u))| vec![v, u])
137 }
138}
139
140impl_updatable_for_mut_optimizer!(AdaDelta);