mlx_rs/optimizers/
adadelta.rsuse std::rc::Rc;
use crate::{
array,
ops::sqrt,
utils::{get_mut_or_insert_with, Updatable},
Array,
};
use mlx_internal_macros::{generate_builder, Buildable};
use crate::error::AdaDeltaBuildError;
use super::*;
generate_builder! {
#[derive(Debug, Clone, Buildable)]
#[buildable(root = crate)]
#[builder(
build_with = build_adadelta,
err = AdaDeltaBuildError,
root = crate
)]
pub struct AdaDelta {
#[builder(ty_override = f32)]
pub lr: Array,
#[builder(optional, ty_override = f32, default = AdaDelta::DEFAULT_RHO)]
pub rho: Array,
#[builder(optional, ty_override = f32, default = AdaDelta::DEFAULT_EPS)]
pub eps: Array,
#[builder(ignore)]
pub state: State<(Array, Array)>,
}
}
fn build_adadelta(builder: AdaDeltaBuilder) -> Result<AdaDelta, AdaDeltaBuildError> {
let rho = builder.rho;
let eps = builder.eps;
if rho < 0.0 {
return Err(AdaDeltaBuildError::NegativeRho);
}
if eps < 0.0 {
return Err(AdaDeltaBuildError::NegativeEps);
}
Ok(AdaDelta {
lr: array!(builder.lr),
rho: array!(rho),
eps: array!(eps),
state: State::new(),
})
}
impl AdaDelta {
pub const DEFAULT_RHO: f32 = 0.99;
pub const DEFAULT_EPS: f32 = 1e-6;
}
impl Optimizer for AdaDelta {
type State = State<(Array, Array)>;
fn state(&self) -> &Self::State {
&self.state
}
fn state_mut(&mut self) -> &mut Self::State {
&mut self.state
}
fn update_single(
&mut self,
key: &Rc<str>,
gradient: &Array,
parameter: &mut Array,
) -> crate::error::Result<()> {
let (v, u) = get_mut_or_insert_with(&mut self.state, key, || (array!(0.0), array!(0.0)));
let one_minus_rho = array!(1.0).subtract(&self.rho)?;
let first_term = self.rho.multiply(&v)?;
let second_term = one_minus_rho.multiply(gradient.square()?)?;
let v_new = first_term.add(&second_term)?;
let num = sqrt(&u.add(&self.eps)?)?;
let den = sqrt(&v_new.add(&self.eps)?)?;
let d = num.divide(&den)?.multiply(gradient)?;
let first_term = self.rho.multiply(&u)?;
let second_term = one_minus_rho.multiply(d.square()?)?;
let u_new = first_term.add(&second_term)?;
let param_new = parameter.subtract(self.lr.multiply(d)?)?;
*parameter = param_new;
*v = v_new;
*u = u_new;
Ok(())
}
}
impl Updatable for AdaDelta {
fn updatable_states(&self) -> impl IntoIterator<Item = &Array> {
use itertools::Itertools;
self.state
.iter()
.sorted_by(|a, b| a.0.cmp(b.0))
.flat_map(|(_, (v, u))| vec![v, u])
}
fn updatable_states_mut(&mut self) -> impl IntoIterator<Item = &mut Array> {
use itertools::Itertools;
self.state
.iter_mut()
.sorted_by(|a, b| a.0.cmp(b.0))
.flat_map(|(_, (v, u))| vec![v, u])
}
}
impl_updatable_for_mut_optimizer!(AdaDelta);