mlx_rs/transforms/
mod.rs

1//! Function transforms
2//!
3//! This mod provides functions for automatic differentiation and other
4//! transformations on functions.
5//!
6//! **WARN**: Because function transforms including compilation works on
7//! the computation graph, the user must ensure that all `Array`s are passed
8//! as inputs to the function/closure. Closures with captured `Array`s may
9//! not work as expected and may lead to undefined behavior.
10//!
11//! # Automatic Differentiation
12//!
13//! Automatic differentiation in MLX works on functions rather than on implicit
14//! graphs.
15//!
16//! **NOTE**: If you are coming to MLX from PyTorch, you no longer need
17//! functions like backward, zero_grad, and detach, or properties like
18//! requires_grad.
19//!
20//! You can use the [`grad()`] and [`value_and_grad()`] function to compute
21//! gradients of more complex functions. These functions compute the gradient
22//! with respect to the first argument, in order to manually specify the the
23//! argument to compute the gradient with respect to, use
24//! [`grad_with_argnums()`] or [`value_and_grad_with_argnums()`].
25//!
26//! TODO: update the example once https://github.com/oxideai/mlx-rs/pull/218 is merged
27//!
28//! ```rust,ignore
29//! use mlx_rs::{Array, error::Result, transforms::grad};
30//!
31//! fn f(x: &Array) -> Result<Array> {
32//!     x.square()
33//! }
34//!
35//! fn calculate_grad(func: impl Fn(&Array) -> Result<Array>, arg: &Array) -> Result<Array> {
36//!     grad(&func, &[0])(arg)
37//! }
38//!
39//! let x = Array::from(1.5);
40//!
41//! let dfdx = calculate_grad(f, &x).unwrap();
42//! assert_eq!(dfdx.item::<f32>(), 2.0 * 1.5);
43//!
44//! let dfdx2 = calculate_grad(|args| calculate_grad(f, args), &x).unwrap();
45//! assert_eq!(dfdx2.item::<f32>(), 2.0);
46//! ```
47
48use mlx_sys::mlx_closure_value_and_grad;
49
50use crate::{
51    error::{get_and_clear_closure_error, Result},
52    module::ModuleParamRef,
53    utils::{guard::Guarded, Closure, VectorArray},
54    Array,
55};
56
57pub mod compile;
58mod grad;
59mod keyed_value_and_grad;
60mod value_and_grad;
61
62pub use grad::*;
63pub use keyed_value_and_grad::*;
64pub use value_and_grad::*;
65
66/// Evaluate an iterator of [`Array`]s.
67pub fn eval<'a>(outputs: impl IntoIterator<Item = &'a Array>) -> Result<()> {
68    let vec = VectorArray::try_from_iter(outputs.into_iter())?;
69    <() as Guarded>::try_from_op(|_| unsafe { mlx_sys::mlx_eval(vec.as_ptr()) })
70}
71
72/// Evaluate a module's parameters.
73///
74/// This is a convenience function that flattens the parameters and evaluates them.
75pub fn eval_params(params: ModuleParamRef<'_>) -> Result<()> {
76    eval(params.flatten().values().copied())
77}
78
79/// Asynchronously evaluate an iterator of [`Array`]s.
80///
81/// Please note that this is not a rust async function.
82pub fn async_eval<'a>(outputs: impl IntoIterator<Item = &'a Array>) -> Result<()> {
83    let vec = VectorArray::try_from_iter(outputs.into_iter())?;
84    <() as Guarded>::try_from_op(|_| unsafe { mlx_sys::mlx_async_eval(vec.as_ptr()) })
85}
86
87/// Asynchronously evaluate a module's parameters.
88///
89/// This is a convenience function that flattens the parameters and evaluates them.
90pub fn async_eval_params(params: ModuleParamRef<'_>) -> Result<()> {
91    async_eval(params.flatten().values().copied())
92}
93
94#[inline]
95fn jvp_inner(
96    closure: Closure<'_>,
97    primals: &[Array],
98    tangents: &[Array],
99) -> Result<(Vec<Array>, Vec<Array>)> {
100    let c_primals = VectorArray::try_from_iter(primals.iter())?;
101    let c_tangents = VectorArray::try_from_iter(tangents.iter())?;
102
103    <(Vec<Array>, Vec<Array>) as Guarded>::try_from_op(|(res_0, res_1)| unsafe {
104        mlx_sys::mlx_jvp(
105            res_0,
106            res_1,
107            closure.as_ptr(),
108            c_primals.as_ptr(),
109            c_tangents.as_ptr(),
110        )
111    })
112    .map_err(|e| match get_and_clear_closure_error() {
113        Some(err) => err,
114        None => e,
115    })
116}
117
118/// Compute the Jacobian-vector product.
119///
120/// This computes the product of the Jacobian of a function `f` evaluated at
121/// `primals` with the `tangents`.
122///
123/// # Params:
124///
125/// - `f`: function which takes an array of `Array` and returns an array of
126///   `Array`
127/// - `primals`: array of `Array` at which to evaluate the Jacobian
128/// - `tangents`: array of `Array` which are the "vector" in the Jacobian-vector
129///   product.  The `tangents` should be the same in number, shape and type as
130///   the inputs of `f`, e.g. the `primals`
131///
132/// # Returns:
133///
134/// Array of the Jacobian-vector products which is the same in number, shape and
135/// type of the outputs of `f`
136pub fn jvp<'a, F>(f: F, primals: &[Array], tangents: &[Array]) -> Result<(Vec<Array>, Vec<Array>)>
137where
138    F: FnMut(&[Array]) -> Vec<Array> + 'a,
139{
140    let closure = Closure::new(f);
141    jvp_inner(closure, primals, tangents)
142}
143
144/// Similar to [`jvp`] but handles closures that can return an error.
145pub fn fallible_jvp<'a, F>(
146    f: F,
147    primals: &[Array],
148    tangents: &[Array],
149) -> Result<(Vec<Array>, Vec<Array>)>
150where
151    F: FnMut(&[Array]) -> Result<Vec<Array>> + 'a,
152{
153    let closure = Closure::new_fallible(f);
154    jvp_inner(closure, primals, tangents)
155}
156
157#[inline]
158fn vjp_inner(
159    closure: Closure<'_>,
160    primals: &[Array],
161    cotangents: &[Array],
162) -> Result<(Vec<Array>, Vec<Array>)> {
163    let c_primals = VectorArray::try_from_iter(primals.iter())?;
164    let c_cotangents = VectorArray::try_from_iter(cotangents.iter())?;
165
166    <(Vec<Array>, Vec<Array>) as Guarded>::try_from_op(|(res_0, res_1)| unsafe {
167        mlx_sys::mlx_vjp(
168            res_0,
169            res_1,
170            closure.as_ptr(),
171            c_primals.as_ptr(),
172            c_cotangents.as_ptr(),
173        )
174    })
175    .map_err(|e| match get_and_clear_closure_error() {
176        Some(err) => err,
177        None => e,
178    })
179}
180
181/// Compute the vector-Jacobian product.
182///
183/// Computes the product of the `cotangents` with the Jacobian of a function `f` evaluated at
184/// `primals`.
185///
186/// # Params:
187///
188/// - f: function which takes an array of `Array` and returns an array of `Array`
189/// - primals: array of `Array` at which to evaluate the Jacobian
190/// - cotangents: array of `Array` which are the "vector" in the vector-Jacobian product. The
191///   `cotangents` should be the same in number, shape and type as the outputs of `f`
192///
193/// # Returns:
194///
195/// array of the vector-Jacobian products which is the same in number, shape and type of the outputs
196/// of `f`
197pub fn vjp<'a, F>(f: F, primals: &[Array], cotangents: &[Array]) -> Result<(Vec<Array>, Vec<Array>)>
198where
199    F: FnMut(&[Array]) -> Vec<Array> + 'a,
200{
201    let closure = Closure::new(f);
202    vjp_inner(closure, primals, cotangents)
203}
204
205/// Similar to [`vjp`] but handles closures that can return an error.
206pub fn fallible_vjp<'a, F>(
207    f: F,
208    primals: &[Array],
209    cotangents: &[Array],
210) -> Result<(Vec<Array>, Vec<Array>)>
211where
212    F: FnMut(&[Array]) -> Result<Vec<Array>> + 'a,
213{
214    let closure = Closure::new_fallible(f);
215    vjp_inner(closure, primals, cotangents)
216}
217
218pub(crate) struct ClosureValueAndGrad {
219    pub(crate) c_closure_value_and_grad: mlx_closure_value_and_grad,
220}
221
222impl ClosureValueAndGrad {
223    pub fn as_ptr(&self) -> mlx_closure_value_and_grad {
224        self.c_closure_value_and_grad
225    }
226}
227
228fn value_and_gradient(
229    value_and_grad: mlx_closure_value_and_grad,
230    arrays: impl Iterator<Item = impl AsRef<Array>>,
231) -> Result<(Vec<Array>, Vec<Array>)> {
232    let input_vector = VectorArray::try_from_iter(arrays)?;
233
234    <(Vec<Array>, Vec<Array>) as Guarded>::try_from_op(|(res_0, res_1)| unsafe {
235        mlx_sys::mlx_closure_value_and_grad_apply(
236            res_0,
237            res_1,
238            value_and_grad,
239            input_vector.as_ptr(),
240        )
241    })
242    .map_err(|e| match get_and_clear_closure_error() {
243        Some(err) => err,
244        None => e,
245    })
246}
247
248#[cfg(test)]
249mod tests {
250
251    use crate::{
252        array,
253        transforms::{jvp, vjp},
254        Array,
255    };
256
257    use super::*;
258
259    // The unit tests below are adapted from the mlx c++ codebase
260
261    #[test]
262    fn test_jvp() {
263        let f = |inputs: &[Array]| -> Vec<Array> { vec![&inputs[0] + &inputs[1]] };
264        let x = array!(1.0f32);
265        let y = array!(1.0f32);
266        let (out, dout) = jvp(f, &[x, y], &[array!(1.0f32), array!(3.0f32)]).unwrap();
267        assert_eq!(out[0].item::<f32>(), 2.0f32);
268        assert_eq!(dout[0].item::<f32>(), 4.0f32);
269    }
270
271    #[test]
272    fn test_jvp_with_error() {
273        let f = |inputs: &[Array]| -> Result<Vec<Array>> {
274            inputs[0].add(&inputs[1]).map(|res| vec![res])
275        };
276
277        // Success case
278        let x = array!(1.0f32);
279        let y = array!(1.0f32);
280        let (out, dout) = fallible_jvp(f, &[x, y], &[array!(1.0f32), array!(3.0f32)]).unwrap();
281        assert_eq!(out[0].item::<f32>(), 2.0f32);
282        assert_eq!(dout[0].item::<f32>(), 4.0f32);
283
284        // Error case
285        // Use non-broadcastable shapes
286        let a = array!([1.0, 2.0, 3.0]);
287        let b = array!([4.0, 5.0]);
288        let result = fallible_jvp(f, &[a, b], &[array!(1.0f32), array!(3.0f32)]);
289        assert!(result.is_err());
290
291        // Check that the error is not just "mlx_closure returned a non-zero value"
292        let err = result.unwrap_err();
293        assert!(!err.what().contains("non-zero value"))
294    }
295
296    #[test]
297    fn test_vjp() {
298        let f = |inputs: &[Array]| -> Vec<Array> { vec![&inputs[0] + &inputs[1]] };
299        let x = array!(1.0f32);
300        let y = array!(1.0f32);
301        let primals = vec![x, y];
302        let cotangents = vec![array!(1.0f32)];
303        let (out, dout) = vjp(f, &primals, &cotangents).unwrap();
304        assert_eq!(out[0].item::<f32>(), 2.0f32);
305        assert_eq!(dout[0].item::<f32>(), 1.0f32);
306    }
307
308    #[test]
309    fn test_vjp_with_error() {
310        let f = |inputs: &[Array]| -> Result<Vec<Array>> {
311            inputs[0].add(&inputs[1]).map(|res| vec![res])
312        };
313
314        // Success case
315        let x = array!(1.0f32);
316        let y = array!(1.0f32);
317        let primals = vec![x, y];
318        let cotangents = vec![array!(1.0f32)];
319        let (out, dout) = fallible_vjp(f, &primals, &cotangents).unwrap();
320        assert_eq!(out[0].item::<f32>(), 2.0f32);
321        assert_eq!(dout[0].item::<f32>(), 1.0f32);
322
323        // Error case
324        // Use non-broadcastable shapes
325        let a = array!([1.0, 2.0, 3.0]);
326        let b = array!([4.0, 5.0]);
327        let result = fallible_vjp(f, &[a, b], &[array!(1.0f32)]);
328        assert!(result.is_err());
329
330        // Check that the error is not just "mlx_closure returned a non-zero value"
331        let err = result.unwrap_err();
332        assert!(!err.what().contains("non-zero value"))
333    }
334}