Module compile

Source
Expand description

Compilation of functions.

See also MLX python documentation.

MLX has a compile() function transformation which compiles computation graphs. Function compilation results in smaller graphs by merging common work and fusing certain operations. In many cases this can lead to big improvements in run-time and memory use.

Getting started with compile() is simple, but there are some edge cases that are good to be aware of for more complex graphs and advanced usage.

WARN: Because function transforms including compilation works on the computation graph, the user must ensure that all Arrays are passed as inputs to the function/closure. Closures with captured Arrays may not work as expected and may lead to undefined behavior.

§Basic usage

use mlx_rs::{Array, array, transforms::compile::compile, error::Exception};

let fun = |(x, y): (&Array, &Array)| -> Result<Array, Exception> {
   mlx_rs::exp!(x.negative()?)?.add(y)
};

let x = array!(1.0);
let y = array!(2.0);

// Regular call, no compilation
let result = fun((&x, &y)).unwrap();
// Prints: array(2.36788, dtype=float32)
println!("{:?}", result);

// Compile the function
let mut compiled_fun = compile(fun, None);
let result = compiled_fun((&x, &y)).unwrap();
// Prints: array(2.36788, dtype=float32)
println!("{:?}", result);

The output of both the regular function and the compiled function is the same up to numerical precision.

The first time you call a compiled function, MLX will build the compute graph, optimize it, and generate and compile code. This can be relatively slow. However, MLX will cache compiled functions, so calling a compiled function multiple times will not initiate a new compilation. This means you should typically compile functions that you plan to use more than once.

use mlx_rs::{Array, array, transforms::compile::compile};

let fun = |(x, y): (&Array, &Array)| {
   mlx_rs::exp!(x.negative()?)?.add(y)
};

let x = array!(1.0);
let y = array!(2.0);

let mut compiled_fun = compile(fun, None);

// Compiled here
let result = compiled_fun((&x, &y)).unwrap();

// Not compiled again
let result = compiled_fun((&x, &y)).unwrap();

// Not compiled again
let compiled_fun2 = compile(fun, None);

There are some important cases to be aware of that can cause a function to be recompiled:

  • Changing the shape or number of dimensions
  • Changing the type of any of the inputs
  • Changing the number of inputs to the function

In certain cases only some of the compilation stack will be rerun (for example when changing the shapes) and in other cases the full compilation stack will be rerun (for example when changing the types). In general you should avoid compiling functions too frequently.

Another idiom to watch out for is compiling functions which get created and destroyed frequently. This can happen, for example, when compiling an closure in a loop.

§Pure Functions

Compiled functions are intended to be pure; that is they should not have side effects. For example:

use mlx_rs::{Array, array, transforms::compile::compile};

let mut c = array!(0.5);

let fun = |(x, y): (&Array, &Array)| {
    let z = (x + y) * c;
    mlx_rs::exp!(z)
};

let mut compiled = compile(fun, None);

let x = array!(1.0);
let y = array!(2.0);

// This may lead to undefined behavior
let result = compiled((&x, &y)).unwrap();
println!("{:?}", result);

Use compile_with_state() to compile functions that have side effects and pass the state as an mutable reference.

use mlx_rs::{Array, array, transforms::compile::compile_with_state};
let mut state = vec![];

let fun = |state: &mut Vec<Array>, (x, y): (&Array, &Array)| {
    let z = x + y;
    let result = mlx_rs::exp!(&z);
    state.push(z);
    result
};

let x = array!(1.0);
let y = array!(2.0);

let mut compiled = compile_with_state(fun, None);
let result = compiled(&mut state, (&x, &y)).unwrap();
println!("{:?}", result);
// println!("{:?}", state); // TODO: this currently doesn't work somehow

This is particularly useful for compiling a function which includes an update to a container of arrays, as is commonly done when training the parameters of a crate::module::Module.

See mlx-rs/mlx-tests/tests/test_compile_with_state.rs for more examples.

Structs§

Compiled
A compiled function that can be called.

Traits§

CallMut
A trait for a compiled function that can be called.
CallMutWithState
A trait for functions that can be called with state.
Compile
A trait for functions that can be compiled.
CompileWithState
A trait for functions that can be compiled with state.

Functions§

clear_cache
Clear the memory cache.
compile
Returns a compiled function that produces the same output as f.
compile_with_state
Similar to crate::transforms::compile but allows for functions that take a mutable reference to a state U.
disable_compile
Globally disable the compilation of functions.
enable_compile
Globally enable the compilation of functions.