Module transforms

Source
Expand description

Function transforms

This mod provides functions for automatic differentiation and other transformations on functions.

WARN: Because function transforms including compilation works on the computation graph, the user must ensure that all Arrays are passed as inputs to the function/closure. Closures with captured Arrays may not work as expected and may lead to undefined behavior.

§Automatic Differentiation

Automatic differentiation in MLX works on functions rather than on implicit graphs.

NOTE: If you are coming to MLX from PyTorch, you no longer need functions like backward, zero_grad, and detach, or properties like requires_grad.

You can use the grad() and value_and_grad() function to compute gradients of more complex functions. These functions compute the gradient with respect to the first argument, in order to manually specify the the argument to compute the gradient with respect to, use grad_with_argnums() or value_and_grad_with_argnums().

TODO: update the example once https://github.com/oxideai/mlx-rs/pull/218 is merged

use mlx_rs::{Array, error::Result, transforms::grad};

fn f(x: &Array) -> Result<Array> {
    x.square()
}

fn calculate_grad(func: impl Fn(&Array) -> Result<Array>, arg: &Array) -> Result<Array> {
    grad(&func, &[0])(arg)
}

let x = Array::from(1.5);

let dfdx = calculate_grad(f, &x).unwrap();
assert_eq!(dfdx.item::<f32>(), 2.0 * 1.5);

let dfdx2 = calculate_grad(|args| calculate_grad(f, args), &x).unwrap();
assert_eq!(dfdx2.item::<f32>(), 2.0);

Modules§

compile
Compilation of functions.

Traits§

IntoGrad
Trait for functions/closures that can be converted into a closure that computes the gradient.
IntoKeyedValueAndGrad
Similar to [IntoValueAndGrad] but for functions that take a hashmap of parameters.
IntoValueAndGrad
Trait for functions/closures that can be converted into a closure that computes the value and gradient.

Functions§

async_eval
Asynchronously evaluate an iterator of Arrays.
async_eval_params
Asynchronously evaluate a module’s parameters.
eval
Evaluate an iterator of Arrays.
eval_params
Evaluate a module’s parameters.
fallible_jvp
Similar to jvp but handles closures that can return an error.
fallible_vjp
Similar to vjp but handles closures that can return an error.
grad
Returns a function which computes the gradient of f with the default argument numbers &[0].
grad_with_argnums
Returns a function which computes the gradient of f.
jvp
Compute the Jacobian-vector product.
keyed_value_and_grad
Returns a function which computes the value and gradient of f with keyed parameters.
value_and_grad
Returns a function which computes the value and gradient of f with a default argument number &[0].
value_and_grad_with_argnums
Returns a function which computes the value and gradient of f.
vjp
Compute the vector-Jacobian product.

Type Aliases§

KeyedGrad
Type alias for a hashmap of gradients.
KeyedParameters
Type alias for a hashmap of parameters.