mlx_rs/transforms/compile/mod.rs
1//! Compilation of functions.
2//!
3//! See also [MLX python
4//! documentation](https://ml-explore.github.io/mlx/build/html/usage/compile.html).
5//!
6//! MLX has a [`compile()`] function transformation which compiles computation
7//! graphs. Function compilation results in smaller graphs by merging common
8//! work and fusing certain operations. In many cases this can lead to big
9//! improvements in run-time and memory use.
10//!
11//! Getting started with compile() is simple, but there are some edge cases that
12//! are good to be aware of for more complex graphs and advanced usage.
13//!
14//! **WARN**: Because function transforms including compilation works on the
15//! computation graph, the user must ensure that all `Array`s are passed as
16//! inputs to the function/closure. Closures with captured `Array`s may not work
17//! as expected and may lead to undefined behavior.
18//!
19//! # Basic usage
20//!
21//! ```rust
22//! use mlx_rs::{Array, array, transforms::compile::compile, error::Exception};
23//!
24//! let fun = |(x, y): (&Array, &Array)| -> Result<Array, Exception> {
25//! mlx_rs::exp!(x.negative()?)?.add(y)
26//! };
27//!
28//! let x = array!(1.0);
29//! let y = array!(2.0);
30//!
31//! // Regular call, no compilation
32//! let result = fun((&x, &y)).unwrap();
33//! // Prints: array(2.36788, dtype=float32)
34//! println!("{:?}", result);
35//!
36//! // Compile the function
37//! let mut compiled_fun = compile(fun, None);
38//! let result = compiled_fun((&x, &y)).unwrap();
39//! // Prints: array(2.36788, dtype=float32)
40//! println!("{:?}", result);
41//! ```
42//!
43//! The output of both the regular function and the compiled function is the
44//! same up to numerical precision.
45//!
46//! The first time you call a compiled function, MLX will build the compute
47//! graph, optimize it, and generate and compile code. This can be relatively
48//! slow. However, MLX will cache compiled functions, so calling a compiled
49//! function multiple times will not initiate a new compilation. This means you
50//! should typically compile functions that you plan to use more than once.
51//!
52//! ```rust
53//! use mlx_rs::{Array, array, transforms::compile::compile};
54//!
55//! let fun = |(x, y): (&Array, &Array)| {
56//! mlx_rs::exp!(x.negative()?)?.add(y)
57//! };
58//!
59//! let x = array!(1.0);
60//! let y = array!(2.0);
61//!
62//! let mut compiled_fun = compile(fun, None);
63//!
64//! // Compiled here
65//! let result = compiled_fun((&x, &y)).unwrap();
66//!
67//! // Not compiled again
68//! let result = compiled_fun((&x, &y)).unwrap();
69//!
70//! // Not compiled again
71//! let compiled_fun2 = compile(fun, None);
72//! ```
73//!
74//! There are some important cases to be aware of that can cause a function to
75//! be recompiled:
76//!
77//! - Changing the shape or number of dimensions
78//! - Changing the type of any of the inputs
79//! - Changing the number of inputs to the function
80//!
81//! In certain cases only some of the compilation stack will be rerun (for
82//! example when changing the shapes) and in other cases the full compilation
83//! stack will be rerun (for example when changing the types). In general you
84//! should avoid compiling functions too frequently.
85//!
86//! Another idiom to watch out for is compiling functions which get created and
87//! destroyed frequently. This can happen, for example, when compiling an
88//! closure in a loop.
89//!
90//! # Pure Functions
91//!
92//! Compiled functions are intended to be pure; that is they should not have
93//! side effects. For example:
94//!
95//! ```rust,ignore
96//! use mlx_rs::{Array, array, transforms::compile::compile};
97//!
98//! let mut c = array!(0.5);
99//!
100//! let fun = |(x, y): (&Array, &Array)| {
101//! let z = (x + y) * c;
102//! mlx_rs::exp!(z)
103//! };
104//!
105//! let mut compiled = compile(fun, None);
106//!
107//! let x = array!(1.0);
108//! let y = array!(2.0);
109//!
110//! // This may lead to undefined behavior
111//! let result = compiled((&x, &y)).unwrap();
112//! println!("{:?}", result);
113//! ```
114//!
115//! Use [`compile_with_state()`] to compile functions that have side effects and
116//! pass the state as an mutable reference.
117//!
118//! ```rust
119//! use mlx_rs::{Array, array, transforms::compile::compile_with_state};
120//! let mut state = vec![];
121//!
122//! let fun = |state: &mut Vec<Array>, (x, y): (&Array, &Array)| {
123//! let z = x + y;
124//! let result = mlx_rs::exp!(&z);
125//! state.push(z);
126//! result
127//! };
128//!
129//! let x = array!(1.0);
130//! let y = array!(2.0);
131//!
132//! let mut compiled = compile_with_state(fun, None);
133//! let result = compiled(&mut state, (&x, &y)).unwrap();
134//! println!("{:?}", result);
135//! // println!("{:?}", state); // TODO: this currently doesn't work somehow
136//! ```
137//!
138//! This is particularly useful for compiling a function which includes an
139//! update to a container of arrays, as is commonly done when training the
140//! parameters of a [`crate::module::Module`].
141//!
142//! See mlx-rs/mlx-tests/tests/test_compile_with_state.rs for more examples.
143//!
144
145use std::collections::hash_map::DefaultHasher;
146use std::hash::{Hash, Hasher};
147
148use super::{Closure, Guarded, VectorArray};
149use crate::Array;
150
151#[allow(clippy::module_inception)]
152mod compile;
153mod compile_with_state;
154
155pub use compile::*;
156pub use compile_with_state::*;
157
158/// Globally enable the compilation of functions.
159///
160/// Default is enabled.
161pub fn enable_compile() {
162 unsafe {
163 mlx_sys::mlx_enable_compile();
164 }
165}
166
167/// Globally disable the compilation of functions.
168///
169/// Default is enabled.
170pub fn disable_compile() {
171 unsafe {
172 mlx_sys::mlx_disable_compile();
173 }
174}
175
176/// Clear the memory cache.
177pub fn clear_cache() {
178 unsafe {
179 mlx_sys::mlx_detail_compile_clear_cache();
180 }
181}
182
183/// A compiled function that can be called.
184#[derive(Debug, Clone)]
185pub struct Compiled<F, G> {
186 f_marker: std::marker::PhantomData<F>,
187 state: CompiledState<G>,
188}
189
190#[derive(Debug, Clone)]
191struct CompiledState<F> {
192 f: F,
193 shapeless: bool,
194 id: usize,
195}
196
197impl<F> Drop for CompiledState<F> {
198 fn drop(&mut self) {
199 unsafe {
200 // remove the compiled structure from the back end
201 mlx_sys::mlx_detail_compile_erase(self.id);
202 }
203 }
204}
205
206fn type_id_to_usize<T>(_val: &T) -> usize
207where
208 T: 'static,
209{
210 // hash type id to usize
211 let type_id = std::any::TypeId::of::<T>();
212 let mut hasher = DefaultHasher::new();
213 type_id.hash(&mut hasher);
214 hasher.finish() as usize
215}
216
217fn update_by_replace_with_ref_to_new_array(src: &mut Array, new_array: &Array) {
218 unsafe {
219 mlx_sys::mlx_array_set(&mut src.as_ptr() as *mut _, new_array.as_ptr());
220 }
221}