1use guard::Guarded;
4use mlx_sys::mlx_vector_array;
5
6use crate::error::set_closure_error;
7use crate::module::ModuleParameters;
8use crate::{complex64, error::Exception, Array, FromNested};
9use std::collections::HashMap;
10use std::{marker::PhantomData, rc::Rc};
11
12pub(crate) const SUCCESS: i32 = 0;
14pub(crate) const FAILURE: i32 = 1;
15
16pub(crate) mod guard;
17pub(crate) mod io;
18
19pub(crate) fn resolve_index_signed_unchecked(index: i32, len: i32) -> i32 {
20 if index < 0 {
21 len.saturating_add(index)
22 } else {
23 index
24 }
25}
26
27pub(crate) fn resolve_index_unchecked(index: i32, len: usize) -> usize {
28 if index.is_negative() {
29 (len as i32 + index) as usize
30 } else {
31 index as usize
32 }
33}
34
35pub(crate) fn axes_or_default_to_all<'a>(axes: impl IntoOption<&'a [i32]>, ndim: i32) -> Vec<i32> {
37 match axes.into_option() {
38 Some(axes) => axes.to_vec(),
39 None => {
40 let axes: Vec<i32> = (0..ndim).collect();
41 axes
42 }
43 }
44}
45
46pub(crate) struct VectorArray {
47 c_vec: mlx_sys::mlx_vector_array,
48}
49
50impl VectorArray {
51 pub(crate) fn as_ptr(&self) -> mlx_sys::mlx_vector_array {
52 self.c_vec
53 }
54
55 pub(crate) unsafe fn from_ptr(c_vec: mlx_sys::mlx_vector_array) -> Self {
56 Self { c_vec }
57 }
58
59 pub(crate) fn try_from_iter(
60 iter: impl Iterator<Item = impl AsRef<Array>>,
61 ) -> Result<Self, Exception> {
62 VectorArray::try_from_op(|res| unsafe {
63 let mut status = SUCCESS;
64 for arr in iter {
65 status = mlx_sys::mlx_vector_array_append_value(*res, arr.as_ref().as_ptr());
66 if status != SUCCESS {
67 return status;
68 }
69 }
70 status
71 })
72 }
73
74 pub(crate) fn try_into_values<T>(self) -> Result<T, Exception>
75 where
76 T: FromIterator<Array>,
77 {
78 unsafe {
79 let size = mlx_sys::mlx_vector_array_size(self.c_vec);
80 (0..size)
81 .map(|i| {
82 Array::try_from_op(|res| mlx_sys::mlx_vector_array_get(res, self.c_vec, i))
83 })
84 .collect::<Result<T, Exception>>()
85 }
86 }
87}
88
89impl Drop for VectorArray {
90 fn drop(&mut self) {
91 let status = unsafe { mlx_sys::mlx_vector_array_free(self.c_vec) };
92 debug_assert_eq!(status, SUCCESS);
93 }
94}
95
96pub trait IntoOption<T> {
99 fn into_option(self) -> Option<T>;
101}
102
103impl<T> IntoOption<T> for Option<T> {
104 fn into_option(self) -> Option<T> {
105 self
106 }
107}
108
109impl<T> IntoOption<T> for T {
110 fn into_option(self) -> Option<T> {
111 Some(self)
112 }
113}
114
115impl<'a, T, const N: usize> IntoOption<&'a [T]> for &'a [T; N] {
116 fn into_option(self) -> Option<&'a [T]> {
117 Some(self)
118 }
119}
120
121impl<'a, T> IntoOption<&'a [T]> for &'a Vec<T> {
122 fn into_option(self) -> Option<&'a [T]> {
123 Some(self)
124 }
125}
126
127pub trait ScalarOrArray<'a> {
129 type Array: AsRef<Array> + 'a;
131
132 fn into_owned_or_ref_array(self) -> Self::Array;
134}
135
136impl ScalarOrArray<'_> for Array {
137 type Array = Array;
138
139 fn into_owned_or_ref_array(self) -> Array {
140 self
141 }
142}
143
144impl<'a> ScalarOrArray<'a> for &'a Array {
145 type Array = &'a Array;
146
147 fn into_owned_or_ref_array(self) -> &'a Array {
149 self
150 }
151}
152
153impl ScalarOrArray<'static> for bool {
154 type Array = Array;
155
156 fn into_owned_or_ref_array(self) -> Array {
157 Array::from_bool(self)
158 }
159}
160
161impl ScalarOrArray<'static> for i32 {
162 type Array = Array;
163
164 fn into_owned_or_ref_array(self) -> Array {
165 Array::from_int(self)
166 }
167}
168
169impl ScalarOrArray<'static> for f32 {
170 type Array = Array;
171
172 fn into_owned_or_ref_array(self) -> Array {
173 Array::from_f32(self)
174 }
175}
176
177impl ScalarOrArray<'static> for complex64 {
187 type Array = Array;
188
189 fn into_owned_or_ref_array(self) -> Array {
190 Array::from_complex(self)
191 }
192}
193
194impl<T> ScalarOrArray<'static> for T
195where
196 Array: FromNested<T>,
197{
198 type Array = Array;
199
200 fn into_owned_or_ref_array(self) -> Array {
201 Array::from_nested(self)
202 }
203}
204
205#[derive(Debug)]
206pub(crate) struct Closure<'a> {
207 c_closure: mlx_sys::mlx_closure,
208 lt_marker: PhantomData<&'a ()>,
209}
210
211impl<'a> Closure<'a> {
212 pub(crate) fn as_ptr(&self) -> mlx_sys::mlx_closure {
213 self.c_closure
214 }
215
216 pub(crate) fn new<F>(closure: F) -> Self
217 where
218 F: FnMut(&[Array]) -> Vec<Array> + 'a,
219 {
220 let c_closure = new_mlx_closure(closure);
221 Self {
222 c_closure,
223 lt_marker: PhantomData,
224 }
225 }
226
227 pub(crate) fn new_fallible<F>(closure: F) -> Self
228 where
229 F: FnMut(&[Array]) -> Result<Vec<Array>, Exception> + 'a,
230 {
231 let c_closure = new_mlx_fallible_closure(closure);
232 Self {
233 c_closure,
234 lt_marker: PhantomData,
235 }
236 }
237}
238
239impl Drop for Closure<'_> {
240 fn drop(&mut self) {
241 let status = unsafe { mlx_sys::mlx_closure_free(self.c_closure) };
242 debug_assert_eq!(status, SUCCESS);
243 }
244}
245
246fn new_mlx_closure<'a, F>(closure: F) -> mlx_sys::mlx_closure
248where
249 F: FnMut(&[Array]) -> Vec<Array> + 'a,
250{
251 let boxed = Box::new(closure);
253
254 let raw = Box::into_raw(boxed);
256 let payload = raw as *mut std::ffi::c_void;
257
258 unsafe {
259 mlx_sys::mlx_closure_new_func_payload(Some(trampoline::<F>), payload, Some(noop_dtor))
260 }
261}
262
263fn new_mlx_fallible_closure<'a, F>(closure: F) -> mlx_sys::mlx_closure
264where
265 F: FnMut(&[Array]) -> Result<Vec<Array>, Exception> + 'a,
266{
267 let boxed = Box::new(closure);
268 let raw = Box::into_raw(boxed);
269 let payload = raw as *mut std::ffi::c_void;
270
271 unsafe {
272 mlx_sys::mlx_closure_new_func_payload(
273 Some(trampoline_fallible::<F>),
274 payload,
275 Some(noop_dtor),
276 )
277 }
278}
279
280fn new_mlx_vector_array(arrays: Vec<Array>) -> mlx_sys::mlx_vector_array {
282 unsafe {
283 let result = mlx_sys::mlx_vector_array_new();
284 let ctx_ptrs: Vec<mlx_sys::mlx_array> = arrays.iter().map(|array| array.as_ptr()).collect();
285 mlx_sys::mlx_vector_array_append_data(result, ctx_ptrs.as_ptr(), arrays.len());
286 result
287 }
288}
289
290fn mlx_vector_array_values(
291 vector_array: mlx_sys::mlx_vector_array,
292) -> Result<Vec<Array>, Exception> {
293 unsafe {
294 let size = mlx_sys::mlx_vector_array_size(vector_array);
295 (0..size)
296 .map(|index| {
297 Array::try_from_op(|res| mlx_sys::mlx_vector_array_get(res, vector_array, index))
298 })
299 .collect()
300 }
301}
302
303extern "C" fn trampoline<'a, F>(
304 ret: *mut mlx_vector_array,
305 vector_array: mlx_vector_array,
306 payload: *mut std::ffi::c_void,
307) -> i32
308where
309 F: FnMut(&[Array]) -> Vec<Array> + 'a,
310{
311 unsafe {
312 let raw_closure: *mut F = payload as *mut _;
313 let mut closure = Box::from_raw(raw_closure);
315 let arrays = match mlx_vector_array_values(vector_array) {
316 Ok(arrays) => arrays,
317 Err(_) => {
318 return FAILURE;
319 }
320 };
321 let result = closure(&arrays);
322 *ret = new_mlx_vector_array(result);
325
326 SUCCESS
327 }
328}
329
330extern "C" fn trampoline_fallible<'a, F>(
331 ret: *mut mlx_vector_array,
332 vector_array: mlx_vector_array,
333 payload: *mut std::ffi::c_void,
334) -> i32
335where
336 F: FnMut(&[Array]) -> Result<Vec<Array>, Exception> + 'a,
337{
338 unsafe {
339 let raw_closure: *mut F = payload as *mut _;
340 let mut closure = Box::from_raw(raw_closure);
341 let arrays = match mlx_vector_array_values(vector_array) {
342 Ok(arrays) => arrays,
343 Err(e) => {
344 set_closure_error(e);
345 return FAILURE;
346 }
347 };
348 let result = closure(&arrays);
349 match result {
350 Ok(result) => {
351 *ret = new_mlx_vector_array(result);
352 SUCCESS
353 }
354 Err(err) => {
355 set_closure_error(err);
356 FAILURE
357 }
358 }
359 }
360}
361
362extern "C" fn noop_dtor(_data: *mut std::ffi::c_void) {}
363
364pub(crate) fn get_mut_or_insert_with<'a, T>(
365 map: &'a mut HashMap<Rc<str>, T>,
366 key: &Rc<str>,
367 f: impl FnOnce() -> T,
368) -> &'a mut T {
369 if !map.contains_key(key) {
370 map.insert(key.clone(), f());
371 }
372
373 map.get_mut(key).unwrap()
374}
375
376pub trait Updatable {
381 fn updatable_states_len(&self) -> usize;
387
388 fn updatable_states(&self) -> impl IntoIterator<Item = &Array>;
393
394 fn updatable_states_mut(&mut self) -> impl IntoIterator<Item = &mut Array>;
399}
400
401impl<T> Updatable for T
402where
403 T: ModuleParameters,
404{
405 fn updatable_states_len(&self) -> usize {
406 self.num_parameters()
407 }
408
409 fn updatable_states(&self) -> impl IntoIterator<Item = &Array> {
410 use itertools::Itertools;
411
412 self.parameters()
414 .flatten()
415 .into_iter()
416 .sorted_by(|a, b| a.0.cmp(&b.0))
417 .map(|(_, v)| v)
418 }
419
420 fn updatable_states_mut(&mut self) -> impl IntoIterator<Item = &mut Array> {
421 use itertools::Itertools;
422
423 self.parameters_mut()
424 .flatten()
425 .into_iter()
426 .sorted_by(|a, b| a.0.cmp(&b.0))
427 .map(|(_, v)| v)
428 }
429}
430
431impl<T1, T2> Updatable for (T1, T2)
432where
433 T1: Updatable,
434 T2: Updatable,
435{
436 fn updatable_states_len(&self) -> usize {
437 let (a, b) = self;
438 a.updatable_states_len() + b.updatable_states_len()
439 }
440
441 fn updatable_states(&self) -> impl IntoIterator<Item = &Array> {
442 let (a, b) = self;
443 let params = a.updatable_states();
444 params.into_iter().chain(b.updatable_states())
445 }
446
447 fn updatable_states_mut(&mut self) -> impl IntoIterator<Item = &mut Array> {
448 let (a, b) = self;
449 let params = a.updatable_states_mut();
450 params.into_iter().chain(b.updatable_states_mut())
451 }
452}
453
454impl Updatable for Vec<Array> {
455 fn updatable_states_len(&self) -> usize {
456 self.len()
457 }
458
459 fn updatable_states(&self) -> impl IntoIterator<Item = &Array> {
460 self.iter()
461 }
462
463 fn updatable_states_mut(&mut self) -> impl IntoIterator<Item = &mut Array> {
464 self.iter_mut()
465 }
466}
467
468#[derive(Debug, Clone, Copy, PartialEq, Eq)]
470pub enum SingleOrPair<T = i32> {
471 Single(T),
473
474 Pair(T, T),
476}
477
478impl<T: Clone> SingleOrPair<T> {
479 pub fn first(&self) -> T {
481 match self {
482 SingleOrPair::Single(v) => v.clone(),
483 SingleOrPair::Pair(v1, _) => v1.clone(),
484 }
485 }
486
487 pub fn second(&self) -> T {
489 match self {
490 SingleOrPair::Single(v) => v.clone(),
491 SingleOrPair::Pair(_, v2) => v2.clone(),
492 }
493 }
494}
495
496impl<T> From<T> for SingleOrPair<T> {
497 fn from(value: T) -> Self {
498 SingleOrPair::Single(value)
499 }
500}
501
502impl<T> From<(T, T)> for SingleOrPair<T> {
503 fn from(value: (T, T)) -> Self {
504 SingleOrPair::Pair(value.0, value.1)
505 }
506}
507
508impl<T: Clone> From<SingleOrPair<T>> for (T, T) {
509 fn from(value: SingleOrPair<T>) -> Self {
510 match value {
511 SingleOrPair::Single(v) => (v.clone(), v),
512 SingleOrPair::Pair(v1, v2) => (v1, v2),
513 }
514 }
515}
516
517#[derive(Debug, Clone, Copy, PartialEq, Eq)]
519pub enum SingleOrTriple<T = i32> {
520 Single(T),
522
523 Triple(T, T, T),
525}
526
527impl<T: Clone> SingleOrTriple<T> {
528 pub fn first(&self) -> T {
530 match self {
531 SingleOrTriple::Single(v) => v.clone(),
532 SingleOrTriple::Triple(v1, _, _) => v1.clone(),
533 }
534 }
535
536 pub fn second(&self) -> T {
538 match self {
539 SingleOrTriple::Single(v) => v.clone(),
540 SingleOrTriple::Triple(_, v2, _) => v2.clone(),
541 }
542 }
543
544 pub fn third(&self) -> T {
546 match self {
547 SingleOrTriple::Single(v) => v.clone(),
548 SingleOrTriple::Triple(_, _, v3) => v3.clone(),
549 }
550 }
551}
552
553impl<T> From<T> for SingleOrTriple<T> {
554 fn from(value: T) -> Self {
555 SingleOrTriple::Single(value)
556 }
557}
558
559impl<T> From<(T, T, T)> for SingleOrTriple<T> {
560 fn from(value: (T, T, T)) -> Self {
561 SingleOrTriple::Triple(value.0, value.1, value.2)
562 }
563}
564
565impl<T: Clone> From<SingleOrTriple<T>> for (T, T, T) {
566 fn from(value: SingleOrTriple<T>) -> Self {
567 match value {
568 SingleOrTriple::Single(v) => (v.clone(), v.clone(), v),
569 SingleOrTriple::Triple(v1, v2, v3) => (v1, v2, v3),
570 }
571 }
572}
573
574#[derive(Debug, Clone, PartialEq, Eq)]
576pub enum SingleOrVec<T> {
577 Single(T),
579
580 Vec(Vec<T>),
582}
583
584impl<T> From<T> for SingleOrVec<T> {
585 fn from(value: T) -> Self {
586 SingleOrVec::Single(value)
587 }
588}
589
590impl<T> From<Vec<T>> for SingleOrVec<T> {
591 fn from(value: Vec<T>) -> Self {
592 SingleOrVec::Vec(value)
593 }
594}