mlx_rs/optimizers/
adam.rsuse std::convert::Infallible;
use mlx_internal_macros::{generate_builder, Buildable};
use crate::{array, utils::get_mut_or_insert_with};
use super::*;
pub type Betas = (f32, f32); generate_builder! {
#[derive(Debug, Clone, Buildable)]
#[buildable(root = crate)]
#[builder(
build_with = build_adam,
root = crate
)]
pub struct Adam {
#[builder(ty_override = f32)]
pub lr: Array,
#[builder(optional, ty_override = Betas, default = Adam::DEFAULT_BETAS)]
pub betas: (Array, Array),
#[builder(optional, ty_override = f32, default = Adam::DEFAULT_EPS)]
pub eps: Array,
#[builder(ignore)]
pub state: State<(Array, Array)>,
}
}
fn build_adam(builder: AdamBuilder) -> Result<Adam, Infallible> {
let lr = array!(builder.lr);
let betas = builder.betas;
let eps = array!(builder.eps);
Ok(Adam {
lr,
betas: (array!(betas.0), array!(betas.1)),
eps,
state: State::new(),
})
}
impl Adam {
pub const DEFAULT_BETAS: (f32, f32) = (0.9, 0.999);
pub const DEFAULT_EPS: f32 = 1e-8;
}
impl Optimizer for Adam {
type State = State<(Array, Array)>;
fn state(&self) -> &Self::State {
&self.state
}
fn state_mut(&mut self) -> &mut Self::State {
&mut self.state
}
fn update_single(
&mut self,
key: &Rc<str>,
gradient: &Array,
parameter: &mut Array,
) -> crate::error::Result<()> {
let betas = &self.betas;
let state = get_mut_or_insert_with(&mut self.state, key, || (array!(0.0), array!(0.0)));
let (new_parameter, new_state) =
adam_apply_single(&self.lr, betas, &self.eps, gradient, parameter, state)?;
*state = new_state;
*parameter = new_parameter;
Ok(())
}
}
pub(super) fn adam_apply_single(
lr: &Array,
betas: &(Array, Array),
eps: &Array,
gradient: &Array,
parameter: &Array,
state: &(Array, Array),
) -> crate::error::Result<(Array, (Array, Array))> {
let (b1, b2) = betas;
let (m, v) = state;
let one_minus_b1 = array!(1.0).subtract(b1)?;
let one_minus_b2 = array!(1.0).subtract(b2)?;
let new_m = b1.multiply(m)?.add(&one_minus_b1.multiply(gradient)?)?;
let new_v = b2
.multiply(v)?
.add(&one_minus_b2.multiply(gradient.square()?)?)?;
let new_parameter =
parameter.subtract(&lr.multiply(&new_m.divide(&new_v.sqrt()?.add(eps)?)?)?)?;
Ok((new_parameter, (new_m, new_v)))
}
impl Updatable for Adam {
fn updatable_states(&self) -> impl IntoIterator<Item = &Array> {
use itertools::Itertools;
self.state
.iter()
.sorted_by(|a, b| a.0.cmp(b.0))
.flat_map(|(_, (v, u))| vec![v, u])
}
fn updatable_states_mut(&mut self) -> impl IntoIterator<Item = &mut Array> {
use itertools::Itertools;
self.state
.iter_mut()
.sorted_by(|a, b| a.0.cmp(b.0))
.flat_map(|(_, (v, u))| vec![v, u])
}
}
impl_updatable_for_mut_optimizer!(Adam);