mlx_rs/optimizers/
adam.rs1use std::convert::Infallible;
2
3use mlx_internal_macros::{generate_builder, Buildable};
4
5use crate::{array, utils::get_mut_or_insert_with};
6
7use super::*;
8
9pub type Betas = (f32, f32); generate_builder! {
14 #[derive(Debug, Clone, Buildable)]
20 #[buildable(root = crate)]
21 #[builder(
22 build_with = build_adam,
23 root = crate
24 )]
25 pub struct Adam {
26 #[builder(ty_override = f32)]
28 pub lr: Array,
29
30 #[builder(optional, ty_override = Betas, default = Adam::DEFAULT_BETAS)]
34 pub betas: (Array, Array),
35
36 #[builder(optional, ty_override = f32, default = Adam::DEFAULT_EPS)]
40 pub eps: Array,
41
42 #[builder(ignore)]
44 pub state: State<(Array, Array)>,
45 }
46}
47
48fn build_adam(builder: AdamBuilder) -> Result<Adam, Infallible> {
50 let lr = array!(builder.lr);
51 let betas = builder.betas;
52 let eps = array!(builder.eps);
53
54 Ok(Adam {
55 lr,
56 betas: (array!(betas.0), array!(betas.1)),
57 eps,
58 state: State::new(),
59 })
60}
61
62impl Adam {
63 pub const DEFAULT_BETAS: (f32, f32) = (0.9, 0.999);
65
66 pub const DEFAULT_EPS: f32 = 1e-8;
68}
69
70impl Optimizer for Adam {
71 type State = State<(Array, Array)>;
72
73 fn state(&self) -> &Self::State {
74 &self.state
75 }
76
77 fn state_mut(&mut self) -> &mut Self::State {
78 &mut self.state
79 }
80
81 fn update_single(
82 &mut self,
83 key: &Rc<str>,
84 gradient: &Array,
85 parameter: &mut Array,
86 ) -> crate::error::Result<()> {
87 let betas = &self.betas;
88 let state = get_mut_or_insert_with(&mut self.state, key, || (array!(0.0), array!(0.0)));
89
90 let (new_parameter, new_state) =
91 adam_apply_single(&self.lr, betas, &self.eps, gradient, parameter, state)?;
92
93 *state = new_state;
94 *parameter = new_parameter;
95
96 Ok(())
97 }
98}
99
100pub(super) fn adam_apply_single(
102 lr: &Array,
103 betas: &(Array, Array),
104 eps: &Array,
105 gradient: &Array,
106 parameter: &Array,
107 state: &(Array, Array),
108) -> crate::error::Result<(Array, (Array, Array))> {
109 let (b1, b2) = betas;
110 let (m, v) = state;
111
112 let one_minus_b1 = array!(1.0).subtract(b1)?;
113 let one_minus_b2 = array!(1.0).subtract(b2)?;
114
115 let new_m = b1.multiply(m)?.add(&one_minus_b1.multiply(gradient)?)?;
116 let new_v = b2
117 .multiply(v)?
118 .add(&one_minus_b2.multiply(gradient.square()?)?)?;
119
120 let new_parameter =
121 parameter.subtract(&lr.multiply(&new_m.divide(&new_v.sqrt()?.add(eps)?)?)?)?;
122
123 Ok((new_parameter, (new_m, new_v)))
124}
125
126impl Updatable for Adam {
127 fn updatable_states(&self) -> impl IntoIterator<Item = &Array> {
128 use itertools::Itertools;
129
130 self.state
131 .iter()
132 .sorted_by(|a, b| a.0.cmp(b.0))
133 .flat_map(|(_, (v, u))| vec![v, u])
134 }
135
136 fn updatable_states_mut(&mut self) -> impl IntoIterator<Item = &mut Array> {
137 use itertools::Itertools;
138
139 self.state
140 .iter_mut()
141 .sorted_by(|a, b| a.0.cmp(b.0))
142 .flat_map(|(_, (v, u))| vec![v, u])
143 }
144}
145
146impl_updatable_for_mut_optimizer!(Adam);