mlx_rs::optimizers

Trait OptimizerState

Source
pub trait OptimizerState: Sized {
    type UnflattenError: Error + Into<IoError>;

    // Required methods
    fn flatten(&self) -> impl Iterator<Item = (Rc<str>, &Array)>;
    fn flatten_mut(&mut self) -> impl Iterator<Item = (Rc<str>, &mut Array)>;
    fn unflatten<I, K>(input: I) -> Result<Self, Self::UnflattenError>
       where I: IntoIterator<Item = (K, Array)>,
             K: Ord + AsRef<str> + Into<Rc<str>>;

    // Provided methods
    fn save_safetensors(&self, path: impl AsRef<Path>) -> Result<(), IoError> { ... }
    fn load_safetensors(
        &mut self,
        path: impl AsRef<Path>,
    ) -> Result<(), IoError> { ... }
}
Expand description

Trait for optimizer states.

Required Associated Types§

Source

type UnflattenError: Error + Into<IoError>

Error type for unflatten.

Required Methods§

Source

fn flatten(&self) -> impl Iterator<Item = (Rc<str>, &Array)>

Flatten the optimizer state.

Source

fn flatten_mut(&mut self) -> impl Iterator<Item = (Rc<str>, &mut Array)>

Flatten the mutable optimizer state.

Source

fn unflatten<I, K>(input: I) -> Result<Self, Self::UnflattenError>
where I: IntoIterator<Item = (K, Array)>, K: Ord + AsRef<str> + Into<Rc<str>>,

Unflatten an iterator of key-value pairs into the optimizer state.

Provided Methods§

Source

fn save_safetensors(&self, path: impl AsRef<Path>) -> Result<(), IoError>

Save the optimizer state to a safetensors file.

Source

fn load_safetensors(&mut self, path: impl AsRef<Path>) -> Result<(), IoError>

Load the optimizer state from a safetensors file.

Dyn Compatibility§

This trait is not dyn compatible.

In older versions of Rust, dyn compatibility was called "object safety", so this trait is not object safe.

Implementors§