mlx_rs::transforms

Function jvp

Source
pub fn jvp<'a, F>(
    f: F,
    primals: &[Array],
    tangents: &[Array],
) -> Result<(Vec<Array>, Vec<Array>)>
where F: FnMut(&[Array]) -> Vec<Array> + 'a,
Expand description

Compute the Jacobian-vector product.

This computes the product of the Jacobian of a function f evaluated at primals with the tangents.

§Params:

  • f: function which takes an array of Array and returns an array of Array
  • primals: array of Array at which to evaluate the Jacobian
  • tangents: array of Array which are the “vector” in the Jacobian-vector product. The tangents should be the same in number, shape and type as the inputs of f, e.g. the primals

§Returns:

Array of the Jacobian-vector products which is the same in number, shape and type of the outputs of f