Expand description
Function transforms
This mod provides functions for automatic differentiation and other transformations on functions.
WARN: Because function transforms including compilation works on
the computation graph, the user must ensure that all Arrays are passed
as inputs to the function/closure. Closures with captured Arrays may
not work as expected and may lead to undefined behavior.
§Automatic Differentiation
Automatic differentiation in MLX works on functions rather than on implicit graphs.
NOTE: If you are coming to MLX from PyTorch, you no longer need functions like backward, zero_grad, and detach, or properties like requires_grad.
You can use the grad() and value_and_grad() function to compute
gradients of more complex functions. These functions compute the gradient
with respect to the first argument, in order to manually specify the the
argument to compute the gradient with respect to, use
grad_with_argnums() or value_and_grad_with_argnums().
TODO: update the example once https://github.com/oxideai/mlx-rs/pull/218 is merged
use mlx_rs::{Array, error::Result, transforms::grad};
fn f(x: &Array) -> Result<Array> {
    x.square()
}
fn calculate_grad(func: impl Fn(&Array) -> Result<Array>, arg: &Array) -> Result<Array> {
    grad(&func, &[0])(arg)
}
let x = Array::from(1.5);
let dfdx = calculate_grad(f, &x).unwrap();
assert_eq!(dfdx.item::<f32>(), 2.0 * 1.5);
let dfdx2 = calculate_grad(|args| calculate_grad(f, args), &x).unwrap();
assert_eq!(dfdx2.item::<f32>(), 2.0);Modules§
- compile
- Compilation of functions.
Traits§
- IntoGrad 
- Trait for functions/closures that can be converted into a closure that computes the gradient.
- IntoKeyed Value AndGrad 
- Similar to [IntoValueAndGrad] but for functions that take a hashmap of parameters.
- IntoValue AndGrad 
- Trait for functions/closures that can be converted into a closure that computes the value and gradient.
Functions§
- async_eval 
- Asynchronously evaluate an iterator of Arrays.
- async_eval_ params 
- Asynchronously evaluate a module’s parameters.
- eval
- Evaluate an iterator of Arrays.
- eval_params 
- Evaluate a module’s parameters.
- fallible_jvp 
- Similar to jvpbut handles closures that can return an error.
- fallible_vjp 
- Similar to vjpbut handles closures that can return an error.
- grad
- Returns a function which computes the gradient of fwith the default argument numbers&[0].
- grad_with_ argnums 
- Returns a function which computes the gradient of f.
- jvp
- Compute the Jacobian-vector product.
- keyed_value_ and_ grad 
- Returns a function which computes the value and gradient of fwith keyed parameters.
- value_and_ grad 
- Returns a function which computes the value and gradient of fwith a default argument number&[0].
- value_and_ grad_ with_ argnums 
- Returns a function which computes the value and gradient of f.
- vjp
- Compute the vector-Jacobian product.
Type Aliases§
- KeyedGrad 
- Type alias for a hashmap of gradients.
- KeyedParameters 
- Type alias for a hashmap of parameters.