mlx_rs/optimizers/
adamw.rsuse std::convert::Infallible;
use mlx_internal_macros::{generate_builder, Buildable};
use crate::{
array,
utils::{get_mut_or_insert_with, Updatable},
Array,
};
use super::*;
generate_builder! {
#[derive(Debug, Clone, Buildable)]
#[buildable(root = crate)]
#[builder(
build_with = build_adamw,
root = crate
)]
pub struct AdamW {
#[builder(ty_override = f32)]
pub lr: Array,
#[builder(optional, ty_override = Betas, default = AdamW::DEFAULT_BETAS)]
pub betas: (Array, Array),
#[builder(optional, ty_override = f32, default = AdamW::DEFAULT_EPS)]
pub eps: Array,
#[builder(optional, ty_override = f32, default = AdamW::DEFAULT_WEIGHT_DECAY)]
pub weight_decay: Array,
#[builder(ignore)]
pub state: State<(Array, Array)>,
}
}
fn build_adamw(builder: AdamWBuilder) -> Result<AdamW, Infallible> {
let lr = builder.lr;
let betas = builder.betas;
let eps = builder.eps;
let weight_decay = builder.weight_decay;
Ok(AdamW {
lr: array!(lr),
betas: (array!(betas.0), array!(betas.1)),
eps: array!(eps),
weight_decay: array!(weight_decay),
state: State::new(),
})
}
impl AdamW {
pub const DEFAULT_BETAS: (f32, f32) = super::Adam::DEFAULT_BETAS;
pub const DEFAULT_EPS: f32 = super::Adam::DEFAULT_EPS;
pub const DEFAULT_WEIGHT_DECAY: f32 = 0.01;
}
impl Optimizer for AdamW {
type State = State<(Array, Array)>;
fn state(&self) -> &Self::State {
&self.state
}
fn state_mut(&mut self) -> &mut Self::State {
&mut self.state
}
fn update_single(
&mut self,
key: &std::rc::Rc<str>,
gradient: &Array,
parameter: &mut Array,
) -> Result<(), crate::error::Exception> {
let betas = &self.betas;
let state = get_mut_or_insert_with(&mut self.state, key, || (array!(0.0), array!(0.0)));
let one_minus_lr_wd = array!(1.0) - (&self.lr * &self.weight_decay);
let decayed_parameter = &*parameter * &one_minus_lr_wd;
let (new_parameter, new_states) = super::adam_apply_single(
&self.lr,
betas,
&self.eps,
gradient,
&decayed_parameter,
state,
)?;
*state = new_states;
*parameter = new_parameter;
Ok(())
}
}
impl Updatable for AdamW {
fn updatable_states(&self) -> impl IntoIterator<Item = &Array> {
use itertools::Itertools;
self.state
.iter()
.sorted_by(|a, b| a.0.cmp(b.0))
.flat_map(|(_, (v, u))| vec![v, u])
}
fn updatable_states_mut(&mut self) -> impl IntoIterator<Item = &mut Array> {
use itertools::Itertools;
self.state
.iter_mut()
.sorted_by(|a, b| a.0.cmp(b.0))
.flat_map(|(_, (v, u))| vec![v, u])
}
}
impl_updatable_for_mut_optimizer!(AdamW);