mlx_rs/optimizers/
lion.rs

1use mlx_internal_macros::{generate_builder, Buildable};
2
3use crate::{
4    array,
5    utils::{get_mut_or_insert_with, Updatable},
6    Array,
7};
8
9use super::*;
10
11generate_builder! {
12    /// The Lion optimizer [1].
13    ///
14    /// Since updates are computed through the sign operation, they tend to have larger norm than
15    /// for other optimizers such as SGD and Adam. We recommend a learning rate that is 3-10x
16    /// smaller than AdamW and a weight decay 3-10x larger than AdamW to maintain the strength `(lr
17    /// * wd)`. Our Lion implementation follows the original paper. In detail,
18    ///
19    /// [1]: Chen, X. Symbolic Discovery of Optimization Algorithms. arXiv preprint
20    ///     arXiv:2302.06675.
21    #[derive(Debug, Clone, Buildable)]
22    #[buildable(root = crate)]
23    #[builder(
24        build_with = build_lion,
25        root = crate
26    )]
27    pub struct Lion {
28        /// The learning rate.
29        pub lr: f32,
30
31        /// The coefficients used for computing running averages of the gradient and its square.
32        /// Default to [`Lion::DEFAULT_BETAS`].
33        #[builder(optional, ty_override = Betas, default = Lion::DEFAULT_BETAS)]
34        pub betas: (Array, Array),
35
36        /// The weight decay. Default to [`Lion::DEFAULT_WEIGHT_DECAY`].
37        #[builder(optional, default = Lion::DEFAULT_WEIGHT_DECAY)]
38        pub weight_decay: f32,
39
40        /// Inner state.
41        #[builder(ignore)]
42        pub state: State,
43    }
44}
45
46fn build_lion(builder: LionBuilder) -> Result<Lion, std::convert::Infallible> {
47    let lr = builder.lr;
48    let betas = builder.betas;
49    let weight_decay = builder.weight_decay;
50
51    Ok(Lion {
52        lr,
53        betas: (array!(betas.0), array!(betas.1)),
54        weight_decay,
55        state: State::new(),
56    })
57}
58
59impl Lion {
60    /// Default values for `betas`
61    pub const DEFAULT_BETAS: (f32, f32) = (0.9, 0.999);
62
63    /// Default value for `weight_decay`
64    pub const DEFAULT_WEIGHT_DECAY: f32 = 0.0;
65}
66
67impl Optimizer for Lion {
68    type State = State;
69
70    fn state(&self) -> &Self::State {
71        &self.state
72    }
73
74    fn state_mut(&mut self) -> &mut Self::State {
75        &mut self.state
76    }
77
78    fn update_single(
79        &mut self,
80        key: &std::rc::Rc<str>,
81        gradient: &Array,
82        parameter: &mut Array,
83    ) -> Result<(), crate::error::Exception> {
84        use crate::ops::sign;
85
86        let (b1, b2) = &self.betas;
87        let m = get_mut_or_insert_with(&mut self.state, key, || array!(0.0));
88
89        let one_minus_b1 = array!(1.0).subtract(b1)?;
90        let one_minus_b2 = array!(1.0).subtract(b2)?;
91
92        let c = b1.multiply(&m)?.add(&one_minus_b1.multiply(gradient)?)?;
93        *m = b2.multiply(&m)?.add(&one_minus_b2.multiply(gradient)?)?;
94
95        if self.weight_decay > 0.0 {
96            // SAFETY: These coeffs are all single-element arrays and won't panic.
97            *parameter = array!(1.0 - self.lr * self.weight_decay) * &*parameter;
98        }
99
100        let lr = array!(self.lr);
101        *parameter = parameter.subtract(lr.multiply(sign(&c)?)?)?;
102
103        Ok(())
104    }
105}
106
107impl Updatable for Lion {
108    fn updatable_states(&self) -> impl IntoIterator<Item = &Array> {
109        use itertools::Itertools;
110
111        self.state
112            .iter()
113            .sorted_by(|a, b| a.0.cmp(b.0))
114            .map(|(_, v)| v)
115    }
116
117    fn updatable_states_mut(&mut self) -> impl IntoIterator<Item = &mut Array> {
118        use itertools::Itertools;
119
120        self.state
121            .iter_mut()
122            .sorted_by(|a, b| a.0.cmp(b.0))
123            .map(|(_, v)| v)
124    }
125}
126
127impl_updatable_for_mut_optimizer!(Lion);