mlx_rs/optimizers/
adagrad.rs1use std::{convert::Infallible, rc::Rc};
2
3use crate::{array, ops::square, utils::Updatable, Array};
4use mlx_internal_macros::{generate_builder, Buildable};
5
6use crate::utils::get_mut_or_insert_with;
7
8use super::*;
9
10generate_builder! {
11 #[derive(Debug, Clone, Buildable)]
18 #[buildable(root = crate)]
19 #[builder(
20 build_with = build_adagrad,
21 root = crate
22 )]
23 pub struct AdaGrad {
24 #[builder(ty_override = f32)]
26 pub lr: Array,
27
28 #[builder(optional, ty_override = f32, default = AdaGrad::DEFAULT_EPS)]
31 pub eps: Array,
32
33 #[builder(ignore)]
35 pub state: State,
36 }
37}
38
39fn build_adagrad(builder: AdaGradBuilder) -> Result<AdaGrad, Infallible> {
41 let eps = array!(builder.eps);
42
43 Ok(AdaGrad {
44 lr: array!(builder.lr),
45 eps,
46 state: State::new(),
47 })
48}
49
50impl AdaGrad {
51 pub const DEFAULT_EPS: f32 = 1e-8;
53}
54
55impl Optimizer for AdaGrad {
56 type State = State;
57
58 fn state(&self) -> &Self::State {
59 &self.state
60 }
61
62 fn state_mut(&mut self) -> &mut Self::State {
63 &mut self.state
64 }
65
66 fn update_single(
67 &mut self,
68 key: &Rc<str>,
69 gradient: &Array,
70 parameter: &mut Array,
71 ) -> crate::error::Result<()> {
72 let state = get_mut_or_insert_with(&mut self.state, key, || array!(0.0));
73
74 let v = state.add(square(gradient)?)?;
75
76 let num = self.lr.multiply(gradient)?;
77 let den = v.sqrt()?.add(&self.eps)?;
78 let new_param = parameter.subtract(num.divide(&den)?)?;
79
80 *state = v;
81 *parameter = new_param;
82
83 Ok(())
84 }
85}
86
87impl Updatable for AdaGrad {
88 fn updatable_states(&self) -> impl IntoIterator<Item = &Array> {
89 use itertools::Itertools;
90
91 self.state
92 .iter()
93 .sorted_by(|a, b| a.0.cmp(b.0))
94 .map(|(_, v)| v)
95 }
96
97 fn updatable_states_mut(&mut self) -> impl IntoIterator<Item = &mut Array> {
98 use itertools::Itertools;
99
100 self.state
101 .iter_mut()
102 .sorted_by(|a, b| a.0.cmp(b.0))
103 .map(|(_, v)| v)
104 }
105}
106
107impl_updatable_for_mut_optimizer!(AdaGrad);