Crate mlx_rs

Source
Expand description

Unofficial rust bindings for the MLX framework.

§Table of Contents

§Quick Start

See also MLX python documentation

§Basics

use mlx_rs::{array, Dtype};

let a = array!([1, 2, 3, 4]);
assert_eq!(a.shape(), &[4]);
assert_eq!(a.dtype(), Dtype::Int32);

let b = array!([1.0, 2.0, 3.0, 4.0]);
assert_eq!(b.dtype(), Dtype::Float32);

Operations in MLX are lazy. Use Array::eval to evaluate the the output of an operation. Operations are also automatically evaluated when inspecting an array with Array::item, printing an array, or attempting to obtain the underlying data with Array::as_slice.

use mlx_rs::{array, transforms::eval};

let a = array!([1, 2, 3, 4]);
let b = array!([1.0, 2.0, 3.0, 4.0]);

let c = &a + &b; // c is not evaluated
c.eval().unwrap(); // evaluates c

let d = &a + &b;
println!("{:?}", d); // evaluates d

let e = &a + &b;
let e_slice: &[f32] = e.as_slice(); // evaluates e

See Lazy Evaluation for more details.

§Function and Graph Transformations

TODO: https://github.com/oxideai/mlx-rs/issues/214

TODO: also document that all Array in the args for function transformations

§Lazy Evaluation

See also MLX python documentation

§Why Lazy Evaluation

When you perform operations in MLX, no computation actually happens. Instead a compute graph is recorded. The actual computation only happens if an Array::eval is performed.

MLX uses lazy evaluation because it has some nice features, some of which we describe below.

§Transforming Compute Graphs

Lazy evaluation lets us record a compute graph without actually doing any computations. This is useful for function transformations like transforms::grad and graph optimizations.

Currently, MLX does not compile and rerun compute graphs. They are all generated dynamically. However, lazy evaluation makes it much easier to integrate compilation for future performance enhancements.

§Only Compute What You Use

In MLX you do not need to worry as much about computing outputs that are never used. For example:

fn fun(x: &Array) -> (Array, Array) {
    let a = cheap_fun(x);
    let b = expensive_fun(x);
    (a, b)
}

let (y, _) = fun(&x);

Here, we never actually compute the output of expensive_fun. Use this pattern with care though, as the graph of expensive_fun is still built, and that has some cost associated to it.

Similarly, lazy evaluation can be beneficial for saving memory while keeping code simple. Say you have a very large model Model implementing module::Module. You can instantiate this model with let model = Model::new(). Typically, this will initialize all of the weights as float32, but the initialization does not actually compute anything until you perform an eval(). If you update the model with float16 weights, your maximum consumed memory will be half that required if eager computation was used instead.

This pattern is simple to do in MLX thanks to lazy computation:

let mut model = Model::new();
model.load_safetensors("model.safetensors").unwrap();

§When to Evaluate

A common question is when to use eval(). The trade-off is between letting graphs get too large and not batching enough useful work.

For example

let mut a = array!([1, 2, 3, 4]);
let mut b = array!([1.0, 2.0, 3.0, 4.0]);

for _ in 0..100 {
    a = a + b;
    a.eval()?;
    b = b * 2.0;
    b.eval()?;
}

This is a bad idea because there is some fixed overhead with each graph evaluation. On the other hand, there is some slight overhead which grows with the compute graph size, so extremely large graphs (while computationally correct) can be costly.

Luckily, a wide range of compute graph sizes work pretty well with MLX: anything from a few tens of operations to many thousands of operations per evaluation should be okay.

Most numerical computations have an iterative outer loop (e.g. the iteration in stochastic gradient descent). A natural and usually efficient place to use eval() is at each iteration of this outer loop.

Here is a concrete example:

for batch in dataset {
    // Nothing has been evaluated yet
    let (loss, grad) = value_and_grad_fn(&mut model, batch)?;

    // Still nothing has been evaluated
    optimizer.update(&mut model, grad)?;

    // Evaluate the loss and the new parameters which will
    // run the full gradient computation and optimizer update
    eval_params(model.parameters())?;
}

