1use std::{borrow::Cow, rc::Rc};
2
3use crate::{array, utils::get_mut_or_insert_with, Array};
4use mlx_internal_macros::{generate_builder, Buildable};
5
6use super::*;
7
8generate_builder! {
9 #[derive(Debug, Clone, Buildable)]
11 #[buildable(root = crate)]
12 #[builder(
13 build_with = build_sgd,
14 root = crate
15 )]
16 pub struct Sgd {
17 pub lr: f32,
19
20 #[builder(optional, default = Sgd::DEFAULT_MOMENTUM)]
22 pub momentum: f32,
23
24 #[builder(optional, default = Sgd::DEFAULT_WEIGHT_DECAY)]
26 pub weight_decay: f32,
27
28 #[builder(optional, default = Sgd::DEFAULT_DAMPENING)]
30 pub dampening: f32,
31
32 #[builder(optional, ty_override = bool, default = Sgd::DEFAULT_NESTEROV)]
34 pub nesterov: bool,
35
36 #[builder(ignore)]
38 pub state: State,
39 }
40}
41
42fn build_sgd(builder: SgdBuilder) -> Result<Sgd, std::convert::Infallible> {
43 let lr = builder.lr;
44 let momentum = builder.momentum;
45 let weight_decay = builder.weight_decay;
46 let dampening = builder.dampening;
47 let nesterov = builder.nesterov;
48
49 Ok(Sgd {
50 lr,
51 momentum,
52 weight_decay,
53 dampening,
54 nesterov,
55 state: State::new(),
56 })
57}
58
59impl Sgd {
60 pub const DEFAULT_MOMENTUM: f32 = 0.0;
62
63 pub const DEFAULT_WEIGHT_DECAY: f32 = 0.0;
65
66 pub const DEFAULT_DAMPENING: f32 = 0.0;
68
69 pub const DEFAULT_NESTEROV: bool = false;
71}
72
73impl Optimizer for Sgd {
74 type State = State;
75
76 fn state(&self) -> &Self::State {
77 &self.state
78 }
79
80 fn state_mut(&mut self) -> &mut Self::State {
81 &mut self.state
82 }
83
84 #[inline]
86 fn update_single(
87 &mut self,
88 key: &Rc<str>,
89 gradient: &Array,
90 parameter: &mut Array,
91 ) -> crate::error::Result<()> {
92 let state = get_mut_or_insert_with(&mut self.state, key, || array!(0.0));
93 let mut gradient = Cow::Borrowed(gradient);
94
95 if self.weight_decay != 0.0 {
96 let weight_decay = array!(self.weight_decay);
97 gradient = Cow::Owned(weight_decay.multiply(&*parameter)?.add(&*gradient)?);
98 }
99
100 if self.momentum <= 0.0 {
101 let lr = array!(self.lr);
102 *parameter = parameter.subtract(lr.multiply(gradient)?)?;
103 return Ok(());
104 }
105
106 let mut v = &*state * self.momentum;
107
108 if self.dampening > 0.0 {
109 let dampening = array!(self.dampening);
110 let one_minus_dampening = array!(1.0).subtract(dampening)?;
111 v = v.add(&one_minus_dampening.multiply(&gradient)?)?;
112 } else {
113 v = v.add(&gradient)?;
114 }
115
116 match self.nesterov {
117 true => {
118 let momentum = array!(self.momentum);
119 let lr = array!(self.lr);
120 let update = gradient.add(momentum.multiply(&v)?)?;
121 *parameter = parameter.subtract(lr.multiply(&update)?)?;
122 *state = v;
123 }
124 false => {
125 let update = &v;
126 let lr = array!(self.lr);
127 *parameter = parameter.subtract(lr.multiply(update)?)?;
128 *state = v;
129 }
130 }
131
132 Ok(())
133 }
134}
135
136impl Updatable for Sgd {
137 fn updatable_states(&self) -> impl IntoIterator<Item = &Array> {
138 use itertools::Itertools;
139
140 self.state
141 .iter()
142 .sorted_by(|a, b| a.0.cmp(b.0))
143 .map(|(_, v)| v)
144 }
145
146 fn updatable_states_mut(&mut self) -> impl IntoIterator<Item = &mut Array> {
147 use itertools::Itertools;
148
149 self.state
150 .iter_mut()
151 .sorted_by(|a, b| a.0.cmp(b.0))
152 .map(|(_, v)| v)
153 }
154}
155
156impl_updatable_for_mut_optimizer!(Sgd);