mlx_rs/optimizers/
sgd.rs

1use std::{borrow::Cow, rc::Rc};
2
3use crate::{array, utils::get_mut_or_insert_with, Array};
4use mlx_internal_macros::{generate_builder, Buildable};
5
6use super::*;
7
8generate_builder! {
9    /// Stochastic gradient descent optimizer.
10    #[derive(Debug, Clone, Buildable)]
11    #[buildable(root = crate)]
12    #[builder(
13        build_with = build_sgd,
14        root = crate
15    )]
16    pub struct Sgd {
17        /// Learning rate
18        pub lr: f32,
19
20        /// Momentum strength. Default to [`Sgd::DEFAULT_MOMENTUM`] if not specified.
21        #[builder(optional, default = Sgd::DEFAULT_MOMENTUM)]
22        pub momentum: f32,
23
24        /// Weight decay (L2 penalty). Default to [`Sgd::DEFAULT_WEIGHT_DECAY`] if not specified.
25        #[builder(optional, default = Sgd::DEFAULT_WEIGHT_DECAY)]
26        pub weight_decay: f32,
27
28        /// Dampening for momentum. Default to [`Sgd::DEFAULT_DAMPENING`] if not specified.
29        #[builder(optional, default = Sgd::DEFAULT_DAMPENING)]
30        pub dampening: f32,
31
32        /// Enables nesterov momentum. Default to [`Sgd::DEFAULT_NESTEROV`] if not specified.
33        #[builder(optional, ty_override = bool, default = Sgd::DEFAULT_NESTEROV)]
34        pub nesterov: bool,
35
36        /// Inner state
37        #[builder(ignore)]
38        pub state: State,
39    }
40}
41
42fn build_sgd(builder: SgdBuilder) -> Result<Sgd, std::convert::Infallible> {
43    let lr = builder.lr;
44    let momentum = builder.momentum;
45    let weight_decay = builder.weight_decay;
46    let dampening = builder.dampening;
47    let nesterov = builder.nesterov;
48
49    Ok(Sgd {
50        lr,
51        momentum,
52        weight_decay,
53        dampening,
54        nesterov,
55        state: State::new(),
56    })
57}
58
59impl Sgd {
60    /// Default momentum if not specified.
61    pub const DEFAULT_MOMENTUM: f32 = 0.0;
62
63    /// Default weight decay if not specified.
64    pub const DEFAULT_WEIGHT_DECAY: f32 = 0.0;
65
66    /// Default dampening if not specified.
67    pub const DEFAULT_DAMPENING: f32 = 0.0;
68
69    /// Default nesterov if not specified.
70    pub const DEFAULT_NESTEROV: bool = false;
71}
72
73impl Optimizer for Sgd {
74    type State = State;
75
76    fn state(&self) -> &Self::State {
77        &self.state
78    }
79
80    fn state_mut(&mut self) -> &mut Self::State {
81        &mut self.state
82    }
83
84    /// Apply SGD to a single parameter. Returns the updated parameter and the updated state.
85    #[inline]
86    fn update_single(
87        &mut self,
88        key: &Rc<str>,
89        gradient: &Array,
90        parameter: &mut Array,
91    ) -> crate::error::Result<()> {
92        let state = get_mut_or_insert_with(&mut self.state, key, || array!(0.0));
93        let mut gradient = Cow::Borrowed(gradient);
94
95        if self.weight_decay != 0.0 {
96            let weight_decay = array!(self.weight_decay);
97            gradient = Cow::Owned(weight_decay.multiply(&*parameter)?.add(&*gradient)?);
98        }
99
100        if self.momentum <= 0.0 {
101            let lr = array!(self.lr);
102            *parameter = parameter.subtract(lr.multiply(gradient)?)?;
103            return Ok(());
104        }
105
106        let mut v = &*state * self.momentum;
107
108        if self.dampening > 0.0 {
109            let dampening = array!(self.dampening);
110            let one_minus_dampening = array!(1.0).subtract(dampening)?;
111            v = v.add(&one_minus_dampening.multiply(&gradient)?)?;
112        } else {
113            v = v.add(&gradient)?;
114        }
115
116        match self.nesterov {
117            true => {
118                let momentum = array!(self.momentum);
119                let lr = array!(self.lr);
120                let update = gradient.add(momentum.multiply(&v)?)?;
121                *parameter = parameter.subtract(lr.multiply(&update)?)?;
122                *state = v;
123            }
124            false => {
125                let update = &v;
126                let lr = array!(self.lr);
127                *parameter = parameter.subtract(lr.multiply(update)?)?;
128                *state = v;
129            }
130        }
131
132        Ok(())
133    }
134}
135
136impl Updatable for Sgd {
137    fn updatable_states(&self) -> impl IntoIterator<Item = &Array> {
138        use itertools::Itertools;
139
140        self.state
141            .iter()
142            .sorted_by(|a, b| a.0.cmp(b.0))
143            .map(|(_, v)| v)
144    }
145
146    fn updatable_states_mut(&mut self) -> impl IntoIterator<Item = &mut Array> {
147        use itertools::Itertools;
148
149        self.state
150            .iter_mut()
151            .sorted_by(|a, b| a.0.cmp(b.0))
152            .map(|(_, v)| v)
153    }
154}
155
156impl_updatable_for_mut_optimizer!(Sgd);