mlx_rs::transforms

Function fallible_vjp

Source
pub fn fallible_vjp<'a, F>(
    f: F,
    primals: &[Array],
    cotangents: &[Array],
) -> Result<(Vec<Array>, Vec<Array>)>
where F: FnMut(&[Array]) -> Result<Vec<Array>> + 'a,
Expand description

Similar to vjp but handles closures that can return an error.