1use guard::Guarded;
4use mlx_sys::mlx_vector_array;
5
6use crate::error::set_closure_error;
7use crate::module::ModuleParameters;
8use crate::{complex64, error::Exception, Array, FromNested};
9use std::collections::HashMap;
10use std::{marker::PhantomData, rc::Rc};
11
12pub(crate) const SUCCESS: i32 = 0;
14pub(crate) const FAILURE: i32 = 1;
15
16pub(crate) mod guard;
17pub(crate) mod io;
18
19pub(crate) fn resolve_index_signed_unchecked(index: i32, len: i32) -> i32 {
20 if index < 0 {
21 len.saturating_add(index)
22 } else {
23 index
24 }
25}
26
27pub(crate) fn resolve_index_unchecked(index: i32, len: usize) -> usize {
28 if index.is_negative() {
29 (len as i32 + index) as usize
30 } else {
31 index as usize
32 }
33}
34
35pub(crate) fn axes_or_default_to_all<'a>(axes: impl IntoOption<&'a [i32]>, ndim: i32) -> Vec<i32> {
37 match axes.into_option() {
38 Some(axes) => axes.to_vec(),
39 None => {
40 let axes: Vec<i32> = (0..ndim).collect();
41 axes
42 }
43 }
44}
45
46pub(crate) struct VectorArray {
47 c_vec: mlx_sys::mlx_vector_array,
48}
49
50impl VectorArray {
51 pub(crate) fn as_ptr(&self) -> mlx_sys::mlx_vector_array {
52 self.c_vec
53 }
54
55 pub(crate) unsafe fn from_ptr(c_vec: mlx_sys::mlx_vector_array) -> Self {
56 Self { c_vec }
57 }
58
59 pub(crate) fn try_from_iter(
60 iter: impl Iterator<Item = impl AsRef<Array>>,
61 ) -> Result<Self, Exception> {
62 VectorArray::try_from_op(|res| unsafe {
63 let mut status = SUCCESS;
64 for arr in iter {
65 status = mlx_sys::mlx_vector_array_append_value(*res, arr.as_ref().as_ptr());
66 if status != SUCCESS {
67 return status;
68 }
69 }
70 status
71 })
72 }
73
74 pub(crate) fn try_into_values<T>(self) -> Result<T, Exception>
75 where
76 T: FromIterator<Array>,
77 {
78 unsafe {
79 let size = mlx_sys::mlx_vector_array_size(self.c_vec);
80 (0..size)
81 .map(|i| {
82 Array::try_from_op(|res| mlx_sys::mlx_vector_array_get(res, self.c_vec, i))
83 })
84 .collect::<Result<T, Exception>>()
85 }
86 }
87}
88
89impl Drop for VectorArray {
90 fn drop(&mut self) {
91 let status = unsafe { mlx_sys::mlx_vector_array_free(self.c_vec) };
92 debug_assert_eq!(status, SUCCESS);
93 }
94}
95
96pub trait IntoOption<T> {
99 fn into_option(self) -> Option<T>;
101}
102
103impl<T> IntoOption<T> for Option<T> {
104 fn into_option(self) -> Option<T> {
105 self
106 }
107}
108
109impl<T> IntoOption<T> for T {
110 fn into_option(self) -> Option<T> {
111 Some(self)
112 }
113}
114
115impl<'a, T, const N: usize> IntoOption<&'a [T]> for &'a [T; N] {
116 fn into_option(self) -> Option<&'a [T]> {
117 Some(self)
118 }
119}
120
121impl<'a, T> IntoOption<&'a [T]> for &'a Vec<T> {
122 fn into_option(self) -> Option<&'a [T]> {
123 Some(self)
124 }
125}
126
127pub trait ScalarOrArray<'a> {
129 type Array: AsRef<Array> + 'a;
131
132 fn into_owned_or_ref_array(self) -> Self::Array;
134}
135
136impl ScalarOrArray<'_> for Array {
137 type Array = Array;
138
139 fn into_owned_or_ref_array(self) -> Array {
140 self
141 }
142}
143
144impl<'a> ScalarOrArray<'a> for &'a Array {
145 type Array = &'a Array;
146
147 fn into_owned_or_ref_array(self) -> &'a Array {
149 self
150 }
151}
152
153impl ScalarOrArray<'static> for bool {
154 type Array = Array;
155
156 fn into_owned_or_ref_array(self) -> Array {
157 Array::from_bool(self)
158 }
159}
160
161impl ScalarOrArray<'static> for i32 {
162 type Array = Array;
163
164 fn into_owned_or_ref_array(self) -> Array {
165 Array::from_int(self)
166 }
167}
168
169impl ScalarOrArray<'static> for f32 {
170 type Array = Array;
171
172 fn into_owned_or_ref_array(self) -> Array {
173 Array::from_f32(self)
174 }
175}
176
177impl ScalarOrArray<'static> for complex64 {
187 type Array = Array;
188
189 fn into_owned_or_ref_array(self) -> Array {
190 Array::from_complex(self)
191 }
192}
193
194impl<T> ScalarOrArray<'static> for T
195where
196 Array: FromNested<T>,
197{
198 type Array = Array;
199
200 fn into_owned_or_ref_array(self) -> Array {
201 Array::from_nested(self)
202 }
203}
204
205#[derive(Debug)]
206pub(crate) struct Closure<'a> {
207 c_closure: mlx_sys::mlx_closure,
208 lt_marker: PhantomData<&'a ()>,
209}
210
211impl<'a> Closure<'a> {
212 pub(crate) fn as_ptr(&self) -> mlx_sys::mlx_closure {
213 self.c_closure
214 }
215
216 pub(crate) fn new<F>(closure: F) -> Self
217 where
218 F: FnMut(&[Array]) -> Vec<Array> + 'a,
219 {
220 let c_closure = new_mlx_closure(closure);
221 Self {
222 c_closure,
223 lt_marker: PhantomData,
224 }
225 }
226
227 pub(crate) fn new_fallible<F>(closure: F) -> Self
228 where
229 F: FnMut(&[Array]) -> Result<Vec<Array>, Exception> + 'a,
230 {
231 let c_closure = new_mlx_fallible_closure(closure);
232 Self {
233 c_closure,
234 lt_marker: PhantomData,
235 }
236 }
237}
238
239impl Drop for Closure<'_> {
240 fn drop(&mut self) {
241 let status = unsafe { mlx_sys::mlx_closure_free(self.c_closure) };
242 debug_assert_eq!(status, SUCCESS);
243 }
244}
245
246fn new_mlx_closure<'a, F>(closure: F) -> mlx_sys::mlx_closure
248where
249 F: FnMut(&[Array]) -> Vec<Array> + 'a,
250{
251 let boxed = Box::new(closure);
253
254 let raw = Box::into_raw(boxed);
256 let payload = raw as *mut std::ffi::c_void;
257
258 unsafe {
259 mlx_sys::mlx_closure_new_func_payload(
260 Some(trampoline::<F>),
261 payload,
262 Some(closure_dtor::<F>),
263 )
264 }
265}
266
267fn new_mlx_fallible_closure<'a, F>(closure: F) -> mlx_sys::mlx_closure
268where
269 F: FnMut(&[Array]) -> Result<Vec<Array>, Exception> + 'a,
270{
271 let boxed = Box::new(closure);
272 let raw = Box::into_raw(boxed);
273 let payload = raw as *mut std::ffi::c_void;
274
275 unsafe {
276 mlx_sys::mlx_closure_new_func_payload(
277 Some(trampoline_fallible::<F>),
278 payload,
279 Some(closure_dtor::<F>),
280 )
281 }
282}
283
284fn new_mlx_vector_array(arrays: Vec<Array>) -> mlx_sys::mlx_vector_array {
286 unsafe {
287 let result = mlx_sys::mlx_vector_array_new();
288 let ctx_ptrs: Vec<mlx_sys::mlx_array> = arrays.iter().map(|array| array.as_ptr()).collect();
289 mlx_sys::mlx_vector_array_append_data(result, ctx_ptrs.as_ptr(), arrays.len());
290 result
291 }
292}
293
294fn mlx_vector_array_values(
295 vector_array: mlx_sys::mlx_vector_array,
296) -> Result<Vec<Array>, Exception> {
297 unsafe {
298 let size = mlx_sys::mlx_vector_array_size(vector_array);
299 (0..size)
300 .map(|index| {
301 Array::try_from_op(|res| mlx_sys::mlx_vector_array_get(res, vector_array, index))
302 })
303 .collect()
304 }
305}
306
307extern "C" fn trampoline<'a, F>(
308 ret: *mut mlx_vector_array,
309 vector_array: mlx_vector_array,
310 payload: *mut std::ffi::c_void,
311) -> i32
312where
313 F: FnMut(&[Array]) -> Vec<Array> + 'a,
314{
315 unsafe {
316 let raw_closure: *mut F = payload as *mut _;
317 let mut closure = Box::from_raw(raw_closure);
319 let arrays = match mlx_vector_array_values(vector_array) {
320 Ok(arrays) => arrays,
321 Err(_) => {
322 let _ = Box::into_raw(closure); return FAILURE;
324 }
325 };
326 let result = closure(&arrays);
327 let _ = Box::into_raw(closure); *ret = new_mlx_vector_array(result);
332
333 SUCCESS
334 }
335}
336
337extern "C" fn trampoline_fallible<'a, F>(
338 ret: *mut mlx_vector_array,
339 vector_array: mlx_vector_array,
340 payload: *mut std::ffi::c_void,
341) -> i32
342where
343 F: FnMut(&[Array]) -> Result<Vec<Array>, Exception> + 'a,
344{
345 unsafe {
346 let raw_closure: *mut F = payload as *mut _;
347 let mut closure = Box::from_raw(raw_closure);
348 let arrays = match mlx_vector_array_values(vector_array) {
349 Ok(arrays) => arrays,
350 Err(e) => {
351 let _ = Box::into_raw(closure); set_closure_error(e);
353 return FAILURE;
354 }
355 };
356 let result = closure(&arrays);
357 let _ = Box::into_raw(closure); match result {
360 Ok(result) => {
361 *ret = new_mlx_vector_array(result);
362 SUCCESS
363 }
364 Err(err) => {
365 set_closure_error(err);
366 FAILURE
367 }
368 }
369 }
370}
371
372extern "C" fn closure_dtor<F>(payload: *mut std::ffi::c_void) {
375 if payload.is_null() {
376 return;
377 }
378 unsafe {
379 drop(Box::from_raw(payload as *mut F));
380 }
381}
382
383pub(crate) fn get_mut_or_insert_with<'a, T>(
384 map: &'a mut HashMap<Rc<str>, T>,
385 key: &Rc<str>,
386 f: impl FnOnce() -> T,
387) -> &'a mut T {
388 if !map.contains_key(key) {
389 map.insert(key.clone(), f());
390 }
391
392 map.get_mut(key).unwrap()
393}
394
395pub trait Updatable {
400 fn updatable_states_len(&self) -> usize;
406
407 fn updatable_states(&self) -> impl IntoIterator<Item = &Array>;
412
413 fn updatable_states_mut(&mut self) -> impl IntoIterator<Item = &mut Array>;
418}
419
420impl<T> Updatable for T
421where
422 T: ModuleParameters,
423{
424 fn updatable_states_len(&self) -> usize {
425 self.num_parameters()
426 }
427
428 fn updatable_states(&self) -> impl IntoIterator<Item = &Array> {
429 use itertools::Itertools;
430
431 self.parameters()
433 .flatten()
434 .into_iter()
435 .sorted_by(|a, b| a.0.cmp(&b.0))
436 .map(|(_, v)| v)
437 }
438
439 fn updatable_states_mut(&mut self) -> impl IntoIterator<Item = &mut Array> {
440 use itertools::Itertools;
441
442 self.parameters_mut()
443 .flatten()
444 .into_iter()
445 .sorted_by(|a, b| a.0.cmp(&b.0))
446 .map(|(_, v)| v)
447 }
448}
449
450impl<T1, T2> Updatable for (T1, T2)
451where
452 T1: Updatable,
453 T2: Updatable,
454{
455 fn updatable_states_len(&self) -> usize {
456 let (a, b) = self;
457 a.updatable_states_len() + b.updatable_states_len()
458 }
459
460 fn updatable_states(&self) -> impl IntoIterator<Item = &Array> {
461 let (a, b) = self;
462 let params = a.updatable_states();
463 params.into_iter().chain(b.updatable_states())
464 }
465
466 fn updatable_states_mut(&mut self) -> impl IntoIterator<Item = &mut Array> {
467 let (a, b) = self;
468 let params = a.updatable_states_mut();
469 params.into_iter().chain(b.updatable_states_mut())
470 }
471}
472
473impl Updatable for Vec<Array> {
474 fn updatable_states_len(&self) -> usize {
475 self.len()
476 }
477
478 fn updatable_states(&self) -> impl IntoIterator<Item = &Array> {
479 self.iter()
480 }
481
482 fn updatable_states_mut(&mut self) -> impl IntoIterator<Item = &mut Array> {
483 self.iter_mut()
484 }
485}
486
487#[derive(Debug, Clone, Copy, PartialEq, Eq)]
489pub enum SingleOrPair<T = i32> {
490 Single(T),
492
493 Pair(T, T),
495}
496
497impl<T: Clone> SingleOrPair<T> {
498 pub fn first(&self) -> T {
500 match self {
501 SingleOrPair::Single(v) => v.clone(),
502 SingleOrPair::Pair(v1, _) => v1.clone(),
503 }
504 }
505
506 pub fn second(&self) -> T {
508 match self {
509 SingleOrPair::Single(v) => v.clone(),
510 SingleOrPair::Pair(_, v2) => v2.clone(),
511 }
512 }
513}
514
515impl<T> From<T> for SingleOrPair<T> {
516 fn from(value: T) -> Self {
517 SingleOrPair::Single(value)
518 }
519}
520
521impl<T> From<(T, T)> for SingleOrPair<T> {
522 fn from(value: (T, T)) -> Self {
523 SingleOrPair::Pair(value.0, value.1)
524 }
525}
526
527impl<T: Clone> From<SingleOrPair<T>> for (T, T) {
528 fn from(value: SingleOrPair<T>) -> Self {
529 match value {
530 SingleOrPair::Single(v) => (v.clone(), v),
531 SingleOrPair::Pair(v1, v2) => (v1, v2),
532 }
533 }
534}
535
536#[derive(Debug, Clone, Copy, PartialEq, Eq)]
538pub enum SingleOrTriple<T = i32> {
539 Single(T),
541
542 Triple(T, T, T),
544}
545
546impl<T: Clone> SingleOrTriple<T> {
547 pub fn first(&self) -> T {
549 match self {
550 SingleOrTriple::Single(v) => v.clone(),
551 SingleOrTriple::Triple(v1, _, _) => v1.clone(),
552 }
553 }
554
555 pub fn second(&self) -> T {
557 match self {
558 SingleOrTriple::Single(v) => v.clone(),
559 SingleOrTriple::Triple(_, v2, _) => v2.clone(),
560 }
561 }
562
563 pub fn third(&self) -> T {
565 match self {
566 SingleOrTriple::Single(v) => v.clone(),
567 SingleOrTriple::Triple(_, _, v3) => v3.clone(),
568 }
569 }
570}
571
572impl<T> From<T> for SingleOrTriple<T> {
573 fn from(value: T) -> Self {
574 SingleOrTriple::Single(value)
575 }
576}
577
578impl<T> From<(T, T, T)> for SingleOrTriple<T> {
579 fn from(value: (T, T, T)) -> Self {
580 SingleOrTriple::Triple(value.0, value.1, value.2)
581 }
582}
583
584impl<T: Clone> From<SingleOrTriple<T>> for (T, T, T) {
585 fn from(value: SingleOrTriple<T>) -> Self {
586 match value {
587 SingleOrTriple::Single(v) => (v.clone(), v.clone(), v),
588 SingleOrTriple::Triple(v1, v2, v3) => (v1, v2, v3),
589 }
590 }
591}
592
593#[derive(Debug, Clone, PartialEq, Eq)]
595pub enum SingleOrVec<T> {
596 Single(T),
598
599 Vec(Vec<T>),
601}
602
603impl<T> From<T> for SingleOrVec<T> {
604 fn from(value: T) -> Self {
605 SingleOrVec::Single(value)
606 }
607}
608
609impl<T> From<Vec<T>> for SingleOrVec<T> {
610 fn from(value: Vec<T>) -> Self {
611 SingleOrVec::Vec(value)
612 }
613}