1use guard::Guarded;
4use mlx_sys::mlx_vector_array;
5
6use crate::error::set_closure_error;
7use crate::module::ModuleParameters;
8use crate::{complex64, error::Exception, Array, FromNested};
9use std::collections::HashMap;
10use std::{marker::PhantomData, rc::Rc};
11
12pub(crate) const SUCCESS: i32 = 0;
14pub(crate) const FAILURE: i32 = 1;
15
16pub(crate) mod guard;
17pub(crate) mod io;
18
19pub(crate) fn resolve_index_signed_unchecked(index: i32, len: i32) -> i32 {
20 if index < 0 {
21 len.saturating_add(index)
22 } else {
23 index
24 }
25}
26
27pub(crate) fn resolve_index_unchecked(index: i32, len: usize) -> usize {
28 if index.is_negative() {
29 (len as i32 + index) as usize
30 } else {
31 index as usize
32 }
33}
34
35pub(crate) fn axes_or_default_to_all<'a>(axes: impl IntoOption<&'a [i32]>, ndim: i32) -> Vec<i32> {
37 match axes.into_option() {
38 Some(axes) => axes.to_vec(),
39 None => {
40 let axes: Vec<i32> = (0..ndim).collect();
41 axes
42 }
43 }
44}
45
46pub(crate) struct VectorArray {
47 c_vec: mlx_sys::mlx_vector_array,
48}
49
50impl VectorArray {
51 pub(crate) fn as_ptr(&self) -> mlx_sys::mlx_vector_array {
52 self.c_vec
53 }
54
55 pub(crate) fn try_from_iter(
56 iter: impl Iterator<Item = impl AsRef<Array>>,
57 ) -> Result<Self, Exception> {
58 VectorArray::try_from_op(|res| unsafe {
59 let mut status = SUCCESS;
60 for arr in iter {
61 status = mlx_sys::mlx_vector_array_append_value(*res, arr.as_ref().as_ptr());
62 if status != SUCCESS {
63 return status;
64 }
65 }
66 status
67 })
68 }
69
70 pub(crate) fn try_into_values<T>(self) -> Result<T, Exception>
71 where
72 T: FromIterator<Array>,
73 {
74 unsafe {
75 let size = mlx_sys::mlx_vector_array_size(self.c_vec);
76 (0..size)
77 .map(|i| {
78 Array::try_from_op(|res| mlx_sys::mlx_vector_array_get(res, self.c_vec, i))
79 })
80 .collect::<Result<T, Exception>>()
81 }
82 }
83}
84
85impl Drop for VectorArray {
86 fn drop(&mut self) {
87 let status = unsafe { mlx_sys::mlx_vector_array_free(self.c_vec) };
88 debug_assert_eq!(status, SUCCESS);
89 }
90}
91
92pub trait IntoOption<T> {
95 fn into_option(self) -> Option<T>;
97}
98
99impl<T> IntoOption<T> for Option<T> {
100 fn into_option(self) -> Option<T> {
101 self
102 }
103}
104
105impl<T> IntoOption<T> for T {
106 fn into_option(self) -> Option<T> {
107 Some(self)
108 }
109}
110
111impl<'a, T, const N: usize> IntoOption<&'a [T]> for &'a [T; N] {
112 fn into_option(self) -> Option<&'a [T]> {
113 Some(self)
114 }
115}
116
117impl<'a, T> IntoOption<&'a [T]> for &'a Vec<T> {
118 fn into_option(self) -> Option<&'a [T]> {
119 Some(self)
120 }
121}
122
123pub trait ScalarOrArray<'a> {
125 type Array: AsRef<Array> + 'a;
127
128 fn into_owned_or_ref_array(self) -> Self::Array;
130}
131
132impl ScalarOrArray<'_> for Array {
133 type Array = Array;
134
135 fn into_owned_or_ref_array(self) -> Array {
136 self
137 }
138}
139
140impl<'a> ScalarOrArray<'a> for &'a Array {
141 type Array = &'a Array;
142
143 fn into_owned_or_ref_array(self) -> &'a Array {
145 self
146 }
147}
148
149impl ScalarOrArray<'static> for bool {
150 type Array = Array;
151
152 fn into_owned_or_ref_array(self) -> Array {
153 Array::from_bool(self)
154 }
155}
156
157impl ScalarOrArray<'static> for i32 {
158 type Array = Array;
159
160 fn into_owned_or_ref_array(self) -> Array {
161 Array::from_int(self)
162 }
163}
164
165impl ScalarOrArray<'static> for f32 {
166 type Array = Array;
167
168 fn into_owned_or_ref_array(self) -> Array {
169 Array::from_f32(self)
170 }
171}
172
173impl ScalarOrArray<'static> for complex64 {
183 type Array = Array;
184
185 fn into_owned_or_ref_array(self) -> Array {
186 Array::from_complex(self)
187 }
188}
189
190impl<T> ScalarOrArray<'static> for T
191where
192 Array: FromNested<T>,
193{
194 type Array = Array;
195
196 fn into_owned_or_ref_array(self) -> Array {
197 Array::from_nested(self)
198 }
199}
200
201#[derive(Debug)]
202pub(crate) struct Closure<'a> {
203 c_closure: mlx_sys::mlx_closure,
204 lt_marker: PhantomData<&'a ()>,
205}
206
207impl<'a> Closure<'a> {
208 pub(crate) fn as_ptr(&self) -> mlx_sys::mlx_closure {
209 self.c_closure
210 }
211
212 pub(crate) fn new<F>(closure: F) -> Self
213 where
214 F: FnMut(&[Array]) -> Vec<Array> + 'a,
215 {
216 let c_closure = new_mlx_closure(closure);
217 Self {
218 c_closure,
219 lt_marker: PhantomData,
220 }
221 }
222
223 pub(crate) fn new_fallible<F>(closure: F) -> Self
224 where
225 F: FnMut(&[Array]) -> Result<Vec<Array>, Exception> + 'a,
226 {
227 let c_closure = new_mlx_fallible_closure(closure);
228 Self {
229 c_closure,
230 lt_marker: PhantomData,
231 }
232 }
233}
234
235impl Drop for Closure<'_> {
236 fn drop(&mut self) {
237 let status = unsafe { mlx_sys::mlx_closure_free(self.c_closure) };
238 debug_assert_eq!(status, SUCCESS);
239 }
240}
241
242fn new_mlx_closure<'a, F>(closure: F) -> mlx_sys::mlx_closure
244where
245 F: FnMut(&[Array]) -> Vec<Array> + 'a,
246{
247 let boxed = Box::new(closure);
249
250 let raw = Box::into_raw(boxed);
252 let payload = raw as *mut std::ffi::c_void;
253
254 unsafe {
255 mlx_sys::mlx_closure_new_func_payload(Some(trampoline::<F>), payload, Some(noop_dtor))
256 }
257}
258
259fn new_mlx_fallible_closure<'a, F>(closure: F) -> mlx_sys::mlx_closure
260where
261 F: FnMut(&[Array]) -> Result<Vec<Array>, Exception> + 'a,
262{
263 let boxed = Box::new(closure);
264 let raw = Box::into_raw(boxed);
265 let payload = raw as *mut std::ffi::c_void;
266
267 unsafe {
268 mlx_sys::mlx_closure_new_func_payload(
269 Some(trampoline_fallible::<F>),
270 payload,
271 Some(noop_dtor),
272 )
273 }
274}
275
276fn new_mlx_vector_array(arrays: Vec<Array>) -> mlx_sys::mlx_vector_array {
278 unsafe {
279 let result = mlx_sys::mlx_vector_array_new();
280 let ctx_ptrs: Vec<mlx_sys::mlx_array> = arrays.iter().map(|array| array.as_ptr()).collect();
281 mlx_sys::mlx_vector_array_append_data(result, ctx_ptrs.as_ptr(), arrays.len());
282 result
283 }
284}
285
286fn mlx_vector_array_values(
287 vector_array: mlx_sys::mlx_vector_array,
288) -> Result<Vec<Array>, Exception> {
289 unsafe {
290 let size = mlx_sys::mlx_vector_array_size(vector_array);
291 (0..size)
292 .map(|index| {
293 Array::try_from_op(|res| mlx_sys::mlx_vector_array_get(res, vector_array, index))
294 })
295 .collect()
296 }
297}
298
299extern "C" fn trampoline<'a, F>(
300 ret: *mut mlx_vector_array,
301 vector_array: mlx_vector_array,
302 payload: *mut std::ffi::c_void,
303) -> i32
304where
305 F: FnMut(&[Array]) -> Vec<Array> + 'a,
306{
307 unsafe {
308 let raw_closure: *mut F = payload as *mut _;
309 let mut closure = Box::from_raw(raw_closure);
311 let arrays = match mlx_vector_array_values(vector_array) {
312 Ok(arrays) => arrays,
313 Err(_) => {
314 return FAILURE;
315 }
316 };
317 let result = closure(&arrays);
318 *ret = new_mlx_vector_array(result);
321
322 SUCCESS
323 }
324}
325
326extern "C" fn trampoline_fallible<'a, F>(
327 ret: *mut mlx_vector_array,
328 vector_array: mlx_vector_array,
329 payload: *mut std::ffi::c_void,
330) -> i32
331where
332 F: FnMut(&[Array]) -> Result<Vec<Array>, Exception> + 'a,
333{
334 unsafe {
335 let raw_closure: *mut F = payload as *mut _;
336 let mut closure = Box::from_raw(raw_closure);
337 let arrays = match mlx_vector_array_values(vector_array) {
338 Ok(arrays) => arrays,
339 Err(e) => {
340 set_closure_error(e);
341 return FAILURE;
342 }
343 };
344 let result = closure(&arrays);
345 match result {
346 Ok(result) => {
347 *ret = new_mlx_vector_array(result);
348 SUCCESS
349 }
350 Err(err) => {
351 set_closure_error(err);
352 FAILURE
353 }
354 }
355 }
356}
357
358extern "C" fn noop_dtor(_data: *mut std::ffi::c_void) {}
359
360pub(crate) fn get_mut_or_insert_with<'a, T>(
361 map: &'a mut HashMap<Rc<str>, T>,
362 key: &Rc<str>,
363 f: impl FnOnce() -> T,
364) -> &'a mut T {
365 if !map.contains_key(key) {
366 map.insert(key.clone(), f());
367 }
368
369 map.get_mut(key).unwrap()
370}
371
372pub trait Updatable {
377 fn updatable_states(&self) -> impl IntoIterator<Item = &Array>;
382
383 fn updatable_states_mut(&mut self) -> impl IntoIterator<Item = &mut Array>;
388}
389
390impl<T> Updatable for T
391where
392 T: ModuleParameters,
393{
394 fn updatable_states(&self) -> impl IntoIterator<Item = &Array> {
395 use itertools::Itertools;
396
397 self.parameters()
399 .flatten()
400 .into_iter()
401 .sorted_by(|a, b| a.0.cmp(&b.0))
402 .map(|(_, v)| v)
403 }
404
405 fn updatable_states_mut(&mut self) -> impl IntoIterator<Item = &mut Array> {
406 use itertools::Itertools;
407
408 self.parameters_mut()
409 .flatten()
410 .into_iter()
411 .sorted_by(|a, b| a.0.cmp(&b.0))
412 .map(|(_, v)| v)
413 }
414}
415
416impl<T1, T2> Updatable for (T1, T2)
417where
418 T1: Updatable,
419 T2: Updatable,
420{
421 fn updatable_states(&self) -> impl IntoIterator<Item = &Array> {
422 let (a, b) = self;
423 let params = a.updatable_states();
424 params.into_iter().chain(b.updatable_states())
425 }
426
427 fn updatable_states_mut(&mut self) -> impl IntoIterator<Item = &mut Array> {
428 let (a, b) = self;
429 let params = a.updatable_states_mut();
430 params.into_iter().chain(b.updatable_states_mut())
431 }
432}
433
434impl Updatable for Vec<Array> {
435 fn updatable_states(&self) -> impl IntoIterator<Item = &Array> {
436 self.iter()
437 }
438
439 fn updatable_states_mut(&mut self) -> impl IntoIterator<Item = &mut Array> {
440 self.iter_mut()
441 }
442}
443
444#[derive(Debug, Clone, Copy, PartialEq, Eq)]
446pub enum SingleOrPair<T = i32> {
447 Single(T),
449
450 Pair(T, T),
452}
453
454impl<T: Clone> SingleOrPair<T> {
455 pub fn first(&self) -> T {
457 match self {
458 SingleOrPair::Single(v) => v.clone(),
459 SingleOrPair::Pair(v1, _) => v1.clone(),
460 }
461 }
462
463 pub fn second(&self) -> T {
465 match self {
466 SingleOrPair::Single(v) => v.clone(),
467 SingleOrPair::Pair(_, v2) => v2.clone(),
468 }
469 }
470}
471
472impl<T> From<T> for SingleOrPair<T> {
473 fn from(value: T) -> Self {
474 SingleOrPair::Single(value)
475 }
476}
477
478impl<T> From<(T, T)> for SingleOrPair<T> {
479 fn from(value: (T, T)) -> Self {
480 SingleOrPair::Pair(value.0, value.1)
481 }
482}
483
484impl<T: Clone> From<SingleOrPair<T>> for (T, T) {
485 fn from(value: SingleOrPair<T>) -> Self {
486 match value {
487 SingleOrPair::Single(v) => (v.clone(), v),
488 SingleOrPair::Pair(v1, v2) => (v1, v2),
489 }
490 }
491}
492
493#[derive(Debug, Clone, Copy, PartialEq, Eq)]
495pub enum SingleOrTriple<T = i32> {
496 Single(T),
498
499 Triple(T, T, T),
501}
502
503impl<T: Clone> SingleOrTriple<T> {
504 pub fn first(&self) -> T {
506 match self {
507 SingleOrTriple::Single(v) => v.clone(),
508 SingleOrTriple::Triple(v1, _, _) => v1.clone(),
509 }
510 }
511
512 pub fn second(&self) -> T {
514 match self {
515 SingleOrTriple::Single(v) => v.clone(),
516 SingleOrTriple::Triple(_, v2, _) => v2.clone(),
517 }
518 }
519
520 pub fn third(&self) -> T {
522 match self {
523 SingleOrTriple::Single(v) => v.clone(),
524 SingleOrTriple::Triple(_, _, v3) => v3.clone(),
525 }
526 }
527}
528
529impl<T> From<T> for SingleOrTriple<T> {
530 fn from(value: T) -> Self {
531 SingleOrTriple::Single(value)
532 }
533}
534
535impl<T> From<(T, T, T)> for SingleOrTriple<T> {
536 fn from(value: (T, T, T)) -> Self {
537 SingleOrTriple::Triple(value.0, value.1, value.2)
538 }
539}
540
541impl<T: Clone> From<SingleOrTriple<T>> for (T, T, T) {
542 fn from(value: SingleOrTriple<T>) -> Self {
543 match value {
544 SingleOrTriple::Single(v) => (v.clone(), v.clone(), v),
545 SingleOrTriple::Triple(v1, v2, v3) => (v1, v2, v3),
546 }
547 }
548}
549
550#[derive(Debug, Clone, PartialEq, Eq)]
552pub enum SingleOrVec<T> {
553 Single(T),
555
556 Vec(Vec<T>),
558}
559
560impl<T> From<T> for SingleOrVec<T> {
561 fn from(value: T) -> Self {
562 SingleOrVec::Single(value)
563 }
564}
565
566impl<T> From<Vec<T>> for SingleOrVec<T> {
567 fn from(value: Vec<T>) -> Self {
568 SingleOrVec::Vec(value)
569 }
570}