mlx_rs/utils/
mod.rs

1//! Utility functions and types.
2
3use guard::Guarded;
4use mlx_sys::mlx_vector_array;
5
6use crate::error::set_closure_error;
7use crate::module::ModuleParameters;
8use crate::{complex64, error::Exception, Array, FromNested};
9use std::collections::HashMap;
10use std::{marker::PhantomData, rc::Rc};
11
12/// Success status code from the c binding
13pub(crate) const SUCCESS: i32 = 0;
14pub(crate) const FAILURE: i32 = 1;
15
16pub(crate) mod guard;
17pub(crate) mod io;
18
19pub(crate) fn resolve_index_signed_unchecked(index: i32, len: i32) -> i32 {
20    if index < 0 {
21        len.saturating_add(index)
22    } else {
23        index
24    }
25}
26
27pub(crate) fn resolve_index_unchecked(index: i32, len: usize) -> usize {
28    if index.is_negative() {
29        (len as i32 + index) as usize
30    } else {
31        index as usize
32    }
33}
34
35/// Helper method to convert an optional slice of axes to a Vec covering all axes.
36pub(crate) fn axes_or_default_to_all<'a>(axes: impl IntoOption<&'a [i32]>, ndim: i32) -> Vec<i32> {
37    match axes.into_option() {
38        Some(axes) => axes.to_vec(),
39        None => {
40            let axes: Vec<i32> = (0..ndim).collect();
41            axes
42        }
43    }
44}
45
46pub(crate) struct VectorArray {
47    c_vec: mlx_sys::mlx_vector_array,
48}
49
50impl VectorArray {
51    pub(crate) fn as_ptr(&self) -> mlx_sys::mlx_vector_array {
52        self.c_vec
53    }
54
55    pub(crate) fn try_from_iter(
56        iter: impl Iterator<Item = impl AsRef<Array>>,
57    ) -> Result<Self, Exception> {
58        VectorArray::try_from_op(|res| unsafe {
59            let mut status = SUCCESS;
60            for arr in iter {
61                status = mlx_sys::mlx_vector_array_append_value(*res, arr.as_ref().as_ptr());
62                if status != SUCCESS {
63                    return status;
64                }
65            }
66            status
67        })
68    }
69
70    pub(crate) fn try_into_values<T>(self) -> Result<T, Exception>
71    where
72        T: FromIterator<Array>,
73    {
74        unsafe {
75            let size = mlx_sys::mlx_vector_array_size(self.c_vec);
76            (0..size)
77                .map(|i| {
78                    Array::try_from_op(|res| mlx_sys::mlx_vector_array_get(res, self.c_vec, i))
79                })
80                .collect::<Result<T, Exception>>()
81        }
82    }
83}
84
85impl Drop for VectorArray {
86    fn drop(&mut self) {
87        let status = unsafe { mlx_sys::mlx_vector_array_free(self.c_vec) };
88        debug_assert_eq!(status, SUCCESS);
89    }
90}
91
92/// A helper trait that is just like `Into<Option<T>>` but improves ergonomics by allowing
93/// implicit conversion from &[T; N] to &[T].
94pub trait IntoOption<T> {
95    /// Convert into an [`Option`].
96    fn into_option(self) -> Option<T>;
97}
98
99impl<T> IntoOption<T> for Option<T> {
100    fn into_option(self) -> Option<T> {
101        self
102    }
103}
104
105impl<T> IntoOption<T> for T {
106    fn into_option(self) -> Option<T> {
107        Some(self)
108    }
109}
110
111impl<'a, T, const N: usize> IntoOption<&'a [T]> for &'a [T; N] {
112    fn into_option(self) -> Option<&'a [T]> {
113        Some(self)
114    }
115}
116
117impl<'a, T> IntoOption<&'a [T]> for &'a Vec<T> {
118    fn into_option(self) -> Option<&'a [T]> {
119        Some(self)
120    }
121}
122
123/// A trait for a scalar or an array.
124pub trait ScalarOrArray<'a> {
125    /// The reference type of the array.
126    type Array: AsRef<Array> + 'a;
127
128    /// Convert to an owned or reference array.
129    fn into_owned_or_ref_array(self) -> Self::Array;
130}
131
132impl ScalarOrArray<'_> for Array {
133    type Array = Array;
134
135    fn into_owned_or_ref_array(self) -> Array {
136        self
137    }
138}
139
140impl<'a> ScalarOrArray<'a> for &'a Array {
141    type Array = &'a Array;
142
143    // TODO: clippy would complain about `as_array`. Is there a better name?
144    fn into_owned_or_ref_array(self) -> &'a Array {
145        self
146    }
147}
148
149impl ScalarOrArray<'static> for bool {
150    type Array = Array;
151
152    fn into_owned_or_ref_array(self) -> Array {
153        Array::from_bool(self)
154    }
155}
156
157impl ScalarOrArray<'static> for i32 {
158    type Array = Array;
159
160    fn into_owned_or_ref_array(self) -> Array {
161        Array::from_int(self)
162    }
163}
164
165impl ScalarOrArray<'static> for f32 {
166    type Array = Array;
167
168    fn into_owned_or_ref_array(self) -> Array {
169        Array::from_f32(self)
170    }
171}
172
173// TODO: this is bugged right now. See https://github.com/ml-explore/mlx/issues/1994
174// impl ScalarOrArray<'static> for f64 {
175//     type Array = Array;
176
177//     fn into_owned_or_ref_array(self) -> Array {
178//         Array::from_f64(self)
179//     }
180// }
181
182impl ScalarOrArray<'static> for complex64 {
183    type Array = Array;
184
185    fn into_owned_or_ref_array(self) -> Array {
186        Array::from_complex(self)
187    }
188}
189
190impl<T> ScalarOrArray<'static> for T
191where
192    Array: FromNested<T>,
193{
194    type Array = Array;
195
196    fn into_owned_or_ref_array(self) -> Array {
197        Array::from_nested(self)
198    }
199}
200
201#[derive(Debug)]
202pub(crate) struct Closure<'a> {
203    c_closure: mlx_sys::mlx_closure,
204    lt_marker: PhantomData<&'a ()>,
205}
206
207impl<'a> Closure<'a> {
208    pub(crate) fn as_ptr(&self) -> mlx_sys::mlx_closure {
209        self.c_closure
210    }
211
212    pub(crate) fn new<F>(closure: F) -> Self
213    where
214        F: FnMut(&[Array]) -> Vec<Array> + 'a,
215    {
216        let c_closure = new_mlx_closure(closure);
217        Self {
218            c_closure,
219            lt_marker: PhantomData,
220        }
221    }
222
223    pub(crate) fn new_fallible<F>(closure: F) -> Self
224    where
225        F: FnMut(&[Array]) -> Result<Vec<Array>, Exception> + 'a,
226    {
227        let c_closure = new_mlx_fallible_closure(closure);
228        Self {
229            c_closure,
230            lt_marker: PhantomData,
231        }
232    }
233}
234
235impl Drop for Closure<'_> {
236    fn drop(&mut self) {
237        let status = unsafe { mlx_sys::mlx_closure_free(self.c_closure) };
238        debug_assert_eq!(status, SUCCESS);
239    }
240}
241
242/// Helper method to create a mlx_closure from a Rust closure.
243fn new_mlx_closure<'a, F>(closure: F) -> mlx_sys::mlx_closure
244where
245    F: FnMut(&[Array]) -> Vec<Array> + 'a,
246{
247    // Box the closure to keep it on the heap
248    let boxed = Box::new(closure);
249
250    // Create a raw pointer from the Box, transferring ownership to C
251    let raw = Box::into_raw(boxed);
252    let payload = raw as *mut std::ffi::c_void;
253
254    unsafe {
255        mlx_sys::mlx_closure_new_func_payload(Some(trampoline::<F>), payload, Some(noop_dtor))
256    }
257}
258
259fn new_mlx_fallible_closure<'a, F>(closure: F) -> mlx_sys::mlx_closure
260where
261    F: FnMut(&[Array]) -> Result<Vec<Array>, Exception> + 'a,
262{
263    let boxed = Box::new(closure);
264    let raw = Box::into_raw(boxed);
265    let payload = raw as *mut std::ffi::c_void;
266
267    unsafe {
268        mlx_sys::mlx_closure_new_func_payload(
269            Some(trampoline_fallible::<F>),
270            payload,
271            Some(noop_dtor),
272        )
273    }
274}
275
276/// Function to create a new (+1 reference) mlx_vector_array from a vector of Array
277fn new_mlx_vector_array(arrays: Vec<Array>) -> mlx_sys::mlx_vector_array {
278    unsafe {
279        let result = mlx_sys::mlx_vector_array_new();
280        let ctx_ptrs: Vec<mlx_sys::mlx_array> = arrays.iter().map(|array| array.as_ptr()).collect();
281        mlx_sys::mlx_vector_array_append_data(result, ctx_ptrs.as_ptr(), arrays.len());
282        result
283    }
284}
285
286fn mlx_vector_array_values(
287    vector_array: mlx_sys::mlx_vector_array,
288) -> Result<Vec<Array>, Exception> {
289    unsafe {
290        let size = mlx_sys::mlx_vector_array_size(vector_array);
291        (0..size)
292            .map(|index| {
293                Array::try_from_op(|res| mlx_sys::mlx_vector_array_get(res, vector_array, index))
294            })
295            .collect()
296    }
297}
298
299extern "C" fn trampoline<'a, F>(
300    ret: *mut mlx_vector_array,
301    vector_array: mlx_vector_array,
302    payload: *mut std::ffi::c_void,
303) -> i32
304where
305    F: FnMut(&[Array]) -> Vec<Array> + 'a,
306{
307    unsafe {
308        let raw_closure: *mut F = payload as *mut _;
309        // Let the box take care of freeing the closure
310        let mut closure = Box::from_raw(raw_closure);
311        let arrays = match mlx_vector_array_values(vector_array) {
312            Ok(arrays) => arrays,
313            Err(_) => {
314                return FAILURE;
315            }
316        };
317        let result = closure(&arrays);
318        // We should probably keep using new_mlx_vector_array here instead of VectorArray
319        // since we probably don't want to drop the arrays in the closure
320        *ret = new_mlx_vector_array(result);
321
322        SUCCESS
323    }
324}
325
326extern "C" fn trampoline_fallible<'a, F>(
327    ret: *mut mlx_vector_array,
328    vector_array: mlx_vector_array,
329    payload: *mut std::ffi::c_void,
330) -> i32
331where
332    F: FnMut(&[Array]) -> Result<Vec<Array>, Exception> + 'a,
333{
334    unsafe {
335        let raw_closure: *mut F = payload as *mut _;
336        let mut closure = Box::from_raw(raw_closure);
337        let arrays = match mlx_vector_array_values(vector_array) {
338            Ok(arrays) => arrays,
339            Err(e) => {
340                set_closure_error(e);
341                return FAILURE;
342            }
343        };
344        let result = closure(&arrays);
345        match result {
346            Ok(result) => {
347                *ret = new_mlx_vector_array(result);
348                SUCCESS
349            }
350            Err(err) => {
351                set_closure_error(err);
352                FAILURE
353            }
354        }
355    }
356}
357
358extern "C" fn noop_dtor(_data: *mut std::ffi::c_void) {}
359
360pub(crate) fn get_mut_or_insert_with<'a, T>(
361    map: &'a mut HashMap<Rc<str>, T>,
362    key: &Rc<str>,
363    f: impl FnOnce() -> T,
364) -> &'a mut T {
365    if !map.contains_key(key) {
366        map.insert(key.clone(), f());
367    }
368
369    map.get_mut(key).unwrap()
370}
371
372/// Helper trait for compiling a function that takes a Module and/or an Optimizer.
373/// The implementation must ensure consistent ordering of the returned states.
374///
375/// This is automatically implemented for all types that implement ModuleParameters.
376pub trait Updatable {
377    /// Returns a list of references to the updatable states.
378    ///
379    /// The order of the states should be consistent across calls and should be the same as the
380    /// order of the states returned by [`Updatable::updatable_states_mut`].
381    fn updatable_states(&self) -> impl IntoIterator<Item = &Array>;
382
383    /// Returns a list of mutable references to the updatable states.
384    ///
385    /// The order of the states should be consistent across calls and should be the same as the
386    /// order of the states returned by [`Updatable::updatable_states`].
387    fn updatable_states_mut(&mut self) -> impl IntoIterator<Item = &mut Array>;
388}
389
390impl<T> Updatable for T
391where
392    T: ModuleParameters,
393{
394    fn updatable_states(&self) -> impl IntoIterator<Item = &Array> {
395        use itertools::Itertools;
396
397        // TODO: should we change the parameter map to a BTreeMap because it is sorted?
398        self.parameters()
399            .flatten()
400            .into_iter()
401            .sorted_by(|a, b| a.0.cmp(&b.0))
402            .map(|(_, v)| v)
403    }
404
405    fn updatable_states_mut(&mut self) -> impl IntoIterator<Item = &mut Array> {
406        use itertools::Itertools;
407
408        self.parameters_mut()
409            .flatten()
410            .into_iter()
411            .sorted_by(|a, b| a.0.cmp(&b.0))
412            .map(|(_, v)| v)
413    }
414}
415
416impl<T1, T2> Updatable for (T1, T2)
417where
418    T1: Updatable,
419    T2: Updatable,
420{
421    fn updatable_states(&self) -> impl IntoIterator<Item = &Array> {
422        let (a, b) = self;
423        let params = a.updatable_states();
424        params.into_iter().chain(b.updatable_states())
425    }
426
427    fn updatable_states_mut(&mut self) -> impl IntoIterator<Item = &mut Array> {
428        let (a, b) = self;
429        let params = a.updatable_states_mut();
430        params.into_iter().chain(b.updatable_states_mut())
431    }
432}
433
434impl Updatable for Vec<Array> {
435    fn updatable_states(&self) -> impl IntoIterator<Item = &Array> {
436        self.iter()
437    }
438
439    fn updatable_states_mut(&mut self) -> impl IntoIterator<Item = &mut Array> {
440        self.iter_mut()
441    }
442}
443
444/// Helper type to represent either a single value or a pair of values.
445#[derive(Debug, Clone, Copy, PartialEq, Eq)]
446pub enum SingleOrPair<T = i32> {
447    /// Single value.
448    Single(T),
449
450    /// Pair of values.
451    Pair(T, T),
452}
453
454impl<T: Clone> SingleOrPair<T> {
455    /// Returns the first value.
456    pub fn first(&self) -> T {
457        match self {
458            SingleOrPair::Single(v) => v.clone(),
459            SingleOrPair::Pair(v1, _) => v1.clone(),
460        }
461    }
462
463    /// Returns the second value.
464    pub fn second(&self) -> T {
465        match self {
466            SingleOrPair::Single(v) => v.clone(),
467            SingleOrPair::Pair(_, v2) => v2.clone(),
468        }
469    }
470}
471
472impl<T> From<T> for SingleOrPair<T> {
473    fn from(value: T) -> Self {
474        SingleOrPair::Single(value)
475    }
476}
477
478impl<T> From<(T, T)> for SingleOrPair<T> {
479    fn from(value: (T, T)) -> Self {
480        SingleOrPair::Pair(value.0, value.1)
481    }
482}
483
484impl<T: Clone> From<SingleOrPair<T>> for (T, T) {
485    fn from(value: SingleOrPair<T>) -> Self {
486        match value {
487            SingleOrPair::Single(v) => (v.clone(), v),
488            SingleOrPair::Pair(v1, v2) => (v1, v2),
489        }
490    }
491}
492
493/// Helper type to represent either a single value or a triple of values.
494#[derive(Debug, Clone, Copy, PartialEq, Eq)]
495pub enum SingleOrTriple<T = i32> {
496    /// Single value.
497    Single(T),
498
499    /// Triple of values.
500    Triple(T, T, T),
501}
502
503impl<T: Clone> SingleOrTriple<T> {
504    /// Returns the first value.
505    pub fn first(&self) -> T {
506        match self {
507            SingleOrTriple::Single(v) => v.clone(),
508            SingleOrTriple::Triple(v1, _, _) => v1.clone(),
509        }
510    }
511
512    /// Returns the second value.
513    pub fn second(&self) -> T {
514        match self {
515            SingleOrTriple::Single(v) => v.clone(),
516            SingleOrTriple::Triple(_, v2, _) => v2.clone(),
517        }
518    }
519
520    /// Returns the third value.
521    pub fn third(&self) -> T {
522        match self {
523            SingleOrTriple::Single(v) => v.clone(),
524            SingleOrTriple::Triple(_, _, v3) => v3.clone(),
525        }
526    }
527}
528
529impl<T> From<T> for SingleOrTriple<T> {
530    fn from(value: T) -> Self {
531        SingleOrTriple::Single(value)
532    }
533}
534
535impl<T> From<(T, T, T)> for SingleOrTriple<T> {
536    fn from(value: (T, T, T)) -> Self {
537        SingleOrTriple::Triple(value.0, value.1, value.2)
538    }
539}
540
541impl<T: Clone> From<SingleOrTriple<T>> for (T, T, T) {
542    fn from(value: SingleOrTriple<T>) -> Self {
543        match value {
544            SingleOrTriple::Single(v) => (v.clone(), v.clone(), v),
545            SingleOrTriple::Triple(v1, v2, v3) => (v1, v2, v3),
546        }
547    }
548}
549
550/// Helper type to represent either a single value or a vector of values.
551#[derive(Debug, Clone, PartialEq, Eq)]
552pub enum SingleOrVec<T> {
553    /// Single value.
554    Single(T),
555
556    /// Vector of values.
557    Vec(Vec<T>),
558}
559
560impl<T> From<T> for SingleOrVec<T> {
561    fn from(value: T) -> Self {
562        SingleOrVec::Single(value)
563    }
564}
565
566impl<T> From<Vec<T>> for SingleOrVec<T> {
567    fn from(value: Vec<T>) -> Self {
568        SingleOrVec::Vec(value)
569    }
570}