1use std::{borrow::Cow, collections::HashMap, rc::Rc};
2
3use mlx_internal_macros::{generate_builder, Buildable};
4
5use crate::{
6 array,
7 error::AdafactorBuildError,
8 ops::{matmul, maximum, mean, minimum, rsqrt, sqrt, square, zeros_dtype, zeros_like},
9 utils::Updatable,
10 Array,
11};
12
13use super::*;
14
15fn rms(inputs: &Array) -> crate::error::Result<Array> {
16 sqrt(&mean(&square(inputs)?, None, None)?)
17}
18
19fn approvate_exp_moving_avg(
20 exp_avg_sq_row: &Array,
21 exp_avg_sq_col: &Array,
22) -> crate::error::Result<Array> {
23 let rfactor = rsqrt(&exp_avg_sq_row.divide(&mean(exp_avg_sq_row, &[-1], true)?)?)?;
24 let cfactor = rsqrt(exp_avg_sq_col)?;
25 matmul(&rfactor.expand_dims(&[-1])?, &cfactor.expand_dims(&[0])?)
26}
27
28pub type AdafactorEps = (f32, f32);
30
31#[derive(Debug, Clone)]
33pub struct AdafactorState {
34 pub(crate) step: Array,
35 pub(crate) exp_avg_sq_row: Option<Array>,
36 pub(crate) exp_avg_sq_col: Option<Array>,
37 pub(crate) exp_avg_sq: Option<Array>,
38 pub(crate) exp_avg: Option<Array>,
39}
40
41impl OptimizerState for State<AdafactorState> {
42 type UnflattenError = UnflattenError;
43
44 fn flatten(&self) -> impl Iterator<Item = (Rc<str>, &Array)> {
45 self.iter().flat_map(|(k, v)| {
46 let mut iter = vec![(Rc::from(format!("{}.step", k)), &v.step)];
47
48 if let Some(exp_avg_sq_row) = &v.exp_avg_sq_row {
49 iter.push((Rc::from(format!("{}.exp_avg_sq_row", k)), exp_avg_sq_row));
50 }
51
52 if let Some(exp_avg_sq_col) = &v.exp_avg_sq_col {
53 iter.push((Rc::from(format!("{}.exp_avg_sq_col", k)), exp_avg_sq_col));
54 }
55
56 if let Some(exp_avg_sq) = &v.exp_avg_sq {
57 iter.push((Rc::from(format!("{}.exp_avg_sq", k)), exp_avg_sq));
58 }
59
60 if let Some(exp_avg) = &v.exp_avg {
61 iter.push((Rc::from(format!("{}.exp_avg", k)), exp_avg));
62 }
63
64 iter
65 })
66 }
67
68 fn flatten_mut(&mut self) -> impl Iterator<Item = (Rc<str>, &mut Array)> {
69 self.iter_mut().flat_map(|(k, v)| {
70 let mut iter = vec![(Rc::from(format!("{}.step", k)), &mut v.step)];
71
72 if let Some(exp_avg_sq_row) = &mut v.exp_avg_sq_row {
73 iter.push((Rc::from(format!("{}.exp_avg_sq_row", k)), exp_avg_sq_row));
74 }
75
76 if let Some(exp_avg_sq_col) = &mut v.exp_avg_sq_col {
77 iter.push((Rc::from(format!("{}.exp_avg_sq_col", k)), exp_avg_sq_col));
78 }
79
80 if let Some(exp_avg_sq) = &mut v.exp_avg_sq {
81 iter.push((Rc::from(format!("{}.exp_avg_sq", k)), exp_avg_sq));
82 }
83
84 if let Some(exp_avg) = &mut v.exp_avg {
85 iter.push((Rc::from(format!("{}.exp_avg", k)), exp_avg));
86 }
87
88 iter
89 })
90 }
91
92 fn unflatten<I, K>(input: I) -> Result<Self, Self::UnflattenError>
93 where
94 Self: Sized,
95 I: IntoIterator<Item = (K, Array)>,
96 K: Ord + AsRef<str> + Into<Rc<str>>,
97 {
98 let mut state = State::new();
99 let iter = input
100 .into_iter()
101 .sorted_by(|a, b| a.0.as_ref().cmp(b.0.as_ref()));
102
103 for (k, v) in iter {
104 let key = k.into();
105 let mut parts = key.rsplit('.');
106 let suffix = parts.next().ok_or(UnflattenError::InvalidKey)?;
107 let prefix = parts.next().ok_or(UnflattenError::InvalidKey)?;
108
109 let prefix = Rc::from(prefix);
110 let state = state.entry(prefix).or_insert_with(|| AdafactorState {
111 step: array!(AdafactorState::DEFAULT_STEP),
112 exp_avg_sq_row: None,
113 exp_avg_sq_col: None,
114 exp_avg_sq: None,
115 exp_avg: None,
116 });
117
118 match suffix {
119 "step" => state.step = v,
120 "exp_avg_sq_row" => state.exp_avg_sq_row = Some(v),
121 "exp_avg_sq_col" => state.exp_avg_sq_col = Some(v),
122 "exp_avg_sq" => state.exp_avg_sq = Some(v),
123 "exp_avg" => state.exp_avg = Some(v),
124 _ => return Err(UnflattenError::InvalidKey),
125 }
126 }
127
128 Ok(state)
129 }
130}
131
132impl AdafactorState {
133 pub const DEFAULT_STEP: i32 = 0;
135
136 fn new(parameter: &Array, beta1_is_some: bool) -> crate::error::Result<Self> {
137 let step = array!(Self::DEFAULT_STEP);
138 let mut exp_avg_sq_row = None;
139 let mut exp_avg_sq_col = None;
140 let mut exp_avg_sq = None;
141 let mut exp_avg = None;
142
143 if parameter.ndim() >= 2 {
144 let shape = parameter.shape();
145 let dtype = parameter.dtype();
146
147 let row_shape = &shape[..shape.len() - 1];
148 exp_avg_sq_row = Some(zeros_dtype(row_shape, dtype)?);
149
150 let mut col_shape = shape[..shape.len() - 2].to_vec();
151 col_shape.push(*shape.last().unwrap());
152 exp_avg_sq_col = Some(zeros_dtype(&col_shape, dtype)?);
153 } else {
154 exp_avg_sq = Some(zeros_like(parameter)?);
155 };
156
157 if beta1_is_some {
158 exp_avg = Some(zeros_like(parameter)?);
159 }
160
161 Ok(Self {
162 step,
163 exp_avg_sq_row,
164 exp_avg_sq_col,
165 exp_avg_sq,
166 exp_avg,
167 })
168 }
169}
170
171pub type AdafactorBuilderLr = Option<f32>;
174
175pub type AdafactorLr = Option<Array>;
177
178pub type AdafactorBuilderBeta1 = Option<f32>;
181
182pub type AdafactorBeta1 = Option<Array>;
184
185generate_builder! {
186 #[derive(Debug, Clone, Buildable)]
192 #[buildable(root = crate)]
193 #[builder(
194 build_with = build_adafactor,
195 err = AdafactorBuildError,
196 root = crate
197 )]
198 pub struct Adafactor {
199 #[builder(optional, default = Adafactor::DEFAULT_LR)]
201 pub lr: Option<f32>,
202
203 #[builder(optional, ty_override = AdafactorEps, default = Adafactor::DEFAULT_EPS)]
206 pub eps: (Array, Array),
207
208 #[builder(optional, ty_override = f32, default = Adafactor::DEFAULT_CLIP_THRESHOLD)]
210 pub clip_threshold: Array,
211
212 #[builder(optional, ty_override = f32, default = Adafactor::DEFAULT_DECAY_RATE)]
215 pub decay_rate: Array,
216
217 #[builder(optional, ty_override = AdafactorBuilderBeta1, default = Adafactor::DEFAULT_BETA1)]
219 pub beta1: AdafactorBeta1,
220
221 #[builder(optional, default = Adafactor::DEFAULT_WEIGHT_DECAY)]
223 pub weight_decay: f32,
224
225 #[builder(optional, default = Adafactor::DEFAULT_SCALE_PARAMETER)]
228 pub scale_parameter: bool,
229
230 #[builder(optional, ty_override = bool, default = Adafactor::DEFAULT_RELATIVE_STEP)]
233 pub relative_step: bool,
234
235 #[builder(optional, default = Adafactor::DEFAULT_WARMUP_INIT)]
238 pub warmup_init: bool,
239
240 #[builder(ignore)]
242 pub state: State<AdafactorState>,
243 }
244}
245
246fn build_adafactor(builder: AdafactorBuilder) -> Result<Adafactor, AdafactorBuildError> {
248 let eps = builder.eps;
249 let clip_threshold = builder.clip_threshold;
250 let decay_rate = builder.decay_rate;
251 let weight_decay = builder.weight_decay;
252 let scale_parameter = builder.scale_parameter;
253 let relative_step = builder.relative_step;
254 let warmup_init = builder.warmup_init;
255
256 if builder.lr.is_none() && !relative_step {
257 return Err(AdafactorBuildError::LrIsNoneAndRelativeStepIsFalse);
258 }
259
260 Ok(Adafactor {
261 lr: builder.lr,
262 eps: (array!(eps.0), array!(eps.1)),
263 clip_threshold: array!(clip_threshold),
264 decay_rate: array!(decay_rate),
265 beta1: builder.beta1.map(Array::from),
266 weight_decay,
267 scale_parameter,
268 relative_step,
269 warmup_init,
270 state: State::new(),
271 })
272}
273
274impl Adafactor {
275 pub const DEFAULT_LR: Option<f32> = None;
277
278 pub const DEFAULT_EPS: (f32, f32) = (1e-30, 1e-3);
280
281 pub const DEFAULT_CLIP_THRESHOLD: f32 = 1.0;
283
284 pub const DEFAULT_DECAY_RATE: f32 = -0.8;
286
287 pub const DEFAULT_WEIGHT_DECAY: f32 = 0.0;
289
290 pub const DEFAULT_SCALE_PARAMETER: bool = true;
292
293 pub const DEFAULT_RELATIVE_STEP: bool = true;
295
296 pub const DEFAULT_WARMUP_INIT: bool = false;
298
299 pub const DEFAULT_BETA1: Option<f32> = None;
301}
302
303fn get_mut_or_insert_with<'a, T, E>(
304 map: &'a mut HashMap<Rc<str>, T>,
305 key: &Rc<str>,
306 f: impl FnOnce() -> Result<T, E>,
307) -> Result<&'a mut T, E> {
308 if !map.contains_key(key) {
309 map.insert(key.clone(), f()?);
310 }
311
312 Ok(map.get_mut(key).unwrap())
313}
314
315fn compute_lr(
316 relative_step: bool,
317 warmup_init: bool,
318 lr: Option<f32>,
319 scale_parameter: bool,
320 eps: &(Array, Array),
321 step: &Array,
322 parameter_rms: &Array,
323) -> crate::error::Result<Array> {
324 let relative_step_size = if relative_step {
325 let min_step = if warmup_init {
326 array!(1e-6) * step
328 } else {
329 array!(1e-2)
330 };
331 minimum(min_step, array!(1.0) / sqrt(step)?)?
333 } else {
334 array!(lr.expect("The learning rate should be set if the relative step is not enabled"))
336 };
337
338 let mut parameter_scale = array!(1.0);
339 if scale_parameter {
340 parameter_scale = maximum(&eps.1, parameter_rms)?;
341 }
342
343 parameter_scale.multiply(relative_step_size)
344}
345
346impl Optimizer for Adafactor {
347 type State = State<AdafactorState>;
348
349 fn state(&self) -> &Self::State {
350 &self.state
351 }
352
353 fn state_mut(&mut self) -> &mut Self::State {
354 &mut self.state
355 }
356
357 fn update_single(
358 &mut self,
359 key: &std::rc::Rc<str>,
360 gradient: &Array,
361 parameter: &mut Array,
362 ) -> crate::error::Result<()> {
363 let beta1_is_some = self.beta1.is_some();
364 let state = get_mut_or_insert_with(&mut self.state, key, || {
365 AdafactorState::new(parameter, beta1_is_some)
366 })?;
367
368 state.step = state.step.add(array!(1))?;
369
370 let gradient_shape = gradient.shape();
371 let factored = gradient_shape.len() >= 2;
372 let step = &state.step;
373
374 let parameter_rms = rms(parameter)?;
375 let lr = compute_lr(
376 self.relative_step,
377 self.warmup_init,
378 self.lr,
379 self.scale_parameter,
380 &self.eps,
381 step,
382 ¶meter_rms,
383 )?;
384 let beta2 = array!(1.0).subtract(&step.power(&self.decay_rate)?)?;
385
386 let mut update: Cow<Array> = Cow::Owned(gradient.square()?.add(&self.eps.0)?);
387
388 let one_minus_beta2 = array!(1.0).subtract(&beta2)?;
389 if factored {
390 let exp_avg_sq_row = state.exp_avg_sq_row.as_mut().unwrap();
392 let exp_avg_sq_col = state.exp_avg_sq_col.as_mut().unwrap();
393
394 *exp_avg_sq_row = beta2
395 .multiply(&*exp_avg_sq_row)?
396 .add(&one_minus_beta2.multiply(&update.mean(&[-1], None)?)?)?;
397 *exp_avg_sq_col = beta2
398 .multiply(&*exp_avg_sq_col)?
399 .add(&one_minus_beta2.multiply(&update.mean(&[-2], None)?)?)?;
400
401 update = Cow::Owned(approvate_exp_moving_avg(
402 &*exp_avg_sq_row,
403 &*exp_avg_sq_col,
404 )?);
405 update = Cow::Owned(update.multiply(gradient)?);
406 } else {
407 let exp_avg_sq = state.exp_avg_sq.as_mut().unwrap();
409
410 *exp_avg_sq = beta2
411 .multiply(&*exp_avg_sq)?
412 .add(&one_minus_beta2.multiply(&update)?)?;
413 update = Cow::Owned(rsqrt(&*exp_avg_sq)?.multiply(gradient)?);
414 }
415
416 let update_rms = rms(&update)?;
417 let max = maximum(array!(1.0), update_rms.divide(&self.clip_threshold)?)?;
418 update = Cow::Owned(update.divide(max)?);
419 update = Cow::Owned(lr.multiply(update)?);
420
421 if let Some(beta1) = &self.beta1 {
422 let exp_avg = state.exp_avg.as_mut().unwrap();
424 let one_minus_beta1 = array!(1.0).subtract(beta1)?;
425 *exp_avg = beta1
426 .multiply(&*exp_avg)?
427 .add(&one_minus_beta1.multiply(&update)?)?;
428 update = Cow::Borrowed(&*exp_avg);
429 }
430
431 if self.weight_decay != 0.0 {
432 let rhs = parameter.multiply(array!(-self.weight_decay).multiply(lr)?)?;
433 *parameter = parameter.add(rhs)?;
434 }
435
436 *parameter = parameter.subtract(&update)?;
437
438 Ok(())
439 }
440}
441
442impl Updatable for Adafactor {
443 fn updatable_states(&self) -> impl IntoIterator<Item = &Array> {
444 use itertools::Itertools;
445
446 self.state
447 .iter()
448 .sorted_by(|a, b| a.0.cmp(b.0))
449 .flat_map(|(_, v)| {
450 [
452 &v.exp_avg_sq_row,
453 &v.exp_avg_sq_col,
454 &v.exp_avg_sq,
455 &v.exp_avg,
456 ]
457 .into_iter()
458 .filter_map(|v| v.as_ref())
459 .collect::<Vec<_>>()
460 })
461 }
462
463 fn updatable_states_mut(&mut self) -> impl IntoIterator<Item = &mut Array> {
464 use itertools::Itertools;
465
466 self.state
467 .iter_mut()
468 .sorted_by(|a, b| a.0.cmp(b.0))
469 .flat_map(|(_, v)| {
470 [
472 &mut v.exp_avg_sq_row,
473 &mut v.exp_avg_sq_col,
474 &mut v.exp_avg_sq,
475 &mut v.exp_avg,
476 ]
477 .into_iter()
478 .filter_map(|v| v.as_mut())
479 .collect::<Vec<_>>()
480 })
481 }
482}
483
484impl_updatable_for_mut_optimizer!(Adafactor);