mlx_rs::nn

Trait IntoModuleValueAndGrad

Source
pub trait IntoModuleValueAndGrad<'a, M, Args, Val, Err>
where M: ModuleParameters + 'a, Args: Clone,
{ // Required method fn into_module_value_and_grad( self, ) -> impl FnMut(&mut M, Args) -> Result<(Val, FlattenedModuleParam), Exception> + 'a; }
Expand description

Helper trait for value_and_grad

Required Methods§

Source

fn into_module_value_and_grad( self, ) -> impl FnMut(&mut M, Args) -> Result<(Val, FlattenedModuleParam), Exception> + 'a

Computes the valud and gradient of the passed function f(model, args) with regard to the model’s trainable parameters.

Dyn Compatibility§

This trait is not dyn compatible.

In older versions of Rust, dyn compatibility was called "object safety", so this trait is not object safe.

Implementors§

Source§

impl<'a, F, M, Args> IntoModuleValueAndGrad<'a, M, Args, Array, ()> for F
where M: ModuleParameters + 'a, F: FnMut(&mut M, Args) -> Array + 'a, Args: Clone,

Source§

impl<'a, F, M, Args> IntoModuleValueAndGrad<'a, M, Args, Array, Exception> for F
where M: ModuleParameters + 'a, F: FnMut(&mut M, Args) -> Result<Array, Exception> + 'a, Args: Clone,

Source§

impl<'a, F, M, Args> IntoModuleValueAndGrad<'a, M, Args, Vec<Array>, ()> for F
where M: ModuleParameters + 'a, F: FnMut(&mut M, Args) -> Vec<Array> + 'a, Args: Clone,

Source§

impl<'a, F, M, Args> IntoModuleValueAndGrad<'a, M, Args, Vec<Array>, Exception> for F
where M: ModuleParameters + 'a, F: FnMut(&mut M, Args) -> Result<Vec<Array>, Exception> + 'a, Args: Clone,