mlx_rs::optimizers

Trait Optimizer

Source
pub trait Optimizer: Updatable {
    type State: OptimizerState;

    // Required methods
    fn state(&self) -> &Self::State;
    fn state_mut(&mut self) -> &mut Self::State;
    fn update_single(
        &mut self,
        key: &Rc<str>,
        gradient: &Array,
        parameter: &mut Array,
    ) -> Result<()>;

    // Provided method
    fn update<M>(
        &mut self,
        model: &mut M,
        gradients: impl Borrow<FlattenedModuleParam>,
    ) -> Result<()>
       where M: ModuleParameters { ... }
}
Expand description

Trait for optimizers.

Required Associated Types§

Source

type State: OptimizerState

State of the optimizer.

Required Methods§

Source

fn state(&self) -> &Self::State

Get the state of the optimizer.

Source

fn state_mut(&mut self) -> &mut Self::State

Get the mutable state of the optimizer.

Source

fn update_single( &mut self, key: &Rc<str>, gradient: &Array, parameter: &mut Array, ) -> Result<()>

Update a single parameter with the given gradient.

The implementation should look up the state for the parameter using the key and update the state and the parameter accordingly. The key is provided instead of the state because it would otherwise create a mutable borrow conflict with the rest of the optimizer fields.

Provided Methods§

Source

fn update<M>( &mut self, model: &mut M, gradients: impl Borrow<FlattenedModuleParam>, ) -> Result<()>

Apply the gradients to the parameters of the model and update the model with the new parameters.

Dyn Compatibility§

This trait is not dyn compatible.

In older versions of Rust, dyn compatibility was called "object safety", so this trait is not object safe.

Implementors§