mlx_rs/transforms/compile/
compile_with_state.rs

1//! Compilation of functions with state.
2//!
3//! # Unit tests
4//!
5//! See `mlx-rs/mlx-tests/tests/test_compile.rs` for unit tests.
6
7// TODO: there's plenty boilerplate code here but it's not clear how to reduce it
8
9use std::{cell::RefCell, marker::PhantomData, rc::Rc};
10
11use crate::{
12    error::Exception,
13    transforms::compile::{type_id_to_usize, CompiledState},
14    utils::Updatable,
15    Array,
16};
17
18use super::{update_by_replace_with_ref_to_new_array, Closure, Compiled, Guarded, VectorArray};
19
20/// Similar to [`crate::transforms::compile`] but allows for functions that take
21/// a mutable reference to a state `U`.
22pub fn compile_with_state<F, U, A, O, E>(
23    f: F,
24    shapeless: impl Into<Option<bool>>,
25) -> impl for<'a> FnMut(&mut U, F::Args<'a>) -> Result<O, Exception>
26where
27    F: CompileWithState<U, A, O, E> + Copy + 'static,
28    U: Updatable,
29{
30    let shapeless = shapeless.into().unwrap_or(false);
31    move |state, args| {
32        let mut compiled = f.compile(shapeless);
33        compiled.call_mut(state, args)
34    }
35}
36
37/// A trait for functions that can be compiled with state.
38///
39/// This trait is used to compile a function that takes a mutable reference to a state
40/// and some arguments and returns a result.
41///
42/// # Generic parameters
43///
44/// - `U`: The type of the state.
45/// - `A`: The type of the arguments.
46/// - `O`: The type of the output.
47/// - `E`: The type of the exception.
48pub trait CompileWithState<U, A, O, E> {
49    /// The type of the arguments that the returned closure takes.
50    ///
51    /// This is needed to relax the lifetime requirements of the returned
52    /// closure. Otherwise, the arguments to the returned closure would have to
53    /// live longer than the closure itself.
54    type Args<'a>;
55
56    /// Compile the function.
57    fn compile<'args>(self, shapeless: bool) -> impl CallMutWithState<U, Self::Args<'args>, O, E>;
58}
59
60impl<F, U> CompileWithState<U, &[Array], Vec<Array>, ()> for F
61where
62    F: FnMut(&mut U, &[Array]) -> Vec<Array> + 'static,
63    U: Updatable,
64{
65    type Args<'a> = &'a [Array];
66
67    fn compile<'args>(
68        self,
69        shapeless: bool,
70    ) -> impl CallMutWithState<U, Self::Args<'args>, Vec<Array>, ()> {
71        let id = type_id_to_usize(&self);
72        let state = CompiledState {
73            f: self,
74            shapeless,
75            id,
76        };
77        Compiled {
78            f_marker: PhantomData::<F>,
79            state,
80        }
81    }
82}
83
84impl<F, U> CompileWithState<U, &Array, Array, ()> for F
85where
86    F: FnMut(&mut U, &Array) -> Array + 'static,
87    U: Updatable,
88{
89    type Args<'a> = &'a Array;
90
91    fn compile<'args>(
92        mut self,
93        shapeless: bool,
94    ) -> impl CallMutWithState<U, Self::Args<'args>, Array, ()> {
95        let id = type_id_to_usize(&self);
96        let f = move |state: &mut U, args: &[Array]| -> Vec<Array> {
97            let result = (self)(state, &args[0]);
98            vec![result]
99        };
100        let state = CompiledState { f, shapeless, id };
101        Compiled {
102            f_marker: PhantomData::<F>,
103            state,
104        }
105    }
106}
107
108impl<F, U> CompileWithState<U, (&Array, &Array), Array, ()> for F
109where
110    F: FnMut(&mut U, (&Array, &Array)) -> Array + 'static,
111    U: Updatable,
112{
113    type Args<'a> = (&'a Array, &'a Array);
114
115    fn compile<'args>(
116        mut self,
117        shapeless: bool,
118    ) -> impl CallMutWithState<U, Self::Args<'args>, Array, ()> {
119        let id = type_id_to_usize(&self);
120        let f = move |state: &mut U, args: &[Array]| -> Vec<Array> {
121            let result = (self)(state, (&args[0], &args[1]));
122            vec![result]
123        };
124        let state = CompiledState { f, shapeless, id };
125        Compiled {
126            f_marker: PhantomData::<F>,
127            state,
128        }
129    }
130}
131
132impl<F, U> CompileWithState<U, (&Array, &Array, &Array), Array, ()> for F
133where
134    F: FnMut(&mut U, (&Array, &Array, &Array)) -> Array + 'static,
135    U: Updatable,
136{
137    type Args<'a> = (&'a Array, &'a Array, &'a Array);
138
139    fn compile<'args>(
140        mut self,
141        shapeless: bool,
142    ) -> impl CallMutWithState<U, Self::Args<'args>, Array, ()> {
143        let id = type_id_to_usize(&self);
144        let f = move |state: &mut U, args: &[Array]| -> Vec<Array> {
145            let result = (self)(state, (&args[0], &args[1], &args[2]));
146            vec![result]
147        };
148        let state = CompiledState { f, shapeless, id };
149        Compiled {
150            f_marker: PhantomData::<F>,
151            state,
152        }
153    }
154}
155
156impl<F, U> CompileWithState<U, &[Array], Vec<Array>, Exception> for F
157where
158    F: FnMut(&mut U, &[Array]) -> Result<Vec<Array>, Exception> + 'static,
159    U: Updatable,
160{
161    type Args<'a> = &'a [Array];
162
163    fn compile<'args>(
164        self,
165        shapeless: bool,
166    ) -> impl CallMutWithState<U, Self::Args<'args>, Vec<Array>, Exception> {
167        let id = type_id_to_usize(&self);
168        let state = CompiledState {
169            f: self,
170            shapeless,
171            id,
172        };
173        Compiled {
174            f_marker: PhantomData::<F>,
175            state,
176        }
177    }
178}
179
180impl<F, U> CompileWithState<U, &Array, Array, Exception> for F
181where
182    F: FnMut(&mut U, &Array) -> Result<Array, Exception> + 'static,
183    U: Updatable,
184{
185    type Args<'a> = &'a Array;
186
187    fn compile<'args>(
188        mut self,
189        shapeless: bool,
190    ) -> impl CallMutWithState<U, Self::Args<'args>, Array, Exception> {
191        let id = type_id_to_usize(&self);
192        let f = move |state: &mut U, args: &[Array]| -> Result<Vec<Array>, Exception> {
193            let result = (self)(state, &args[0])?;
194            Ok(vec![result])
195        };
196        let state = CompiledState { f, shapeless, id };
197        Compiled {
198            f_marker: PhantomData::<F>,
199            state,
200        }
201    }
202}
203
204impl<F, U> CompileWithState<U, (&Array, &Array), Array, Exception> for F
205where
206    F: FnMut(&mut U, (&Array, &Array)) -> Result<Array, Exception> + 'static,
207    U: Updatable,
208{
209    type Args<'a> = (&'a Array, &'a Array);
210
211    fn compile<'args>(
212        mut self,
213        shapeless: bool,
214    ) -> impl CallMutWithState<U, Self::Args<'args>, Array, Exception> {
215        let id = type_id_to_usize(&self);
216        let f = move |state: &mut U, args: &[Array]| -> Result<Vec<Array>, Exception> {
217            let result = (self)(state, (&args[0], &args[1]))?;
218            Ok(vec![result])
219        };
220        let state = CompiledState { f, shapeless, id };
221        Compiled {
222            f_marker: PhantomData::<F>,
223            state,
224        }
225    }
226}
227
228impl<F, U> CompileWithState<U, (&Array, &Array, &Array), Array, Exception> for F
229where
230    F: FnMut(&mut U, (&Array, &Array, &Array)) -> Result<Array, Exception> + 'static,
231    U: Updatable,
232{
233    type Args<'a> = (&'a Array, &'a Array, &'a Array);
234
235    fn compile<'args>(
236        mut self,
237        shapeless: bool,
238    ) -> impl CallMutWithState<U, Self::Args<'args>, Array, Exception> {
239        let id = type_id_to_usize(&self);
240        let f = move |state: &mut U, args: &[Array]| -> Result<Vec<Array>, Exception> {
241            let result = (self)(state, (&args[0], &args[1], &args[2]))?;
242            Ok(vec![result])
243        };
244        let state = CompiledState { f, shapeless, id };
245        Compiled {
246            f_marker: PhantomData::<F>,
247            state,
248        }
249    }
250}
251
252/// A trait for functions that can be called with state.
253pub trait CallMutWithState<U, A, O, E> {
254    /// Call the function with the given state and arguments.
255    fn call_mut(&mut self, state: &mut U, args: A) -> Result<O, Exception>;
256}
257
258impl<U, F, G> CallMutWithState<U, &[Array], Vec<Array>, ()> for Compiled<F, G>
259where
260    F: FnMut(&mut U, &[Array]) -> Vec<Array>,
261    G: FnMut(&mut U, &[Array]) -> Vec<Array>,
262    U: Updatable,
263{
264    fn call_mut(&mut self, state: &mut U, args: &[Array]) -> Result<Vec<Array>, Exception> {
265        self.state.retry_call_mut_with_state(state, args)
266    }
267}
268
269impl<U, F, G> CallMutWithState<U, &Array, Array, ()> for Compiled<F, G>
270where
271    F: FnMut(&mut U, &Array) -> Array,
272    G: FnMut(&mut U, &[Array]) -> Vec<Array>,
273    U: Updatable,
274{
275    fn call_mut(&mut self, state: &mut U, args: &Array) -> Result<Array, Exception> {
276        let args = std::slice::from_ref(args);
277        let result = self.state.retry_call_mut_with_state(state, args)?;
278        Ok(result.into_iter().next().unwrap())
279    }
280}
281
282impl<U, F, G> CallMutWithState<U, (&Array, &Array), Array, ()> for Compiled<F, G>
283where
284    F: FnMut(&mut U, (&Array, &Array)) -> Array,
285    G: FnMut(&mut U, &[Array]) -> Vec<Array>,
286    U: Updatable,
287{
288    fn call_mut(&mut self, state: &mut U, args: (&Array, &Array)) -> Result<Array, Exception> {
289        let args = &[args.0, args.1];
290        let result = self.state.retry_call_mut_with_state(state, args)?;
291        Ok(result.into_iter().next().unwrap())
292    }
293}
294
295impl<U, F, G> CallMutWithState<U, (&Array, &Array, &Array), Array, ()> for Compiled<F, G>
296where
297    F: FnMut(&mut U, (&Array, &Array, &Array)) -> Array,
298    G: FnMut(&mut U, &[Array]) -> Vec<Array>,
299    U: Updatable,
300{
301    fn call_mut(
302        &mut self,
303        state: &mut U,
304        args: (&Array, &Array, &Array),
305    ) -> Result<Array, Exception> {
306        let args = &[args.0, args.1, args.2];
307        let result = self.state.retry_call_mut_with_state(state, args)?;
308        Ok(result.into_iter().next().unwrap())
309    }
310}
311
312impl<U, F, G> CallMutWithState<U, &[Array], Vec<Array>, Exception> for Compiled<F, G>
313where
314    F: FnMut(&mut U, &[Array]) -> Result<Vec<Array>, Exception>,
315    G: FnMut(&mut U, &[Array]) -> Result<Vec<Array>, Exception>,
316    U: Updatable,
317{
318    fn call_mut(&mut self, state: &mut U, args: &[Array]) -> Result<Vec<Array>, Exception> {
319        self.state.retry_fallible_call_mut_with_state(state, args)
320    }
321}
322
323impl<U, F, G> CallMutWithState<U, &Array, Array, Exception> for Compiled<F, G>
324where
325    F: FnMut(&mut U, &Array) -> Result<Array, Exception>,
326    G: FnMut(&mut U, &[Array]) -> Result<Vec<Array>, Exception>,
327    U: Updatable,
328{
329    fn call_mut(&mut self, state: &mut U, args: &Array) -> Result<Array, Exception> {
330        let args = std::slice::from_ref(args);
331        let result = self.state.retry_fallible_call_mut_with_state(state, args)?;
332        Ok(result.into_iter().next().unwrap())
333    }
334}
335
336impl<U, F, G> CallMutWithState<U, (&Array, &Array), Array, Exception> for Compiled<F, G>
337where
338    F: FnMut(&mut U, (&Array, &Array)) -> Result<Array, Exception>,
339    G: FnMut(&mut U, &[Array]) -> Result<Vec<Array>, Exception>,
340    U: Updatable,
341{
342    fn call_mut(&mut self, state: &mut U, args: (&Array, &Array)) -> Result<Array, Exception> {
343        let args = &[args.0, args.1];
344        let result = self.state.retry_fallible_call_mut_with_state(state, args)?;
345        Ok(result.into_iter().next().unwrap())
346    }
347}
348
349impl<U, F, G> CallMutWithState<U, (&Array, &Array, &Array), Array, Exception> for Compiled<F, G>
350where
351    F: FnMut(&mut U, (&Array, &Array, &Array)) -> Result<Array, Exception>,
352    G: FnMut(&mut U, &[Array]) -> Result<Vec<Array>, Exception>,
353    U: Updatable,
354{
355    fn call_mut(
356        &mut self,
357        state: &mut U,
358        args: (&Array, &Array, &Array),
359    ) -> Result<Array, Exception> {
360        let args = &[args.0, args.1, args.2];
361        let result = self.state.retry_fallible_call_mut_with_state(state, args)?;
362        Ok(result.into_iter().next().unwrap())
363    }
364}
365
366#[inline]
367fn call_mut_with_state_inner<U>(
368    inner_closure: Closure,
369    fun_id: usize,
370    shapeless: bool,
371    state: Rc<RefCell<&mut U>>,
372    args: &[impl AsRef<Array>],
373) -> crate::error::Result<Vec<Array>>
374where
375    U: Updatable,
376{
377    // note: this will use the cached compile (via the id)
378    // but will be able to re-evaluate with fresh state if needed
379    let compiled = Closure::try_from_op(|res| unsafe {
380        let constants = &[];
381        mlx_sys::mlx_detail_compile(
382            res,
383            inner_closure.as_ptr(),
384            fun_id,
385            shapeless,
386            constants.as_ptr(),
387            0,
388        )
389    })?;
390
391    let (state_params_len, inner_inputs_vector) = {
392        let borrow = state.borrow();
393        let state_params: Vec<_> = borrow.updatable_states().into_iter().collect();
394        let state_params_len = state_params.len();
395        let inner_inputs_vector = VectorArray::try_from_iter(
396            args.iter()
397                .map(AsRef::as_ref)
398                .chain(state_params.into_iter()),
399        )?;
400        (state_params_len, inner_inputs_vector)
401    };
402
403    // will compile the function (if needed) and evaluate the
404    // compiled graph
405    let result_vector = VectorArray::try_from_op(|res| unsafe {
406        mlx_sys::mlx_closure_apply(res, compiled.as_ptr(), inner_inputs_vector.as_ptr())
407    })?;
408    let result_plus_state_output: Vec<Array> = result_vector.try_into_values()?;
409
410    // push the stateOutput into the state
411    let result_plus_state_output_len = result_plus_state_output.len();
412    let suffix_len = result_plus_state_output_len - state_params_len;
413    for (s, new_values) in state
414        .borrow_mut()
415        .updatable_states_mut()
416        .into_iter()
417        .zip(result_plus_state_output[suffix_len..].iter())
418    {
419        update_by_replace_with_ref_to_new_array(s, new_values);
420    }
421
422    let result_len = result_plus_state_output.len() - state_params_len;
423    Ok(result_plus_state_output
424        .into_iter()
425        .take(result_len)
426        .collect())
427}
428
429impl<F> CompiledState<F> {
430    fn retry_call_mut_with_state<U>(
431        &mut self,
432        state: &mut U,
433        args: &[impl AsRef<Array>],
434    ) -> Result<Vec<Array>, Exception>
435    where
436        F: FnMut(&mut U, &[Array]) -> Vec<Array>,
437        U: Updatable,
438    {
439        self.call_mut_with_state(state, args).or_else(|_e| {
440            // Somehow the mlx_closure_apply may fail on the first call for
441            // certain types of state with the error message:
442            // "unordered_map::at: key not found", so we just try again.
443            //
444            // One type that is known to cause this is a tuple of
445            // `Module` and `Optimizer` eg. `(<Module>, <Optimizer>)`
446            self.call_mut_with_state(state, args)
447        })
448    }
449
450    fn retry_fallible_call_mut_with_state<U>(
451        &mut self,
452        state: &mut U,
453        args: &[impl AsRef<Array>],
454    ) -> Result<Vec<Array>, Exception>
455    where
456        F: FnMut(&mut U, &[Array]) -> Result<Vec<Array>, Exception>,
457        U: Updatable,
458    {
459        self.fallible_call_mut_with_state(state, args)
460            .or_else(|_e| {
461                // Somehow the mlx_closure_apply may fail on the first call for
462                // certain types of state with the error message:
463                // "unordered_map::at: key not found", so we just try again.
464                //
465                // One type that is known to cause this is a tuple of
466                // `Module` and `Optimizer` eg. `(<Module>, <Optimizer>)`
467                self.fallible_call_mut_with_state(state, args)
468            })
469    }
470
471    fn call_mut_with_state<U>(
472        &mut self,
473        state: &mut U,
474        args: &[impl AsRef<Array>],
475    ) -> Result<Vec<Array>, Exception>
476    where
477        F: FnMut(&mut U, &[Array]) -> Vec<Array>,
478        U: Updatable,
479    {
480        let args_len = args.len();
481        let state = Rc::new(RefCell::new(state));
482        let f = &mut self.f;
483
484        let state_clone = Rc::clone(&state);
485        let inner = move |tracers: &[Array]| -> Vec<Array> {
486            // put the tracers in their appropriate places:
487            // - arguments to the function
488            // - inner state
489
490            let tracer_args = &tracers[..args_len];
491
492            // save a snapshot of the inner state
493            let saved_state_inputs = state_clone
494                .borrow()
495                .updatable_states()
496                .into_iter()
497                .map(|array| (*array).clone())
498                .collect::<Vec<Array>>();
499
500            // replace the inner state with the tracers
501            for (s, tracer) in state_clone
502                .borrow_mut()
503                .updatable_states_mut()
504                .into_iter()
505                .zip(tracers.iter().skip(args_len))
506            {
507                update_by_replace_with_ref_to_new_array(s, tracer);
508            }
509
510            // call the function with the tracer arguments and the state holding tracers
511            let mut result = (f)(*state_clone.borrow_mut(), tracer_args);
512
513            // recapture the state as it may have changed
514            let mut state_output_tracers = state_clone
515                .borrow()
516                .updatable_states()
517                .into_iter()
518                .map(|array| (*array).clone())
519                .collect::<Vec<Array>>();
520
521            // put the original values back in the state
522            for (s, saved) in state_clone
523                .borrow_mut()
524                .updatable_states_mut()
525                .into_iter()
526                .zip(saved_state_inputs)
527            {
528                update_by_replace_with_ref_to_new_array(s, &saved);
529            }
530
531            // return the result of the function and the state
532            result.append(&mut state_output_tracers);
533
534            result
535        };
536
537        let inner_closure = Closure::new(inner);
538        call_mut_with_state_inner(inner_closure, self.id, self.shapeless, state, args)
539    }
540
541    fn fallible_call_mut_with_state<U>(
542        &mut self,
543        state: &mut U,
544        args: &[impl AsRef<Array>],
545    ) -> Result<Vec<Array>, Exception>
546    where
547        F: FnMut(&mut U, &[Array]) -> Result<Vec<Array>, Exception>,
548        U: Updatable,
549    {
550        let args_len = args.len();
551        let state = Rc::new(RefCell::new(state));
552        let f = &mut self.f;
553
554        let state_clone = Rc::clone(&state);
555        let inner = move |tracers: &[Array]| -> Result<Vec<Array>, Exception> {
556            // put the tracers in their appropriate places:
557            // - arguments to the function
558            // - inner state
559
560            let tracer_args = &tracers[..args_len];
561
562            // save a snapshot of the inner state
563            let saved_state_inputs = state_clone
564                .borrow()
565                .updatable_states()
566                .into_iter()
567                .map(|array| (*array).clone())
568                .collect::<Vec<Array>>();
569
570            // replace the inner state with the tracers
571            for (s, tracer) in state_clone
572                .borrow_mut()
573                .updatable_states_mut()
574                .into_iter()
575                .zip(tracers.iter().skip(args_len))
576            {
577                update_by_replace_with_ref_to_new_array(s, tracer);
578            }
579
580            // call the function with the tracer arguments and the state holding tracers
581            let mut result = (f)(*state_clone.borrow_mut(), tracer_args)?;
582
583            // recapture the state as it may have changed
584            let mut state_output_tracers = state_clone
585                .borrow()
586                .updatable_states()
587                .into_iter()
588                .map(|array| (*array).clone())
589                .collect::<Vec<Array>>();
590
591            // put the original values back in the state
592            for (s, saved) in state_clone
593                .borrow_mut()
594                .updatable_states_mut()
595                .into_iter()
596                .zip(saved_state_inputs)
597            {
598                update_by_replace_with_ref_to_new_array(s, &saved);
599            }
600
601            // return the result of the function and the state
602            result.append(&mut state_output_tracers);
603
604            Ok(result)
605        };
606
607        let inner_closure = Closure::new_fallible(inner);
608        call_mut_with_state_inner(inner_closure, self.id, self.shapeless, state, args)
609    }
610}