mlx_rs::transforms

Function fallible_jvp

Source
pub fn fallible_jvp<'a, F>(
    f: F,
    primals: &[Array],
    tangents: &[Array],
) -> Result<(Vec<Array>, Vec<Array>)>
where F: FnMut(&[Array]) -> Result<Vec<Array>> + 'a,
Expand description

Similar to jvp but handles closures that can return an error.