Expand description
Function transforms
This mod provides functions for automatic differentiation and other transformations on functions.
WARN: Because function transforms including compilation works on
the computation graph, the user must ensure that all Array
s are passed
as inputs to the function/closure. Closures with captured Array
s may
not work as expected and may lead to undefined behavior.
§Automatic Differentiation
Automatic differentiation in MLX works on functions rather than on implicit graphs.
NOTE: If you are coming to MLX from PyTorch, you no longer need functions like backward, zero_grad, and detach, or properties like requires_grad.
You can use the grad()
and value_and_grad()
function to compute
gradients of more complex functions. These functions compute the gradient
with respect to the first argument, in order to manually specify the the
argument to compute the gradient with respect to, use
grad_with_argnums()
or value_and_grad_with_argnums()
.
TODO: update the example once https://github.com/oxideai/mlx-rs/pull/218 is merged
use mlx_rs::{Array, error::Result, transforms::grad};
fn f(x: &Array) -> Result<Array> {
x.square()
}
fn calculate_grad(func: impl Fn(&Array) -> Result<Array>, arg: &Array) -> Result<Array> {
grad(&func, &[0])(arg)
}
let x = Array::from(1.5);
let dfdx = calculate_grad(f, &x).unwrap();
assert_eq!(dfdx.item::<f32>(), 2.0 * 1.5);
let dfdx2 = calculate_grad(|args| calculate_grad(f, args), &x).unwrap();
assert_eq!(dfdx2.item::<f32>(), 2.0);
Modules§
- compile
- Compilation of functions.
Traits§
- Into
Grad - Trait for functions/closures that can be converted into a closure that computes the gradient.
- Into
Keyed Value AndGrad - Similar to [
IntoValueAndGrad
] but for functions that take a hashmap of parameters. - Into
Value AndGrad - Trait for functions/closures that can be converted into a closure that computes the value and gradient.
Functions§
- async_
eval - Asynchronously evaluate an iterator of
Array
s. - async_
eval_ params - Asynchronously evaluate a module’s parameters.
- eval
- Evaluate an iterator of
Array
s. - eval_
params - Evaluate a module’s parameters.
- fallible_
jvp - Similar to
jvp
but handles closures that can return an error. - fallible_
vjp - Similar to
vjp
but handles closures that can return an error. - grad
- Returns a function which computes the gradient of
f
with the default argument numbers&[0]
. - grad_
with_ argnums - Returns a function which computes the gradient of
f
. - jvp
- Compute the Jacobian-vector product.
- keyed_
value_ and_ grad - Returns a function which computes the value and gradient of
f
with keyed parameters. - value_
and_ grad - Returns a function which computes the value and gradient of
f
with a default argument number&[0]
. - value_
and_ grad_ with_ argnums - Returns a function which computes the value and gradient of
f
. - vjp
- Compute the vector-Jacobian product.
Type Aliases§
- Keyed
Grad - Type alias for a hashmap of gradients.
- Keyed
Parameters - Type alias for a hashmap of parameters.