1use std::{borrow::Cow, collections::HashMap, rc::Rc};
2
3use mlx_internal_macros::{generate_builder, Buildable};
4
5use crate::{
6 array,
7 error::AdafactorBuildError,
8 ops::{
9 matmul, maximum, mean, mean_axes, minimum, rsqrt, sqrt, square, zeros_dtype, zeros_like,
10 },
11 utils::Updatable,
12 Array,
13};
14
15use super::*;
16
17fn rms(inputs: &Array) -> crate::error::Result<Array> {
18 sqrt(&mean(&square(inputs)?, None)?)
19}
20
21fn approvate_exp_moving_avg(
22 exp_avg_sq_row: &Array,
23 exp_avg_sq_col: &Array,
24) -> crate::error::Result<Array> {
25 let rfactor = rsqrt(&exp_avg_sq_row.divide(&mean_axes(exp_avg_sq_row, &[-1], true)?)?)?;
26 let cfactor = rsqrt(exp_avg_sq_col)?;
27 matmul(&rfactor.expand_dims(-1)?, &cfactor.expand_dims(0)?)
28}
29
30pub type AdafactorEps = (f32, f32);
32
33#[derive(Debug, Clone)]
35pub struct AdafactorState {
36 pub(crate) step: Array,
37 pub(crate) exp_avg_sq_row: Option<Array>,
38 pub(crate) exp_avg_sq_col: Option<Array>,
39 pub(crate) exp_avg_sq: Option<Array>,
40 pub(crate) exp_avg: Option<Array>,
41}
42
43impl OptimizerState for State<AdafactorState> {
44 type UnflattenError = UnflattenError;
45
46 fn flatten(&self) -> impl Iterator<Item = (Rc<str>, &Array)> {
47 self.iter().flat_map(|(k, v)| {
48 let mut iter = vec![(Rc::from(format!("{}.step", k)), &v.step)];
49
50 if let Some(exp_avg_sq_row) = &v.exp_avg_sq_row {
51 iter.push((Rc::from(format!("{}.exp_avg_sq_row", k)), exp_avg_sq_row));
52 }
53
54 if let Some(exp_avg_sq_col) = &v.exp_avg_sq_col {
55 iter.push((Rc::from(format!("{}.exp_avg_sq_col", k)), exp_avg_sq_col));
56 }
57
58 if let Some(exp_avg_sq) = &v.exp_avg_sq {
59 iter.push((Rc::from(format!("{}.exp_avg_sq", k)), exp_avg_sq));
60 }
61
62 if let Some(exp_avg) = &v.exp_avg {
63 iter.push((Rc::from(format!("{}.exp_avg", k)), exp_avg));
64 }
65
66 iter
67 })
68 }
69
70 fn flatten_mut(&mut self) -> impl Iterator<Item = (Rc<str>, &mut Array)> {
71 self.iter_mut().flat_map(|(k, v)| {
72 let mut iter = vec![(Rc::from(format!("{}.step", k)), &mut v.step)];
73
74 if let Some(exp_avg_sq_row) = &mut v.exp_avg_sq_row {
75 iter.push((Rc::from(format!("{}.exp_avg_sq_row", k)), exp_avg_sq_row));
76 }
77
78 if let Some(exp_avg_sq_col) = &mut v.exp_avg_sq_col {
79 iter.push((Rc::from(format!("{}.exp_avg_sq_col", k)), exp_avg_sq_col));
80 }
81
82 if let Some(exp_avg_sq) = &mut v.exp_avg_sq {
83 iter.push((Rc::from(format!("{}.exp_avg_sq", k)), exp_avg_sq));
84 }
85
86 if let Some(exp_avg) = &mut v.exp_avg {
87 iter.push((Rc::from(format!("{}.exp_avg", k)), exp_avg));
88 }
89
90 iter
91 })
92 }
93
94 fn unflatten<I, K>(input: I) -> Result<Self, Self::UnflattenError>
95 where
96 Self: Sized,
97 I: IntoIterator<Item = (K, Array)>,
98 K: Ord + AsRef<str> + Into<Rc<str>>,
99 {
100 let mut state = State::new();
101 let iter = input
102 .into_iter()
103 .sorted_by(|a, b| a.0.as_ref().cmp(b.0.as_ref()));
104
105 for (k, v) in iter {
106 let key = k.into();
107 let mut parts = key.rsplit('.');
108 let suffix = parts.next().ok_or(UnflattenError::InvalidKey)?;
109 let prefix = parts.next().ok_or(UnflattenError::InvalidKey)?;
110
111 let prefix = Rc::from(prefix);
112 let state = state.entry(prefix).or_insert_with(|| AdafactorState {
113 step: array!(AdafactorState::DEFAULT_STEP),
114 exp_avg_sq_row: None,
115 exp_avg_sq_col: None,
116 exp_avg_sq: None,
117 exp_avg: None,
118 });
119
120 match suffix {
121 "step" => state.step = v,
122 "exp_avg_sq_row" => state.exp_avg_sq_row = Some(v),
123 "exp_avg_sq_col" => state.exp_avg_sq_col = Some(v),
124 "exp_avg_sq" => state.exp_avg_sq = Some(v),
125 "exp_avg" => state.exp_avg = Some(v),
126 _ => return Err(UnflattenError::InvalidKey),
127 }
128 }
129
130 Ok(state)
131 }
132}
133
134impl AdafactorState {
135 pub const DEFAULT_STEP: i32 = 0;
137
138 fn new(parameter: &Array, beta1_is_some: bool) -> crate::error::Result<Self> {
139 let step = array!(Self::DEFAULT_STEP);
140 let mut exp_avg_sq_row = None;
141 let mut exp_avg_sq_col = None;
142 let mut exp_avg_sq = None;
143 let mut exp_avg = None;
144
145 if parameter.ndim() >= 2 {
146 let shape = parameter.shape();
147 let dtype = parameter.dtype();
148
149 let row_shape = &shape[..shape.len() - 1];
150 exp_avg_sq_row = Some(zeros_dtype(row_shape, dtype)?);
151
152 let mut col_shape = shape[..shape.len() - 2].to_vec();
153 col_shape.push(*shape.last().unwrap());
154 exp_avg_sq_col = Some(zeros_dtype(&col_shape, dtype)?);
155 } else {
156 exp_avg_sq = Some(zeros_like(parameter)?);
157 };
158
159 if beta1_is_some {
160 exp_avg = Some(zeros_like(parameter)?);
161 }
162
163 Ok(Self {
164 step,
165 exp_avg_sq_row,
166 exp_avg_sq_col,
167 exp_avg_sq,
168 exp_avg,
169 })
170 }
171}
172
173pub type AdafactorBuilderLr = Option<f32>;
176
177pub type AdafactorLr = Option<Array>;
179
180pub type AdafactorBuilderBeta1 = Option<f32>;
183
184pub type AdafactorBeta1 = Option<Array>;
186
187generate_builder! {
188 #[derive(Debug, Clone, Buildable)]
194 #[buildable(root = crate)]
195 #[builder(
196 build_with = build_adafactor,
197 err = AdafactorBuildError,
198 root = crate
199 )]
200 pub struct Adafactor {
201 #[builder(optional, default = Adafactor::DEFAULT_LR)]
203 pub lr: Option<f32>,
204
205 #[builder(optional, ty_override = AdafactorEps, default = Adafactor::DEFAULT_EPS)]
208 pub eps: (Array, Array),
209
210 #[builder(optional, ty_override = f32, default = Adafactor::DEFAULT_CLIP_THRESHOLD)]
212 pub clip_threshold: Array,
213
214 #[builder(optional, ty_override = f32, default = Adafactor::DEFAULT_DECAY_RATE)]
217 pub decay_rate: Array,
218
219 #[builder(optional, ty_override = AdafactorBuilderBeta1, default = Adafactor::DEFAULT_BETA1)]
221 pub beta1: AdafactorBeta1,
222
223 #[builder(optional, default = Adafactor::DEFAULT_WEIGHT_DECAY)]
225 pub weight_decay: f32,
226
227 #[builder(optional, default = Adafactor::DEFAULT_SCALE_PARAMETER)]
230 pub scale_parameter: bool,
231
232 #[builder(optional, ty_override = bool, default = Adafactor::DEFAULT_RELATIVE_STEP)]
235 pub relative_step: bool,
236
237 #[builder(optional, default = Adafactor::DEFAULT_WARMUP_INIT)]
240 pub warmup_init: bool,
241
242 #[builder(ignore)]
244 pub state: State<AdafactorState>,
245 }
246}
247
248fn build_adafactor(builder: AdafactorBuilder) -> Result<Adafactor, AdafactorBuildError> {
250 let eps = builder.eps;
251 let clip_threshold = builder.clip_threshold;
252 let decay_rate = builder.decay_rate;
253 let weight_decay = builder.weight_decay;
254 let scale_parameter = builder.scale_parameter;
255 let relative_step = builder.relative_step;
256 let warmup_init = builder.warmup_init;
257
258 if builder.lr.is_none() && !relative_step {
259 return Err(AdafactorBuildError::LrIsNoneAndRelativeStepIsFalse);
260 }
261
262 Ok(Adafactor {
263 lr: builder.lr,
264 eps: (array!(eps.0), array!(eps.1)),
265 clip_threshold: array!(clip_threshold),
266 decay_rate: array!(decay_rate),
267 beta1: builder.beta1.map(Array::from),
268 weight_decay,
269 scale_parameter,
270 relative_step,
271 warmup_init,
272 state: State::new(),
273 })
274}
275
276impl Adafactor {
277 pub const DEFAULT_LR: Option<f32> = None;
279
280 pub const DEFAULT_EPS: (f32, f32) = (1e-30, 1e-3);
282
283 pub const DEFAULT_CLIP_THRESHOLD: f32 = 1.0;
285
286 pub const DEFAULT_DECAY_RATE: f32 = -0.8;
288
289 pub const DEFAULT_WEIGHT_DECAY: f32 = 0.0;
291
292 pub const DEFAULT_SCALE_PARAMETER: bool = true;
294
295 pub const DEFAULT_RELATIVE_STEP: bool = true;
297
298 pub const DEFAULT_WARMUP_INIT: bool = false;
300
301 pub const DEFAULT_BETA1: Option<f32> = None;
303}
304
305fn get_mut_or_insert_with<'a, T, E>(
306 map: &'a mut HashMap<Rc<str>, T>,
307 key: &Rc<str>,
308 f: impl FnOnce() -> Result<T, E>,
309) -> Result<&'a mut T, E> {
310 if !map.contains_key(key) {
311 map.insert(key.clone(), f()?);
312 }
313
314 Ok(map.get_mut(key).unwrap())
315}
316
317fn compute_lr(
318 relative_step: bool,
319 warmup_init: bool,
320 lr: Option<f32>,
321 scale_parameter: bool,
322 eps: &(Array, Array),
323 step: &Array,
324 parameter_rms: &Array,
325) -> crate::error::Result<Array> {
326 let relative_step_size = if relative_step {
327 let min_step = if warmup_init {
328 array!(1e-6) * step
330 } else {
331 array!(1e-2)
332 };
333 minimum(min_step, array!(1.0) / sqrt(step)?)?
335 } else {
336 array!(lr.expect("The learning rate should be set if the relative step is not enabled"))
338 };
339
340 let mut parameter_scale = array!(1.0);
341 if scale_parameter {
342 parameter_scale = maximum(&eps.1, parameter_rms)?;
343 }
344
345 parameter_scale.multiply(relative_step_size)
346}
347
348impl Optimizer for Adafactor {
349 type State = State<AdafactorState>;
350
351 fn state(&self) -> &Self::State {
352 &self.state
353 }
354
355 fn state_mut(&mut self) -> &mut Self::State {
356 &mut self.state
357 }
358
359 fn update_single(
360 &mut self,
361 key: &std::rc::Rc<str>,
362 gradient: &Array,
363 parameter: &mut Array,
364 ) -> crate::error::Result<()> {
365 let beta1_is_some = self.beta1.is_some();
366 let state = get_mut_or_insert_with(&mut self.state, key, || {
367 AdafactorState::new(parameter, beta1_is_some)
368 })?;
369
370 state.step = state.step.add(array!(1))?;
371
372 let gradient_shape = gradient.shape();
373 let factored = gradient_shape.len() >= 2;
374 let step = &state.step;
375
376 let parameter_rms = rms(parameter)?;
377 let lr = compute_lr(
378 self.relative_step,
379 self.warmup_init,
380 self.lr,
381 self.scale_parameter,
382 &self.eps,
383 step,
384 ¶meter_rms,
385 )?;
386 let beta2 = array!(1.0).subtract(&step.power(&self.decay_rate)?)?;
387
388 let mut update: Cow<Array> = Cow::Owned(gradient.square()?.add(&self.eps.0)?);
389
390 let one_minus_beta2 = array!(1.0).subtract(&beta2)?;
391 if factored {
392 let exp_avg_sq_row = state.exp_avg_sq_row.as_mut().unwrap();
394 let exp_avg_sq_col = state.exp_avg_sq_col.as_mut().unwrap();
395
396 *exp_avg_sq_row = beta2
397 .multiply(&*exp_avg_sq_row)?
398 .add(&one_minus_beta2.multiply(&update.mean_axes(&[-1], None)?)?)?;
399 *exp_avg_sq_col = beta2
400 .multiply(&*exp_avg_sq_col)?
401 .add(&one_minus_beta2.multiply(&update.mean_axes(&[-2], None)?)?)?;
402
403 update = Cow::Owned(approvate_exp_moving_avg(
404 &*exp_avg_sq_row,
405 &*exp_avg_sq_col,
406 )?);
407 update = Cow::Owned(update.multiply(gradient)?);
408 } else {
409 let exp_avg_sq = state.exp_avg_sq.as_mut().unwrap();
411
412 *exp_avg_sq = beta2
413 .multiply(&*exp_avg_sq)?
414 .add(&one_minus_beta2.multiply(&update)?)?;
415 update = Cow::Owned(rsqrt(&*exp_avg_sq)?.multiply(gradient)?);
416 }
417
418 let update_rms = rms(&update)?;
419 let max = maximum(array!(1.0), update_rms.divide(&self.clip_threshold)?)?;
420 update = Cow::Owned(update.divide(max)?);
421 update = Cow::Owned(lr.multiply(update)?);
422
423 if let Some(beta1) = &self.beta1 {
424 let exp_avg = state.exp_avg.as_mut().unwrap();
426 let one_minus_beta1 = array!(1.0).subtract(beta1)?;
427 *exp_avg = beta1
428 .multiply(&*exp_avg)?
429 .add(&one_minus_beta1.multiply(&update)?)?;
430 update = Cow::Borrowed(&*exp_avg);
431 }
432
433 if self.weight_decay != 0.0 {
434 let rhs = parameter.multiply(array!(-self.weight_decay).multiply(lr)?)?;
435 *parameter = parameter.add(rhs)?;
436 }
437
438 *parameter = parameter.subtract(&update)?;
439
440 Ok(())
441 }
442}
443
444impl Updatable for Adafactor {
445 fn updatable_states_len(&self) -> usize {
446 self.updatable_states().into_iter().count()
447 }
448
449 fn updatable_states(&self) -> impl IntoIterator<Item = &Array> {
450 use itertools::Itertools;
451
452 self.state
453 .iter()
454 .sorted_by(|a, b| a.0.cmp(b.0))
455 .flat_map(|(_, v)| {
456 [
458 &v.exp_avg_sq_row,
459 &v.exp_avg_sq_col,
460 &v.exp_avg_sq,
461 &v.exp_avg,
462 ]
463 .into_iter()
464 .filter_map(|v| v.as_ref())
465 .collect::<Vec<_>>()
466 })
467 }
468
469 fn updatable_states_mut(&mut self) -> impl IntoIterator<Item = &mut Array> {
470 use itertools::Itertools;
471
472 self.state
473 .iter_mut()
474 .sorted_by(|a, b| a.0.cmp(b.0))
475 .flat_map(|(_, v)| {
476 [
478 &mut v.exp_avg_sq_row,
479 &mut v.exp_avg_sq_col,
480 &mut v.exp_avg_sq,
481 &mut v.exp_avg,
482 ]
483 .into_iter()
484 .filter_map(|v| v.as_mut())
485 .collect::<Vec<_>>()
486 })
487 }
488}
489
490impl_updatable_for_mut_optimizer!(Adafactor);