mlx_rs/
lib.rs

1//! Unofficial rust bindings for the [MLX
2//! framework](https://github.com/ml-explore/mlx).
3//!
4//! # Table of Contents
5//!
6//! - [Quick Start](#quick-start)
7//! - [Lazy Evaluation](#lazy-evaluation)
8//! - [Unified Memory](#unified-memory)
9//! - [Indexing Arrays](#indexing-arrays)
10//! - [Saving and Loading](#saving-and-loading)
11//!
12//! # Quick Start
13//!
14//! See also [MLX python
15//! documentation](https://ml-explore.github.io/mlx/build/html/usage/quick_start.html)
16//!
17//! ## Basics
18//!
19//! ```rust
20//! use mlx_rs::{array, Dtype};
21//!
22//! let a = array!([1, 2, 3, 4]);
23//! assert_eq!(a.shape(), &[4]);
24//! assert_eq!(a.dtype(), Dtype::Int32);
25//!
26//! let b = array!([1.0, 2.0, 3.0, 4.0]);
27//! assert_eq!(b.dtype(), Dtype::Float32);
28//! ```
29//!
30//! Operations in MLX are lazy. Use [`Array::eval`] to evaluate the the output
31//! of an operation. Operations are also automatically evaluated when inspecting
32//! an array with [`Array::item`], printing an array, or attempting to obtain
33//! the underlying data with [`Array::as_slice`].
34//!
35//! ```rust
36//! use mlx_rs::{array, transforms::eval};
37//!
38//! let a = array!([1, 2, 3, 4]);
39//! let b = array!([1.0, 2.0, 3.0, 4.0]);
40//!
41//! let c = &a + &b; // c is not evaluated
42//! c.eval().unwrap(); // evaluates c
43//!
44//! let d = &a + &b;
45//! println!("{:?}", d); // evaluates d
46//!
47//! let e = &a + &b;
48//! let e_slice: &[f32] = e.as_slice(); // evaluates e
49//! ```
50//!
51//! See [Lazy Evaluation](#lazy-evaluation) for more details.
52//!
53//! ## Function and Graph Transformations
54//!
55//! TODO: https://github.com/oxideai/mlx-rs/issues/214
56//!
57//! TODO: also document that all `Array` in the args for function
58//!       transformations
59//!
60//! # Lazy Evaluation
61//!
62//! See also [MLX python
63//! documentation](https://ml-explore.github.io/mlx/build/html/usage/lazy_evaluation.html)
64//!
65//! ## Why Lazy Evaluation
66//!
67//! When you perform operations in MLX, no computation actually happens. Instead
68//! a compute graph is recorded. The actual computation only happens if an
69//! [`Array::eval`] is performed.
70//!
71//! MLX uses lazy evaluation because it has some nice features, some of which we
72//! describe below.
73//!
74//! ## Transforming Compute Graphs
75//!
76//! Lazy evaluation lets us record a compute graph without actually doing any
77//! computations. This is useful for function transformations like
78//! [`transforms::grad`] and graph optimizations.
79//!
80//! Currently, MLX does not compile and rerun compute graphs. They are all
81//! generated dynamically. However, lazy evaluation makes it much easier to
82//! integrate compilation for future performance enhancements.
83//!
84//! ## Only Compute What You Use
85//!
86//! In MLX you do not need to worry as much about computing outputs that are
87//! never used. For example:
88//!
89//! ```rust,ignore
90//! fn fun(x: &Array) -> (Array, Array) {
91//!     let a = cheap_fun(x);
92//!     let b = expensive_fun(x);
93//!     (a, b)
94//! }
95//!
96//! let (y, _) = fun(&x);
97//! ```
98//!
99//! Here, we never actually compute the output of `expensive_fun`. Use this
100//! pattern with care though, as the graph of `expensive_fun` is still built,
101//! and that has some cost associated to it.
102//!
103//! Similarly, lazy evaluation can be beneficial for saving memory while keeping
104//! code simple. Say you have a very large model `Model` implementing
105//! [`module::Module`]. You can instantiate this model with `let model =
106//! Model::new()`. Typically, this will initialize all of the weights as
107//! `float32`, but the initialization does not actually compute anything until
108//! you perform an `eval()`. If you update the model with `float16` weights,
109//! your maximum consumed memory will be half that required if eager computation
110//! was used instead.
111//!
112//! This pattern is simple to do in MLX thanks to lazy computation:
113//!
114//! ```rust,ignore
115//! let mut model = Model::new();
116//! model.load_safetensors("model.safetensors").unwrap();
117//! ```
118//!
119//! ## When to Evaluate
120//!
121//! A common question is when to use `eval()`. The trade-off is between letting
122//! graphs get too large and not batching enough useful work.
123//!
124//! For example
125//!
126//! ```rust,ignore
127//! let mut a = array!([1, 2, 3, 4]);
128//! let mut b = array!([1.0, 2.0, 3.0, 4.0]);
129//!
130//! for _ in 0..100 {
131//!     a = a + b;
132//!     a.eval()?;
133//!     b = b * 2.0;
134//!     b.eval()?;
135//! }
136//! ```
137//!
138//! This is a bad idea because there is some fixed overhead with each graph
139//! evaluation. On the other hand, there is some slight overhead which grows
140//! with the compute graph size, so extremely large graphs (while
141//! computationally correct) can be costly.
142//!
143//! Luckily, a wide range of compute graph sizes work pretty well with MLX:
144//! anything from a few tens of operations to many thousands of operations per
145//! evaluation should be okay.
146//!
147//! Most numerical computations have an iterative outer loop (e.g. the iteration
148//! in stochastic gradient descent). A natural and usually efficient place to
149//! use `eval()` is at each iteration of this outer loop.
150//!
151//! Here is a concrete example:
152//!
153//! ```rust,ignore
154//! for batch in dataset {
155//!     // Nothing has been evaluated yet
156//!     let (loss, grad) = value_and_grad_fn(&mut model, batch)?;
157//!
158//!     // Still nothing has been evaluated
159//!     optimizer.update(&mut model, grad)?;
160//!
161//!     // Evaluate the loss and the new parameters which will
162//!     // run the full gradient computation and optimizer update
163//!     eval_params(model.parameters())?;
164//! }
165//! ```
166//!
167//! An important behavior to be aware of is when the graph will be implicitly
168//! evaluated. Anytime you `print` an array, or otherwise access its memory via
169//! [`Array::as_slice`], the graph will be evaluated. Saving arrays via
170//! [`Array::save_numpy`] or [`Array::save_safetensors`] (or any other MLX
171//! saving functions) will also evaluate the array.
172//!
173//! Calling [`Array::item`] on a scalar array will also evaluate it. In the
174//! example above, printing the loss (`println!("{:?}", loss)`) or pushing the
175//! loss scalar to a [`Vec`] (`losses.push(loss.item::<f32>())`) would cause a
176//! graph evaluation. If these lines are before evaluating the loss and module
177//! parameters, then this will be a partial evaluation, computing only the
178//! forward pass.
179//!
180//! Also, calling `eval()` on an array or set of arrays multiple times is
181//! perfectly fine. This is effectively a no-op.
182//!
183//! **Warning**: Using scalar arrays for control-flow will cause an evaluation.
184//!
185//! ```rust,ignore
186//! fn fun(x: &Array) -> Array {
187//!     let (h, y) = first_layer(x);
188//!
189//!     if y.gt(array!(0.5)).unwrap().item() {
190//!         second_layer_a(h)
191//!     } else {
192//!         second_layer_b(h)
193//!     }
194//! }
195//! ```
196//!
197//! Using arrays for control flow should be done with care. The above example
198//! works and can even be used with gradient transformations. However, this can
199//! be very inefficient if evaluations are done too frequently.
200//!
201//! # Unified Memory
202//!
203//! See also [MLX python
204//! documentation](https://ml-explore.github.io/mlx/build/html/usage/unified_memory.html)
205//!
206//! Apple silicon has a unified memory architecture. The CPU and GPU have direct
207//! access to the same memory pool. MLX is designed to take advantage of that.
208//!
209//! Concretely, when you make an array in MLX you don’t have to specify its
210//! location:
211//!
212//! ```rust
213//! // let a = mlx_rs::random::normal(&[100], None, None, None, None).unwrap();
214//! // let b = mlx_rs::random::normal(&[100], None, None, None, None).unwrap();
215//!
216//! let a = mlx_rs::normal!(shape=&[100]).unwrap();
217//! let b = mlx_rs::normal!(shape=&[100]).unwrap();
218//! ```
219//!
220//! Both `a` and `b` live in unified memory.
221//!
222//! In MLX, rather than moving arrays to devices, you specify the device when
223//! you run the operation. Any device can perform any operation on `a` and `b`
224//! without needing to move them from one memory location to another. For
225//! example:
226//!
227//! ```rust,ignore
228//! // mlx_rs::ops::add_device(&a, &b, StreamOrDevice::cpu()).unwrap();
229//! // mlx_rs::ops::add_device(&a, &b, StreamOrDevice::gpu()).unwrap();
230//!
231//! mlx_rs::add!(&a, &b, stream=StreamOrDevice::cpu()).unwrap();
232//! mlx_rs::add!(&a, &b, stream=StreamOrDevice::gpu()).unwrap();
233//! ```
234//!
235//! In the above, both the CPU and the GPU will perform the same add operation.
236//!
237//! TODO: The remaining python documentations states that the stream can be used
238//! to parallelize operations without worrying about racing conditions. We
239//! should check if this is true given that we've already observed data racing
240//! when executing unit tests in parallel.
241//!
242//! # Indexing Arrays
243//!
244//! See also [MLX python
245//! documentation](https://ml-explore.github.io/mlx/build/html/usage/indexing.html)
246//!
247//! Please refer to the indexing modules ([`ops::indexing`]) for more details.
248//!
249//! # Saving and Loading
250//!
251//! See also [MLX python
252//! documentation](https://ml-explore.github.io/mlx/build/html/usage/saving_and_loading.html)
253//!
254//! `mlx-rs` supports loading from `.npy` and `.safetensors` files and saving to
255//! `.safetensors` files. Module parameters and optimizer states can also be saved
256//! and loaded from `.safetensors` files.
257//!
258//! | type | load function | save function |
259//! |------|---------------|----------------|
260//! | [`Array`] | [`Array::load_numpy`] | [`Array::save_numpy`] |
261//! | `HashMap<String, Array>` | [`Array::load_safetensors`] | [`Array::save_safetensors`] |
262//! | [`module::Module`] | [`module::ModuleParametersExt::load_safetensors`] | [`module::ModuleParametersExt::save_safetensors`] |
263//! | [`optimizers::Optimizer`] | [`optimizers::OptimizerState::load_safetensors`] | [`optimizers::OptimizerState::save_safetensors`] |
264//!
265//! # Function Transforms
266//!
267//! See also [MLX python
268//! documentation](https://ml-explore.github.io/mlx/build/html/usage/function_transforms.html)
269//!
270//! Please refer to the transforms module ([`transforms`]) for more details.
271//!
272//! # Compilation
273
274#![deny(unused_unsafe, missing_debug_implementations, missing_docs)]
275#![cfg_attr(test, allow(clippy::approx_constant))]
276
277#[macro_use]
278pub mod macros; // Must be first to ensure the other modules can use the macros
279
280mod array;
281pub mod builder;
282mod device;
283mod dtype;
284pub mod error;
285pub mod fast;
286pub mod fft;
287pub mod linalg;
288pub mod losses;
289pub mod module;
290pub mod nested;
291pub mod nn;
292pub mod ops;
293pub mod optimizers;
294pub mod quantization;
295pub mod random;
296mod stream;
297pub mod transforms;
298pub mod utils;
299
300pub use array::*;
301pub use device::*;
302pub use dtype::*;
303pub use stream::*;
304
305pub(crate) mod constants {
306    /// The default length of the stack-allocated vector in `SmallVec<[T; DEFAULT_STACK_VEC_LEN]>`
307    pub(crate) const DEFAULT_STACK_VEC_LEN: usize = 4;
308}
309
310pub(crate) mod sealed {
311    /// A marker trait to prevent external implementations of the `Sealed` trait.
312    pub trait Sealed {}
313
314    impl Sealed for () {}
315
316    impl<A> Sealed for (A,) where A: Sealed {}
317    impl<A, B> Sealed for (A, B)
318    where
319        A: Sealed,
320        B: Sealed,
321    {
322    }
323    impl<A, B, C> Sealed for (A, B, C)
324    where
325        A: Sealed,
326        B: Sealed,
327        C: Sealed,
328    {
329    }
330}