1use half::{bf16, f16};
2use mlx_sys::{__BindgenComplex, bfloat16_t, float16_t, mlx_array};
3
4use crate::{complex64, error::Exception, Array};
5
6use super::{VectorArray, SUCCESS};
7
8type Status = i32;
9
10pub trait Guard<T>: Default {
11 type MutRawPtr;
12
13 fn as_mut_raw_ptr(&mut self) -> Self::MutRawPtr;
14
15 fn set_init_success(&mut self, success: bool);
16
17 fn try_into_guarded(self) -> Result<T, Exception>;
18}
19
20pub(crate) trait Guarded: Sized {
21 type Guard: Guard<Self>;
22
23 #[track_caller]
24 fn try_from_op<F>(f: F) -> Result<Self, Exception>
25 where
26 F: FnOnce(<Self::Guard as Guard<Self>>::MutRawPtr) -> Status,
27 {
28 crate::error::INIT_ERR_HANDLER
29 .with(|init| init.call_once(crate::error::setup_mlx_error_handler));
30
31 let mut guard = Self::Guard::default();
32 let status = f(guard.as_mut_raw_ptr());
33 match status {
34 SUCCESS => {
35 guard.set_init_success(true);
36 guard.try_into_guarded()
37 }
38 _ => {
39 let what = crate::error::get_and_clear_last_mlx_error()
42 .expect("MLX operation failed but no error was set")
43 .what;
44 let location = std::panic::Location::caller();
45 Err(Exception { what, location })
46 }
47 }
48 }
49}
50
51pub(crate) struct MaybeUninitArray {
52 pub(crate) ptr: mlx_array,
53 pub(crate) init_success: bool,
54}
55
56impl Default for MaybeUninitArray {
57 fn default() -> Self {
58 Self::new()
59 }
60}
61
62impl MaybeUninitArray {
63 pub fn new() -> Self {
64 unsafe {
65 Self {
66 ptr: mlx_sys::mlx_array_new(),
67 init_success: false,
68 }
69 }
70 }
71}
72
73impl Drop for MaybeUninitArray {
74 fn drop(&mut self) {
75 if !self.init_success {
76 unsafe {
77 mlx_sys::mlx_array_free(self.ptr);
78 }
79 }
80 }
81}
82
83impl Guard<Array> for MaybeUninitArray {
84 type MutRawPtr = *mut mlx_array;
85
86 fn as_mut_raw_ptr(&mut self) -> Self::MutRawPtr {
87 &mut self.ptr
88 }
89
90 fn set_init_success(&mut self, success: bool) {
91 self.init_success = success;
92 }
93
94 fn try_into_guarded(self) -> Result<Array, Exception> {
95 debug_assert!(self.init_success);
96 unsafe { Ok(Array::from_ptr(self.ptr)) }
97 }
98}
99
100impl Guarded for Array {
101 type Guard = MaybeUninitArray;
102}
103
104pub(crate) struct MaybeUninitVectorArray {
105 pub(crate) ptr: mlx_sys::mlx_vector_array,
106 pub(crate) init_success: bool,
107}
108
109impl Default for MaybeUninitVectorArray {
110 fn default() -> Self {
111 Self::new()
112 }
113}
114
115impl MaybeUninitVectorArray {
116 pub fn new() -> Self {
117 unsafe {
118 Self {
119 ptr: mlx_sys::mlx_vector_array_new(),
120 init_success: false,
121 }
122 }
123 }
124}
125
126impl Drop for MaybeUninitVectorArray {
127 fn drop(&mut self) {
128 if !self.init_success {
129 unsafe {
130 mlx_sys::mlx_vector_array_free(self.ptr);
131 }
132 }
133 }
134}
135
136impl Guard<Vec<Array>> for MaybeUninitVectorArray {
137 type MutRawPtr = *mut mlx_sys::mlx_vector_array;
138
139 fn as_mut_raw_ptr(&mut self) -> Self::MutRawPtr {
140 &mut self.ptr
141 }
142
143 fn set_init_success(&mut self, success: bool) {
144 self.init_success = success;
145 }
146
147 fn try_into_guarded(self) -> Result<Vec<Array>, Exception> {
148 unsafe {
149 let size = mlx_sys::mlx_vector_array_size(self.ptr);
150 (0..size)
151 .map(|i| Array::try_from_op(|res| mlx_sys::mlx_vector_array_get(res, self.ptr, i)))
152 .collect()
153 }
154 }
155}
156
157impl Guarded for Vec<Array> {
158 type Guard = MaybeUninitVectorArray;
159}
160
161impl Guard<VectorArray> for MaybeUninitVectorArray {
162 type MutRawPtr = *mut mlx_sys::mlx_vector_array;
163
164 fn as_mut_raw_ptr(&mut self) -> Self::MutRawPtr {
165 &mut self.ptr
166 }
167
168 fn set_init_success(&mut self, success: bool) {
169 self.init_success = success;
170 }
171
172 fn try_into_guarded(self) -> Result<VectorArray, Exception> {
173 Ok(VectorArray { c_vec: self.ptr })
174 }
175}
176
177impl Guarded for VectorArray {
178 type Guard = MaybeUninitVectorArray;
179}
180
181impl Guard<(Array, Array)> for (MaybeUninitArray, MaybeUninitArray) {
182 type MutRawPtr = (*mut mlx_array, *mut mlx_array);
183
184 fn as_mut_raw_ptr(&mut self) -> Self::MutRawPtr {
185 (self.0.as_mut_raw_ptr(), self.1.as_mut_raw_ptr())
186 }
187
188 fn set_init_success(&mut self, success: bool) {
189 self.0.set_init_success(success);
190 self.1.set_init_success(success);
191 }
192
193 fn try_into_guarded(self) -> Result<(Array, Array), Exception> {
194 Ok((self.0.try_into_guarded()?, self.1.try_into_guarded()?))
195 }
196}
197
198impl Guarded for (Array, Array) {
199 type Guard = (MaybeUninitArray, MaybeUninitArray);
200}
201
202impl Guard<(Array, Array, Array)> for (MaybeUninitArray, MaybeUninitArray, MaybeUninitArray) {
203 type MutRawPtr = (*mut mlx_array, *mut mlx_array, *mut mlx_array);
204
205 fn as_mut_raw_ptr(&mut self) -> Self::MutRawPtr {
206 (
207 self.0.as_mut_raw_ptr(),
208 self.1.as_mut_raw_ptr(),
209 self.2.as_mut_raw_ptr(),
210 )
211 }
212
213 fn set_init_success(&mut self, success: bool) {
214 self.0.set_init_success(success);
215 self.1.set_init_success(success);
216 self.2.set_init_success(success);
217 }
218
219 fn try_into_guarded(self) -> Result<(Array, Array, Array), Exception> {
220 Ok((
221 self.0.try_into_guarded()?,
222 self.1.try_into_guarded()?,
223 self.2.try_into_guarded()?,
224 ))
225 }
226}
227
228impl Guarded for (Array, Array, Array) {
229 type Guard = (MaybeUninitArray, MaybeUninitArray, MaybeUninitArray);
230}
231
232impl Guard<(Vec<Array>, Vec<Array>)> for (MaybeUninitVectorArray, MaybeUninitVectorArray) {
233 type MutRawPtr = (
234 *mut mlx_sys::mlx_vector_array,
235 *mut mlx_sys::mlx_vector_array,
236 );
237
238 fn as_mut_raw_ptr(&mut self) -> Self::MutRawPtr {
239 (
240 <MaybeUninitVectorArray as Guard<Vec<Array>>>::as_mut_raw_ptr(&mut self.0),
241 <MaybeUninitVectorArray as Guard<Vec<Array>>>::as_mut_raw_ptr(&mut self.1),
242 )
243 }
244
245 fn set_init_success(&mut self, success: bool) {
246 <MaybeUninitVectorArray as Guard<Vec<Array>>>::set_init_success(&mut self.0, success);
247 <MaybeUninitVectorArray as Guard<Vec<Array>>>::set_init_success(&mut self.1, success);
248 }
249
250 fn try_into_guarded(self) -> Result<(Vec<Array>, Vec<Array>), Exception> {
251 Ok((self.0.try_into_guarded()?, self.1.try_into_guarded()?))
252 }
253}
254
255impl Guarded for (Vec<Array>, Vec<Array>) {
256 type Guard = (MaybeUninitVectorArray, MaybeUninitVectorArray);
257}
258
259pub(crate) struct MaybeUninitDevice {
260 pub(crate) ptr: mlx_sys::mlx_device,
261 pub(crate) init_success: bool,
262}
263
264impl Default for MaybeUninitDevice {
265 fn default() -> Self {
266 Self::new()
267 }
268}
269
270impl MaybeUninitDevice {
271 pub fn new() -> Self {
272 unsafe {
273 Self {
274 ptr: mlx_sys::mlx_device_new(),
275 init_success: false,
276 }
277 }
278 }
279}
280
281impl Drop for MaybeUninitDevice {
282 fn drop(&mut self) {
283 if !self.init_success {
284 unsafe {
285 mlx_sys::mlx_device_free(self.ptr);
286 }
287 }
288 }
289}
290
291impl Guard<crate::Device> for MaybeUninitDevice {
292 type MutRawPtr = *mut mlx_sys::mlx_device;
293
294 fn as_mut_raw_ptr(&mut self) -> Self::MutRawPtr {
295 &mut self.ptr
296 }
297
298 fn set_init_success(&mut self, success: bool) {
299 self.init_success = success;
300 }
301
302 fn try_into_guarded(self) -> Result<crate::Device, Exception> {
303 debug_assert!(self.init_success);
304 Ok(crate::Device { c_device: self.ptr })
305 }
306}
307
308impl Guarded for crate::DeviceType {
309 type Guard = mlx_sys::mlx_device_type;
310}
311
312impl Guard<crate::DeviceType> for mlx_sys::mlx_device_type {
313 type MutRawPtr = *mut mlx_sys::mlx_device_type;
314
315 fn as_mut_raw_ptr(&mut self) -> Self::MutRawPtr {
316 self
317 }
318
319 fn set_init_success(&mut self, _: bool) {}
320
321 fn try_into_guarded(self) -> Result<crate::DeviceType, Exception> {
322 match self {
323 mlx_sys::mlx_device_type__MLX_CPU => Ok(crate::DeviceType::Cpu),
324 mlx_sys::mlx_device_type__MLX_GPU => Ok(crate::DeviceType::Gpu),
325 _ => Err(Exception {
326 what: "Unknown device type".to_string(),
327 location: std::panic::Location::caller(),
328 }),
329 }
330 }
331}
332
333impl Guarded for crate::Device {
334 type Guard = MaybeUninitDevice;
335}
336
337pub(crate) struct MaybeUninitStream {
338 pub(crate) ptr: mlx_sys::mlx_stream,
339 pub(crate) init_success: bool,
340}
341
342impl Default for MaybeUninitStream {
343 fn default() -> Self {
344 Self::new()
345 }
346}
347
348impl MaybeUninitStream {
349 pub fn new() -> Self {
350 unsafe {
351 Self {
352 ptr: mlx_sys::mlx_stream_new(),
353 init_success: false,
354 }
355 }
356 }
357}
358
359impl Drop for MaybeUninitStream {
360 fn drop(&mut self) {
361 if !self.init_success {
362 unsafe {
363 mlx_sys::mlx_stream_free(self.ptr);
364 }
365 }
366 }
367}
368
369impl Guard<crate::Stream> for MaybeUninitStream {
370 type MutRawPtr = *mut mlx_sys::mlx_stream;
371
372 fn as_mut_raw_ptr(&mut self) -> Self::MutRawPtr {
373 &mut self.ptr
374 }
375
376 fn set_init_success(&mut self, success: bool) {
377 self.init_success = success;
378 }
379
380 fn try_into_guarded(self) -> Result<crate::Stream, Exception> {
381 debug_assert!(self.init_success);
382 Ok(crate::Stream { c_stream: self.ptr })
383 }
384}
385
386impl Guarded for crate::Stream {
387 type Guard = MaybeUninitStream;
388}
389
390pub(crate) struct MaybeUninitSafeTensors {
391 pub(crate) c_data: mlx_sys::mlx_map_string_to_array,
392 pub(crate) c_metadata: mlx_sys::mlx_map_string_to_string,
393 pub(crate) init_success: bool,
394}
395
396impl Default for MaybeUninitSafeTensors {
397 fn default() -> Self {
398 Self::new()
399 }
400}
401
402impl MaybeUninitSafeTensors {
403 pub fn new() -> Self {
404 unsafe {
405 Self {
406 c_metadata: mlx_sys::mlx_map_string_to_string_new(),
407 c_data: mlx_sys::mlx_map_string_to_array_new(),
408 init_success: false,
409 }
410 }
411 }
412}
413
414impl Drop for MaybeUninitSafeTensors {
415 fn drop(&mut self) {
416 if !self.init_success {
417 unsafe {
418 mlx_sys::mlx_map_string_to_string_free(self.c_metadata);
419 mlx_sys::mlx_map_string_to_array_free(self.c_data);
420 }
421 }
422 }
423}
424
425impl Guard<crate::utils::io::SafeTensors> for MaybeUninitSafeTensors {
426 type MutRawPtr = (
427 *mut mlx_sys::mlx_map_string_to_array,
428 *mut mlx_sys::mlx_map_string_to_string,
429 );
430
431 fn as_mut_raw_ptr(&mut self) -> Self::MutRawPtr {
432 (&mut self.c_data, &mut self.c_metadata)
433 }
434
435 fn set_init_success(&mut self, success: bool) {
436 self.init_success = success;
437 }
438
439 fn try_into_guarded(self) -> Result<crate::utils::io::SafeTensors, Exception> {
440 debug_assert!(self.init_success);
441 Ok(crate::utils::io::SafeTensors {
442 c_metadata: self.c_metadata,
443 c_data: self.c_data,
444 })
445 }
446}
447
448impl Guarded for crate::utils::io::SafeTensors {
449 type Guard = MaybeUninitSafeTensors;
450}
451
452pub(crate) struct MaybeUninitClosure {
453 pub(crate) ptr: mlx_sys::mlx_closure,
454 pub(crate) init_success: bool,
455}
456
457impl Default for MaybeUninitClosure {
458 fn default() -> Self {
459 Self::new()
460 }
461}
462
463impl MaybeUninitClosure {
464 pub fn new() -> Self {
465 unsafe {
466 Self {
467 ptr: mlx_sys::mlx_closure_new(),
468 init_success: false,
469 }
470 }
471 }
472}
473
474impl Drop for MaybeUninitClosure {
475 fn drop(&mut self) {
476 if !self.init_success {
477 unsafe {
478 mlx_sys::mlx_closure_free(self.ptr);
479 }
480 }
481 }
482}
483
484impl<'a> Guard<crate::utils::Closure<'a>> for MaybeUninitClosure {
485 type MutRawPtr = *mut mlx_sys::mlx_closure;
486
487 fn as_mut_raw_ptr(&mut self) -> Self::MutRawPtr {
488 &mut self.ptr
489 }
490
491 fn set_init_success(&mut self, success: bool) {
492 self.init_success = success;
493 }
494
495 fn try_into_guarded(self) -> Result<crate::utils::Closure<'a>, Exception> {
496 debug_assert!(self.init_success);
497 Ok(crate::utils::Closure {
498 c_closure: self.ptr,
499 lt_marker: std::marker::PhantomData,
500 })
501 }
502}
503
504impl Guarded for crate::utils::Closure<'_> {
505 type Guard = MaybeUninitClosure;
506}
507
508pub(crate) struct MaybeUninitClosureValueAndGrad {
509 pub(crate) ptr: mlx_sys::mlx_closure_value_and_grad,
510 pub(crate) init_success: bool,
511}
512
513impl Default for MaybeUninitClosureValueAndGrad {
514 fn default() -> Self {
515 Self::new()
516 }
517}
518
519impl MaybeUninitClosureValueAndGrad {
520 pub fn new() -> Self {
521 unsafe {
522 Self {
523 ptr: mlx_sys::mlx_closure_value_and_grad_new(),
524 init_success: false,
525 }
526 }
527 }
528}
529
530impl Drop for MaybeUninitClosureValueAndGrad {
531 fn drop(&mut self) {
532 if !self.init_success {
533 unsafe {
534 mlx_sys::mlx_closure_value_and_grad_free(self.ptr);
535 }
536 }
537 }
538}
539
540impl Guard<crate::transforms::ClosureValueAndGrad> for MaybeUninitClosureValueAndGrad {
541 type MutRawPtr = *mut mlx_sys::mlx_closure_value_and_grad;
542
543 fn as_mut_raw_ptr(&mut self) -> Self::MutRawPtr {
544 &mut self.ptr
545 }
546
547 fn set_init_success(&mut self, success: bool) {
548 self.init_success = success;
549 }
550
551 fn try_into_guarded(self) -> Result<crate::transforms::ClosureValueAndGrad, Exception> {
552 debug_assert!(self.init_success);
553 Ok(crate::transforms::ClosureValueAndGrad {
554 c_closure_value_and_grad: self.ptr,
555 })
556 }
557}
558
559impl Guarded for crate::transforms::ClosureValueAndGrad {
560 type Guard = MaybeUninitClosureValueAndGrad;
561}
562
563macro_rules! impl_guarded_for_primitive {
564 ($type:ty) => {
565 impl Guarded for $type {
566 type Guard = $type;
567 }
568
569 impl Guard<$type> for $type {
570 type MutRawPtr = *mut $type;
571
572 fn as_mut_raw_ptr(&mut self) -> Self::MutRawPtr {
573 self
574 }
575
576 fn set_init_success(&mut self, _: bool) { }
577
578 fn try_into_guarded(self) -> Result<$type, Exception> {
579 Ok(self)
580 }
581 }
582 };
583
584 ($($type:ty),*) => {
585 $(impl_guarded_for_primitive!($type);)*
586 };
587}
588
589impl_guarded_for_primitive!(bool, u8, u16, u32, u64, i8, i16, i32, i64, f32, f64, ());
590
591impl Guarded for f16 {
592 type Guard = float16_t;
593}
594
595impl Guard<f16> for float16_t {
596 type MutRawPtr = *mut float16_t;
597
598 fn as_mut_raw_ptr(&mut self) -> Self::MutRawPtr {
599 self
600 }
601
602 fn set_init_success(&mut self, _: bool) {}
603
604 fn try_into_guarded(self) -> Result<f16, Exception> {
605 Ok(f16::from_bits(self.0))
606 }
607}
608
609impl Guarded for bf16 {
610 type Guard = bfloat16_t;
611}
612
613impl Guard<bf16> for bfloat16_t {
614 type MutRawPtr = *mut bfloat16_t;
615
616 fn as_mut_raw_ptr(&mut self) -> Self::MutRawPtr {
617 self
618 }
619
620 fn set_init_success(&mut self, _: bool) {}
621
622 fn try_into_guarded(self) -> Result<bf16, Exception> {
623 Ok(bf16::from_bits(self))
624 }
625}
626
627impl Guarded for complex64 {
628 type Guard = __BindgenComplex<f32>;
629}
630
631impl Guard<complex64> for __BindgenComplex<f32> {
632 type MutRawPtr = *mut __BindgenComplex<f32>;
633
634 fn as_mut_raw_ptr(&mut self) -> Self::MutRawPtr {
635 self
636 }
637
638 fn set_init_success(&mut self, _: bool) {}
639
640 fn try_into_guarded(self) -> Result<complex64, Exception> {
641 Ok(complex64::new(self.re, self.im))
642 }
643}