mlx_rs/utils/
mod.rs

1//! Utility functions and types.
2
3use guard::Guarded;
4use mlx_sys::mlx_vector_array;
5
6use crate::error::set_closure_error;
7use crate::module::ModuleParameters;
8use crate::{complex64, error::Exception, Array, FromNested};
9use std::collections::HashMap;
10use std::{marker::PhantomData, rc::Rc};
11
12/// Success status code from the c binding
13pub(crate) const SUCCESS: i32 = 0;
14pub(crate) const FAILURE: i32 = 1;
15
16pub(crate) mod guard;
17pub(crate) mod io;
18
19pub(crate) fn resolve_index_signed_unchecked(index: i32, len: i32) -> i32 {
20    if index < 0 {
21        len.saturating_add(index)
22    } else {
23        index
24    }
25}
26
27pub(crate) fn resolve_index_unchecked(index: i32, len: usize) -> usize {
28    if index.is_negative() {
29        (len as i32 + index) as usize
30    } else {
31        index as usize
32    }
33}
34
35/// Helper method to convert an optional slice of axes to a Vec covering all axes.
36pub(crate) fn axes_or_default_to_all<'a>(axes: impl IntoOption<&'a [i32]>, ndim: i32) -> Vec<i32> {
37    match axes.into_option() {
38        Some(axes) => axes.to_vec(),
39        None => {
40            let axes: Vec<i32> = (0..ndim).collect();
41            axes
42        }
43    }
44}
45
46pub(crate) struct VectorArray {
47    c_vec: mlx_sys::mlx_vector_array,
48}
49
50impl VectorArray {
51    pub(crate) fn as_ptr(&self) -> mlx_sys::mlx_vector_array {
52        self.c_vec
53    }
54
55    pub(crate) unsafe fn from_ptr(c_vec: mlx_sys::mlx_vector_array) -> Self {
56        Self { c_vec }
57    }
58
59    pub(crate) fn try_from_iter(
60        iter: impl Iterator<Item = impl AsRef<Array>>,
61    ) -> Result<Self, Exception> {
62        VectorArray::try_from_op(|res| unsafe {
63            let mut status = SUCCESS;
64            for arr in iter {
65                status = mlx_sys::mlx_vector_array_append_value(*res, arr.as_ref().as_ptr());
66                if status != SUCCESS {
67                    return status;
68                }
69            }
70            status
71        })
72    }
73
74    pub(crate) fn try_into_values<T>(self) -> Result<T, Exception>
75    where
76        T: FromIterator<Array>,
77    {
78        unsafe {
79            let size = mlx_sys::mlx_vector_array_size(self.c_vec);
80            (0..size)
81                .map(|i| {
82                    Array::try_from_op(|res| mlx_sys::mlx_vector_array_get(res, self.c_vec, i))
83                })
84                .collect::<Result<T, Exception>>()
85        }
86    }
87}
88
89impl Drop for VectorArray {
90    fn drop(&mut self) {
91        let status = unsafe { mlx_sys::mlx_vector_array_free(self.c_vec) };
92        debug_assert_eq!(status, SUCCESS);
93    }
94}
95
96/// A helper trait that is just like `Into<Option<T>>` but improves ergonomics by allowing
97/// implicit conversion from &[T; N] to &[T].
98pub trait IntoOption<T> {
99    /// Convert into an [`Option`].
100    fn into_option(self) -> Option<T>;
101}
102
103impl<T> IntoOption<T> for Option<T> {
104    fn into_option(self) -> Option<T> {
105        self
106    }
107}
108
109impl<T> IntoOption<T> for T {
110    fn into_option(self) -> Option<T> {
111        Some(self)
112    }
113}
114
115impl<'a, T, const N: usize> IntoOption<&'a [T]> for &'a [T; N] {
116    fn into_option(self) -> Option<&'a [T]> {
117        Some(self)
118    }
119}
120
121impl<'a, T> IntoOption<&'a [T]> for &'a Vec<T> {
122    fn into_option(self) -> Option<&'a [T]> {
123        Some(self)
124    }
125}
126
127/// A trait for a scalar or an array.
128pub trait ScalarOrArray<'a> {
129    /// The reference type of the array.
130    type Array: AsRef<Array> + 'a;
131
132    /// Convert to an owned or reference array.
133    fn into_owned_or_ref_array(self) -> Self::Array;
134}
135
136impl ScalarOrArray<'_> for Array {
137    type Array = Array;
138
139    fn into_owned_or_ref_array(self) -> Array {
140        self
141    }
142}
143
144impl<'a> ScalarOrArray<'a> for &'a Array {
145    type Array = &'a Array;
146
147    // TODO: clippy would complain about `as_array`. Is there a better name?
148    fn into_owned_or_ref_array(self) -> &'a Array {
149        self
150    }
151}
152
153impl ScalarOrArray<'static> for bool {
154    type Array = Array;
155
156    fn into_owned_or_ref_array(self) -> Array {
157        Array::from_bool(self)
158    }
159}
160
161impl ScalarOrArray<'static> for i32 {
162    type Array = Array;
163
164    fn into_owned_or_ref_array(self) -> Array {
165        Array::from_int(self)
166    }
167}
168
169impl ScalarOrArray<'static> for f32 {
170    type Array = Array;
171
172    fn into_owned_or_ref_array(self) -> Array {
173        Array::from_f32(self)
174    }
175}
176
177// TODO: this is bugged right now. See https://github.com/ml-explore/mlx/issues/1994
178// impl ScalarOrArray<'static> for f64 {
179//     type Array = Array;
180
181//     fn into_owned_or_ref_array(self) -> Array {
182//         Array::from_f64(self)
183//     }
184// }
185
186impl ScalarOrArray<'static> for complex64 {
187    type Array = Array;
188
189    fn into_owned_or_ref_array(self) -> Array {
190        Array::from_complex(self)
191    }
192}
193
194impl<T> ScalarOrArray<'static> for T
195where
196    Array: FromNested<T>,
197{
198    type Array = Array;
199
200    fn into_owned_or_ref_array(self) -> Array {
201        Array::from_nested(self)
202    }
203}
204
205#[derive(Debug)]
206pub(crate) struct Closure<'a> {
207    c_closure: mlx_sys::mlx_closure,
208    lt_marker: PhantomData<&'a ()>,
209}
210
211impl<'a> Closure<'a> {
212    pub(crate) fn as_ptr(&self) -> mlx_sys::mlx_closure {
213        self.c_closure
214    }
215
216    pub(crate) fn new<F>(closure: F) -> Self
217    where
218        F: FnMut(&[Array]) -> Vec<Array> + 'a,
219    {
220        let c_closure = new_mlx_closure(closure);
221        Self {
222            c_closure,
223            lt_marker: PhantomData,
224        }
225    }
226
227    pub(crate) fn new_fallible<F>(closure: F) -> Self
228    where
229        F: FnMut(&[Array]) -> Result<Vec<Array>, Exception> + 'a,
230    {
231        let c_closure = new_mlx_fallible_closure(closure);
232        Self {
233            c_closure,
234            lt_marker: PhantomData,
235        }
236    }
237}
238
239impl Drop for Closure<'_> {
240    fn drop(&mut self) {
241        let status = unsafe { mlx_sys::mlx_closure_free(self.c_closure) };
242        debug_assert_eq!(status, SUCCESS);
243    }
244}
245
246/// Helper method to create a mlx_closure from a Rust closure.
247fn new_mlx_closure<'a, F>(closure: F) -> mlx_sys::mlx_closure
248where
249    F: FnMut(&[Array]) -> Vec<Array> + 'a,
250{
251    // Box the closure to keep it on the heap
252    let boxed = Box::new(closure);
253
254    // Create a raw pointer from the Box, transferring ownership to C
255    let raw = Box::into_raw(boxed);
256    let payload = raw as *mut std::ffi::c_void;
257
258    unsafe {
259        mlx_sys::mlx_closure_new_func_payload(Some(trampoline::<F>), payload, Some(noop_dtor))
260    }
261}
262
263fn new_mlx_fallible_closure<'a, F>(closure: F) -> mlx_sys::mlx_closure
264where
265    F: FnMut(&[Array]) -> Result<Vec<Array>, Exception> + 'a,
266{
267    let boxed = Box::new(closure);
268    let raw = Box::into_raw(boxed);
269    let payload = raw as *mut std::ffi::c_void;
270
271    unsafe {
272        mlx_sys::mlx_closure_new_func_payload(
273            Some(trampoline_fallible::<F>),
274            payload,
275            Some(noop_dtor),
276        )
277    }
278}
279
280/// Function to create a new (+1 reference) mlx_vector_array from a vector of Array
281fn new_mlx_vector_array(arrays: Vec<Array>) -> mlx_sys::mlx_vector_array {
282    unsafe {
283        let result = mlx_sys::mlx_vector_array_new();
284        let ctx_ptrs: Vec<mlx_sys::mlx_array> = arrays.iter().map(|array| array.as_ptr()).collect();
285        mlx_sys::mlx_vector_array_append_data(result, ctx_ptrs.as_ptr(), arrays.len());
286        result
287    }
288}
289
290fn mlx_vector_array_values(
291    vector_array: mlx_sys::mlx_vector_array,
292) -> Result<Vec<Array>, Exception> {
293    unsafe {
294        let size = mlx_sys::mlx_vector_array_size(vector_array);
295        (0..size)
296            .map(|index| {
297                Array::try_from_op(|res| mlx_sys::mlx_vector_array_get(res, vector_array, index))
298            })
299            .collect()
300    }
301}
302
303extern "C" fn trampoline<'a, F>(
304    ret: *mut mlx_vector_array,
305    vector_array: mlx_vector_array,
306    payload: *mut std::ffi::c_void,
307) -> i32
308where
309    F: FnMut(&[Array]) -> Vec<Array> + 'a,
310{
311    unsafe {
312        let raw_closure: *mut F = payload as *mut _;
313        // Let the box take care of freeing the closure
314        let mut closure = Box::from_raw(raw_closure);
315        let arrays = match mlx_vector_array_values(vector_array) {
316            Ok(arrays) => arrays,
317            Err(_) => {
318                return FAILURE;
319            }
320        };
321        let result = closure(&arrays);
322        // We should probably keep using new_mlx_vector_array here instead of VectorArray
323        // since we probably don't want to drop the arrays in the closure
324        *ret = new_mlx_vector_array(result);
325
326        SUCCESS
327    }
328}
329
330extern "C" fn trampoline_fallible<'a, F>(
331    ret: *mut mlx_vector_array,
332    vector_array: mlx_vector_array,
333    payload: *mut std::ffi::c_void,
334) -> i32
335where
336    F: FnMut(&[Array]) -> Result<Vec<Array>, Exception> + 'a,
337{
338    unsafe {
339        let raw_closure: *mut F = payload as *mut _;
340        let mut closure = Box::from_raw(raw_closure);
341        let arrays = match mlx_vector_array_values(vector_array) {
342            Ok(arrays) => arrays,
343            Err(e) => {
344                set_closure_error(e);
345                return FAILURE;
346            }
347        };
348        let result = closure(&arrays);
349        match result {
350            Ok(result) => {
351                *ret = new_mlx_vector_array(result);
352                SUCCESS
353            }
354            Err(err) => {
355                set_closure_error(err);
356                FAILURE
357            }
358        }
359    }
360}
361
362extern "C" fn noop_dtor(_data: *mut std::ffi::c_void) {}
363
364pub(crate) fn get_mut_or_insert_with<'a, T>(
365    map: &'a mut HashMap<Rc<str>, T>,
366    key: &Rc<str>,
367    f: impl FnOnce() -> T,
368) -> &'a mut T {
369    if !map.contains_key(key) {
370        map.insert(key.clone(), f());
371    }
372
373    map.get_mut(key).unwrap()
374}
375
376/// Helper trait for compiling a function that takes a Module and/or an Optimizer.
377/// The implementation must ensure consistent ordering of the returned states.
378///
379/// This is automatically implemented for all types that implement ModuleParameters.
380pub trait Updatable {
381    /// Returns the number of updatable states.
382    ///
383    /// The number should be the same as calling `self.updatable_states().len()` but
384    /// this method should be more efficient in general. The implementation should
385    /// avoid iterating over the states if possible.
386    fn updatable_states_len(&self) -> usize;
387
388    /// Returns a list of references to the updatable states.
389    ///
390    /// The order of the states should be consistent across calls and should be the same as the
391    /// order of the states returned by [`Updatable::updatable_states_mut`].
392    fn updatable_states(&self) -> impl IntoIterator<Item = &Array>;
393
394    /// Returns a list of mutable references to the updatable states.
395    ///
396    /// The order of the states should be consistent across calls and should be the same as the
397    /// order of the states returned by [`Updatable::updatable_states`].
398    fn updatable_states_mut(&mut self) -> impl IntoIterator<Item = &mut Array>;
399}
400
401impl<T> Updatable for T
402where
403    T: ModuleParameters,
404{
405    fn updatable_states_len(&self) -> usize {
406        self.num_parameters()
407    }
408
409    fn updatable_states(&self) -> impl IntoIterator<Item = &Array> {
410        use itertools::Itertools;
411
412        // TODO: should we change the parameter map to a BTreeMap because it is sorted?
413        self.parameters()
414            .flatten()
415            .into_iter()
416            .sorted_by(|a, b| a.0.cmp(&b.0))
417            .map(|(_, v)| v)
418    }
419
420    fn updatable_states_mut(&mut self) -> impl IntoIterator<Item = &mut Array> {
421        use itertools::Itertools;
422
423        self.parameters_mut()
424            .flatten()
425            .into_iter()
426            .sorted_by(|a, b| a.0.cmp(&b.0))
427            .map(|(_, v)| v)
428    }
429}
430
431impl<T1, T2> Updatable for (T1, T2)
432where
433    T1: Updatable,
434    T2: Updatable,
435{
436    fn updatable_states_len(&self) -> usize {
437        let (a, b) = self;
438        a.updatable_states_len() + b.updatable_states_len()
439    }
440
441    fn updatable_states(&self) -> impl IntoIterator<Item = &Array> {
442        let (a, b) = self;
443        let params = a.updatable_states();
444        params.into_iter().chain(b.updatable_states())
445    }
446
447    fn updatable_states_mut(&mut self) -> impl IntoIterator<Item = &mut Array> {
448        let (a, b) = self;
449        let params = a.updatable_states_mut();
450        params.into_iter().chain(b.updatable_states_mut())
451    }
452}
453
454impl Updatable for Vec<Array> {
455    fn updatable_states_len(&self) -> usize {
456        self.len()
457    }
458
459    fn updatable_states(&self) -> impl IntoIterator<Item = &Array> {
460        self.iter()
461    }
462
463    fn updatable_states_mut(&mut self) -> impl IntoIterator<Item = &mut Array> {
464        self.iter_mut()
465    }
466}
467
468/// Helper type to represent either a single value or a pair of values.
469#[derive(Debug, Clone, Copy, PartialEq, Eq)]
470pub enum SingleOrPair<T = i32> {
471    /// Single value.
472    Single(T),
473
474    /// Pair of values.
475    Pair(T, T),
476}
477
478impl<T: Clone> SingleOrPair<T> {
479    /// Returns the first value.
480    pub fn first(&self) -> T {
481        match self {
482            SingleOrPair::Single(v) => v.clone(),
483            SingleOrPair::Pair(v1, _) => v1.clone(),
484        }
485    }
486
487    /// Returns the second value.
488    pub fn second(&self) -> T {
489        match self {
490            SingleOrPair::Single(v) => v.clone(),
491            SingleOrPair::Pair(_, v2) => v2.clone(),
492        }
493    }
494}
495
496impl<T> From<T> for SingleOrPair<T> {
497    fn from(value: T) -> Self {
498        SingleOrPair::Single(value)
499    }
500}
501
502impl<T> From<(T, T)> for SingleOrPair<T> {
503    fn from(value: (T, T)) -> Self {
504        SingleOrPair::Pair(value.0, value.1)
505    }
506}
507
508impl<T: Clone> From<SingleOrPair<T>> for (T, T) {
509    fn from(value: SingleOrPair<T>) -> Self {
510        match value {
511            SingleOrPair::Single(v) => (v.clone(), v),
512            SingleOrPair::Pair(v1, v2) => (v1, v2),
513        }
514    }
515}
516
517/// Helper type to represent either a single value or a triple of values.
518#[derive(Debug, Clone, Copy, PartialEq, Eq)]
519pub enum SingleOrTriple<T = i32> {
520    /// Single value.
521    Single(T),
522
523    /// Triple of values.
524    Triple(T, T, T),
525}
526
527impl<T: Clone> SingleOrTriple<T> {
528    /// Returns the first value.
529    pub fn first(&self) -> T {
530        match self {
531            SingleOrTriple::Single(v) => v.clone(),
532            SingleOrTriple::Triple(v1, _, _) => v1.clone(),
533        }
534    }
535
536    /// Returns the second value.
537    pub fn second(&self) -> T {
538        match self {
539            SingleOrTriple::Single(v) => v.clone(),
540            SingleOrTriple::Triple(_, v2, _) => v2.clone(),
541        }
542    }
543
544    /// Returns the third value.
545    pub fn third(&self) -> T {
546        match self {
547            SingleOrTriple::Single(v) => v.clone(),
548            SingleOrTriple::Triple(_, _, v3) => v3.clone(),
549        }
550    }
551}
552
553impl<T> From<T> for SingleOrTriple<T> {
554    fn from(value: T) -> Self {
555        SingleOrTriple::Single(value)
556    }
557}
558
559impl<T> From<(T, T, T)> for SingleOrTriple<T> {
560    fn from(value: (T, T, T)) -> Self {
561        SingleOrTriple::Triple(value.0, value.1, value.2)
562    }
563}
564
565impl<T: Clone> From<SingleOrTriple<T>> for (T, T, T) {
566    fn from(value: SingleOrTriple<T>) -> Self {
567        match value {
568            SingleOrTriple::Single(v) => (v.clone(), v.clone(), v),
569            SingleOrTriple::Triple(v1, v2, v3) => (v1, v2, v3),
570        }
571    }
572}
573
574/// Helper type to represent either a single value or a vector of values.
575#[derive(Debug, Clone, PartialEq, Eq)]
576pub enum SingleOrVec<T> {
577    /// Single value.
578    Single(T),
579
580    /// Vector of values.
581    Vec(Vec<T>),
582}
583
584impl<T> From<T> for SingleOrVec<T> {
585    fn from(value: T) -> Self {
586        SingleOrVec::Single(value)
587    }
588}
589
590impl<T> From<Vec<T>> for SingleOrVec<T> {
591    fn from(value: Vec<T>) -> Self {
592        SingleOrVec::Vec(value)
593    }
594}