mlx_rs/array/
mod.rs

1use crate::{
2    dtype::Dtype,
3    error::AsSliceError,
4    sealed::Sealed,
5    utils::{guard::Guarded, SUCCESS},
6    Stream,
7};
8use element::FromSliceElement;
9use mlx_internal_macros::default_device;
10use mlx_sys::mlx_array;
11use num_complex::Complex;
12use std::{
13    ffi::{c_void, CStr},
14    iter::Sum,
15};
16
17mod element;
18mod operators;
19
20cfg_safetensors! {
21    mod safetensors;
22}
23
24pub use element::ArrayElement;
25
26// Not using Complex64 because `num_complex::Complex64` is actually Complex<f64>
27
28/// Type alias for `num_complex::Complex<f32>`.
29#[allow(non_camel_case_types)]
30pub type complex64 = Complex<f32>;
31
32/// An n-dimensional array.
33#[repr(transparent)]
34pub struct Array {
35    c_array: mlx_array,
36}
37
38impl Sealed for Array {}
39
40impl Sealed for &Array {}
41
42impl std::fmt::Debug for Array {
43    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44        write!(f, "{self}")
45    }
46}
47
48impl std::fmt::Display for Array {
49    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
50        unsafe {
51            let mut mlx_str = mlx_sys::mlx_string_new();
52            let status = mlx_sys::mlx_array_tostring(&mut mlx_str as *mut _, self.as_ptr());
53            if status != SUCCESS {
54                return Err(std::fmt::Error);
55            }
56            let ptr = mlx_sys::mlx_string_data(mlx_str);
57            let c_str = CStr::from_ptr(ptr);
58            write!(f, "{}", c_str.to_str().map_err(|_| std::fmt::Error)?)?;
59            mlx_sys::mlx_string_free(mlx_str);
60            Ok(())
61        }
62    }
63}
64
65impl Drop for Array {
66    fn drop(&mut self) {
67        // TODO: check memory leak with some tool?
68
69        // Decrease the reference count
70        unsafe { mlx_sys::mlx_array_free(self.as_ptr()) };
71    }
72}
73
74unsafe impl Send for Array {}
75
76impl PartialEq for Array {
77    /// Array equality check.
78    ///
79    /// Compare two arrays for equality. Returns `true` iff the arrays have
80    /// the same shape and their values are equal. The arrays need not have
81    /// the same type to be considered equal.
82    ///
83    /// If you're looking for element-wise equality, use the [Array::eq()] method.
84    fn eq(&self, other: &Self) -> bool {
85        self.array_eq(other, None).unwrap().item()
86    }
87}
88
89impl Array {
90    /// Create a new array from an existing mlx_array pointer.
91    ///
92    /// # Safety
93    ///
94    /// The caller must ensure the reference count of the array is properly incremented with
95    /// `mlx_sys::mlx_retain`.
96    pub unsafe fn from_ptr(c_array: mlx_array) -> Array {
97        Self { c_array }
98    }
99
100    /// Get the underlying mlx_array pointer.
101    pub fn as_ptr(&self) -> mlx_array {
102        self.c_array
103    }
104
105    /// New array from a bool scalar.
106    pub fn from_bool(val: bool) -> Array {
107        let c_array = unsafe { mlx_sys::mlx_array_new_bool(val) };
108        Array { c_array }
109    }
110
111    /// New array from an int scalar.
112    pub fn from_int(val: i32) -> Array {
113        let c_array = unsafe { mlx_sys::mlx_array_new_int(val) };
114        Array { c_array }
115    }
116
117    /// New array from a f32 scalar.
118    pub fn from_f32(val: f32) -> Array {
119        let c_array = unsafe { mlx_sys::mlx_array_new_float32(val) };
120        Array { c_array }
121    }
122
123    /// New array from a f64 scalar.
124    pub fn from_f64(val: f64) -> Array {
125        let c_array = unsafe { mlx_sys::mlx_array_new_float64(val) };
126        Array { c_array }
127    }
128
129    /// New array from a complex scalar.
130    pub fn from_complex(val: complex64) -> Array {
131        let c_array = unsafe { mlx_sys::mlx_array_new_complex(val.re, val.im) };
132        Array { c_array }
133    }
134
135    /// New array from existing buffer.
136    ///
137    /// Please note that floating point literals are treated as f32 instead of
138    /// f64. Use [`Array::from_slice_f64`] for f64.
139    ///
140    /// # Parameters
141    ///
142    /// - `data`: A buffer which will be copied.
143    /// - `shape`: Shape of the array.
144    ///
145    /// # Panic
146    ///
147    /// - Panics if the product of the shape is not equal to the length of the
148    ///   data.
149    /// - Panics if the shape is too large.
150    pub fn from_slice<T: FromSliceElement>(data: &[T], shape: &[i32]) -> Self {
151        // Validate data size and shape
152        assert_eq!(data.len(), shape.iter().product::<i32>() as usize);
153
154        unsafe { Self::from_raw_data(data.as_ptr() as *const c_void, shape, T::DTYPE) }
155    }
156
157    /// New array from a slice of f64.
158    ///
159    /// A separate method is provided for f64 because f64 is not supported on GPU
160    /// and rust defaults to f64 for floating point literals
161    pub fn from_slice_f64(data: &[f64], shape: &[i32]) -> Self {
162        // Validate data size and shape
163        assert_eq!(data.len(), shape.iter().product::<i32>() as usize);
164
165        unsafe { Self::from_raw_data(data.as_ptr() as *const c_void, shape, Dtype::Float64) }
166    }
167
168    /// Create a new array from raw data buffer.
169    ///
170    /// This is a convenience wrapper around [`mlx_sy::mlx_array_new_data`].
171    ///
172    /// # Safety
173    ///
174    /// This is unsafe because the caller must ensure that the data buffer is valid and that the
175    /// shape is correct.
176    #[inline]
177    pub unsafe fn from_raw_data(data: *const c_void, shape: &[i32], dtype: Dtype) -> Self {
178        let dim = if shape.len() > i32::MAX as usize {
179            panic!("Shape is too large")
180        } else {
181            shape.len() as i32
182        };
183
184        let c_array = mlx_sys::mlx_array_new_data(data, shape.as_ptr(), dim, dtype.into());
185        Array { c_array }
186    }
187
188    /// New array from an iterator.
189    ///
190    /// Please note that floating point literals are treated as f32 instead of
191    /// f64. Use [`Array::from_iter_f64`] for f64.
192    ///
193    /// This is a convenience method that is equivalent to
194    ///
195    /// ```rust, ignore
196    /// let data: Vec<T> = iter.collect();
197    /// Array::from_slice(&data, shape)
198    /// ```
199    ///
200    /// # Example
201    ///
202    /// ```rust
203    /// use mlx_rs::Array;
204    ///
205    /// let data = vec![1i32, 2, 3, 4, 5];
206    /// let mut array = Array::from_iter(data.clone(), &[5]);
207    /// assert_eq!(array.as_slice::<i32>(), &data[..]);
208    /// ```
209    pub fn from_iter<I: IntoIterator<Item = T>, T: FromSliceElement>(
210        iter: I,
211        shape: &[i32],
212    ) -> Self {
213        let data: Vec<T> = iter.into_iter().collect();
214        Self::from_slice(&data, shape)
215    }
216
217    /// New array from an iterator of f64.
218    ///
219    /// A separate method is provided for f64 because f64 is not supported on GPU
220    /// and rust defaults to f64 for floating point literals
221    pub fn from_iter_f64<I: IntoIterator<Item = f64>>(iter: I, shape: &[i32]) -> Self {
222        let data: Vec<f64> = iter.into_iter().collect();
223        Self::from_slice_f64(&data, shape)
224    }
225
226    /// The size of the array’s datatype in bytes.
227    pub fn item_size(&self) -> usize {
228        unsafe { mlx_sys::mlx_array_itemsize(self.as_ptr()) }
229    }
230
231    /// Number of elements in the array.
232    pub fn size(&self) -> usize {
233        unsafe { mlx_sys::mlx_array_size(self.as_ptr()) }
234    }
235
236    /// The strides of the array.
237    pub fn strides(&self) -> &[usize] {
238        let ndim = self.ndim();
239        if ndim == 0 {
240            // The data pointer may be null which would panic even if len is 0
241            return &[];
242        }
243
244        unsafe {
245            let data = mlx_sys::mlx_array_strides(self.as_ptr());
246            std::slice::from_raw_parts(data, ndim)
247        }
248    }
249
250    /// The number of bytes in the array.
251    pub fn nbytes(&self) -> usize {
252        unsafe { mlx_sys::mlx_array_nbytes(self.as_ptr()) }
253    }
254
255    /// The array’s dimension.
256    pub fn ndim(&self) -> usize {
257        unsafe { mlx_sys::mlx_array_ndim(self.as_ptr()) }
258    }
259
260    /// The shape of the array.
261    ///
262    /// Returns: a pointer to the sizes of each dimension.
263    pub fn shape(&self) -> &[i32] {
264        let ndim = self.ndim();
265        if ndim == 0 {
266            // The data pointer may be null which would panic even if len is 0
267            return &[];
268        }
269
270        unsafe {
271            let data = mlx_sys::mlx_array_shape(self.as_ptr());
272            std::slice::from_raw_parts(data, ndim)
273        }
274    }
275
276    /// The shape of the array in a particular dimension.
277    ///
278    /// # Panic
279    ///
280    /// - Panics if the array is scalar.
281    /// - Panics if `dim` is negative and `dim + ndim` overflows
282    /// - Panics if the dimension is out of bounds.
283    pub fn dim(&self, dim: i32) -> i32 {
284        let dim = if dim.is_negative() {
285            (self.ndim() as i32).checked_add(dim).unwrap()
286        } else {
287            dim
288        };
289
290        // This will panic on a scalar array
291        unsafe { mlx_sys::mlx_array_dim(self.as_ptr(), dim) }
292    }
293
294    /// The array element type.
295    pub fn dtype(&self) -> Dtype {
296        let dtype = unsafe { mlx_sys::mlx_array_dtype(self.as_ptr()) };
297        Dtype::try_from(dtype).unwrap()
298    }
299
300    /// Evaluate the array.
301    pub fn eval(&self) -> crate::error::Result<()> {
302        <() as Guarded>::try_from_op(|_| unsafe { mlx_sys::mlx_array_eval(self.as_ptr()) })
303    }
304
305    /// Access the value of a scalar array.
306    /// If `T` does not match the array's `dtype` this will convert the type first.
307    ///
308    /// _Note: This will evaluate the array._
309    pub fn item<T: ArrayElement>(&self) -> T {
310        self.try_item().unwrap()
311    }
312
313    /// Access the value of a scalar array returning an error if the array is not a scalar.
314    /// If `T` does not match the array's `dtype` this will convert the type first.
315    ///
316    /// _Note: This will evaluate the array._
317    pub fn try_item<T: ArrayElement>(&self) -> crate::error::Result<T> {
318        self.eval()?;
319
320        // Evaluate the array, so we have content to work with in the conversion
321        self.eval()?;
322
323        // Though `mlx_array_item_<dtype>` returns a status code, it doesn't
324        // return any non-success status code even if the dtype doesn't match.
325        if self.dtype() != T::DTYPE {
326            let new_array = Array::try_from_op(|res| unsafe {
327                mlx_sys::mlx_astype(
328                    res,
329                    self.as_ptr(),
330                    T::DTYPE.into(),
331                    Stream::default().as_ptr(),
332                )
333            })?;
334            new_array.eval()?;
335            return T::array_item(&new_array);
336        }
337
338        T::array_item(self)
339    }
340
341    /// Returns a slice of the array data without validating the dtype.
342    ///
343    /// # Safety
344    ///
345    /// This is unsafe because the underlying data ptr is not checked for null or if the desired
346    /// dtype matches the actual dtype of the array.
347    ///
348    /// # Example
349    ///
350    /// ```rust
351    /// use mlx_rs::Array;
352    ///
353    /// let data = [1i32, 2, 3, 4, 5];
354    /// let mut array = Array::from_slice(&data[..], &[5]);
355    ///
356    /// unsafe {
357    ///    let slice = array.as_slice_unchecked::<i32>();
358    ///    assert_eq!(slice, &[1, 2, 3, 4, 5]);
359    /// }
360    /// ```
361    pub unsafe fn as_slice_unchecked<T: ArrayElement>(&self) -> &[T] {
362        self.eval().unwrap();
363
364        unsafe {
365            let data = T::array_data(self);
366            let size = self.size();
367            std::slice::from_raw_parts(data, size)
368        }
369    }
370
371    /// Returns a slice of the array data returning an error if the dtype does not match the actual dtype.
372    ///
373    /// # Example
374    ///
375    /// ```rust
376    /// use mlx_rs::Array;
377    ///
378    /// let data = [1i32, 2, 3, 4, 5];
379    /// let mut array = Array::from_slice(&data[..], &[5]);
380    ///
381    /// let slice = array.try_as_slice::<i32>();
382    /// assert_eq!(slice, Ok(&data[..]));
383    /// ```
384    pub fn try_as_slice<T: ArrayElement>(&self) -> Result<&[T], AsSliceError> {
385        if self.dtype() != T::DTYPE {
386            return Err(AsSliceError::DtypeMismatch {
387                expecting: T::DTYPE,
388                found: self.dtype(),
389            });
390        }
391
392        self.eval()?;
393
394        unsafe {
395            let size = self.size();
396            let data = T::array_data(self);
397            if data.is_null() || size == 0 {
398                return Err(AsSliceError::Null);
399            }
400
401            Ok(std::slice::from_raw_parts(data, size))
402        }
403    }
404
405    /// Returns a slice of the array data.
406    /// This method requires a mutable reference (`&self`) because it evaluates the array.
407    ///
408    /// # Panics
409    ///
410    /// Panics if the array is not evaluated or if the desired dtype does not match the actual dtype
411    ///
412    /// # Example
413    ///
414    /// ```rust
415    /// use mlx_rs::Array;
416    ///
417    /// let data = [1i32, 2, 3, 4, 5];
418    /// let mut array = Array::from_slice(&data[..], &[5]);
419    ///
420    /// let slice = array.as_slice::<i32>();
421    /// assert_eq!(slice, &data[..]);
422    /// ```
423    pub fn as_slice<T: ArrayElement>(&self) -> &[T] {
424        self.try_as_slice().unwrap()
425    }
426
427    /// Clone the array by copying the data.
428    ///
429    /// This is named `deep_clone` to avoid confusion with the `Clone` trait.
430    pub fn deep_clone(&self) -> Self {
431        unsafe {
432            let dtype = self.dtype();
433            let shape = self.shape();
434            let data = match dtype {
435                Dtype::Bool => mlx_sys::mlx_array_data_bool(self.as_ptr()) as *const c_void,
436                Dtype::Uint8 => mlx_sys::mlx_array_data_uint8(self.as_ptr()) as *const c_void,
437                Dtype::Uint16 => mlx_sys::mlx_array_data_uint16(self.as_ptr()) as *const c_void,
438                Dtype::Uint32 => mlx_sys::mlx_array_data_uint32(self.as_ptr()) as *const c_void,
439                Dtype::Uint64 => mlx_sys::mlx_array_data_uint64(self.as_ptr()) as *const c_void,
440                Dtype::Int8 => mlx_sys::mlx_array_data_int8(self.as_ptr()) as *const c_void,
441                Dtype::Int16 => mlx_sys::mlx_array_data_int16(self.as_ptr()) as *const c_void,
442                Dtype::Int32 => mlx_sys::mlx_array_data_int32(self.as_ptr()) as *const c_void,
443                Dtype::Int64 => mlx_sys::mlx_array_data_int64(self.as_ptr()) as *const c_void,
444                Dtype::Float16 => mlx_sys::mlx_array_data_float16(self.as_ptr()) as *const c_void,
445                Dtype::Float32 => mlx_sys::mlx_array_data_float32(self.as_ptr()) as *const c_void,
446                Dtype::Float64 => mlx_sys::mlx_array_data_float64(self.as_ptr()) as *const c_void,
447                Dtype::Bfloat16 => mlx_sys::mlx_array_data_bfloat16(self.as_ptr()) as *const c_void,
448                Dtype::Complex64 => {
449                    mlx_sys::mlx_array_data_complex64(self.as_ptr()) as *const c_void
450                }
451            };
452
453            let new_c_array =
454                mlx_sys::mlx_array_new_data(data, shape.as_ptr(), shape.len() as i32, dtype.into());
455
456            Array::from_ptr(new_c_array)
457        }
458    }
459}
460
461impl Clone for Array {
462    fn clone(&self) -> Self {
463        Array::try_from_op(|res| unsafe { mlx_sys::mlx_array_set(res, self.as_ptr()) })
464            // Exception may be thrown when calling `new` in cpp.
465            .expect("Failed to clone array")
466    }
467}
468
469impl Sum for Array {
470    fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
471        iter.fold(Array::from_int(0), |acc, x| acc.add(&x).unwrap())
472    }
473}
474
475/// Stop gradients from being computed.
476///
477/// The operation is the identity but it prevents gradients from flowing
478/// through the array.
479#[default_device]
480pub fn stop_gradient_device(
481    a: impl AsRef<Array>,
482    stream: impl AsRef<Stream>,
483) -> crate::error::Result<Array> {
484    Array::try_from_op(|res| unsafe {
485        mlx_sys::mlx_stop_gradient(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
486    })
487}
488
489impl From<bool> for Array {
490    fn from(value: bool) -> Self {
491        Array::from_bool(value)
492    }
493}
494
495impl From<i32> for Array {
496    fn from(value: i32) -> Self {
497        Array::from_int(value)
498    }
499}
500
501impl From<f32> for Array {
502    fn from(value: f32) -> Self {
503        Array::from_f32(value)
504    }
505}
506
507impl From<complex64> for Array {
508    fn from(value: complex64) -> Self {
509        Array::from_complex(value)
510    }
511}
512
513impl<T> From<T> for Array
514where
515    Array: FromNested<T>,
516{
517    fn from(value: T) -> Self {
518        Array::from_nested(value)
519    }
520}
521
522impl AsRef<Array> for Array {
523    fn as_ref(&self) -> &Array {
524        self
525    }
526}
527
528/// A helper trait to construct `Array` from scalar values.
529///
530/// This trait is intended to be used with the macro [`array!`] but can be used directly if needed.
531pub trait FromScalar<T>
532where
533    T: ArrayElement,
534{
535    /// Create an array from a scalar value.
536    fn from_scalar(val: T) -> Array;
537}
538
539impl FromScalar<bool> for Array {
540    fn from_scalar(val: bool) -> Array {
541        Array::from_bool(val)
542    }
543}
544
545impl FromScalar<i32> for Array {
546    fn from_scalar(val: i32) -> Array {
547        Array::from_int(val)
548    }
549}
550
551impl FromScalar<f32> for Array {
552    fn from_scalar(val: f32) -> Array {
553        Array::from_f32(val)
554    }
555}
556
557impl FromScalar<complex64> for Array {
558    fn from_scalar(val: complex64) -> Array {
559        Array::from_complex(val)
560    }
561}
562
563/// A helper trait to construct `Array` from nested arrays or slices.
564///
565/// Given that this is not intended for use other than the macro [`array!`], this trait is added
566/// instead of directly implementing `From` for `Array` to avoid conflicts with other `From`
567/// implementations.
568///
569/// Beware that this is subject to change in the future should we find a better way to implement
570/// the macro without creating conflicts.
571pub trait FromNested<T> {
572    /// Create an array from nested arrays or slices.
573    fn from_nested(data: T) -> Array;
574}
575
576impl<T: FromSliceElement> FromNested<&[T]> for Array {
577    fn from_nested(data: &[T]) -> Self {
578        Array::from_slice(data, &[data.len() as i32])
579    }
580}
581
582impl<T: FromSliceElement, const N: usize> FromNested<[T; N]> for Array {
583    fn from_nested(data: [T; N]) -> Self {
584        Array::from_slice(&data, &[N as i32])
585    }
586}
587
588impl<T: FromSliceElement, const N: usize> FromNested<&[T; N]> for Array {
589    fn from_nested(data: &[T; N]) -> Self {
590        Array::from_slice(data, &[N as i32])
591    }
592}
593
594impl<T: FromSliceElement + Copy> FromNested<&[&[T]]> for Array {
595    fn from_nested(data: &[&[T]]) -> Self {
596        // check that all rows have the same length
597        let row_len = data[0].len();
598        assert!(
599            data.iter().all(|row| row.len() == row_len),
600            "Rows must have the same length"
601        );
602
603        let shape = [data.len() as i32, row_len as i32];
604        let data = data
605            .iter()
606            .flat_map(|x| x.iter())
607            .copied()
608            .collect::<Vec<T>>();
609        Array::from_slice(&data, &shape)
610    }
611}
612
613impl<T: FromSliceElement + Copy, const N: usize> FromNested<[&[T]; N]> for Array {
614    fn from_nested(data: [&[T]; N]) -> Self {
615        // check that all rows have the same length
616        let row_len = data[0].len();
617        assert!(
618            data.iter().all(|row| row.len() == row_len),
619            "Rows must have the same length"
620        );
621
622        let shape = [N as i32, row_len as i32];
623        let data = data
624            .iter()
625            .flat_map(|x| x.iter())
626            .copied()
627            .collect::<Vec<T>>();
628        Array::from_slice(&data, &shape)
629    }
630}
631
632impl<T: FromSliceElement + Copy, const N: usize> FromNested<&[[T; N]]> for Array {
633    fn from_nested(data: &[[T; N]]) -> Self {
634        let shape = [data.len() as i32, N as i32];
635        let data = data
636            .iter()
637            .flat_map(|x| x.iter().copied())
638            .collect::<Vec<T>>();
639        Array::from_slice(&data, &shape)
640    }
641}
642
643impl<T: FromSliceElement + Copy, const N: usize> FromNested<&[&[T; N]]> for Array {
644    fn from_nested(data: &[&[T; N]]) -> Self {
645        let shape = [data.len() as i32, N as i32];
646        let data = data
647            .iter()
648            .flat_map(|x| x.iter().copied())
649            .collect::<Vec<T>>();
650        Array::from_slice(&data, &shape)
651    }
652}
653
654impl<T: FromSliceElement + Copy, const N: usize, const M: usize> FromNested<[[T; N]; M]> for Array {
655    fn from_nested(data: [[T; N]; M]) -> Self {
656        let shape = [M as i32, N as i32];
657        let data = data
658            .iter()
659            .flat_map(|x| x.iter().copied())
660            .collect::<Vec<T>>();
661        Array::from_slice(&data, &shape)
662    }
663}
664
665impl<T: FromSliceElement + Copy, const N: usize, const M: usize> FromNested<&[[T; N]; M]>
666    for Array
667{
668    fn from_nested(data: &[[T; N]; M]) -> Self {
669        let shape = [M as i32, N as i32];
670        let data = data
671            .iter()
672            .flat_map(|x| x.iter().copied())
673            .collect::<Vec<T>>();
674        Array::from_slice(&data, &shape)
675    }
676}
677
678impl<T: FromSliceElement + Copy, const N: usize, const M: usize> FromNested<&[&[T; N]; M]>
679    for Array
680{
681    fn from_nested(data: &[&[T; N]; M]) -> Self {
682        let shape = [M as i32, N as i32];
683        let data = data
684            .iter()
685            .flat_map(|x| x.iter().copied())
686            .collect::<Vec<T>>();
687        Array::from_slice(&data, &shape)
688    }
689}
690
691impl<T: FromSliceElement + Copy> FromNested<&[&[&[T]]]> for Array {
692    fn from_nested(data: &[&[&[T]]]) -> Self {
693        // check that 2nd dimension has the same length
694        let len_2d = data[0].len();
695        assert!(
696            data.iter().all(|x| x.len() == len_2d),
697            "2nd dimension must have the same length"
698        );
699
700        // check that 3rd dimension has the same length
701        let len_3d = data[0][0].len();
702        assert!(
703            data.iter().all(|x| x.iter().all(|y| y.len() == len_3d)),
704            "3rd dimension must have the same length"
705        );
706
707        let shape = [data.len() as i32, len_2d as i32, len_3d as i32];
708        let data = data
709            .iter()
710            .flat_map(|x| x.iter().flat_map(|y| y.iter().copied()))
711            .collect::<Vec<T>>();
712        Array::from_slice(&data, &shape)
713    }
714}
715
716impl<T: FromSliceElement + Copy, const N: usize> FromNested<[&[&[T]]; N]> for Array {
717    fn from_nested(data: [&[&[T]]; N]) -> Self {
718        // check that 2nd dimension has the same length
719        let len_2d = data[0].len();
720        assert!(
721            data.iter().all(|x| x.len() == len_2d),
722            "2nd dimension must have the same length"
723        );
724
725        // check that 3rd dimension has the same length
726        let len_3d = data[0][0].len();
727        assert!(
728            data.iter().all(|x| x.iter().all(|y| y.len() == len_3d)),
729            "3rd dimension must have the same length"
730        );
731
732        let shape = [N as i32, len_2d as i32, len_3d as i32];
733        let data = data
734            .iter()
735            .flat_map(|x| x.iter().flat_map(|y| y.iter().copied()))
736            .collect::<Vec<T>>();
737        Array::from_slice(&data, &shape)
738    }
739}
740
741impl<T: FromSliceElement + Copy, const N: usize> FromNested<&[[&[T]; N]]> for Array {
742    fn from_nested(data: &[[&[T]; N]]) -> Self {
743        // check that 3rd dimension has the same length
744        let len_3d = data[0][0].len();
745        assert!(
746            data.iter().all(|x| x.iter().all(|y| y.len() == len_3d)),
747            "3rd dimension must have the same length"
748        );
749
750        let shape = [data.len() as i32, N as i32, len_3d as i32];
751        let data = data
752            .iter()
753            .flat_map(|x| x.iter().flat_map(|y| y.iter().copied()))
754            .collect::<Vec<T>>();
755        Array::from_slice(&data, &shape)
756    }
757}
758
759impl<T: FromSliceElement + Copy, const N: usize> FromNested<&[&[[T; N]]]> for Array {
760    fn from_nested(data: &[&[[T; N]]]) -> Self {
761        // check that 2nd dimension has the same length
762        let len_2d = data[0].len();
763        assert!(
764            data.iter().all(|x| x.len() == len_2d),
765            "2nd dimension must have the same length"
766        );
767
768        let shape = [data.len() as i32, len_2d as i32, N as i32];
769        let data = data
770            .iter()
771            .flat_map(|x| x.iter().flat_map(|y| y.iter().copied()))
772            .collect::<Vec<T>>();
773        Array::from_slice(&data, &shape)
774    }
775}
776
777impl<T: FromSliceElement + Copy, const N: usize, const M: usize> FromNested<[[&[T]; N]; M]>
778    for Array
779{
780    fn from_nested(data: [[&[T]; N]; M]) -> Self {
781        // check that 3rd dimension has the same length
782        let len_3d = data[0][0].len();
783        assert!(
784            data.iter().all(|x| x.iter().all(|y| y.len() == len_3d)),
785            "3rd dimension must have the same length"
786        );
787
788        let shape = [M as i32, N as i32, len_3d as i32];
789        let data = data
790            .iter()
791            .flat_map(|x| x.iter().flat_map(|y| y.iter().copied()))
792            .collect::<Vec<T>>();
793        Array::from_slice(&data, &shape)
794    }
795}
796
797impl<T: FromSliceElement + Copy, const N: usize, const M: usize> FromNested<&[[&[T]; N]; M]>
798    for Array
799{
800    fn from_nested(data: &[[&[T]; N]; M]) -> Self {
801        // check that 3rd dimension has the same length
802        let len_3d = data[0][0].len();
803        assert!(
804            data.iter().all(|x| x.iter().all(|y| y.len() == len_3d)),
805            "3rd dimension must have the same length"
806        );
807
808        let shape = [M as i32, N as i32, len_3d as i32];
809        let data = data
810            .iter()
811            .flat_map(|x| x.iter().flat_map(|y| y.iter().copied()))
812            .collect::<Vec<T>>();
813        Array::from_slice(&data, &shape)
814    }
815}
816
817impl<T: FromSliceElement + Copy, const N: usize, const M: usize> FromNested<&[&[[T; N]]; M]>
818    for Array
819{
820    fn from_nested(data: &[&[[T; N]]; M]) -> Self {
821        // check that 2nd dimension has the same length
822        let len_2d = data[0].len();
823        assert!(
824            data.iter().all(|x| x.len() == len_2d),
825            "2nd dimension must have the same length"
826        );
827
828        let shape = [M as i32, len_2d as i32, N as i32];
829        let data = data
830            .iter()
831            .flat_map(|x| x.iter().flat_map(|y| y.iter().copied()))
832            .collect::<Vec<T>>();
833        Array::from_slice(&data, &shape)
834    }
835}
836
837impl<T: FromSliceElement + Copy, const N: usize, const M: usize, const O: usize>
838    FromNested<[[[T; N]; M]; O]> for Array
839{
840    fn from_nested(data: [[[T; N]; M]; O]) -> Self {
841        let shape = [O as i32, M as i32, N as i32];
842        let data = data
843            .iter()
844            .flat_map(|x| x.iter().flat_map(|y| y.iter().copied()))
845            .collect::<Vec<T>>();
846        Array::from_slice(&data, &shape)
847    }
848}
849
850impl<T: FromSliceElement + Copy, const N: usize, const M: usize, const O: usize>
851    FromNested<&[[[T; N]; M]; O]> for Array
852{
853    fn from_nested(data: &[[[T; N]; M]; O]) -> Self {
854        let shape = [O as i32, M as i32, N as i32];
855        let data = data
856            .iter()
857            .flat_map(|x| x.iter().flat_map(|y| y.iter().copied()))
858            .collect::<Vec<T>>();
859        Array::from_slice(&data, &shape)
860    }
861}
862
863impl<T: FromSliceElement + Copy, const N: usize, const M: usize, const O: usize>
864    FromNested<&[&[[T; N]; M]; O]> for Array
865{
866    fn from_nested(data: &[&[[T; N]; M]; O]) -> Self {
867        let shape = [O as i32, M as i32, N as i32];
868        let data = data
869            .iter()
870            .flat_map(|x| x.iter().flat_map(|y| y.iter().copied()))
871            .collect::<Vec<T>>();
872        Array::from_slice(&data, &shape)
873    }
874}
875
876impl<T: FromSliceElement + Copy, const N: usize, const M: usize, const O: usize>
877    FromNested<&[[&[T; N]; M]; O]> for Array
878{
879    fn from_nested(data: &[[&[T; N]; M]; O]) -> Self {
880        let shape = [O as i32, M as i32, N as i32];
881        let data = data
882            .iter()
883            .flat_map(|x| x.iter().flat_map(|y| y.iter().copied()))
884            .collect::<Vec<T>>();
885        Array::from_slice(&data, &shape)
886    }
887}
888
889impl<T: FromSliceElement + Copy, const N: usize, const M: usize, const O: usize>
890    FromNested<&[&[&[T; N]; M]; O]> for Array
891{
892    fn from_nested(data: &[&[&[T; N]; M]; O]) -> Self {
893        let shape = [O as i32, M as i32, N as i32];
894        let data = data
895            .iter()
896            .flat_map(|x| x.iter().flat_map(|y| y.iter().copied()))
897            .collect::<Vec<T>>();
898        Array::from_slice(&data, &shape)
899    }
900}
901
902#[cfg(test)]
903mod tests {
904    use super::*;
905
906    #[test]
907    fn new_scalar_array_from_bool() {
908        let array = Array::from_bool(true);
909        assert!(array.item::<bool>());
910        assert_eq!(array.item_size(), 1);
911        assert_eq!(array.size(), 1);
912        assert!(array.strides().is_empty());
913        assert_eq!(array.nbytes(), 1);
914        assert_eq!(array.ndim(), 0);
915        assert!(array.shape().is_empty());
916        assert_eq!(array.dtype(), Dtype::Bool);
917    }
918
919    #[test]
920    fn new_scalar_array_from_int() {
921        let array = Array::from_int(42);
922        assert_eq!(array.item::<i32>(), 42);
923        assert_eq!(array.item_size(), 4);
924        assert_eq!(array.size(), 1);
925        assert!(array.strides().is_empty());
926        assert_eq!(array.nbytes(), 4);
927        assert_eq!(array.ndim(), 0);
928        assert!(array.shape().is_empty());
929        assert_eq!(array.dtype(), Dtype::Int32);
930    }
931
932    #[test]
933    fn new_scalar_array_from_f32() {
934        let array = Array::from_f32(3.14);
935        assert_eq!(array.item::<f32>(), 3.14);
936        assert_eq!(array.item_size(), 4);
937        assert_eq!(array.size(), 1);
938        assert!(array.strides().is_empty());
939        assert_eq!(array.nbytes(), 4);
940        assert_eq!(array.ndim(), 0);
941        assert!(array.shape().is_empty());
942        assert_eq!(array.dtype(), Dtype::Float32);
943    }
944
945    #[test]
946    fn new_scalar_array_from_f64() {
947        let array = Array::from_f64(3.14).as_dtype(Dtype::Float64).unwrap();
948        float_eq::assert_float_eq!(array.item::<f64>(), 3.14, abs <= 1e-5);
949        assert_eq!(array.item_size(), 8);
950        assert_eq!(array.size(), 1);
951        assert!(array.strides().is_empty());
952        assert_eq!(array.nbytes(), 8);
953        assert_eq!(array.ndim(), 0);
954        assert!(array.shape().is_empty());
955        assert_eq!(array.dtype(), Dtype::Float64);
956    }
957
958    #[test]
959    fn new_array_from_slice_f64() {
960        let array = Array::from_slice_f64(&[1.0, 2.0, 3.0], &[3]);
961        assert_eq!(array.item_size(), 8);
962        assert_eq!(array.size(), 3);
963        assert_eq!(array.strides(), &[1]);
964        assert_eq!(array.nbytes(), 24);
965        assert_eq!(array.ndim(), 1);
966        assert_eq!(array.dim(0), 3);
967        assert_eq!(array.shape(), &[3]);
968        assert_eq!(array.dtype(), Dtype::Float64);
969    }
970
971    #[test]
972    fn new_scalar_array_from_complex() {
973        let val = complex64::new(1.0, 2.0);
974        let array = Array::from_complex(val);
975        assert_eq!(array.item::<complex64>(), val);
976        assert_eq!(array.item_size(), 8);
977        assert_eq!(array.size(), 1);
978        assert!(array.strides().is_empty());
979        assert_eq!(array.nbytes(), 8);
980        assert_eq!(array.ndim(), 0);
981        assert!(array.shape().is_empty());
982        assert_eq!(array.dtype(), Dtype::Complex64);
983    }
984
985    #[test]
986    fn new_array_from_single_element_slice() {
987        let data = [1i32];
988        let array = Array::from_slice(&data, &[1]);
989        assert_eq!(array.as_slice::<i32>(), &data[..]);
990        assert_eq!(array.item::<i32>(), 1);
991        assert_eq!(array.item_size(), 4);
992        assert_eq!(array.size(), 1);
993        assert_eq!(array.strides(), &[1]);
994        assert_eq!(array.nbytes(), 4);
995        assert_eq!(array.ndim(), 1);
996        assert_eq!(array.dim(0), 1);
997        assert_eq!(array.shape(), &[1]);
998        assert_eq!(array.dtype(), Dtype::Int32);
999    }
1000
1001    #[test]
1002    fn new_array_from_multi_element_slice() {
1003        let data = [1i32, 2, 3, 4, 5];
1004        let array = Array::from_slice(&data, &[5]);
1005        assert_eq!(array.as_slice::<i32>(), &data[..]);
1006        assert_eq!(array.item_size(), 4);
1007        assert_eq!(array.size(), 5);
1008        assert_eq!(array.strides(), &[1]);
1009        assert_eq!(array.nbytes(), 20);
1010        assert_eq!(array.ndim(), 1);
1011        assert_eq!(array.dim(0), 5);
1012        assert_eq!(array.shape(), &[5]);
1013        assert_eq!(array.dtype(), Dtype::Int32);
1014    }
1015
1016    #[test]
1017    fn new_2d_array_from_slice() {
1018        let data = [1i32, 2, 3, 4, 5, 6];
1019        let array = Array::from_slice(&data, &[2, 3]);
1020        assert_eq!(array.as_slice::<i32>(), &data[..]);
1021        assert_eq!(array.item_size(), 4);
1022        assert_eq!(array.size(), 6);
1023        assert_eq!(array.strides(), &[3, 1]);
1024        assert_eq!(array.nbytes(), 24);
1025        assert_eq!(array.ndim(), 2);
1026        assert_eq!(array.dim(0), 2);
1027        assert_eq!(array.dim(1), 3);
1028        assert_eq!(array.dim(-1), 3); // negative index
1029        assert_eq!(array.dim(-2), 2); // negative index
1030        assert_eq!(array.shape(), &[2, 3]);
1031        assert_eq!(array.dtype(), Dtype::Int32);
1032    }
1033
1034    #[test]
1035    fn deep_cloned_array_has_different_ptr() {
1036        let data = [1i32, 2, 3, 4, 5];
1037        let orig = Array::from_slice(&data, &[5]);
1038        let clone = orig.deep_clone();
1039
1040        // Data should be the same
1041        assert_eq!(orig.as_slice::<i32>(), clone.as_slice::<i32>());
1042
1043        // Addr of `mlx_array` should be different
1044        assert_ne!(orig.as_ptr().ctx, clone.as_ptr().ctx);
1045
1046        // Addr of data should be different
1047        assert_ne!(
1048            orig.as_slice::<i32>().as_ptr(),
1049            clone.as_slice::<i32>().as_ptr()
1050        );
1051    }
1052
1053    #[test]
1054    fn test_array_eq() {
1055        let data = [1i32, 2, 3, 4, 5];
1056        let array1 = Array::from_slice(&data, &[5]);
1057        let array2 = Array::from_slice(&data, &[5]);
1058        let array3 = Array::from_slice(&[1i32, 2, 3, 4, 6], &[5]);
1059
1060        assert_eq!(&array1, &array2);
1061        assert_ne!(&array1, &array3);
1062    }
1063
1064    #[test]
1065    fn test_array_item_non_scalar() {
1066        let data = [1i32, 2, 3, 4, 5];
1067        let array = Array::from_slice(&data, &[5]);
1068        assert!(array.try_item::<i32>().is_err());
1069    }
1070
1071    #[test]
1072    fn test_item_type_conversion() {
1073        let array = Array::from_f32(1.0);
1074        assert_eq!(array.item::<i32>(), 1);
1075        assert_eq!(array.item::<complex64>(), complex64::new(1.0, 0.0));
1076        assert_eq!(array.item::<u8>(), 1);
1077
1078        assert_eq!(array.as_slice::<f32>(), &[1.0]);
1079    }
1080}