mlx_rs::nn

Function value_and_grad

Source
pub fn value_and_grad<'a, F, M, Args, Val, Err>(
    f: F,
) -> impl FnMut(&mut M, Args) -> Result<(Val, FlattenedModuleParam), Exception> + 'a
where M: ModuleParameters + 'a, F: IntoModuleValueAndGrad<'a, M, Args, Val, Err>, Args: Clone,
Expand description

Transform the passed function f(model, args) to a function that computes the gradients of f with regard to the model’s trainable parameters and also its value.