mlx_rs/optimizers/
adamax.rs
1use std::{convert::Infallible, rc::Rc};
2
3use mlx_internal_macros::{generate_builder, Buildable};
4
5use crate::{
6 array,
7 ops::{abs, maximum},
8 utils::{get_mut_or_insert_with, Updatable},
9 Array,
10};
11
12use super::*;
13
14generate_builder! {
15 #[derive(Debug, Clone, Buildable)]
22 #[buildable(root = crate)]
23 #[builder(
24 build_with = build_adamax,
25 root = crate
26 )]
27 pub struct Adamax {
28 #[builder(ty_override = f32)]
30 pub lr: Array,
31
32 #[builder(optional, ty_override = Betas, default = Adamax::DEFAULT_BETAS)]
34 pub betas: (Array, Array),
35
36 #[builder(optional, ty_override = f32, default = Adamax::DEFAULT_EPS)]
38 pub eps: Array,
39
40 #[builder(ignore)]
42 pub state: State<(Array, Array)>,
43 }
44}
45
46fn build_adamax(builder: AdamaxBuilder) -> Result<Adamax, Infallible> {
47 let lr = builder.lr;
48 let betas = builder.betas;
49 let eps = builder.eps;
50
51 Ok(Adamax {
52 lr: array!(lr),
53 betas: (array!(betas.0), array!(betas.1)),
54 eps: array!(eps),
55 state: State::new(),
56 })
57}
58
59impl Adamax {
60 pub const DEFAULT_BETAS: (f32, f32) = (0.9, 0.999);
62
63 pub const DEFAULT_EPS: f32 = 1e-8;
65}
66
67impl Optimizer for Adamax {
68 type State = State<(Array, Array)>;
69
70 fn state(&self) -> &Self::State {
71 &self.state
72 }
73
74 fn state_mut(&mut self) -> &mut Self::State {
75 &mut self.state
76 }
77
78 fn update_single(
79 &mut self,
80 key: &Rc<str>,
81 gradient: &Array,
82 parameter: &mut Array,
83 ) -> crate::error::Result<()> {
84 let (b1, b2) = &self.betas;
85 let (m, v) = get_mut_or_insert_with(&mut self.state, key, || (array!(0.0), array!(0.0)));
86
87 let one_minus_b1 = array!(1.0).subtract(b1)?;
88 let new_m = b1.multiply(&*m)?.add(&one_minus_b1.multiply(gradient)?)?;
89 let new_v = maximum(b2.multiply(&*v)?, abs(gradient)?)?;
90
91 let new_parameter =
92 parameter.subtract(self.lr.multiply(&new_m)?.divide(&new_v.add(&self.eps)?)?)?;
93
94 *m = new_m;
95 *v = new_v;
96 *parameter = new_parameter;
97
98 Ok(())
99 }
100}
101
102impl Updatable for Adamax {
103 fn updatable_states(&self) -> impl IntoIterator<Item = &Array> {
104 use itertools::Itertools;
105
106 self.state
107 .iter()
108 .sorted_by(|a, b| a.0.cmp(b.0))
109 .flat_map(|(_, (v, u))| vec![v, u])
110 }
111
112 fn updatable_states_mut(&mut self) -> impl IntoIterator<Item = &mut Array> {
113 use itertools::Itertools;
114
115 self.state
116 .iter_mut()
117 .sorted_by(|a, b| a.0.cmp(b.0))
118 .flat_map(|(_, (v, u))| vec![v, u])
119 }
120}
121
122impl_updatable_for_mut_optimizer!(Adamax);