mlx_rs/transforms/compile/
compile.rs

1//! Compilation of functions.
2
3// TODO: there's plenty boilerplate code here but it's not clear how to reduce it
4
5use std::marker::PhantomData;
6
7use crate::{error::Exception, Array};
8
9use super::{type_id_to_usize, Closure, Compiled, CompiledState, Guarded, VectorArray};
10
11/// Returns a compiled function that produces the same output as `f`.
12///
13/// Please refer to the [swift binding
14/// documentation](https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/compilation)
15/// for more information.
16pub fn compile<F, A, O, E>(
17    f: F,
18    shapeless: impl Into<Option<bool>>,
19) -> impl for<'a> FnMut(F::Args<'a>) -> Result<O, Exception>
20where
21    F: Compile<A, O, E> + 'static + Copy,
22{
23    let shapeless = shapeless.into().unwrap_or(false);
24    move |args| {
25        // NOTE: we have to place this here to avoid the lifetime issue
26        // `f.compile` will look up the cached compiled function so it shouldn't result in re-compilation
27        let mut compiled = f.compile(shapeless);
28        compiled.call_mut(args)
29    }
30}
31
32/// A trait for functions that can be compiled.
33///
34/// # Generic parameters
35///
36/// - `A`: The type of the array arguments
37/// - `O`: The type of the output
38/// - `E`: The type of the error
39pub trait Compile<A, O, E>: Sized {
40    /// The type of the arguments that the returned closure takes.
41    ///
42    /// This is needed to relax the lifetime requirements of the returned
43    /// closure. Otherwise, the arguments to the returned closure would have to
44    /// live longer than the closure itself.
45    type Args<'a>;
46
47    /// Compiles the function.
48    fn compile<'args>(self, shapeless: bool) -> impl CallMut<Self::Args<'args>, O, E>;
49}
50
51impl<F> Compile<&[Array], Vec<Array>, ()> for F
52where
53    F: FnMut(&[Array]) -> Vec<Array> + 'static,
54{
55    type Args<'a> = &'a [Array];
56
57    fn compile<'args>(self, shapeless: bool) -> impl CallMut<Self::Args<'args>, Vec<Array>, ()> {
58        let id = type_id_to_usize(&self);
59        let state = CompiledState {
60            f: self,
61
62            shapeless,
63            id,
64        };
65        Compiled {
66            f_marker: PhantomData::<F>,
67            state,
68        }
69    }
70}
71
72impl<F> Compile<&Array, Array, ()> for F
73where
74    F: FnMut(&Array) -> Array + 'static,
75{
76    type Args<'a> = &'a Array;
77
78    fn compile<'args>(mut self, shapeless: bool) -> impl CallMut<Self::Args<'args>, Array, ()> {
79        let id = type_id_to_usize(&self);
80        let f = move |args: &[Array]| -> Vec<Array> {
81            let result = (self)(&args[0]);
82            vec![result]
83        };
84        let state = CompiledState { f, shapeless, id };
85        Compiled {
86            f_marker: PhantomData::<F>,
87            state,
88        }
89    }
90}
91
92impl<F> Compile<(&Array, &Array), Array, ()> for F
93where
94    F: FnMut((&Array, &Array)) -> Array + 'static,
95{
96    type Args<'a> = (&'a Array, &'a Array);
97
98    fn compile<'args>(mut self, shapeless: bool) -> impl CallMut<Self::Args<'args>, Array, ()> {
99        let id = type_id_to_usize(&self);
100        let f = move |args: &[Array]| -> Vec<Array> {
101            let result = (self)((&args[0], &args[1]));
102            vec![result]
103        };
104        let state = CompiledState { f, shapeless, id };
105        Compiled {
106            f_marker: PhantomData::<F>,
107            state,
108        }
109    }
110}
111
112impl<F> Compile<(&Array, &Array, &Array), Array, ()> for F
113where
114    F: FnMut((&Array, &Array, &Array)) -> Array + 'static,
115{
116    type Args<'a> = (&'a Array, &'a Array, &'a Array);
117
118    fn compile<'args>(mut self, shapeless: bool) -> impl CallMut<Self::Args<'args>, Array, ()> {
119        let id = type_id_to_usize(&self);
120        let f = move |args: &[Array]| -> Vec<Array> {
121            let result = (self)((&args[0], &args[1], &args[2]));
122            vec![result]
123        };
124        let state = CompiledState { f, shapeless, id };
125        Compiled {
126            f_marker: PhantomData::<F>,
127            state,
128        }
129    }
130}
131
132impl<F> Compile<&[Array], Vec<Array>, Exception> for F
133where
134    F: FnMut(&[Array]) -> Result<Vec<Array>, Exception> + 'static,
135{
136    type Args<'a> = &'a [Array];
137
138    fn compile<'args>(
139        self,
140        shapeless: bool,
141    ) -> impl CallMut<Self::Args<'args>, Vec<Array>, Exception> {
142        let id = type_id_to_usize(&self);
143        let state = CompiledState {
144            f: self,
145            shapeless,
146            id,
147        };
148        Compiled {
149            f_marker: PhantomData::<F>,
150            state,
151        }
152    }
153}
154
155impl<F> Compile<&Array, Array, Exception> for F
156where
157    F: FnMut(&Array) -> Result<Array, Exception> + 'static,
158{
159    type Args<'a> = &'a Array;
160
161    fn compile<'args>(
162        mut self,
163        shapeless: bool,
164    ) -> impl CallMut<Self::Args<'args>, Array, Exception> {
165        let id = type_id_to_usize(&self);
166        let f = move |args: &[Array]| -> Result<Vec<Array>, Exception> {
167            let result = (self)(&args[0])?;
168            Ok(vec![result])
169        };
170        let state = CompiledState { f, shapeless, id };
171        Compiled {
172            f_marker: PhantomData::<F>,
173            state,
174        }
175    }
176}
177
178impl<F> Compile<(&Array, &Array), Array, Exception> for F
179where
180    F: FnMut((&Array, &Array)) -> Result<Array, Exception> + 'static,
181{
182    type Args<'a> = (&'a Array, &'a Array);
183
184    fn compile<'args>(
185        mut self,
186        shapeless: bool,
187    ) -> impl CallMut<Self::Args<'args>, Array, Exception> {
188        let id = type_id_to_usize(&self);
189        let f = move |args: &[Array]| -> Result<Vec<Array>, Exception> {
190            let result = (self)((&args[0], &args[1]))?;
191            Ok(vec![result])
192        };
193        let state = CompiledState { f, shapeless, id };
194        Compiled {
195            f_marker: PhantomData::<F>,
196            state,
197        }
198    }
199}
200
201impl<F> Compile<(&Array, &Array, &Array), Array, Exception> for F
202where
203    F: FnMut((&Array, &Array, &Array)) -> Result<Array, Exception> + 'static,
204{
205    type Args<'a> = (&'a Array, &'a Array, &'a Array);
206
207    fn compile<'args>(
208        mut self,
209        shapeless: bool,
210    ) -> impl CallMut<Self::Args<'args>, Array, Exception> {
211        let id = type_id_to_usize(&self);
212        let f = move |args: &[Array]| -> Result<Vec<Array>, Exception> {
213            let result = (self)((&args[0], &args[1], &args[2]))?;
214            Ok(vec![result])
215        };
216        let state = CompiledState { f, shapeless, id };
217        Compiled {
218            f_marker: PhantomData::<F>,
219            state,
220        }
221    }
222}
223
224/// A trait for a compiled function that can be called.
225pub trait CallMut<A, O, E> {
226    /// Calls the compiled function with the given arguments.
227    fn call_mut(&mut self, args: A) -> Result<O, Exception>;
228}
229
230impl<'a, F, G> CallMut<&'a [Array], Vec<Array>, ()> for Compiled<F, G>
231where
232    F: FnMut(&[Array]) -> Vec<Array> + 'a,
233    G: FnMut(&[Array]) -> Vec<Array> + 'a,
234{
235    fn call_mut(&mut self, args: &[Array]) -> Result<Vec<Array>, Exception> {
236        self.state.call_mut(args)
237    }
238}
239
240impl<'a, F, G> CallMut<&'a Array, Array, ()> for Compiled<F, G>
241where
242    F: FnMut(&Array) -> Array + 'a,
243    G: FnMut(&[Array]) -> Vec<Array> + 'a,
244{
245    fn call_mut(&mut self, args: &Array) -> Result<Array, Exception> {
246        let args = std::slice::from_ref(args);
247        let result = self.state.call_mut(args)?;
248        Ok(result.into_iter().next().unwrap())
249    }
250}
251
252impl<'a, F, G> CallMut<(&'a Array, &'a Array), Array, ()> for Compiled<F, G>
253where
254    F: FnMut((&Array, &Array)) -> Array + 'a,
255    G: FnMut(&[Array]) -> Vec<Array> + 'a,
256{
257    fn call_mut(&mut self, args: (&Array, &Array)) -> Result<Array, Exception> {
258        let args = &[args.0, args.1];
259        let result = self.state.call_mut(args)?;
260        Ok(result.into_iter().next().unwrap())
261    }
262}
263
264impl<'a, F, G> CallMut<(&'a Array, &'a Array, &'a Array), Array, ()> for Compiled<F, G>
265where
266    F: FnMut((&Array, &Array, &Array)) -> Array + 'a,
267    G: FnMut(&[Array]) -> Vec<Array> + 'a,
268{
269    fn call_mut(&mut self, args: (&Array, &Array, &Array)) -> Result<Array, Exception> {
270        // Is there any way to avoid this shallow clone?
271        let args = &[args.0, args.1, args.2];
272        let result = self.state.call_mut(args)?;
273        Ok(result.into_iter().next().unwrap())
274    }
275}
276
277impl<'a, F, G> CallMut<&'a [Array], Vec<Array>, Exception> for Compiled<F, G>
278where
279    F: FnMut(&[Array]) -> Result<Vec<Array>, Exception> + 'a,
280    G: FnMut(&[Array]) -> Result<Vec<Array>, Exception> + 'a,
281{
282    fn call_mut(&mut self, args: &[Array]) -> Result<Vec<Array>, Exception> {
283        self.state.fallible_call_mut(args)
284    }
285}
286
287impl<'a, F, G> CallMut<&'a Array, Array, Exception> for Compiled<F, G>
288where
289    F: FnMut(&Array) -> Result<Array, Exception> + 'a,
290    G: FnMut(&[Array]) -> Result<Vec<Array>, Exception> + 'a,
291{
292    fn call_mut(&mut self, args: &Array) -> Result<Array, Exception> {
293        let args = &[args];
294        let result = self.state.fallible_call_mut(args)?;
295        Ok(result.into_iter().next().unwrap())
296    }
297}
298
299impl<'a, F, G> CallMut<(&'a Array, &'a Array), Array, Exception> for Compiled<F, G>
300where
301    F: FnMut((&Array, &Array)) -> Result<Array, Exception> + 'a,
302    G: FnMut(&[Array]) -> Result<Vec<Array>, Exception> + 'a,
303{
304    fn call_mut(&mut self, args: (&Array, &Array)) -> Result<Array, Exception> {
305        let args = &[args.0, args.1];
306        let result = self.state.fallible_call_mut(args)?;
307        Ok(result.into_iter().next().unwrap())
308    }
309}
310
311impl<'a, F, G> CallMut<(&'a Array, &'a Array, &'a Array), Array, Exception> for Compiled<F, G>
312where
313    F: FnMut((&Array, &Array, &Array)) -> Result<Array, Exception> + 'a,
314    G: FnMut(&[Array]) -> Result<Vec<Array>, Exception> + 'a,
315{
316    fn call_mut(&mut self, args: (&Array, &Array, &Array)) -> Result<Array, Exception> {
317        let args = &[args.0, args.1, args.2];
318        let result = self.state.fallible_call_mut(args)?;
319        Ok(result.into_iter().next().unwrap())
320    }
321}
322
323#[inline]
324fn call_mut_inner(
325    inner_closure: Closure,
326    fun_id: usize,
327    shapeless: bool,
328    args: &[impl AsRef<Array>],
329) -> crate::error::Result<Vec<Array>> {
330    // note: this will use the cached compile (via the id)
331    // but will be able to re-evaluate with fresh state if needed
332    let compiled = Closure::try_from_op(|res| unsafe {
333        let constants = &[];
334        mlx_sys::mlx_detail_compile(
335            res,
336            inner_closure.as_ptr(),
337            fun_id,
338            shapeless,
339            constants.as_ptr(),
340            0,
341        )
342    })?;
343
344    let inner_inputs_vector = VectorArray::try_from_iter(args.iter())?;
345
346    // will compile the function (if needed) and evaluate the
347    // compiled graph
348    let result_vector = VectorArray::try_from_op(|res| unsafe {
349        mlx_sys::mlx_closure_apply(res, compiled.as_ptr(), inner_inputs_vector.as_ptr())
350    })?;
351    let result_plus_state_output: Vec<Array> = result_vector.try_into_values()?;
352
353    let result_len = result_plus_state_output.len();
354    Ok(result_plus_state_output
355        .into_iter()
356        .take(result_len)
357        .collect())
358}
359
360impl<F> CompiledState<F> {
361    fn call_mut(&mut self, args: &[impl AsRef<Array>]) -> Result<Vec<Array>, Exception>
362    where
363        F: FnMut(&[Array]) -> Vec<Array>,
364    {
365        let inner_closure = Closure::new(&mut self.f);
366
367        call_mut_inner(inner_closure, self.id, self.shapeless, args)
368    }
369
370    fn fallible_call_mut(&mut self, args: &[impl AsRef<Array>]) -> Result<Vec<Array>, Exception>
371    where
372        F: FnMut(&[Array]) -> Result<Vec<Array>, Exception>,
373    {
374        let inner_closure = Closure::new_fallible(&mut self.f);
375
376        call_mut_inner(inner_closure, self.id, self.shapeless, args)
377    }
378}
379
380#[cfg(test)]
381mod tests {
382    use core::panic;
383
384    use crate::{
385        array,
386        error::Exception,
387        ops::{multiply, ones},
388        Array,
389    };
390
391    use super::compile;
392
393    fn example_fn_0(x: f32) -> f32 {
394        x + 1.0
395    }
396
397    fn example_fn_3(x: f32) -> f32 {
398        x + 1.0
399    }
400
401    #[test]
402    fn test_type_id_to_usize() {
403        // We would like to check that different functions that share the same signature can produce
404        // different ids
405
406        let example_fn_1 = |x: f32| x + 1.0;
407        let example_fn_2 = |x: f32| x + 1.0;
408
409        let mut ids = Vec::new();
410
411        ids.push(super::type_id_to_usize(&example_fn_0));
412
413        let id1 = super::type_id_to_usize(&example_fn_1);
414        if ids.contains(&id1) {
415            panic!("id1 already exists");
416        }
417        ids.push(id1);
418
419        let id2 = super::type_id_to_usize(&example_fn_2);
420        if ids.contains(&id2) {
421            panic!("id2 already exists");
422        }
423        ids.push(id2);
424
425        let id3 = super::type_id_to_usize(&example_fn_3);
426        if ids.contains(&id3) {
427            panic!("id3 already exists");
428        }
429        ids.push(id3);
430
431        assert_eq!(ids.len(), 4);
432    }
433
434    #[test]
435    fn test_compile() {
436        // This unit test is modified from the mlx-swift codebase
437
438        let f = |inputs: &[Array]| -> Vec<Array> { vec![&inputs[0] * &inputs[1]] };
439        let mut compiled = compile(f, None);
440
441        let i1 = ones::<f32>(&[20, 20]).unwrap();
442        let i2 = ones::<f32>(&[20, 20]).unwrap();
443
444        let args = [i1, i2];
445
446        // evaluate directly
447        let r1 = f(&args).drain(0..1).next().unwrap();
448        // evaluate compiled
449        let r2 = compiled(&args).unwrap().drain(0..1).next().unwrap();
450
451        assert_eq!(&r1, &r2);
452
453        let r3 = compiled(&args).unwrap().drain(0..1).next().unwrap();
454        assert_eq!(&r1, &r3);
455    }
456
457    #[test]
458    fn test_compile_with_error() {
459        let f = |inputs: &[Array]| -> Result<Vec<Array>, Exception> {
460            multiply(&inputs[0], &inputs[1]).map(|x| vec![x])
461        };
462
463        // Success case
464        let i1 = ones::<f32>(&[20, 20]).unwrap();
465        let i2 = ones::<f32>(&[20, 20]).unwrap();
466        let args = [i1, i2];
467
468        // evaluate directly
469        let r1 = f(&args).unwrap().drain(0..1).next().unwrap();
470
471        // evaluate compiled
472        let mut compiled = compile(f, None);
473        let r2 = compiled(&args).unwrap().drain(0..1).next().unwrap();
474
475        assert_eq!(&r1, &r2);
476
477        let r3 = compiled(&args).unwrap().drain(0..1).next().unwrap();
478        assert_eq!(&r1, &r3);
479
480        // Error case
481        let a = array!([1.0, 2.0, 3.0]);
482        let b = array!([4.0, 5.0]);
483        let args = [a, b];
484
485        // The cache is keyed by function pointer and argument shapes
486        let c = array!([4.0, 5.0, 6.0]);
487        let d = array!([7.0, 8.0]);
488        let another_args = [c, d];
489
490        // evaluate directly
491        let result = f(&args);
492        assert!(result.is_err());
493
494        // evaluate compiled
495        let mut compiled = compile(f, None);
496        let result = compiled(&args);
497        assert!(result.is_err());
498
499        let result = compiled(&args);
500        assert!(result.is_err());
501
502        let result = compiled(&another_args);
503        assert!(result.is_err());
504    }
505
506    #[test]
507    fn test_compile_with_one_arg() {
508        let f = |x: &Array| x * x;
509
510        let i = ones::<f32>(&[20, 20]).unwrap();
511
512        // evaluate directly
513        let r1 = f(&i);
514
515        // evaluate compiled
516        let mut compiled = compile(f, None);
517        let r2 = compiled(&i).unwrap();
518
519        assert_eq!(&r1, &r2);
520
521        let r3 = compiled(&i).unwrap();
522        assert_eq!(&r1, &r3);
523    }
524
525    #[test]
526    fn test_compile_with_two_args() {
527        let f = |(x, y): (&Array, &Array)| x * y;
528
529        let i1 = ones::<f32>(&[20, 20]).unwrap();
530        let i2 = ones::<f32>(&[20, 20]).unwrap();
531
532        // evaluate directly
533        let r1 = f((&i1, &i2));
534
535        // evaluate compiled
536        let mut compiled = compile(f, None);
537        let r2 = compiled((&i1, &i2)).unwrap();
538
539        assert_eq!(&r1, &r2);
540
541        let r3 = compiled((&i1, &i2)).unwrap();
542        assert_eq!(&r1, &r3);
543    }
544
545    #[test]
546    fn test_compile_with_three_args() {
547        let f = |(x, y, z): (&Array, &Array, &Array)| x * y * z;
548        let mut compiled = compile(f, None);
549
550        let i1 = ones::<f32>(&[20, 20]).unwrap();
551        let i2 = ones::<f32>(&[20, 20]).unwrap();
552        let i3 = ones::<f32>(&[20, 20]).unwrap();
553
554        // evaluate directly
555        let r1 = f((&i1, &i2, &i3));
556
557        // evaluate compiled
558        let r2 = compiled((&i1, &i2, &i3)).unwrap();
559
560        assert_eq!(&r1, &r2);
561
562        let r3 = compiled((&i1, &i2, &i3)).unwrap();
563        assert_eq!(&r1, &r3);
564    }
565}