mlx_rs/array/
mod.rs

1use crate::{
2    dtype::Dtype,
3    error::AsSliceError,
4    sealed::Sealed,
5    utils::{guard::Guarded, SUCCESS},
6    Stream, StreamOrDevice,
7};
8use element::FromSliceElement;
9use mlx_internal_macros::default_device;
10use mlx_sys::mlx_array;
11use num_complex::Complex;
12use std::{
13    ffi::{c_void, CStr},
14    iter::Sum,
15};
16
17mod element;
18mod operators;
19
20cfg_safetensors! {
21    mod safetensors;
22}
23
24pub use element::ArrayElement;
25
26// Not using Complex64 because `num_complex::Complex64` is actually Complex<f64>
27
28/// Type alias for `num_complex::Complex<f32>`.
29#[allow(non_camel_case_types)]
30pub type complex64 = Complex<f32>;
31
32/// An n-dimensional array.
33#[repr(transparent)]
34pub struct Array {
35    c_array: mlx_array,
36}
37
38impl Sealed for Array {}
39
40impl Sealed for &Array {}
41
42impl std::fmt::Debug for Array {
43    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44        write!(f, "{}", self)
45    }
46}
47
48impl std::fmt::Display for Array {
49    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
50        unsafe {
51            let mut mlx_str = mlx_sys::mlx_string_new();
52            let status = mlx_sys::mlx_array_tostring(&mut mlx_str as *mut _, self.as_ptr());
53            if status != SUCCESS {
54                return Err(std::fmt::Error);
55            }
56            let ptr = mlx_sys::mlx_string_data(mlx_str);
57            let c_str = CStr::from_ptr(ptr);
58            write!(f, "{:?}", c_str)?;
59            mlx_sys::mlx_string_free(mlx_str);
60            Ok(())
61        }
62    }
63}
64
65impl Drop for Array {
66    fn drop(&mut self) {
67        // TODO: check memory leak with some tool?
68
69        // Decrease the reference count
70        unsafe { mlx_sys::mlx_array_free(self.as_ptr()) };
71    }
72}
73
74unsafe impl Send for Array {}
75
76impl PartialEq for Array {
77    /// Array equality check.
78    ///
79    /// Compare two arrays for equality. Returns `true` iff the arrays have
80    /// the same shape and their values are equal. The arrays need not have
81    /// the same type to be considered equal.
82    ///
83    /// If you're looking for element-wise equality, use the [Array::eq()] method.
84    fn eq(&self, other: &Self) -> bool {
85        self.array_eq(other, None).unwrap().item()
86    }
87}
88
89impl Array {
90    /// Create a new array from an existing mlx_array pointer.
91    ///
92    /// # Safety
93    ///
94    /// The caller must ensure the reference count of the array is properly incremented with
95    /// `mlx_sys::mlx_retain`.
96    pub unsafe fn from_ptr(c_array: mlx_array) -> Array {
97        Self { c_array }
98    }
99
100    /// Get the underlying mlx_array pointer.
101    pub fn as_ptr(&self) -> mlx_array {
102        self.c_array
103    }
104
105    /// New array from a bool scalar.
106    pub fn from_bool(val: bool) -> Array {
107        let c_array = unsafe { mlx_sys::mlx_array_new_bool(val) };
108        Array { c_array }
109    }
110
111    /// New array from an int scalar.
112    pub fn from_int(val: i32) -> Array {
113        let c_array = unsafe { mlx_sys::mlx_array_new_int(val) };
114        Array { c_array }
115    }
116
117    /// New array from a f32 scalar.
118    pub fn from_f32(val: f32) -> Array {
119        let c_array = unsafe { mlx_sys::mlx_array_new_float32(val) };
120        Array { c_array }
121    }
122
123    // // TODO: This is bugged right now. See https://github.com/ml-explore/mlx/issues/1994
124    // /// New array from a f64 scalar.
125    // pub fn from_f64(val: f64) -> Array {
126    //     let c_array = unsafe { mlx_sys::mlx_array_new_float64(val) };
127    //     Array { c_array }
128    // }
129
130    /// New array from a complex scalar.
131    pub fn from_complex(val: complex64) -> Array {
132        let c_array = unsafe { mlx_sys::mlx_array_new_complex(val.re, val.im) };
133        Array { c_array }
134    }
135
136    /// New array from existing buffer.
137    ///
138    /// Please note that floating point literals are treated as f32 instead of
139    /// f64. Use [`Array::from_slice_f64`] for f64.
140    ///
141    /// # Parameters
142    ///
143    /// - `data`: A buffer which will be copied.
144    /// - `shape`: Shape of the array.
145    ///
146    /// # Panic
147    ///
148    /// - Panics if the product of the shape is not equal to the length of the
149    ///   data.
150    /// - Panics if the shape is too large.
151    pub fn from_slice<T: FromSliceElement>(data: &[T], shape: &[i32]) -> Self {
152        // Validate data size and shape
153        assert_eq!(data.len(), shape.iter().product::<i32>() as usize);
154
155        unsafe { Self::from_raw_data(data.as_ptr() as *const c_void, shape, T::DTYPE) }
156    }
157
158    /// New array from a slice of f64.
159    ///
160    /// A separate method is provided for f64 because f64 is not supported on GPU
161    /// and rust defaults to f64 for floating point literals
162    pub fn from_slice_f64(data: &[f64], shape: &[i32]) -> Self {
163        // Validate data size and shape
164        assert_eq!(data.len(), shape.iter().product::<i32>() as usize);
165
166        unsafe { Self::from_raw_data(data.as_ptr() as *const c_void, shape, Dtype::Float64) }
167    }
168
169    /// Create a new array from raw data buffer.
170    ///
171    /// This is a convenience wrapper around [`mlx_sy::mlx_array_new_data`].
172    ///
173    /// # Safety
174    ///
175    /// This is unsafe because the caller must ensure that the data buffer is valid and that the
176    /// shape is correct.
177    #[inline]
178    pub unsafe fn from_raw_data(data: *const c_void, shape: &[i32], dtype: Dtype) -> Self {
179        let dim = if shape.len() > i32::MAX as usize {
180            panic!("Shape is too large")
181        } else {
182            shape.len() as i32
183        };
184
185        let c_array = mlx_sys::mlx_array_new_data(data, shape.as_ptr(), dim, dtype.into());
186        Array { c_array }
187    }
188
189    /// New array from an iterator.
190    ///
191    /// Please note that floating point literals are treated as f32 instead of
192    /// f64. Use [`Array::from_iter_f64`] for f64.
193    ///
194    /// This is a convenience method that is equivalent to
195    ///
196    /// ```rust, ignore
197    /// let data: Vec<T> = iter.collect();
198    /// Array::from_slice(&data, shape)
199    /// ```
200    ///
201    /// # Example
202    ///
203    /// ```rust
204    /// use mlx_rs::Array;
205    ///
206    /// let data = vec![1i32, 2, 3, 4, 5];
207    /// let mut array = Array::from_iter(data.clone(), &[5]);
208    /// assert_eq!(array.as_slice::<i32>(), &data[..]);
209    /// ```
210    pub fn from_iter<I: IntoIterator<Item = T>, T: FromSliceElement>(
211        iter: I,
212        shape: &[i32],
213    ) -> Self {
214        let data: Vec<T> = iter.into_iter().collect();
215        Self::from_slice(&data, shape)
216    }
217
218    /// New array from an iterator of f64.
219    ///
220    /// A separate method is provided for f64 because f64 is not supported on GPU
221    /// and rust defaults to f64 for floating point literals
222    pub fn from_iter_f64<I: IntoIterator<Item = f64>>(iter: I, shape: &[i32]) -> Self {
223        let data: Vec<f64> = iter.into_iter().collect();
224        Self::from_slice_f64(&data, shape)
225    }
226
227    /// The size of the array’s datatype in bytes.
228    pub fn item_size(&self) -> usize {
229        unsafe { mlx_sys::mlx_array_itemsize(self.as_ptr()) }
230    }
231
232    /// Number of elements in the array.
233    pub fn size(&self) -> usize {
234        unsafe { mlx_sys::mlx_array_size(self.as_ptr()) }
235    }
236
237    /// The strides of the array.
238    pub fn strides(&self) -> &[usize] {
239        let ndim = self.ndim();
240        if ndim == 0 {
241            // The data pointer may be null which would panic even if len is 0
242            return &[];
243        }
244
245        unsafe {
246            let data = mlx_sys::mlx_array_strides(self.as_ptr());
247            std::slice::from_raw_parts(data, ndim)
248        }
249    }
250
251    /// The number of bytes in the array.
252    pub fn nbytes(&self) -> usize {
253        unsafe { mlx_sys::mlx_array_nbytes(self.as_ptr()) }
254    }
255
256    /// The array’s dimension.
257    pub fn ndim(&self) -> usize {
258        unsafe { mlx_sys::mlx_array_ndim(self.as_ptr()) }
259    }
260
261    /// The shape of the array.
262    ///
263    /// Returns: a pointer to the sizes of each dimension.
264    pub fn shape(&self) -> &[i32] {
265        let ndim = self.ndim();
266        if ndim == 0 {
267            // The data pointer may be null which would panic even if len is 0
268            return &[];
269        }
270
271        unsafe {
272            let data = mlx_sys::mlx_array_shape(self.as_ptr());
273            std::slice::from_raw_parts(data, ndim)
274        }
275    }
276
277    /// The shape of the array in a particular dimension.
278    ///
279    /// # Panic
280    ///
281    /// - Panics if the array is scalar.
282    /// - Panics if `dim` is negative and `dim + ndim` overflows
283    /// - Panics if the dimension is out of bounds.
284    pub fn dim(&self, dim: i32) -> i32 {
285        let dim = if dim.is_negative() {
286            (self.ndim() as i32).checked_add(dim).unwrap()
287        } else {
288            dim
289        };
290
291        // This will panic on a scalar array
292        unsafe { mlx_sys::mlx_array_dim(self.as_ptr(), dim) }
293    }
294
295    /// The array element type.
296    pub fn dtype(&self) -> Dtype {
297        let dtype = unsafe { mlx_sys::mlx_array_dtype(self.as_ptr()) };
298        Dtype::try_from(dtype).unwrap()
299    }
300
301    // TODO: document that mlx is lazy
302    /// Evaluate the array.
303    pub fn eval(&self) -> crate::error::Result<()> {
304        <() as Guarded>::try_from_op(|_| unsafe { mlx_sys::mlx_array_eval(self.as_ptr()) })
305    }
306
307    /// Access the value of a scalar array.
308    /// If `T` does not match the array's `dtype` this will convert the type first.
309    ///
310    /// _Note: This will evaluate the array._
311    pub fn item<T: ArrayElement>(&self) -> T {
312        self.try_item().unwrap()
313    }
314
315    /// Access the value of a scalar array returning an error if the array is not a scalar.
316    /// If `T` does not match the array's `dtype` this will convert the type first.
317    ///
318    /// _Note: This will evaluate the array._
319    pub fn try_item<T: ArrayElement>(&self) -> crate::error::Result<T> {
320        self.eval()?;
321
322        // Evaluate the array, so we have content to work with in the conversion
323        self.eval()?;
324
325        // Though `mlx_array_item_<dtype>` returns a status code, it doesn't
326        // return any non-success status code even if the dtype doesn't match.
327        if self.dtype() != T::DTYPE {
328            let new_array = Array::try_from_op(|res| unsafe {
329                mlx_sys::mlx_astype(
330                    res,
331                    self.as_ptr(),
332                    T::DTYPE.into(),
333                    Stream::default().as_ptr(),
334                )
335            })?;
336            new_array.eval()?;
337            return T::array_item(&new_array);
338        }
339
340        T::array_item(self)
341    }
342
343    /// Returns a slice of the array data without validating the dtype.
344    ///
345    /// # Safety
346    ///
347    /// This is unsafe because the underlying data ptr is not checked for null or if the desired
348    /// dtype matches the actual dtype of the array.
349    ///
350    /// # Example
351    ///
352    /// ```rust
353    /// use mlx_rs::Array;
354    ///
355    /// let data = [1i32, 2, 3, 4, 5];
356    /// let mut array = Array::from_slice(&data[..], &[5]);
357    ///
358    /// unsafe {
359    ///    let slice = array.as_slice_unchecked::<i32>();
360    ///    assert_eq!(slice, &[1, 2, 3, 4, 5]);
361    /// }
362    /// ```
363    pub unsafe fn as_slice_unchecked<T: ArrayElement>(&self) -> &[T] {
364        self.eval().unwrap();
365
366        unsafe {
367            let data = T::array_data(self);
368            let size = self.size();
369            std::slice::from_raw_parts(data, size)
370        }
371    }
372
373    /// Returns a slice of the array data returning an error if the dtype does not match the actual dtype.
374    ///
375    /// # Example
376    ///
377    /// ```rust
378    /// use mlx_rs::Array;
379    ///
380    /// let data = [1i32, 2, 3, 4, 5];
381    /// let mut array = Array::from_slice(&data[..], &[5]);
382    ///
383    /// let slice = array.try_as_slice::<i32>();
384    /// assert_eq!(slice, Ok(&data[..]));
385    /// ```
386    pub fn try_as_slice<T: ArrayElement>(&self) -> Result<&[T], AsSliceError> {
387        if self.dtype() != T::DTYPE {
388            return Err(AsSliceError::DtypeMismatch {
389                expecting: T::DTYPE,
390                found: self.dtype(),
391            });
392        }
393
394        self.eval()?;
395
396        unsafe {
397            let size = self.size();
398            let data = T::array_data(self);
399            if data.is_null() || size == 0 {
400                return Err(AsSliceError::Null);
401            }
402
403            Ok(std::slice::from_raw_parts(data, size))
404        }
405    }
406
407    /// Returns a slice of the array data.
408    /// This method requires a mutable reference (`&self`) because it evaluates the array.
409    ///
410    /// # Panics
411    ///
412    /// Panics if the array is not evaluated or if the desired dtype does not match the actual dtype
413    ///
414    /// # Example
415    ///
416    /// ```rust
417    /// use mlx_rs::Array;
418    ///
419    /// let data = [1i32, 2, 3, 4, 5];
420    /// let mut array = Array::from_slice(&data[..], &[5]);
421    ///
422    /// let slice = array.as_slice::<i32>();
423    /// assert_eq!(slice, &data[..]);
424    /// ```
425    pub fn as_slice<T: ArrayElement>(&self) -> &[T] {
426        self.try_as_slice().unwrap()
427    }
428
429    /// Clone the array by copying the data.
430    ///
431    /// This is named `deep_clone` to avoid confusion with the `Clone` trait.
432    pub fn deep_clone(&self) -> Self {
433        unsafe {
434            let dtype = self.dtype();
435            let shape = self.shape();
436            let data = match dtype {
437                Dtype::Bool => mlx_sys::mlx_array_data_bool(self.as_ptr()) as *const c_void,
438                Dtype::Uint8 => mlx_sys::mlx_array_data_uint8(self.as_ptr()) as *const c_void,
439                Dtype::Uint16 => mlx_sys::mlx_array_data_uint16(self.as_ptr()) as *const c_void,
440                Dtype::Uint32 => mlx_sys::mlx_array_data_uint32(self.as_ptr()) as *const c_void,
441                Dtype::Uint64 => mlx_sys::mlx_array_data_uint64(self.as_ptr()) as *const c_void,
442                Dtype::Int8 => mlx_sys::mlx_array_data_int8(self.as_ptr()) as *const c_void,
443                Dtype::Int16 => mlx_sys::mlx_array_data_int16(self.as_ptr()) as *const c_void,
444                Dtype::Int32 => mlx_sys::mlx_array_data_int32(self.as_ptr()) as *const c_void,
445                Dtype::Int64 => mlx_sys::mlx_array_data_int64(self.as_ptr()) as *const c_void,
446                Dtype::Float16 => mlx_sys::mlx_array_data_float16(self.as_ptr()) as *const c_void,
447                Dtype::Float32 => mlx_sys::mlx_array_data_float32(self.as_ptr()) as *const c_void,
448                Dtype::Float64 => mlx_sys::mlx_array_data_float64(self.as_ptr()) as *const c_void,
449                Dtype::Bfloat16 => mlx_sys::mlx_array_data_bfloat16(self.as_ptr()) as *const c_void,
450                Dtype::Complex64 => {
451                    mlx_sys::mlx_array_data_complex64(self.as_ptr()) as *const c_void
452                }
453            };
454
455            let new_c_array =
456                mlx_sys::mlx_array_new_data(data, shape.as_ptr(), shape.len() as i32, dtype.into());
457
458            Array::from_ptr(new_c_array)
459        }
460    }
461}
462
463impl Clone for Array {
464    fn clone(&self) -> Self {
465        Array::try_from_op(|res| unsafe { mlx_sys::mlx_array_set(res, self.as_ptr()) })
466            // Exception may be thrown when calling `new` in cpp.
467            .expect("Failed to clone array")
468    }
469}
470
471impl Sum for Array {
472    fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
473        iter.fold(Array::from_int(0), |acc, x| acc.add(&x).unwrap())
474    }
475}
476
477/// Stop gradients from being computed.
478///
479/// The operation is the identity but it prevents gradients from flowing
480/// through the array.
481#[default_device]
482pub fn stop_gradient_device(
483    a: impl AsRef<Array>,
484    stream: impl AsRef<Stream>,
485) -> crate::error::Result<Array> {
486    Array::try_from_op(|res| unsafe {
487        mlx_sys::mlx_stop_gradient(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
488    })
489}
490
491impl From<bool> for Array {
492    fn from(value: bool) -> Self {
493        Array::from_bool(value)
494    }
495}
496
497impl From<i32> for Array {
498    fn from(value: i32) -> Self {
499        Array::from_int(value)
500    }
501}
502
503impl From<f32> for Array {
504    fn from(value: f32) -> Self {
505        Array::from_f32(value)
506    }
507}
508
509impl From<complex64> for Array {
510    fn from(value: complex64) -> Self {
511        Array::from_complex(value)
512    }
513}
514
515impl<T> From<T> for Array
516where
517    Array: FromNested<T>,
518{
519    fn from(value: T) -> Self {
520        Array::from_nested(value)
521    }
522}
523
524impl AsRef<Array> for Array {
525    fn as_ref(&self) -> &Array {
526        self
527    }
528}
529
530/// A helper trait to construct `Array` from scalar values.
531///
532/// This trait is intended to be used with the macro [`array!`] but can be used directly if needed.
533pub trait FromScalar<T>
534where
535    T: ArrayElement,
536{
537    /// Create an array from a scalar value.
538    fn from_scalar(val: T) -> Array;
539}
540
541impl FromScalar<bool> for Array {
542    fn from_scalar(val: bool) -> Array {
543        Array::from_bool(val)
544    }
545}
546
547impl FromScalar<i32> for Array {
548    fn from_scalar(val: i32) -> Array {
549        Array::from_int(val)
550    }
551}
552
553impl FromScalar<f32> for Array {
554    fn from_scalar(val: f32) -> Array {
555        Array::from_f32(val)
556    }
557}
558
559impl FromScalar<complex64> for Array {
560    fn from_scalar(val: complex64) -> Array {
561        Array::from_complex(val)
562    }
563}
564
565/// A helper trait to construct `Array` from nested arrays or slices.
566///
567/// Given that this is not intended for use other than the macro [`array!`], this trait is added
568/// instead of directly implementing `From` for `Array` to avoid conflicts with other `From`
569/// implementations.
570///
571/// Beware that this is subject to change in the future should we find a better way to implement
572/// the macro without creating conflicts.
573pub trait FromNested<T> {
574    /// Create an array from nested arrays or slices.
575    fn from_nested(data: T) -> Array;
576}
577
578impl<T: FromSliceElement> FromNested<&[T]> for Array {
579    fn from_nested(data: &[T]) -> Self {
580        Array::from_slice(data, &[data.len() as i32])
581    }
582}
583
584impl<T: FromSliceElement, const N: usize> FromNested<[T; N]> for Array {
585    fn from_nested(data: [T; N]) -> Self {
586        Array::from_slice(&data, &[N as i32])
587    }
588}
589
590impl<T: FromSliceElement, const N: usize> FromNested<&[T; N]> for Array {
591    fn from_nested(data: &[T; N]) -> Self {
592        Array::from_slice(data, &[N as i32])
593    }
594}
595
596impl<T: FromSliceElement + Copy> FromNested<&[&[T]]> for Array {
597    fn from_nested(data: &[&[T]]) -> Self {
598        // check that all rows have the same length
599        let row_len = data[0].len();
600        assert!(
601            data.iter().all(|row| row.len() == row_len),
602            "Rows must have the same length"
603        );
604
605        let shape = [data.len() as i32, row_len as i32];
606        let data = data
607            .iter()
608            .flat_map(|x| x.iter())
609            .copied()
610            .collect::<Vec<T>>();
611        Array::from_slice(&data, &shape)
612    }
613}
614
615impl<T: FromSliceElement + Copy, const N: usize> FromNested<[&[T]; N]> for Array {
616    fn from_nested(data: [&[T]; N]) -> Self {
617        // check that all rows have the same length
618        let row_len = data[0].len();
619        assert!(
620            data.iter().all(|row| row.len() == row_len),
621            "Rows must have the same length"
622        );
623
624        let shape = [N as i32, row_len as i32];
625        let data = data
626            .iter()
627            .flat_map(|x| x.iter())
628            .copied()
629            .collect::<Vec<T>>();
630        Array::from_slice(&data, &shape)
631    }
632}
633
634impl<T: FromSliceElement + Copy, const N: usize> FromNested<&[[T; N]]> for Array {
635    fn from_nested(data: &[[T; N]]) -> Self {
636        let shape = [data.len() as i32, N as i32];
637        let data = data
638            .iter()
639            .flat_map(|x| x.iter().copied())
640            .collect::<Vec<T>>();
641        Array::from_slice(&data, &shape)
642    }
643}
644
645impl<T: FromSliceElement + Copy, const N: usize> FromNested<&[&[T; N]]> for Array {
646    fn from_nested(data: &[&[T; N]]) -> Self {
647        let shape = [data.len() as i32, N as i32];
648        let data = data
649            .iter()
650            .flat_map(|x| x.iter().copied())
651            .collect::<Vec<T>>();
652        Array::from_slice(&data, &shape)
653    }
654}
655
656impl<T: FromSliceElement + Copy, const N: usize, const M: usize> FromNested<[[T; N]; M]> for Array {
657    fn from_nested(data: [[T; N]; M]) -> Self {
658        let shape = [M as i32, N as i32];
659        let data = data
660            .iter()
661            .flat_map(|x| x.iter().copied())
662            .collect::<Vec<T>>();
663        Array::from_slice(&data, &shape)
664    }
665}
666
667impl<T: FromSliceElement + Copy, const N: usize, const M: usize> FromNested<&[[T; N]; M]>
668    for Array
669{
670    fn from_nested(data: &[[T; N]; M]) -> Self {
671        let shape = [M as i32, N as i32];
672        let data = data
673            .iter()
674            .flat_map(|x| x.iter().copied())
675            .collect::<Vec<T>>();
676        Array::from_slice(&data, &shape)
677    }
678}
679
680impl<T: FromSliceElement + Copy, const N: usize, const M: usize> FromNested<&[&[T; N]; M]>
681    for Array
682{
683    fn from_nested(data: &[&[T; N]; M]) -> Self {
684        let shape = [M as i32, N as i32];
685        let data = data
686            .iter()
687            .flat_map(|x| x.iter().copied())
688            .collect::<Vec<T>>();
689        Array::from_slice(&data, &shape)
690    }
691}
692
693impl<T: FromSliceElement + Copy> FromNested<&[&[&[T]]]> for Array {
694    fn from_nested(data: &[&[&[T]]]) -> Self {
695        // check that 2nd dimension has the same length
696        let len_2d = data[0].len();
697        assert!(
698            data.iter().all(|x| x.len() == len_2d),
699            "2nd dimension must have the same length"
700        );
701
702        // check that 3rd dimension has the same length
703        let len_3d = data[0][0].len();
704        assert!(
705            data.iter().all(|x| x.iter().all(|y| y.len() == len_3d)),
706            "3rd dimension must have the same length"
707        );
708
709        let shape = [data.len() as i32, len_2d as i32, len_3d as i32];
710        let data = data
711            .iter()
712            .flat_map(|x| x.iter().flat_map(|y| y.iter().copied()))
713            .collect::<Vec<T>>();
714        Array::from_slice(&data, &shape)
715    }
716}
717
718impl<T: FromSliceElement + Copy, const N: usize> FromNested<[&[&[T]]; N]> for Array {
719    fn from_nested(data: [&[&[T]]; N]) -> Self {
720        // check that 2nd dimension has the same length
721        let len_2d = data[0].len();
722        assert!(
723            data.iter().all(|x| x.len() == len_2d),
724            "2nd dimension must have the same length"
725        );
726
727        // check that 3rd dimension has the same length
728        let len_3d = data[0][0].len();
729        assert!(
730            data.iter().all(|x| x.iter().all(|y| y.len() == len_3d)),
731            "3rd dimension must have the same length"
732        );
733
734        let shape = [N as i32, len_2d as i32, len_3d as i32];
735        let data = data
736            .iter()
737            .flat_map(|x| x.iter().flat_map(|y| y.iter().copied()))
738            .collect::<Vec<T>>();
739        Array::from_slice(&data, &shape)
740    }
741}
742
743impl<T: FromSliceElement + Copy, const N: usize> FromNested<&[[&[T]; N]]> for Array {
744    fn from_nested(data: &[[&[T]; N]]) -> Self {
745        // check that 3rd dimension has the same length
746        let len_3d = data[0][0].len();
747        assert!(
748            data.iter().all(|x| x.iter().all(|y| y.len() == len_3d)),
749            "3rd dimension must have the same length"
750        );
751
752        let shape = [data.len() as i32, N as i32, len_3d as i32];
753        let data = data
754            .iter()
755            .flat_map(|x| x.iter().flat_map(|y| y.iter().copied()))
756            .collect::<Vec<T>>();
757        Array::from_slice(&data, &shape)
758    }
759}
760
761impl<T: FromSliceElement + Copy, const N: usize> FromNested<&[&[[T; N]]]> for Array {
762    fn from_nested(data: &[&[[T; N]]]) -> Self {
763        // check that 2nd dimension has the same length
764        let len_2d = data[0].len();
765        assert!(
766            data.iter().all(|x| x.len() == len_2d),
767            "2nd dimension must have the same length"
768        );
769
770        let shape = [data.len() as i32, len_2d as i32, N as i32];
771        let data = data
772            .iter()
773            .flat_map(|x| x.iter().flat_map(|y| y.iter().copied()))
774            .collect::<Vec<T>>();
775        Array::from_slice(&data, &shape)
776    }
777}
778
779impl<T: FromSliceElement + Copy, const N: usize, const M: usize> FromNested<[[&[T]; N]; M]>
780    for Array
781{
782    fn from_nested(data: [[&[T]; N]; M]) -> Self {
783        // check that 3rd dimension has the same length
784        let len_3d = data[0][0].len();
785        assert!(
786            data.iter().all(|x| x.iter().all(|y| y.len() == len_3d)),
787            "3rd dimension must have the same length"
788        );
789
790        let shape = [M as i32, N as i32, len_3d as i32];
791        let data = data
792            .iter()
793            .flat_map(|x| x.iter().flat_map(|y| y.iter().copied()))
794            .collect::<Vec<T>>();
795        Array::from_slice(&data, &shape)
796    }
797}
798
799impl<T: FromSliceElement + Copy, const N: usize, const M: usize> FromNested<&[[&[T]; N]; M]>
800    for Array
801{
802    fn from_nested(data: &[[&[T]; N]; M]) -> Self {
803        // check that 3rd dimension has the same length
804        let len_3d = data[0][0].len();
805        assert!(
806            data.iter().all(|x| x.iter().all(|y| y.len() == len_3d)),
807            "3rd dimension must have the same length"
808        );
809
810        let shape = [M as i32, N as i32, len_3d as i32];
811        let data = data
812            .iter()
813            .flat_map(|x| x.iter().flat_map(|y| y.iter().copied()))
814            .collect::<Vec<T>>();
815        Array::from_slice(&data, &shape)
816    }
817}
818
819impl<T: FromSliceElement + Copy, const N: usize, const M: usize> FromNested<&[&[[T; N]]; M]>
820    for Array
821{
822    fn from_nested(data: &[&[[T; N]]; M]) -> Self {
823        // check that 2nd dimension has the same length
824        let len_2d = data[0].len();
825        assert!(
826            data.iter().all(|x| x.len() == len_2d),
827            "2nd dimension must have the same length"
828        );
829
830        let shape = [M as i32, len_2d as i32, N as i32];
831        let data = data
832            .iter()
833            .flat_map(|x| x.iter().flat_map(|y| y.iter().copied()))
834            .collect::<Vec<T>>();
835        Array::from_slice(&data, &shape)
836    }
837}
838
839impl<T: FromSliceElement + Copy, const N: usize, const M: usize, const O: usize>
840    FromNested<[[[T; N]; M]; O]> for Array
841{
842    fn from_nested(data: [[[T; N]; M]; O]) -> Self {
843        let shape = [O as i32, M as i32, N as i32];
844        let data = data
845            .iter()
846            .flat_map(|x| x.iter().flat_map(|y| y.iter().copied()))
847            .collect::<Vec<T>>();
848        Array::from_slice(&data, &shape)
849    }
850}
851
852impl<T: FromSliceElement + Copy, const N: usize, const M: usize, const O: usize>
853    FromNested<&[[[T; N]; M]; O]> for Array
854{
855    fn from_nested(data: &[[[T; N]; M]; O]) -> Self {
856        let shape = [O as i32, M as i32, N as i32];
857        let data = data
858            .iter()
859            .flat_map(|x| x.iter().flat_map(|y| y.iter().copied()))
860            .collect::<Vec<T>>();
861        Array::from_slice(&data, &shape)
862    }
863}
864
865impl<T: FromSliceElement + Copy, const N: usize, const M: usize, const O: usize>
866    FromNested<&[&[[T; N]; M]; O]> for Array
867{
868    fn from_nested(data: &[&[[T; N]; M]; O]) -> Self {
869        let shape = [O as i32, M as i32, N as i32];
870        let data = data
871            .iter()
872            .flat_map(|x| x.iter().flat_map(|y| y.iter().copied()))
873            .collect::<Vec<T>>();
874        Array::from_slice(&data, &shape)
875    }
876}
877
878impl<T: FromSliceElement + Copy, const N: usize, const M: usize, const O: usize>
879    FromNested<&[[&[T; N]; M]; O]> for Array
880{
881    fn from_nested(data: &[[&[T; N]; M]; O]) -> Self {
882        let shape = [O as i32, M as i32, N as i32];
883        let data = data
884            .iter()
885            .flat_map(|x| x.iter().flat_map(|y| y.iter().copied()))
886            .collect::<Vec<T>>();
887        Array::from_slice(&data, &shape)
888    }
889}
890
891impl<T: FromSliceElement + Copy, const N: usize, const M: usize, const O: usize>
892    FromNested<&[&[&[T; N]; M]; O]> for Array
893{
894    fn from_nested(data: &[&[&[T; N]; M]; O]) -> Self {
895        let shape = [O as i32, M as i32, N as i32];
896        let data = data
897            .iter()
898            .flat_map(|x| x.iter().flat_map(|y| y.iter().copied()))
899            .collect::<Vec<T>>();
900        Array::from_slice(&data, &shape)
901    }
902}
903
904#[cfg(test)]
905mod tests {
906    use super::*;
907
908    #[test]
909    fn new_scalar_array_from_bool() {
910        let array = Array::from_bool(true);
911        assert!(array.item::<bool>());
912        assert_eq!(array.item_size(), 1);
913        assert_eq!(array.size(), 1);
914        assert!(array.strides().is_empty());
915        assert_eq!(array.nbytes(), 1);
916        assert_eq!(array.ndim(), 0);
917        assert!(array.shape().is_empty());
918        assert_eq!(array.dtype(), Dtype::Bool);
919    }
920
921    #[test]
922    fn new_scalar_array_from_int() {
923        let array = Array::from_int(42);
924        assert_eq!(array.item::<i32>(), 42);
925        assert_eq!(array.item_size(), 4);
926        assert_eq!(array.size(), 1);
927        assert!(array.strides().is_empty());
928        assert_eq!(array.nbytes(), 4);
929        assert_eq!(array.ndim(), 0);
930        assert!(array.shape().is_empty());
931        assert_eq!(array.dtype(), Dtype::Int32);
932    }
933
934    #[test]
935    fn new_scalar_array_from_f32() {
936        let array = Array::from_f32(3.14);
937        assert_eq!(array.item::<f32>(), 3.14);
938        assert_eq!(array.item_size(), 4);
939        assert_eq!(array.size(), 1);
940        assert!(array.strides().is_empty());
941        assert_eq!(array.nbytes(), 4);
942        assert_eq!(array.ndim(), 0);
943        assert!(array.shape().is_empty());
944        assert_eq!(array.dtype(), Dtype::Float32);
945    }
946
947    // TODO: this is bugged right now. See https://github.com/ml-explore/mlx/issues/1994
948    // #[test]
949    // fn new_scalar_array_from_f64() {
950    //     let array = Array::from_f64(3.14).as_dtype(Dtype::Float64).unwrap();
951    //     float_eq::assert_float_eq!(array.item::<f64>(), 3.14, abs <= 1e-5);
952    //     assert_eq!(array.item_size(), 8);
953    //     assert_eq!(array.size(), 1);
954    //     assert!(array.strides().is_empty());
955    //     assert_eq!(array.nbytes(), 8);
956    //     assert_eq!(array.ndim(), 0);
957    //     assert!(array.shape().is_empty());
958    //     assert_eq!(array.dtype(), Dtype::Float64);
959    // }
960
961    #[test]
962    fn new_array_from_slice_f64() {
963        let array = Array::from_slice_f64(&[1.0, 2.0, 3.0], &[3]);
964        assert_eq!(array.item_size(), 8);
965        assert_eq!(array.size(), 3);
966        assert_eq!(array.strides(), &[1]);
967        assert_eq!(array.nbytes(), 24);
968        assert_eq!(array.ndim(), 1);
969        assert_eq!(array.dim(0), 3);
970        assert_eq!(array.shape(), &[3]);
971        assert_eq!(array.dtype(), Dtype::Float64);
972    }
973
974    #[test]
975    fn new_scalar_array_from_complex() {
976        let val = complex64::new(1.0, 2.0);
977        let array = Array::from_complex(val);
978        assert_eq!(array.item::<complex64>(), val);
979        assert_eq!(array.item_size(), 8);
980        assert_eq!(array.size(), 1);
981        assert!(array.strides().is_empty());
982        assert_eq!(array.nbytes(), 8);
983        assert_eq!(array.ndim(), 0);
984        assert!(array.shape().is_empty());
985        assert_eq!(array.dtype(), Dtype::Complex64);
986    }
987
988    #[test]
989    fn new_array_from_single_element_slice() {
990        let data = [1i32];
991        let array = Array::from_slice(&data, &[1]);
992        assert_eq!(array.as_slice::<i32>(), &data[..]);
993        assert_eq!(array.item::<i32>(), 1);
994        assert_eq!(array.item_size(), 4);
995        assert_eq!(array.size(), 1);
996        assert_eq!(array.strides(), &[1]);
997        assert_eq!(array.nbytes(), 4);
998        assert_eq!(array.ndim(), 1);
999        assert_eq!(array.dim(0), 1);
1000        assert_eq!(array.shape(), &[1]);
1001        assert_eq!(array.dtype(), Dtype::Int32);
1002    }
1003
1004    #[test]
1005    fn new_array_from_multi_element_slice() {
1006        let data = [1i32, 2, 3, 4, 5];
1007        let array = Array::from_slice(&data, &[5]);
1008        assert_eq!(array.as_slice::<i32>(), &data[..]);
1009        assert_eq!(array.item_size(), 4);
1010        assert_eq!(array.size(), 5);
1011        assert_eq!(array.strides(), &[1]);
1012        assert_eq!(array.nbytes(), 20);
1013        assert_eq!(array.ndim(), 1);
1014        assert_eq!(array.dim(0), 5);
1015        assert_eq!(array.shape(), &[5]);
1016        assert_eq!(array.dtype(), Dtype::Int32);
1017    }
1018
1019    #[test]
1020    fn new_2d_array_from_slice() {
1021        let data = [1i32, 2, 3, 4, 5, 6];
1022        let array = Array::from_slice(&data, &[2, 3]);
1023        assert_eq!(array.as_slice::<i32>(), &data[..]);
1024        assert_eq!(array.item_size(), 4);
1025        assert_eq!(array.size(), 6);
1026        assert_eq!(array.strides(), &[3, 1]);
1027        assert_eq!(array.nbytes(), 24);
1028        assert_eq!(array.ndim(), 2);
1029        assert_eq!(array.dim(0), 2);
1030        assert_eq!(array.dim(1), 3);
1031        assert_eq!(array.dim(-1), 3); // negative index
1032        assert_eq!(array.dim(-2), 2); // negative index
1033        assert_eq!(array.shape(), &[2, 3]);
1034        assert_eq!(array.dtype(), Dtype::Int32);
1035    }
1036
1037    #[test]
1038    fn deep_cloned_array_has_different_ptr() {
1039        let data = [1i32, 2, 3, 4, 5];
1040        let orig = Array::from_slice(&data, &[5]);
1041        let clone = orig.deep_clone();
1042
1043        // Data should be the same
1044        assert_eq!(orig.as_slice::<i32>(), clone.as_slice::<i32>());
1045
1046        // Addr of `mlx_array` should be different
1047        assert_ne!(orig.as_ptr().ctx, clone.as_ptr().ctx);
1048
1049        // Addr of data should be different
1050        assert_ne!(
1051            orig.as_slice::<i32>().as_ptr(),
1052            clone.as_slice::<i32>().as_ptr()
1053        );
1054    }
1055
1056    #[test]
1057    fn test_array_eq() {
1058        let data = [1i32, 2, 3, 4, 5];
1059        let array1 = Array::from_slice(&data, &[5]);
1060        let array2 = Array::from_slice(&data, &[5]);
1061        let array3 = Array::from_slice(&[1i32, 2, 3, 4, 6], &[5]);
1062
1063        assert_eq!(&array1, &array2);
1064        assert_ne!(&array1, &array3);
1065    }
1066
1067    #[test]
1068    fn test_array_item_non_scalar() {
1069        let data = [1i32, 2, 3, 4, 5];
1070        let array = Array::from_slice(&data, &[5]);
1071        assert!(array.try_item::<i32>().is_err());
1072    }
1073
1074    #[test]
1075    fn test_item_type_conversion() {
1076        let array = Array::from_f32(1.0);
1077        assert_eq!(array.item::<i32>(), 1);
1078        assert_eq!(array.item::<complex64>(), complex64::new(1.0, 0.0));
1079        assert_eq!(array.item::<u8>(), 1);
1080
1081        assert_eq!(array.as_slice::<f32>(), &[1.0]);
1082    }
1083}