mlx_rs/utils/
mod.rs

1//! Utility functions and types.
2
3use guard::Guarded;
4use mlx_sys::mlx_vector_array;
5
6use crate::error::set_closure_error;
7use crate::module::ModuleParameters;
8use crate::{complex64, error::Exception, Array, FromNested};
9use std::collections::HashMap;
10use std::{marker::PhantomData, rc::Rc};
11
12/// Success status code from the c binding
13pub(crate) const SUCCESS: i32 = 0;
14pub(crate) const FAILURE: i32 = 1;
15
16pub(crate) mod guard;
17pub(crate) mod io;
18
19pub(crate) fn resolve_index_signed_unchecked(index: i32, len: i32) -> i32 {
20    if index < 0 {
21        len.saturating_add(index)
22    } else {
23        index
24    }
25}
26
27pub(crate) fn resolve_index_unchecked(index: i32, len: usize) -> usize {
28    if index.is_negative() {
29        (len as i32 + index) as usize
30    } else {
31        index as usize
32    }
33}
34
35/// Helper method to convert an optional slice of axes to a Vec covering all axes.
36pub(crate) fn axes_or_default_to_all<'a>(axes: impl IntoOption<&'a [i32]>, ndim: i32) -> Vec<i32> {
37    match axes.into_option() {
38        Some(axes) => axes.to_vec(),
39        None => {
40            let axes: Vec<i32> = (0..ndim).collect();
41            axes
42        }
43    }
44}
45
46pub(crate) struct VectorArray {
47    c_vec: mlx_sys::mlx_vector_array,
48}
49
50impl VectorArray {
51    pub(crate) fn as_ptr(&self) -> mlx_sys::mlx_vector_array {
52        self.c_vec
53    }
54
55    pub(crate) unsafe fn from_ptr(c_vec: mlx_sys::mlx_vector_array) -> Self {
56        Self { c_vec }
57    }
58
59    pub(crate) fn try_from_iter(
60        iter: impl Iterator<Item = impl AsRef<Array>>,
61    ) -> Result<Self, Exception> {
62        VectorArray::try_from_op(|res| unsafe {
63            let mut status = SUCCESS;
64            for arr in iter {
65                status = mlx_sys::mlx_vector_array_append_value(*res, arr.as_ref().as_ptr());
66                if status != SUCCESS {
67                    return status;
68                }
69            }
70            status
71        })
72    }
73
74    pub(crate) fn try_into_values<T>(self) -> Result<T, Exception>
75    where
76        T: FromIterator<Array>,
77    {
78        unsafe {
79            let size = mlx_sys::mlx_vector_array_size(self.c_vec);
80            (0..size)
81                .map(|i| {
82                    Array::try_from_op(|res| mlx_sys::mlx_vector_array_get(res, self.c_vec, i))
83                })
84                .collect::<Result<T, Exception>>()
85        }
86    }
87}
88
89impl Drop for VectorArray {
90    fn drop(&mut self) {
91        let status = unsafe { mlx_sys::mlx_vector_array_free(self.c_vec) };
92        debug_assert_eq!(status, SUCCESS);
93    }
94}
95
96/// A helper trait that is just like `Into<Option<T>>` but improves ergonomics by allowing
97/// implicit conversion from &[T; N] to &[T].
98pub trait IntoOption<T> {
99    /// Convert into an [`Option`].
100    fn into_option(self) -> Option<T>;
101}
102
103impl<T> IntoOption<T> for Option<T> {
104    fn into_option(self) -> Option<T> {
105        self
106    }
107}
108
109impl<T> IntoOption<T> for T {
110    fn into_option(self) -> Option<T> {
111        Some(self)
112    }
113}
114
115impl<'a, T, const N: usize> IntoOption<&'a [T]> for &'a [T; N] {
116    fn into_option(self) -> Option<&'a [T]> {
117        Some(self)
118    }
119}
120
121impl<'a, T> IntoOption<&'a [T]> for &'a Vec<T> {
122    fn into_option(self) -> Option<&'a [T]> {
123        Some(self)
124    }
125}
126
127/// A trait for a scalar or an array.
128pub trait ScalarOrArray<'a> {
129    /// The reference type of the array.
130    type Array: AsRef<Array> + 'a;
131
132    /// Convert to an owned or reference array.
133    fn into_owned_or_ref_array(self) -> Self::Array;
134}
135
136impl ScalarOrArray<'_> for Array {
137    type Array = Array;
138
139    fn into_owned_or_ref_array(self) -> Array {
140        self
141    }
142}
143
144impl<'a> ScalarOrArray<'a> for &'a Array {
145    type Array = &'a Array;
146
147    // TODO: clippy would complain about `as_array`. Is there a better name?
148    fn into_owned_or_ref_array(self) -> &'a Array {
149        self
150    }
151}
152
153impl ScalarOrArray<'static> for bool {
154    type Array = Array;
155
156    fn into_owned_or_ref_array(self) -> Array {
157        Array::from_bool(self)
158    }
159}
160
161impl ScalarOrArray<'static> for i32 {
162    type Array = Array;
163
164    fn into_owned_or_ref_array(self) -> Array {
165        Array::from_int(self)
166    }
167}
168
169impl ScalarOrArray<'static> for f32 {
170    type Array = Array;
171
172    fn into_owned_or_ref_array(self) -> Array {
173        Array::from_f32(self)
174    }
175}
176
177// TODO: this is bugged right now. See https://github.com/ml-explore/mlx/issues/1994
178// impl ScalarOrArray<'static> for f64 {
179//     type Array = Array;
180
181//     fn into_owned_or_ref_array(self) -> Array {
182//         Array::from_f64(self)
183//     }
184// }
185
186impl ScalarOrArray<'static> for complex64 {
187    type Array = Array;
188
189    fn into_owned_or_ref_array(self) -> Array {
190        Array::from_complex(self)
191    }
192}
193
194impl<T> ScalarOrArray<'static> for T
195where
196    Array: FromNested<T>,
197{
198    type Array = Array;
199
200    fn into_owned_or_ref_array(self) -> Array {
201        Array::from_nested(self)
202    }
203}
204
205#[derive(Debug)]
206pub(crate) struct Closure<'a> {
207    c_closure: mlx_sys::mlx_closure,
208    lt_marker: PhantomData<&'a ()>,
209}
210
211impl<'a> Closure<'a> {
212    pub(crate) fn as_ptr(&self) -> mlx_sys::mlx_closure {
213        self.c_closure
214    }
215
216    pub(crate) fn new<F>(closure: F) -> Self
217    where
218        F: FnMut(&[Array]) -> Vec<Array> + 'a,
219    {
220        let c_closure = new_mlx_closure(closure);
221        Self {
222            c_closure,
223            lt_marker: PhantomData,
224        }
225    }
226
227    pub(crate) fn new_fallible<F>(closure: F) -> Self
228    where
229        F: FnMut(&[Array]) -> Result<Vec<Array>, Exception> + 'a,
230    {
231        let c_closure = new_mlx_fallible_closure(closure);
232        Self {
233            c_closure,
234            lt_marker: PhantomData,
235        }
236    }
237}
238
239impl Drop for Closure<'_> {
240    fn drop(&mut self) {
241        let status = unsafe { mlx_sys::mlx_closure_free(self.c_closure) };
242        debug_assert_eq!(status, SUCCESS);
243    }
244}
245
246/// Helper method to create a mlx_closure from a Rust closure.
247fn new_mlx_closure<'a, F>(closure: F) -> mlx_sys::mlx_closure
248where
249    F: FnMut(&[Array]) -> Vec<Array> + 'a,
250{
251    // Box the closure to keep it on the heap
252    let boxed = Box::new(closure);
253
254    // Create a raw pointer from the Box, transferring ownership to C
255    let raw = Box::into_raw(boxed);
256    let payload = raw as *mut std::ffi::c_void;
257
258    unsafe {
259        mlx_sys::mlx_closure_new_func_payload(
260            Some(trampoline::<F>),
261            payload,
262            Some(closure_dtor::<F>),
263        )
264    }
265}
266
267fn new_mlx_fallible_closure<'a, F>(closure: F) -> mlx_sys::mlx_closure
268where
269    F: FnMut(&[Array]) -> Result<Vec<Array>, Exception> + 'a,
270{
271    let boxed = Box::new(closure);
272    let raw = Box::into_raw(boxed);
273    let payload = raw as *mut std::ffi::c_void;
274
275    unsafe {
276        mlx_sys::mlx_closure_new_func_payload(
277            Some(trampoline_fallible::<F>),
278            payload,
279            Some(closure_dtor::<F>),
280        )
281    }
282}
283
284/// Function to create a new (+1 reference) mlx_vector_array from a vector of Array
285fn new_mlx_vector_array(arrays: Vec<Array>) -> mlx_sys::mlx_vector_array {
286    unsafe {
287        let result = mlx_sys::mlx_vector_array_new();
288        let ctx_ptrs: Vec<mlx_sys::mlx_array> = arrays.iter().map(|array| array.as_ptr()).collect();
289        mlx_sys::mlx_vector_array_append_data(result, ctx_ptrs.as_ptr(), arrays.len());
290        result
291    }
292}
293
294fn mlx_vector_array_values(
295    vector_array: mlx_sys::mlx_vector_array,
296) -> Result<Vec<Array>, Exception> {
297    unsafe {
298        let size = mlx_sys::mlx_vector_array_size(vector_array);
299        (0..size)
300            .map(|index| {
301                Array::try_from_op(|res| mlx_sys::mlx_vector_array_get(res, vector_array, index))
302            })
303            .collect()
304    }
305}
306
307extern "C" fn trampoline<'a, F>(
308    ret: *mut mlx_vector_array,
309    vector_array: mlx_vector_array,
310    payload: *mut std::ffi::c_void,
311) -> i32
312where
313    F: FnMut(&[Array]) -> Vec<Array> + 'a,
314{
315    unsafe {
316        let raw_closure: *mut F = payload as *mut _;
317        // Let the box take care of freeing the closure
318        let mut closure = Box::from_raw(raw_closure);
319        let arrays = match mlx_vector_array_values(vector_array) {
320            Ok(arrays) => arrays,
321            Err(_) => {
322                let _ = Box::into_raw(closure); // prevent premature drop
323                return FAILURE;
324            }
325        };
326        let result = closure(&arrays);
327        let _ = Box::into_raw(closure); // prevent premature drop
328
329        // We should probably keep using new_mlx_vector_array here instead of VectorArray
330        // since we probably don't want to drop the arrays in the closure
331        *ret = new_mlx_vector_array(result);
332
333        SUCCESS
334    }
335}
336
337extern "C" fn trampoline_fallible<'a, F>(
338    ret: *mut mlx_vector_array,
339    vector_array: mlx_vector_array,
340    payload: *mut std::ffi::c_void,
341) -> i32
342where
343    F: FnMut(&[Array]) -> Result<Vec<Array>, Exception> + 'a,
344{
345    unsafe {
346        let raw_closure: *mut F = payload as *mut _;
347        let mut closure = Box::from_raw(raw_closure);
348        let arrays = match mlx_vector_array_values(vector_array) {
349            Ok(arrays) => arrays,
350            Err(e) => {
351                let _ = Box::into_raw(closure); // prevent premature drop
352                set_closure_error(e);
353                return FAILURE;
354            }
355        };
356        let result = closure(&arrays);
357        let _ = Box::into_raw(closure); // prevent premature drop
358
359        match result {
360            Ok(result) => {
361                *ret = new_mlx_vector_array(result);
362                SUCCESS
363            }
364            Err(err) => {
365                set_closure_error(err);
366                FAILURE
367            }
368        }
369    }
370}
371
372// extern "C" fn noop_dtor(_data: *mut std::ffi::c_void) {}
373
374extern "C" fn closure_dtor<F>(payload: *mut std::ffi::c_void) {
375    if payload.is_null() {
376        return;
377    }
378    unsafe {
379        drop(Box::from_raw(payload as *mut F));
380    }
381}
382
383pub(crate) fn get_mut_or_insert_with<'a, T>(
384    map: &'a mut HashMap<Rc<str>, T>,
385    key: &Rc<str>,
386    f: impl FnOnce() -> T,
387) -> &'a mut T {
388    if !map.contains_key(key) {
389        map.insert(key.clone(), f());
390    }
391
392    map.get_mut(key).unwrap()
393}
394
395/// Helper trait for compiling a function that takes a Module and/or an Optimizer.
396/// The implementation must ensure consistent ordering of the returned states.
397///
398/// This is automatically implemented for all types that implement ModuleParameters.
399pub trait Updatable {
400    /// Returns the number of updatable states.
401    ///
402    /// The number should be the same as calling `self.updatable_states().len()` but
403    /// this method should be more efficient in general. The implementation should
404    /// avoid iterating over the states if possible.
405    fn updatable_states_len(&self) -> usize;
406
407    /// Returns a list of references to the updatable states.
408    ///
409    /// The order of the states should be consistent across calls and should be the same as the
410    /// order of the states returned by [`Updatable::updatable_states_mut`].
411    fn updatable_states(&self) -> impl IntoIterator<Item = &Array>;
412
413    /// Returns a list of mutable references to the updatable states.
414    ///
415    /// The order of the states should be consistent across calls and should be the same as the
416    /// order of the states returned by [`Updatable::updatable_states`].
417    fn updatable_states_mut(&mut self) -> impl IntoIterator<Item = &mut Array>;
418}
419
420impl<T> Updatable for T
421where
422    T: ModuleParameters,
423{
424    fn updatable_states_len(&self) -> usize {
425        self.num_parameters()
426    }
427
428    fn updatable_states(&self) -> impl IntoIterator<Item = &Array> {
429        use itertools::Itertools;
430
431        // TODO: should we change the parameter map to a BTreeMap because it is sorted?
432        self.parameters()
433            .flatten()
434            .into_iter()
435            .sorted_by(|a, b| a.0.cmp(&b.0))
436            .map(|(_, v)| v)
437    }
438
439    fn updatable_states_mut(&mut self) -> impl IntoIterator<Item = &mut Array> {
440        use itertools::Itertools;
441
442        self.parameters_mut()
443            .flatten()
444            .into_iter()
445            .sorted_by(|a, b| a.0.cmp(&b.0))
446            .map(|(_, v)| v)
447    }
448}
449
450impl<T1, T2> Updatable for (T1, T2)
451where
452    T1: Updatable,
453    T2: Updatable,
454{
455    fn updatable_states_len(&self) -> usize {
456        let (a, b) = self;
457        a.updatable_states_len() + b.updatable_states_len()
458    }
459
460    fn updatable_states(&self) -> impl IntoIterator<Item = &Array> {
461        let (a, b) = self;
462        let params = a.updatable_states();
463        params.into_iter().chain(b.updatable_states())
464    }
465
466    fn updatable_states_mut(&mut self) -> impl IntoIterator<Item = &mut Array> {
467        let (a, b) = self;
468        let params = a.updatable_states_mut();
469        params.into_iter().chain(b.updatable_states_mut())
470    }
471}
472
473impl Updatable for Vec<Array> {
474    fn updatable_states_len(&self) -> usize {
475        self.len()
476    }
477
478    fn updatable_states(&self) -> impl IntoIterator<Item = &Array> {
479        self.iter()
480    }
481
482    fn updatable_states_mut(&mut self) -> impl IntoIterator<Item = &mut Array> {
483        self.iter_mut()
484    }
485}
486
487/// Helper type to represent either a single value or a pair of values.
488#[derive(Debug, Clone, Copy, PartialEq, Eq)]
489pub enum SingleOrPair<T = i32> {
490    /// Single value.
491    Single(T),
492
493    /// Pair of values.
494    Pair(T, T),
495}
496
497impl<T: Clone> SingleOrPair<T> {
498    /// Returns the first value.
499    pub fn first(&self) -> T {
500        match self {
501            SingleOrPair::Single(v) => v.clone(),
502            SingleOrPair::Pair(v1, _) => v1.clone(),
503        }
504    }
505
506    /// Returns the second value.
507    pub fn second(&self) -> T {
508        match self {
509            SingleOrPair::Single(v) => v.clone(),
510            SingleOrPair::Pair(_, v2) => v2.clone(),
511        }
512    }
513}
514
515impl<T> From<T> for SingleOrPair<T> {
516    fn from(value: T) -> Self {
517        SingleOrPair::Single(value)
518    }
519}
520
521impl<T> From<(T, T)> for SingleOrPair<T> {
522    fn from(value: (T, T)) -> Self {
523        SingleOrPair::Pair(value.0, value.1)
524    }
525}
526
527impl<T: Clone> From<SingleOrPair<T>> for (T, T) {
528    fn from(value: SingleOrPair<T>) -> Self {
529        match value {
530            SingleOrPair::Single(v) => (v.clone(), v),
531            SingleOrPair::Pair(v1, v2) => (v1, v2),
532        }
533    }
534}
535
536/// Helper type to represent either a single value or a triple of values.
537#[derive(Debug, Clone, Copy, PartialEq, Eq)]
538pub enum SingleOrTriple<T = i32> {
539    /// Single value.
540    Single(T),
541
542    /// Triple of values.
543    Triple(T, T, T),
544}
545
546impl<T: Clone> SingleOrTriple<T> {
547    /// Returns the first value.
548    pub fn first(&self) -> T {
549        match self {
550            SingleOrTriple::Single(v) => v.clone(),
551            SingleOrTriple::Triple(v1, _, _) => v1.clone(),
552        }
553    }
554
555    /// Returns the second value.
556    pub fn second(&self) -> T {
557        match self {
558            SingleOrTriple::Single(v) => v.clone(),
559            SingleOrTriple::Triple(_, v2, _) => v2.clone(),
560        }
561    }
562
563    /// Returns the third value.
564    pub fn third(&self) -> T {
565        match self {
566            SingleOrTriple::Single(v) => v.clone(),
567            SingleOrTriple::Triple(_, _, v3) => v3.clone(),
568        }
569    }
570}
571
572impl<T> From<T> for SingleOrTriple<T> {
573    fn from(value: T) -> Self {
574        SingleOrTriple::Single(value)
575    }
576}
577
578impl<T> From<(T, T, T)> for SingleOrTriple<T> {
579    fn from(value: (T, T, T)) -> Self {
580        SingleOrTriple::Triple(value.0, value.1, value.2)
581    }
582}
583
584impl<T: Clone> From<SingleOrTriple<T>> for (T, T, T) {
585    fn from(value: SingleOrTriple<T>) -> Self {
586        match value {
587            SingleOrTriple::Single(v) => (v.clone(), v.clone(), v),
588            SingleOrTriple::Triple(v1, v2, v3) => (v1, v2, v3),
589        }
590    }
591}
592
593/// Helper type to represent either a single value or a vector of values.
594#[derive(Debug, Clone, PartialEq, Eq)]
595pub enum SingleOrVec<T> {
596    /// Single value.
597    Single(T),
598
599    /// Vector of values.
600    Vec(Vec<T>),
601}
602
603impl<T> From<T> for SingleOrVec<T> {
604    fn from(value: T) -> Self {
605        SingleOrVec::Single(value)
606    }
607}
608
609impl<T> From<Vec<T>> for SingleOrVec<T> {
610    fn from(value: Vec<T>) -> Self {
611        SingleOrVec::Vec(value)
612    }
613}