mlx_rs/optimizers/
adamw.rs1use std::convert::Infallible;
2
3use mlx_internal_macros::{generate_builder, Buildable};
4
5use crate::{
6 array,
7 utils::{get_mut_or_insert_with, Updatable},
8 Array,
9};
10
11use super::*;
12
13generate_builder! {
14 #[derive(Debug, Clone, Buildable)]
22 #[buildable(root = crate)]
23 #[builder(
24 build_with = build_adamw,
25 root = crate
26 )]
27 pub struct AdamW {
28 #[builder(ty_override = f32)]
30 pub lr: Array,
31
32 #[builder(optional, ty_override = Betas, default = AdamW::DEFAULT_BETAS)]
36 pub betas: (Array, Array),
37
38 #[builder(optional, ty_override = f32, default = AdamW::DEFAULT_EPS)]
42 pub eps: Array,
43
44 #[builder(optional, ty_override = f32, default = AdamW::DEFAULT_WEIGHT_DECAY)]
48 pub weight_decay: Array,
49
50 #[builder(ignore)]
52 pub state: State<(Array, Array)>,
53 }
54}
55
56fn build_adamw(builder: AdamWBuilder) -> Result<AdamW, Infallible> {
58 let lr = builder.lr;
59 let betas = builder.betas;
60 let eps = builder.eps;
61 let weight_decay = builder.weight_decay;
62
63 Ok(AdamW {
64 lr: array!(lr),
65 betas: (array!(betas.0), array!(betas.1)),
66 eps: array!(eps),
67 weight_decay: array!(weight_decay),
68 state: State::new(),
69 })
70}
71
72impl AdamW {
73 pub const DEFAULT_BETAS: (f32, f32) = super::Adam::DEFAULT_BETAS;
75
76 pub const DEFAULT_EPS: f32 = super::Adam::DEFAULT_EPS;
78
79 pub const DEFAULT_WEIGHT_DECAY: f32 = 0.01;
81}
82
83impl Optimizer for AdamW {
84 type State = State<(Array, Array)>;
85
86 fn state(&self) -> &Self::State {
87 &self.state
88 }
89
90 fn state_mut(&mut self) -> &mut Self::State {
91 &mut self.state
92 }
93
94 fn update_single(
95 &mut self,
96 key: &std::rc::Rc<str>,
97 gradient: &Array,
98 parameter: &mut Array,
99 ) -> Result<(), crate::error::Exception> {
100 let betas = &self.betas;
101 let state = get_mut_or_insert_with(&mut self.state, key, || (array!(0.0), array!(0.0)));
102
103 let one_minus_lr_wd = array!(1.0) - (&self.lr * &self.weight_decay);
105 let decayed_parameter = &*parameter * &one_minus_lr_wd;
106
107 let (new_parameter, new_states) = super::adam_apply_single(
108 &self.lr,
109 betas,
110 &self.eps,
111 gradient,
112 &decayed_parameter,
113 state,
114 )?;
115
116 *state = new_states;
117 *parameter = new_parameter;
118
119 Ok(())
120 }
121}
122
123impl Updatable for AdamW {
124 fn updatable_states(&self) -> impl IntoIterator<Item = &Array> {
125 use itertools::Itertools;
126
127 self.state
128 .iter()
129 .sorted_by(|a, b| a.0.cmp(b.0))
130 .flat_map(|(_, (v, u))| vec![v, u])
131 }
132
133 fn updatable_states_mut(&mut self) -> impl IntoIterator<Item = &mut Array> {
134 use itertools::Itertools;
135
136 self.state
137 .iter_mut()
138 .sorted_by(|a, b| a.0.cmp(b.0))
139 .flat_map(|(_, (v, u))| vec![v, u])
140 }
141}
142
143impl_updatable_for_mut_optimizer!(AdamW);