An important behavior to be aware of is when the graph will be implicitly evaluated. Anytime you print an array, or otherwise access its memory via Array::as_slice, the graph will be evaluated. Saving arrays via Array::save_numpy or Array::save_safetensors (or any other MLX saving functions) will also evaluate the array.

Calling Array::item on a scalar array will also evaluate it. In the example above, printing the loss (println!("{:?}", loss)) or pushing the loss scalar to a Vec (losses.push(loss.item::<f32>())) would cause a graph evaluation. If these lines are before evaluating the loss and module parameters, then this will be a partial evaluation, computing only the forward pass.

Also, calling eval() on an array or set of arrays multiple times is perfectly fine. This is effectively a no-op.

Warning: Using scalar arrays for control-flow will cause an evaluation.

fn fun(x: &Array) -> Array {
    let (h, y) = first_layer(x);

    if y.gt(array!(0.5)).unwrap().item() {
        second_layer_a(h)
    } else {
        second_layer_b(h)
    }
}

Using arrays for control flow should be done with care. The above example works and can even be used with gradient transformations. However, this can be very inefficient if evaluations are done too frequently.

§Unified Memory

See also MLX python documentation

Apple silicon has a unified memory architecture. The CPU and GPU have direct access to the same memory pool. MLX is designed to take advantage of that.

Concretely, when you make an array in MLX you don’t have to specify its location:

// let a = mlx_rs::random::normal(&[100], None, None, None, None).unwrap();
// let b = mlx_rs::random::normal(&[100], None, None, None, None).unwrap();

let a = mlx_rs::normal!(shape=&[100]).unwrap();
let b = mlx_rs::normal!(shape=&[100]).unwrap();

Both a and b live in unified memory.

In MLX, rather than moving arrays to devices, you specify the device when you run the operation. Any device can perform any operation on a and b without needing to move them from one memory location to another. For example:

// mlx_rs::ops::add_device(&a, &b, StreamOrDevice::cpu()).unwrap();
// mlx_rs::ops::add_device(&a, &b, StreamOrDevice::gpu()).unwrap();

mlx_rs::add!(&a, &b, stream=StreamOrDevice::cpu()).unwrap();
mlx_rs::add!(&a, &b, stream=StreamOrDevice::gpu()).unwrap();

In the above, both the CPU and the GPU will perform the same add operation.

TODO: The remaining python documentations states that the stream can be used to parallelize operations without worrying about racing conditions. We should check if this is true given that we’ve already observed data racing when executing unit tests in parallel.

§Indexing Arrays

See also MLX python documentation

Please refer to the indexing modules (ops::indexing) for more details.

§Saving and Loading

See also MLX python documentation

mlx-rs supports loading from .npy and .safetensors files and saving to .safetensors files. Module parameters and optimizer states can also be saved and loaded from .safetensors files.

§Function Transforms

See also MLX python documentation

Please refer to the transforms module (transforms) for more details.

§Compilation

Modules§

builder
Defines helper traits for builder pattern
error
Custom error types and handler for the c ffi
fast
Fast implementations of commonly used multi-op functions.
fft
Fast Fourier Transform (FFT) and its inverse (IFFT) for one, two, and N dimensions.
linalg
Linear algebra operations.
losses
Loss functions
macros
Macros for mlx-rs.
module
This mod defines the traits for neural network modules and parameters.
nested
Implements a nested hashmap
nn
Neural network support for MLX
ops
Operations
optimizers
Trait and implementations for optimizers.
quantization
Traits for quantization
random
Collection of functions related to random number generation
transforms
Function transforms
utils
Utility functions and types.

Macros§

