mlx_rs/optimizers/
lion.rs1use mlx_internal_macros::{generate_builder, Buildable};
2
3use crate::{
4 array,
5 utils::{get_mut_or_insert_with, Updatable},
6 Array,
7};
8
9use super::*;
10
11generate_builder! {
12 #[derive(Debug, Clone, Buildable)]
22 #[buildable(root = crate)]
23 #[builder(
24 build_with = build_lion,
25 root = crate
26 )]
27 pub struct Lion {
28 pub lr: f32,
30
31 #[builder(optional, ty_override = Betas, default = Lion::DEFAULT_BETAS)]
34 pub betas: (Array, Array),
35
36 #[builder(optional, default = Lion::DEFAULT_WEIGHT_DECAY)]
38 pub weight_decay: f32,
39
40 #[builder(ignore)]
42 pub state: State,
43 }
44}
45
46fn build_lion(builder: LionBuilder) -> Result<Lion, std::convert::Infallible> {
47 let lr = builder.lr;
48 let betas = builder.betas;
49 let weight_decay = builder.weight_decay;
50
51 Ok(Lion {
52 lr,
53 betas: (array!(betas.0), array!(betas.1)),
54 weight_decay,
55 state: State::new(),
56 })
57}
58
59impl Lion {
60 pub const DEFAULT_BETAS: (f32, f32) = (0.9, 0.999);
62
63 pub const DEFAULT_WEIGHT_DECAY: f32 = 0.0;
65}
66
67impl Optimizer for Lion {
68 type State = State;
69
70 fn state(&self) -> &Self::State {
71 &self.state
72 }
73
74 fn state_mut(&mut self) -> &mut Self::State {
75 &mut self.state
76 }
77
78 fn update_single(
79 &mut self,
80 key: &std::rc::Rc<str>,
81 gradient: &Array,
82 parameter: &mut Array,
83 ) -> Result<(), crate::error::Exception> {
84 use crate::ops::sign;
85
86 let (b1, b2) = &self.betas;
87 let m = get_mut_or_insert_with(&mut self.state, key, || array!(0.0));
88
89 let one_minus_b1 = array!(1.0).subtract(b1)?;
90 let one_minus_b2 = array!(1.0).subtract(b2)?;
91
92 let c = b1.multiply(&m)?.add(&one_minus_b1.multiply(gradient)?)?;
93 *m = b2.multiply(&m)?.add(&one_minus_b2.multiply(gradient)?)?;
94
95 if self.weight_decay > 0.0 {
96 *parameter = array!(1.0 - self.lr * self.weight_decay) * &*parameter;
98 }
99
100 let lr = array!(self.lr);
101 *parameter = parameter.subtract(lr.multiply(sign(&c)?)?)?;
102
103 Ok(())
104 }
105}
106
107impl Updatable for Lion {
108 fn updatable_states(&self) -> impl IntoIterator<Item = &Array> {
109 use itertools::Itertools;
110
111 self.state
112 .iter()
113 .sorted_by(|a, b| a.0.cmp(b.0))
114 .map(|(_, v)| v)
115 }
116
117 fn updatable_states_mut(&mut self) -> impl IntoIterator<Item = &mut Array> {
118 use itertools::Itertools;
119
120 self.state
121 .iter_mut()
122 .sorted_by(|a, b| a.0.cmp(b.0))
123 .map(|(_, v)| v)
124 }
125}
126
127impl_updatable_for_mut_optimizer!(Lion);