mlx_rs/optimizers/
adamax.rs

1use std::{convert::Infallible, rc::Rc};
2
3use mlx_internal_macros::{generate_builder, Buildable};
4
5use crate::{
6    array,
7    ops::{abs, maximum},
8    utils::{get_mut_or_insert_with, Updatable},
9    Array,
10};
11
12use super::*;
13
14generate_builder! {
15    /// The Adamax optimizer, a variant of Adam based on the infinity norm [1].
16    ///
17    /// Our Adam implementation follows the original paper and omits the bias
18    /// correction in the first and second moment estimates. In detail,
19    ///
20    /// [1]: Kingma, D.P. and Ba, J., 2015. Adam: A method for stochastic optimization. ICLR 2015.
21    #[derive(Debug, Clone, Buildable)]
22    #[buildable(root = crate)]
23    #[builder(
24        build_with = build_adamax,
25        root = crate
26    )]
27    pub struct Adamax {
28        /// The learning rate.
29        #[builder(ty_override = f32)]
30        pub lr: Array,
31
32        /// The beta coefficients
33        #[builder(optional, ty_override = Betas, default = Adamax::DEFAULT_BETAS)]
34        pub betas: (Array, Array),
35
36        /// The epsilon added to the denominator to improve numerical stability.
37        #[builder(optional, ty_override = f32, default = Adamax::DEFAULT_EPS)]
38        pub eps: Array,
39
40        /// Inner state.
41        #[builder(ignore)]
42        pub state: State<(Array, Array)>,
43    }
44}
45
46fn build_adamax(builder: AdamaxBuilder) -> Result<Adamax, Infallible> {
47    let lr = builder.lr;
48    let betas = builder.betas;
49    let eps = builder.eps;
50
51    Ok(Adamax {
52        lr: array!(lr),
53        betas: (array!(betas.0), array!(betas.1)),
54        eps: array!(eps),
55        state: State::new(),
56    })
57}
58
59impl Adamax {
60    /// Default value for `betas`.
61    pub const DEFAULT_BETAS: (f32, f32) = (0.9, 0.999);
62
63    /// Default value for `eps`.
64    pub const DEFAULT_EPS: f32 = 1e-8;
65}
66
67impl Optimizer for Adamax {
68    type State = State<(Array, Array)>;
69
70    fn state(&self) -> &Self::State {
71        &self.state
72    }
73
74    fn state_mut(&mut self) -> &mut Self::State {
75        &mut self.state
76    }
77
78    fn update_single(
79        &mut self,
80        key: &Rc<str>,
81        gradient: &Array,
82        parameter: &mut Array,
83    ) -> crate::error::Result<()> {
84        let (b1, b2) = &self.betas;
85        let (m, v) = get_mut_or_insert_with(&mut self.state, key, || (array!(0.0), array!(0.0)));
86
87        let one_minus_b1 = array!(1.0).subtract(b1)?;
88        let new_m = b1.multiply(&*m)?.add(&one_minus_b1.multiply(gradient)?)?;
89        let new_v = maximum(b2.multiply(&*v)?, abs(gradient)?)?;
90
91        let new_parameter =
92            parameter.subtract(self.lr.multiply(&new_m)?.divide(&new_v.add(&self.eps)?)?)?;
93
94        *m = new_m;
95        *v = new_v;
96        *parameter = new_parameter;
97
98        Ok(())
99    }
100}
101
102impl Updatable for Adamax {
103    fn updatable_states(&self) -> impl IntoIterator<Item = &Array> {
104        use itertools::Itertools;
105
106        self.state
107            .iter()
108            .sorted_by(|a, b| a.0.cmp(b.0))
109            .flat_map(|(_, (v, u))| vec![v, u])
110    }
111
112    fn updatable_states_mut(&mut self) -> impl IntoIterator<Item = &mut Array> {
113        use itertools::Itertools;
114
115        self.state
116            .iter_mut()
117            .sorted_by(|a, b| a.0.cmp(b.0))
118            .flat_map(|(_, (v, u))| vec![v, u])
119    }
120}
121
122impl_updatable_for_mut_optimizer!(Adamax);