abs
Macro generated for the function crate::ops::abs. See the function documentation for more details.
acos
Macro generated for the function crate::ops::acos. See the function documentation for more details.
acosh
Macro generated for the function crate::ops::acosh. See the function documentation for more details.
add
Macro generated for the function crate::ops::add. See the function documentation for more details.
addmm
Macro generated for the function crate::ops::addmm. See the function documentation for more details.
all
Macro generated for the function crate::ops::all. See the function documentation for more details.
all_close
Macro generated for the function crate::ops::all_close. See the function documentation for more details.
any
Macro generated for the function crate::ops::any. See the function documentation for more details.
arange
Macro generated for the function crate::ops::arange. See the function documentation for more details.
argmax
Macro generated for the function crate::ops::indexing::argmax. See the function documentation for more details.
argmax_all
Macro generated for the function crate::ops::indexing::argmax_all. See the function documentation for more details.
argmin
Macro generated for the function crate::ops::indexing::argmin. See the function documentation for more details.
argmin_all
Macro generated for the function crate::ops::indexing::argmin_all. See the function documentation for more details.
argpartition
Macro generated for the function crate::ops::argpartition. See the function documentation for more details.
argpartition_all
Macro generated for the function crate::ops::argpartition_all. See the function documentation for more details.
argsort
Macro generated for the function crate::ops::argsort. See the function documentation for more details.
argsort_all
Macro generated for the function crate::ops::argsort_all. See the function documentation for more details.
array
A helper macro to create an array with up to 3 dimensions.
array_eq
Macro generated for the function crate::ops::array_eq. See the function documentation for more details.
as_strided
Macro generated for the function crate::ops::as_strided. See the function documentation for more details.
asin
Macro generated for the function crate::ops::asin. See the function documentation for more details.
asinh
Macro generated for the function crate::ops::asinh. See the function documentation for more details.
assert_array_eq
Asserts that two arrays are equal.
at_least_1d
Macro generated for the function crate::ops::at_least_1d. See the function documentation for more details.
at_least_2d
Macro generated for the function crate::ops::at_least_2d. See the function documentation for more details.
at_least_3d
Macro generated for the function crate::ops::at_least_3d. See the function documentation for more details.
atan
Macro generated for the function crate::ops::atan. See the function documentation for more details.
atanh
Macro generated for the function crate::ops::atanh. See the function documentation for more details.
bernoulli
Macro generated for the function crate::random::bernoulli. See the function documentation for more details.
block_masked_mm
Macro generated for the function crate::ops::block_masked_mm. See the function documentation for more details.
broadcast_arrays
Macro generated for the function crate::ops::broadcast_arrays. See the function documentation for more details.
broadcast_to
Macro generated for the function crate::ops::broadcast_to. See the function documentation for more details.
categorical
Macro generated for the function crate::random::categorical. See the function documentation for more details.
ceil
Macro generated for the function crate::ops::ceil. See the function documentation for more details.
cholesky
Macro generated for the function crate::linalg::cholesky. See the function documentation for more details.
cholesky_inv
Macro generated for the function crate::linalg::cholesky_inv. See the function documentation for more details.
clip
Macro generated for the function crate::ops::clip. See the function documentation for more details.
concatenate
Macro generated for the function crate::ops::concatenate. See the function documentation for more details.
conv1d
Macro generated for the function crate::ops::conv1d. See the function documentation for more details.
conv2d
Macro generated for the function crate::ops::conv2d. See the function documentation for more details.
conv3d
Macro generated for the function crate::ops::conv3d. See the function documentation for more details.
conv_general
Macro generated for the function crate::ops::conv_general. See the function documentation for more details.
conv_transpose1d
Macro generated for the function crate::ops::conv_transpose1d. See the function documentation for more details.
conv_transpose2d
Macro generated for the function crate::ops::conv_transpose2d. See the function documentation for more details.
conv_transpose3d
Macro generated for the function crate::ops::conv_transpose3d. See the function documentation for more details.
cos
Macro generated for the function crate::ops::cos. See the function documentation for more details.
cosh
Macro generated for the function crate::ops::cosh. See the function documentation for more details.
cross
Macro generated for the function crate::linalg::cross. See the function documentation for more details.
cummax
Macro generated for the function crate::ops::cummax. See the function documentation for more details.
cummin
Macro generated for the function crate::ops::cummin. See the function documentation for more details.
cumprod
Macro generated for the function crate::ops::cumprod. See the function documentation for more details.
cumsum
Macro generated for the function crate::ops::cumsum. See the function documentation for more details.
degrees
Macro generated for the function crate::ops::degrees. See the function documentation for more details.
dequantize
Macro generated for the function crate::ops::dequantize. See the function documentation for more details.
diag
Macro generated for the function crate::ops::diag. See the function documentation for more details.
diagonal
Macro generated for the function crate::ops::diagonal. See the function documentation for more details.
divide
Macro generated for the function crate::ops::divide. See the function documentation for more details.
divmod
Macro generated for the function crate::ops::divmod. See the function documentation for more details.
eigh
Macro generated for the function crate::linalg::eigh. See the function documentation for more details.
eigvalsh
Macro generated for the function crate::linalg::eigvalsh. See the function documentation for more details.
einsum
Macro generated for the function crate::ops::einsum. See the function documentation for more details.
eq
Macro generated for the function crate::ops::eq. See the function documentation for more details.
erf
Macro generated for the function crate::ops::erf. See the function documentation for more details.
erfinv
Macro generated for the function crate::ops::erfinv. See the function documentation for more details.
exp
Macro generated for the function crate::ops::exp. See the function documentation for more details.
expand_dims
Macro generated for the function crate::ops::expand_dims. See the function documentation for more details.
expm1
Macro generated for the function crate::ops::expm1. See the function documentation for more details.
eye
Macro generated for the function crate::ops::eye. See the function documentation for more details.
fft
Macro generated for the function crate::fft::fft. See the function documentation for more details.
fft2
Macro generated for the function crate::fft::fft2. See the function documentation for more details.
fftn
Macro generated for the function crate::fft::fftn. See the function documentation for more details.
flatten
Macro generated for the function crate::ops::flatten. See the function documentation for more details.
floor
Macro generated for the function crate::ops::floor. See the function documentation for more details.
floor_divide
Macro generated for the function crate::ops::floor_divide. See the function documentation for more details.
full
Macro generated for the function crate::ops::full. See the function documentation for more details.
ge
Macro generated for the function crate::ops::ge. See the function documentation for more details.
gt
Macro generated for the function crate::ops::gt. See the function documentation for more details.
gumbel
Macro generated for the function crate::random::gumbel. See the function documentation for more details.
identity
Macro generated for the function crate::ops::identity. See the function documentation for more details.
ifft
Macro generated for the function crate::fft::ifft. See the function documentation for more details.
ifft2
Macro generated for the function crate::fft::ifft2. See the function documentation for more details.
ifftn
Macro generated for the function crate::fft::ifftn. See the function documentation for more details.
inner
Macro generated for the function crate::ops::inner. See the function documentation for more details.
inv
Macro generated for the function crate::linalg::inv. See the function documentation for more details.
irfft
Macro generated for the function crate::fft::irfft. See the function documentation for more details.
irfft2
Macro generated for the function crate::fft::irfft2. See the function documentation for more details.
irfftn
Macro generated for the function crate::fft::irfftn. See the function documentation for more details.
is_close
Macro generated for the function crate::ops::is_close. See the function documentation for more details.
is_inf
Macro generated for the function crate::ops::is_inf. See the function documentation for more details.
is_nan
Macro generated for the function crate::ops::is_nan. See the function documentation for more details.
is_neg_inf
Macro generated for the function crate::ops::is_neg_inf. See the function documentation for more details.
is_pos_inf
Macro generated for the function crate::ops::is_pos_inf. See the function documentation for more details.
kron
Macro generated for the function crate::ops::kron. See the function documentation for more details.
le
Macro generated for the function crate::ops::le. See the function documentation for more details.
linspace
Macro generated for the function crate::ops::linspace. See the function documentation for more details.
log
Macro generated for the function crate::ops::log. See the function documentation for more details.
log2
Macro generated for the function crate::ops::log2. See the function documentation for more details.
log1p
Macro generated for the function crate::ops::log1p. See the function documentation for more details.
log10
Macro generated for the function crate::ops::log10. See the function documentation for more details.
log_add_exp
Macro generated for the function crate::ops::log_add_exp. See the function documentation for more details.
log_sum_exp
Macro generated for the function crate::ops::log_sum_exp. See the function documentation for more details.
logical_and
Macro generated for the function crate::ops::logical_and. See the function documentation for more details.
logical_not
Macro generated for the function crate::ops::logical_not. See the function documentation for more details.
logical_or
Macro generated for the function crate::ops::logical_or. See the function documentation for more details.
lt
Macro generated for the function crate::ops::lt. See the function documentation for more details.
lu
Macro generated for the function crate::linalg::lu. See the function documentation for more details.
lu_factor
Macro generated for the function crate::linalg::lu_factor. See the function documentation for more details.
matmul
Macro generated for the function crate::ops::matmul. See the function documentation for more details.
max
Macro generated for the function crate::ops::max. See the function documentation for more details.
maximum
Macro generated for the function crate::ops::maximum. See the function documentation for more details.
mean
Macro generated for the function crate::ops::mean. See the function documentation for more details.
min
Macro generated for the function crate::ops::min. See the function documentation for more details.
minimum
Macro generated for the function crate::ops::minimum. See the function documentation for more details.
move_axis
Macro generated for the function crate::ops::move_axis. See the function documentation for more details.
multiply
Macro generated for the function crate::ops::multiply. See the function documentation for more details.
multivariate_normal
Macro generated for the function crate::random::multivariate_normal. See the function documentation for more details.
ne
Macro generated for the function crate::ops::ne. See the function documentation for more details.
negative
Macro generated for the function crate::ops::negative. See the function documentation for more details.
norm
Macro generated for the function crate::linalg::norm. See the function documentation for more details.
norm_ord
Macro generated for the function crate::linalg::norm_ord. See the function documentation for more details.
norm_p
Macro generated for the function crate::linalg::norm_p. See the function documentation for more details.
normal
Macro generated for the function crate::random::normal. See the function documentation for more details.
ones
Macro generated for the function crate::ops::ones. See the function documentation for more details.
ones_dtype
Macro generated for the function crate::ops::ones_dtype. See the function documentation for more details.
ones_like
Macro generated for the function crate::ops::ones_like. See the function documentation for more details.
outer
Macro generated for the function crate::ops::outer. See the function documentation for more details.
pad
Macro generated for the function crate::ops::pad. See the function documentation for more details.
partition
Macro generated for the function crate::ops::partition. See the function documentation for more details.
partition_all
Macro generated for the function crate::ops::partition_all. See the function documentation for more details.
pinv
Macro generated for the function crate::linalg::pinv. See the function documentation for more details.
power
Macro generated for the function crate::ops::power. See the function documentation for more details.
prod
Macro generated for the function crate::ops::prod. See the function documentation for more details.
put_along_axis
Macro generated for the function crate::ops::indexing::put_along_axis. See the function documentation for more details.
qr
Macro generated for the function crate::linalg::qr. See the function documentation for more details.
quantize
Macro generated for the function crate::ops::quantize. See the function documentation for more details.
quantized_matmul
Macro generated for the function crate::ops::quantized_matmul. See the function documentation for more details.
radians
Macro generated for the function crate::ops::radians. See the function documentation for more details.
randint
Macro generated for the function crate::random::randint. See the function documentation for more details.
reciprocal
Macro generated for the function crate::ops::reciprocal. See the function documentation for more details.
remainder
Macro generated for the function crate::ops::remainder. See the function documentation for more details.
repeat
Macro generated for the function crate::ops::repeat. See the function documentation for more details.
repeat_all
Macro generated for the function crate::ops::repeat_all. See the function documentation for more details.
reshape
Macro generated for the function crate::ops::reshape. See the function documentation for more details.
rfft
Macro generated for the function crate::fft::rfft. See the function documentation for more details.
rfft2
Macro generated for the function crate::fft::rfft2. See the function documentation for more details.
rfftn
Macro generated for the function crate::fft::rfftn. See the function documentation for more details.
round
Macro generated for the function crate::ops::round. See the function documentation for more details.
rsqrt
Macro generated for the function crate::ops::rsqrt. See the function documentation for more details.
sigmoid
Macro generated for the function crate::ops::sigmoid. See the function documentation for more details.
sign
Macro generated for the function crate::ops::sign. See the function documentation for more details.
sin
Macro generated for the function crate::ops::sin. See the function documentation for more details.
sinh
Macro generated for the function crate::ops::sinh. See the function documentation for more details.
softmax
Macro generated for the function crate::ops::softmax. See the function documentation for more details.
softmax_all
Macro generated for the function crate::ops::softmax_all. See the function documentation for more details.
solve
Macro generated for the function crate::linalg::solve. See the function documentation for more details.
solve_triangular
Macro generated for the function crate::linalg::solve_triangular. See the function documentation for more details.
sort
Macro generated for the function crate::ops::sort. See the function documentation for more details.
sort_all
Macro generated for the function crate::ops::sort_all. See the function documentation for more details.
split
Macro generated for the function crate::ops::split. See the function documentation for more details.
split_equal
Macro generated for the function crate::ops::split_equal. See the function documentation for more details.
sqrt
Macro generated for the function crate::ops::sqrt. See the function documentation for more details.
square
Macro generated for the function crate::ops::square. See the function documentation for more details.
squeeze
Macro generated for the function crate::ops::squeeze. See the function documentation for more details.
stack
Macro generated for the function crate::ops::stack. See the function documentation for more details.
stack_all
Macro generated for the function crate::ops::stack_all. See the function documentation for more details.
std
Macro generated for the function crate::ops::std. See the function documentation for more details.
subtract
Macro generated for the function crate::ops::subtract. See the function documentation for more details.
sum
Macro generated for the function crate::ops::sum. See the function documentation for more details.
svd
Macro generated for the function crate::linalg::svd. See the function documentation for more details.
swap_axes
Macro generated for the function crate::ops::swap_axes. See the function documentation for more details.
take
Macro generated for the function crate::ops::indexing::take. See the function documentation for more details.
take_all
Macro generated for the function crate::ops::indexing::take_all. See the function documentation for more details.
take_along_axis
Macro generated for the function crate::ops::indexing::take_along_axis. See the function documentation for more details.
tan
Macro generated for the function crate::ops::tan. See the function documentation for more details.
tanh
Macro generated for the function crate::ops::tanh. See the function documentation for more details.
tensordot
Macro generated for the function crate::ops::tensordot. See the function documentation for more details.
tile
Macro generated for the function crate::ops::tile. See the function documentation for more details.
topk
Macro generated for the function crate::ops::indexing::topk. See the function documentation for more details.
topk_all
Macro generated for the function crate::ops::indexing::topk_all. See the function documentation for more details.
transpose
Macro generated for the function crate::ops::transpose. See the function documentation for more details.
transpose_all
Macro generated for the function crate::ops::transpose_all. See the function documentation for more details.
tri
Macro generated for the function crate::ops::tri. See the function documentation for more details.
tri_inv
Macro generated for the function crate::linalg::tri_inv. See the function documentation for more details.
tril
Macro generated for the function crate::ops::tril. See the function documentation for more details.
triu
Macro generated for the function crate::ops::triu. See the function documentation for more details.
truncated_normal
Macro generated for the function crate::random::truncated_normal. See the function documentation for more details.
unflatten
Macro generated for the function crate::ops::unflatten. See the function documentation for more details.
uniform
Macro generated for the function crate::random::uniform. See the function documentation for more details.
variance
Macro generated for the function crate::ops::variance. See the function documentation for more details.
which
Macro generated for the function crate::ops::which. See the function documentation for more details.
zeros
Macro generated for the function crate::ops::zeros. See the function documentation for more details.
zeros_dtype
Macro generated for the function crate::ops::zeros_dtype. See the function documentation for more details.
zeros_like
Macro generated for the function crate::ops::zeros_like. See the function documentation for more details.

Structs§

Array
An n-dimensional array.
Device
Representation of a Device in MLX.
DtypeIter
An iterator over the variants of Dtype
Stream
A stream of evaluation attached to a particular device.
StreamOrDevice
Parameter type for all MLX operations.

Enums§

DeviceType
Type of device.
Dtype
Array element type

Traits§

ArrayElement
A marker trait for array elements.
FromNested
A helper trait to construct Array from nested arrays or slices.
FromScalar
A helper trait to construct Array from scalar values.

Functions§

stop_gradient
Stop gradients from being computed.
stop_gradient_device
Stop gradients from being computed.

Type Aliases§

complex64
Type alias for num_complex::Complex<f32>.