mlx_rs/utils/
guard.rs

1use half::{bf16, f16};
2use mlx_sys::{__BindgenComplex, bfloat16_t, float16_t, mlx_array};
3
4use crate::{complex64, error::Exception, Array};
5
6use super::{VectorArray, SUCCESS};
7
8type Status = i32;
9
10pub trait Guard<T>: Default {
11    type MutRawPtr;
12
13    fn as_mut_raw_ptr(&mut self) -> Self::MutRawPtr;
14
15    fn set_init_success(&mut self, success: bool);
16
17    fn try_into_guarded(self) -> Result<T, Exception>;
18}
19
20pub(crate) trait Guarded: Sized {
21    type Guard: Guard<Self>;
22
23    #[track_caller]
24    fn try_from_op<F>(f: F) -> Result<Self, Exception>
25    where
26        F: FnOnce(<Self::Guard as Guard<Self>>::MutRawPtr) -> Status,
27    {
28        crate::error::INIT_ERR_HANDLER
29            .with(|init| init.call_once(crate::error::setup_mlx_error_handler));
30
31        let mut guard = Self::Guard::default();
32        let status = f(guard.as_mut_raw_ptr());
33        match status {
34            SUCCESS => {
35                guard.set_init_success(true);
36                guard.try_into_guarded()
37            }
38            _ => {
39                // Err(crate::error::get_and_clear_last_mlx_error()
40                // .expect("MLX operation failed but no error was set"))
41                let what = crate::error::get_and_clear_last_mlx_error()
42                    .expect("MLX operation failed but no error was set")
43                    .what;
44                let location = std::panic::Location::caller();
45                Err(Exception { what, location })
46            }
47        }
48    }
49}
50
51pub(crate) struct MaybeUninitArray {
52    pub(crate) ptr: mlx_array,
53    pub(crate) init_success: bool,
54}
55
56impl Default for MaybeUninitArray {
57    fn default() -> Self {
58        Self::new()
59    }
60}
61
62impl MaybeUninitArray {
63    pub fn new() -> Self {
64        unsafe {
65            Self {
66                ptr: mlx_sys::mlx_array_new(),
67                init_success: false,
68            }
69        }
70    }
71}
72
73impl Drop for MaybeUninitArray {
74    fn drop(&mut self) {
75        if !self.init_success {
76            unsafe {
77                mlx_sys::mlx_array_free(self.ptr);
78            }
79        }
80    }
81}
82
83impl Guard<Array> for MaybeUninitArray {
84    type MutRawPtr = *mut mlx_array;
85
86    fn as_mut_raw_ptr(&mut self) -> Self::MutRawPtr {
87        &mut self.ptr
88    }
89
90    fn set_init_success(&mut self, success: bool) {
91        self.init_success = success;
92    }
93
94    fn try_into_guarded(self) -> Result<Array, Exception> {
95        debug_assert!(self.init_success);
96        unsafe { Ok(Array::from_ptr(self.ptr)) }
97    }
98}
99
100impl Guarded for Array {
101    type Guard = MaybeUninitArray;
102}
103
104pub(crate) struct MaybeUninitVectorArray {
105    pub(crate) ptr: mlx_sys::mlx_vector_array,
106    pub(crate) init_success: bool,
107}
108
109impl Default for MaybeUninitVectorArray {
110    fn default() -> Self {
111        Self::new()
112    }
113}
114
115impl MaybeUninitVectorArray {
116    pub fn new() -> Self {
117        unsafe {
118            Self {
119                ptr: mlx_sys::mlx_vector_array_new(),
120                init_success: false,
121            }
122        }
123    }
124}
125
126impl Drop for MaybeUninitVectorArray {
127    fn drop(&mut self) {
128        if !self.init_success {
129            unsafe {
130                mlx_sys::mlx_vector_array_free(self.ptr);
131            }
132        }
133    }
134}
135
136impl Guard<Vec<Array>> for MaybeUninitVectorArray {
137    type MutRawPtr = *mut mlx_sys::mlx_vector_array;
138
139    fn as_mut_raw_ptr(&mut self) -> Self::MutRawPtr {
140        &mut self.ptr
141    }
142
143    fn set_init_success(&mut self, success: bool) {
144        self.init_success = success;
145    }
146
147    fn try_into_guarded(self) -> Result<Vec<Array>, Exception> {
148        unsafe {
149            let size = mlx_sys::mlx_vector_array_size(self.ptr);
150            (0..size)
151                .map(|i| Array::try_from_op(|res| mlx_sys::mlx_vector_array_get(res, self.ptr, i)))
152                .collect()
153        }
154    }
155}
156
157impl Guarded for Vec<Array> {
158    type Guard = MaybeUninitVectorArray;
159}
160
161impl Guard<VectorArray> for MaybeUninitVectorArray {
162    type MutRawPtr = *mut mlx_sys::mlx_vector_array;
163
164    fn as_mut_raw_ptr(&mut self) -> Self::MutRawPtr {
165        &mut self.ptr
166    }
167
168    fn set_init_success(&mut self, success: bool) {
169        self.init_success = success;
170    }
171
172    fn try_into_guarded(self) -> Result<VectorArray, Exception> {
173        Ok(VectorArray { c_vec: self.ptr })
174    }
175}
176
177impl Guarded for VectorArray {
178    type Guard = MaybeUninitVectorArray;
179}
180
181impl Guard<(Array, Array)> for (MaybeUninitArray, MaybeUninitArray) {
182    type MutRawPtr = (*mut mlx_array, *mut mlx_array);
183
184    fn as_mut_raw_ptr(&mut self) -> Self::MutRawPtr {
185        (self.0.as_mut_raw_ptr(), self.1.as_mut_raw_ptr())
186    }
187
188    fn set_init_success(&mut self, success: bool) {
189        self.0.set_init_success(success);
190        self.1.set_init_success(success);
191    }
192
193    fn try_into_guarded(self) -> Result<(Array, Array), Exception> {
194        Ok((self.0.try_into_guarded()?, self.1.try_into_guarded()?))
195    }
196}
197
198impl Guarded for (Array, Array) {
199    type Guard = (MaybeUninitArray, MaybeUninitArray);
200}
201
202impl Guard<(Array, Array, Array)> for (MaybeUninitArray, MaybeUninitArray, MaybeUninitArray) {
203    type MutRawPtr = (*mut mlx_array, *mut mlx_array, *mut mlx_array);
204
205    fn as_mut_raw_ptr(&mut self) -> Self::MutRawPtr {
206        (
207            self.0.as_mut_raw_ptr(),
208            self.1.as_mut_raw_ptr(),
209            self.2.as_mut_raw_ptr(),
210        )
211    }
212
213    fn set_init_success(&mut self, success: bool) {
214        self.0.set_init_success(success);
215        self.1.set_init_success(success);
216        self.2.set_init_success(success);
217    }
218
219    fn try_into_guarded(self) -> Result<(Array, Array, Array), Exception> {
220        Ok((
221            self.0.try_into_guarded()?,
222            self.1.try_into_guarded()?,
223            self.2.try_into_guarded()?,
224        ))
225    }
226}
227
228impl Guarded for (Array, Array, Array) {
229    type Guard = (MaybeUninitArray, MaybeUninitArray, MaybeUninitArray);
230}
231
232impl Guard<(Vec<Array>, Vec<Array>)> for (MaybeUninitVectorArray, MaybeUninitVectorArray) {
233    type MutRawPtr = (
234        *mut mlx_sys::mlx_vector_array,
235        *mut mlx_sys::mlx_vector_array,
236    );
237
238    fn as_mut_raw_ptr(&mut self) -> Self::MutRawPtr {
239        (
240            <MaybeUninitVectorArray as Guard<Vec<Array>>>::as_mut_raw_ptr(&mut self.0),
241            <MaybeUninitVectorArray as Guard<Vec<Array>>>::as_mut_raw_ptr(&mut self.1),
242        )
243    }
244
245    fn set_init_success(&mut self, success: bool) {
246        <MaybeUninitVectorArray as Guard<Vec<Array>>>::set_init_success(&mut self.0, success);
247        <MaybeUninitVectorArray as Guard<Vec<Array>>>::set_init_success(&mut self.1, success);
248    }
249
250    fn try_into_guarded(self) -> Result<(Vec<Array>, Vec<Array>), Exception> {
251        Ok((self.0.try_into_guarded()?, self.1.try_into_guarded()?))
252    }
253}
254
255impl Guarded for (Vec<Array>, Vec<Array>) {
256    type Guard = (MaybeUninitVectorArray, MaybeUninitVectorArray);
257}
258
259pub(crate) struct MaybeUninitDevice {
260    pub(crate) ptr: mlx_sys::mlx_device,
261    pub(crate) init_success: bool,
262}
263
264impl Default for MaybeUninitDevice {
265    fn default() -> Self {
266        Self::new()
267    }
268}
269
270impl MaybeUninitDevice {
271    pub fn new() -> Self {
272        unsafe {
273            Self {
274                ptr: mlx_sys::mlx_device_new(),
275                init_success: false,
276            }
277        }
278    }
279}
280
281impl Drop for MaybeUninitDevice {
282    fn drop(&mut self) {
283        if !self.init_success {
284            unsafe {
285                mlx_sys::mlx_device_free(self.ptr);
286            }
287        }
288    }
289}
290
291impl Guard<crate::Device> for MaybeUninitDevice {
292    type MutRawPtr = *mut mlx_sys::mlx_device;
293
294    fn as_mut_raw_ptr(&mut self) -> Self::MutRawPtr {
295        &mut self.ptr
296    }
297
298    fn set_init_success(&mut self, success: bool) {
299        self.init_success = success;
300    }
301
302    fn try_into_guarded(self) -> Result<crate::Device, Exception> {
303        debug_assert!(self.init_success);
304        Ok(crate::Device { c_device: self.ptr })
305    }
306}
307
308impl Guarded for crate::DeviceType {
309    type Guard = mlx_sys::mlx_device_type;
310}
311
312impl Guard<crate::DeviceType> for mlx_sys::mlx_device_type {
313    type MutRawPtr = *mut mlx_sys::mlx_device_type;
314
315    fn as_mut_raw_ptr(&mut self) -> Self::MutRawPtr {
316        self
317    }
318
319    fn set_init_success(&mut self, _: bool) {}
320
321    fn try_into_guarded(self) -> Result<crate::DeviceType, Exception> {
322        match self {
323            mlx_sys::mlx_device_type__MLX_CPU => Ok(crate::DeviceType::Cpu),
324            mlx_sys::mlx_device_type__MLX_GPU => Ok(crate::DeviceType::Gpu),
325            _ => Err(Exception {
326                what: "Unknown device type".to_string(),
327                location: std::panic::Location::caller(),
328            }),
329        }
330    }
331}
332
333impl Guarded for crate::Device {
334    type Guard = MaybeUninitDevice;
335}
336
337pub(crate) struct MaybeUninitStream {
338    pub(crate) ptr: mlx_sys::mlx_stream,
339    pub(crate) init_success: bool,
340}
341
342impl Default for MaybeUninitStream {
343    fn default() -> Self {
344        Self::new()
345    }
346}
347
348impl MaybeUninitStream {
349    pub fn new() -> Self {
350        unsafe {
351            Self {
352                ptr: mlx_sys::mlx_stream_new(),
353                init_success: false,
354            }
355        }
356    }
357}
358
359impl Drop for MaybeUninitStream {
360    fn drop(&mut self) {
361        if !self.init_success {
362            unsafe {
363                mlx_sys::mlx_stream_free(self.ptr);
364            }
365        }
366    }
367}
368
369impl Guard<crate::Stream> for MaybeUninitStream {
370    type MutRawPtr = *mut mlx_sys::mlx_stream;
371
372    fn as_mut_raw_ptr(&mut self) -> Self::MutRawPtr {
373        &mut self.ptr
374    }
375
376    fn set_init_success(&mut self, success: bool) {
377        self.init_success = success;
378    }
379
380    fn try_into_guarded(self) -> Result<crate::Stream, Exception> {
381        debug_assert!(self.init_success);
382        Ok(crate::Stream { c_stream: self.ptr })
383    }
384}
385
386impl Guarded for crate::Stream {
387    type Guard = MaybeUninitStream;
388}
389
390pub(crate) struct MaybeUninitSafeTensors {
391    pub(crate) c_data: mlx_sys::mlx_map_string_to_array,
392    pub(crate) c_metadata: mlx_sys::mlx_map_string_to_string,
393    pub(crate) init_success: bool,
394}
395
396impl Default for MaybeUninitSafeTensors {
397    fn default() -> Self {
398        Self::new()
399    }
400}
401
402impl MaybeUninitSafeTensors {
403    pub fn new() -> Self {
404        unsafe {
405            Self {
406                c_metadata: mlx_sys::mlx_map_string_to_string_new(),
407                c_data: mlx_sys::mlx_map_string_to_array_new(),
408                init_success: false,
409            }
410        }
411    }
412}
413
414impl Drop for MaybeUninitSafeTensors {
415    fn drop(&mut self) {
416        if !self.init_success {
417            unsafe {
418                mlx_sys::mlx_map_string_to_string_free(self.c_metadata);
419                mlx_sys::mlx_map_string_to_array_free(self.c_data);
420            }
421        }
422    }
423}
424
425impl Guard<crate::utils::io::SafeTensors> for MaybeUninitSafeTensors {
426    type MutRawPtr = (
427        *mut mlx_sys::mlx_map_string_to_array,
428        *mut mlx_sys::mlx_map_string_to_string,
429    );
430
431    fn as_mut_raw_ptr(&mut self) -> Self::MutRawPtr {
432        (&mut self.c_data, &mut self.c_metadata)
433    }
434
435    fn set_init_success(&mut self, success: bool) {
436        self.init_success = success;
437    }
438
439    fn try_into_guarded(self) -> Result<crate::utils::io::SafeTensors, Exception> {
440        debug_assert!(self.init_success);
441        Ok(crate::utils::io::SafeTensors {
442            c_metadata: self.c_metadata,
443            c_data: self.c_data,
444        })
445    }
446}
447
448impl Guarded for crate::utils::io::SafeTensors {
449    type Guard = MaybeUninitSafeTensors;
450}
451
452pub(crate) struct MaybeUninitClosure {
453    pub(crate) ptr: mlx_sys::mlx_closure,
454    pub(crate) init_success: bool,
455}
456
457impl Default for MaybeUninitClosure {
458    fn default() -> Self {
459        Self::new()
460    }
461}
462
463impl MaybeUninitClosure {
464    pub fn new() -> Self {
465        unsafe {
466            Self {
467                ptr: mlx_sys::mlx_closure_new(),
468                init_success: false,
469            }
470        }
471    }
472}
473
474impl Drop for MaybeUninitClosure {
475    fn drop(&mut self) {
476        if !self.init_success {
477            unsafe {
478                mlx_sys::mlx_closure_free(self.ptr);
479            }
480        }
481    }
482}
483
484impl<'a> Guard<crate::utils::Closure<'a>> for MaybeUninitClosure {
485    type MutRawPtr = *mut mlx_sys::mlx_closure;
486
487    fn as_mut_raw_ptr(&mut self) -> Self::MutRawPtr {
488        &mut self.ptr
489    }
490
491    fn set_init_success(&mut self, success: bool) {
492        self.init_success = success;
493    }
494
495    fn try_into_guarded(self) -> Result<crate::utils::Closure<'a>, Exception> {
496        debug_assert!(self.init_success);
497        Ok(crate::utils::Closure {
498            c_closure: self.ptr,
499            lt_marker: std::marker::PhantomData,
500        })
501    }
502}
503
504impl Guarded for crate::utils::Closure<'_> {
505    type Guard = MaybeUninitClosure;
506}
507
508pub(crate) struct MaybeUninitClosureValueAndGrad {
509    pub(crate) ptr: mlx_sys::mlx_closure_value_and_grad,
510    pub(crate) init_success: bool,
511}
512
513impl Default for MaybeUninitClosureValueAndGrad {
514    fn default() -> Self {
515        Self::new()
516    }
517}
518
519impl MaybeUninitClosureValueAndGrad {
520    pub fn new() -> Self {
521        unsafe {
522            Self {
523                ptr: mlx_sys::mlx_closure_value_and_grad_new(),
524                init_success: false,
525            }
526        }
527    }
528}
529
530impl Drop for MaybeUninitClosureValueAndGrad {
531    fn drop(&mut self) {
532        if !self.init_success {
533            unsafe {
534                mlx_sys::mlx_closure_value_and_grad_free(self.ptr);
535            }
536        }
537    }
538}
539
540impl Guard<crate::transforms::ClosureValueAndGrad> for MaybeUninitClosureValueAndGrad {
541    type MutRawPtr = *mut mlx_sys::mlx_closure_value_and_grad;
542
543    fn as_mut_raw_ptr(&mut self) -> Self::MutRawPtr {
544        &mut self.ptr
545    }
546
547    fn set_init_success(&mut self, success: bool) {
548        self.init_success = success;
549    }
550
551    fn try_into_guarded(self) -> Result<crate::transforms::ClosureValueAndGrad, Exception> {
552        debug_assert!(self.init_success);
553        Ok(crate::transforms::ClosureValueAndGrad {
554            c_closure_value_and_grad: self.ptr,
555        })
556    }
557}
558
559impl Guarded for crate::transforms::ClosureValueAndGrad {
560    type Guard = MaybeUninitClosureValueAndGrad;
561}
562
563macro_rules! impl_guarded_for_primitive {
564    ($type:ty) => {
565        impl Guarded for $type {
566            type Guard = $type;
567        }
568
569        impl Guard<$type> for $type {
570            type MutRawPtr = *mut $type;
571
572            fn as_mut_raw_ptr(&mut self) -> Self::MutRawPtr {
573                self
574            }
575
576            fn set_init_success(&mut self, _: bool) { }
577
578            fn try_into_guarded(self) -> Result<$type, Exception> {
579                Ok(self)
580            }
581        }
582    };
583
584    ($($type:ty),*) => {
585        $(impl_guarded_for_primitive!($type);)*
586    };
587}
588
589impl_guarded_for_primitive!(bool, u8, u16, u32, u64, i8, i16, i32, i64, f32, f64, ());
590
591impl Guarded for f16 {
592    type Guard = float16_t;
593}
594
595impl Guard<f16> for float16_t {
596    type MutRawPtr = *mut float16_t;
597
598    fn as_mut_raw_ptr(&mut self) -> Self::MutRawPtr {
599        self
600    }
601
602    fn set_init_success(&mut self, _: bool) {}
603
604    fn try_into_guarded(self) -> Result<f16, Exception> {
605        Ok(f16::from_bits(self.0))
606    }
607}
608
609impl Guarded for bf16 {
610    type Guard = bfloat16_t;
611}
612
613impl Guard<bf16> for bfloat16_t {
614    type MutRawPtr = *mut bfloat16_t;
615
616    fn as_mut_raw_ptr(&mut self) -> Self::MutRawPtr {
617        self
618    }
619
620    fn set_init_success(&mut self, _: bool) {}
621
622    fn try_into_guarded(self) -> Result<bf16, Exception> {
623        Ok(bf16::from_bits(self))
624    }
625}
626
627impl Guarded for complex64 {
628    type Guard = __BindgenComplex<f32>;
629}
630
631impl Guard<complex64> for __BindgenComplex<f32> {
632    type MutRawPtr = *mut __BindgenComplex<f32>;
633
634    fn as_mut_raw_ptr(&mut self) -> Self::MutRawPtr {
635        self
636    }
637
638    fn set_init_success(&mut self, _: bool) {}
639
640    fn try_into_guarded(self) -> Result<complex64, Exception> {
641        Ok(complex64::new(self.re, self.im))
642    }
643}