1use crate::{
2 dtype::Dtype,
3 error::AsSliceError,
4 sealed::Sealed,
5 utils::{guard::Guarded, SUCCESS},
6 Stream,
7};
8use element::FromSliceElement;
9use mlx_internal_macros::default_device;
10use mlx_sys::mlx_array;
11use num_complex::Complex;
12use std::{
13 ffi::{c_void, CStr},
14 iter::Sum,
15};
16
17mod element;
18mod operators;
19
20cfg_safetensors! {
21 mod safetensors;
22}
23
24pub use element::ArrayElement;
25
26#[allow(non_camel_case_types)]
30pub type complex64 = Complex<f32>;
31
32#[repr(transparent)]
34pub struct Array {
35 c_array: mlx_array,
36}
37
38impl Sealed for Array {}
39
40impl Sealed for &Array {}
41
42impl std::fmt::Debug for Array {
43 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44 write!(f, "{self}")
45 }
46}
47
48impl std::fmt::Display for Array {
49 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
50 unsafe {
51 let mut mlx_str = mlx_sys::mlx_string_new();
52 let status = mlx_sys::mlx_array_tostring(&mut mlx_str as *mut _, self.as_ptr());
53 if status != SUCCESS {
54 return Err(std::fmt::Error);
55 }
56 let ptr = mlx_sys::mlx_string_data(mlx_str);
57 let c_str = CStr::from_ptr(ptr);
58 write!(f, "{}", c_str.to_str().map_err(|_| std::fmt::Error)?)?;
59 mlx_sys::mlx_string_free(mlx_str);
60 Ok(())
61 }
62 }
63}
64
65impl Drop for Array {
66 fn drop(&mut self) {
67 unsafe { mlx_sys::mlx_array_free(self.as_ptr()) };
71 }
72}
73
74unsafe impl Send for Array {}
75
76impl PartialEq for Array {
77 fn eq(&self, other: &Self) -> bool {
85 self.array_eq(other, None).unwrap().item()
86 }
87}
88
89impl Array {
90 pub unsafe fn from_ptr(c_array: mlx_array) -> Array {
97 Self { c_array }
98 }
99
100 pub fn as_ptr(&self) -> mlx_array {
102 self.c_array
103 }
104
105 pub fn from_bool(val: bool) -> Array {
107 let c_array = unsafe { mlx_sys::mlx_array_new_bool(val) };
108 Array { c_array }
109 }
110
111 pub fn from_int(val: i32) -> Array {
113 let c_array = unsafe { mlx_sys::mlx_array_new_int(val) };
114 Array { c_array }
115 }
116
117 pub fn from_f32(val: f32) -> Array {
119 let c_array = unsafe { mlx_sys::mlx_array_new_float32(val) };
120 Array { c_array }
121 }
122
123 pub fn from_f64(val: f64) -> Array {
125 let c_array = unsafe { mlx_sys::mlx_array_new_float64(val) };
126 Array { c_array }
127 }
128
129 pub fn from_complex(val: complex64) -> Array {
131 let c_array = unsafe { mlx_sys::mlx_array_new_complex(val.re, val.im) };
132 Array { c_array }
133 }
134
135 pub fn from_slice<T: FromSliceElement>(data: &[T], shape: &[i32]) -> Self {
151 assert_eq!(data.len(), shape.iter().product::<i32>() as usize);
153
154 unsafe { Self::from_raw_data(data.as_ptr() as *const c_void, shape, T::DTYPE) }
155 }
156
157 pub fn from_slice_f64(data: &[f64], shape: &[i32]) -> Self {
162 assert_eq!(data.len(), shape.iter().product::<i32>() as usize);
164
165 unsafe { Self::from_raw_data(data.as_ptr() as *const c_void, shape, Dtype::Float64) }
166 }
167
168 #[inline]
177 pub unsafe fn from_raw_data(data: *const c_void, shape: &[i32], dtype: Dtype) -> Self {
178 let dim = if shape.len() > i32::MAX as usize {
179 panic!("Shape is too large")
180 } else {
181 shape.len() as i32
182 };
183
184 let c_array = mlx_sys::mlx_array_new_data(data, shape.as_ptr(), dim, dtype.into());
185 Array { c_array }
186 }
187
188 pub fn from_iter<I: IntoIterator<Item = T>, T: FromSliceElement>(
210 iter: I,
211 shape: &[i32],
212 ) -> Self {
213 let data: Vec<T> = iter.into_iter().collect();
214 Self::from_slice(&data, shape)
215 }
216
217 pub fn from_iter_f64<I: IntoIterator<Item = f64>>(iter: I, shape: &[i32]) -> Self {
222 let data: Vec<f64> = iter.into_iter().collect();
223 Self::from_slice_f64(&data, shape)
224 }
225
226 pub fn item_size(&self) -> usize {
228 unsafe { mlx_sys::mlx_array_itemsize(self.as_ptr()) }
229 }
230
231 pub fn size(&self) -> usize {
233 unsafe { mlx_sys::mlx_array_size(self.as_ptr()) }
234 }
235
236 pub fn strides(&self) -> &[usize] {
238 let ndim = self.ndim();
239 if ndim == 0 {
240 return &[];
242 }
243
244 unsafe {
245 let data = mlx_sys::mlx_array_strides(self.as_ptr());
246 std::slice::from_raw_parts(data, ndim)
247 }
248 }
249
250 pub fn nbytes(&self) -> usize {
252 unsafe { mlx_sys::mlx_array_nbytes(self.as_ptr()) }
253 }
254
255 pub fn ndim(&self) -> usize {
257 unsafe { mlx_sys::mlx_array_ndim(self.as_ptr()) }
258 }
259
260 pub fn shape(&self) -> &[i32] {
264 let ndim = self.ndim();
265 if ndim == 0 {
266 return &[];
268 }
269
270 unsafe {
271 let data = mlx_sys::mlx_array_shape(self.as_ptr());
272 std::slice::from_raw_parts(data, ndim)
273 }
274 }
275
276 pub fn dim(&self, dim: i32) -> i32 {
284 let dim = if dim.is_negative() {
285 (self.ndim() as i32).checked_add(dim).unwrap()
286 } else {
287 dim
288 };
289
290 unsafe { mlx_sys::mlx_array_dim(self.as_ptr(), dim) }
292 }
293
294 pub fn dtype(&self) -> Dtype {
296 let dtype = unsafe { mlx_sys::mlx_array_dtype(self.as_ptr()) };
297 Dtype::try_from(dtype).unwrap()
298 }
299
300 pub fn eval(&self) -> crate::error::Result<()> {
302 <() as Guarded>::try_from_op(|_| unsafe { mlx_sys::mlx_array_eval(self.as_ptr()) })
303 }
304
305 pub fn item<T: ArrayElement>(&self) -> T {
310 self.try_item().unwrap()
311 }
312
313 pub fn try_item<T: ArrayElement>(&self) -> crate::error::Result<T> {
318 self.eval()?;
319
320 self.eval()?;
322
323 if self.dtype() != T::DTYPE {
326 let new_array = Array::try_from_op(|res| unsafe {
327 mlx_sys::mlx_astype(
328 res,
329 self.as_ptr(),
330 T::DTYPE.into(),
331 Stream::default().as_ptr(),
332 )
333 })?;
334 new_array.eval()?;
335 return T::array_item(&new_array);
336 }
337
338 T::array_item(self)
339 }
340
341 pub unsafe fn as_slice_unchecked<T: ArrayElement>(&self) -> &[T] {
362 self.eval().unwrap();
363
364 unsafe {
365 let data = T::array_data(self);
366 let size = self.size();
367 std::slice::from_raw_parts(data, size)
368 }
369 }
370
371 pub fn try_as_slice<T: ArrayElement>(&self) -> Result<&[T], AsSliceError> {
385 if self.dtype() != T::DTYPE {
386 return Err(AsSliceError::DtypeMismatch {
387 expecting: T::DTYPE,
388 found: self.dtype(),
389 });
390 }
391
392 self.eval()?;
393
394 unsafe {
395 let size = self.size();
396 let data = T::array_data(self);
397 if data.is_null() || size == 0 {
398 return Err(AsSliceError::Null);
399 }
400
401 Ok(std::slice::from_raw_parts(data, size))
402 }
403 }
404
405 pub fn as_slice<T: ArrayElement>(&self) -> &[T] {
424 self.try_as_slice().unwrap()
425 }
426
427 pub fn deep_clone(&self) -> Self {
431 unsafe {
432 let dtype = self.dtype();
433 let shape = self.shape();
434 let data = match dtype {
435 Dtype::Bool => mlx_sys::mlx_array_data_bool(self.as_ptr()) as *const c_void,
436 Dtype::Uint8 => mlx_sys::mlx_array_data_uint8(self.as_ptr()) as *const c_void,
437 Dtype::Uint16 => mlx_sys::mlx_array_data_uint16(self.as_ptr()) as *const c_void,
438 Dtype::Uint32 => mlx_sys::mlx_array_data_uint32(self.as_ptr()) as *const c_void,
439 Dtype::Uint64 => mlx_sys::mlx_array_data_uint64(self.as_ptr()) as *const c_void,
440 Dtype::Int8 => mlx_sys::mlx_array_data_int8(self.as_ptr()) as *const c_void,
441 Dtype::Int16 => mlx_sys::mlx_array_data_int16(self.as_ptr()) as *const c_void,
442 Dtype::Int32 => mlx_sys::mlx_array_data_int32(self.as_ptr()) as *const c_void,
443 Dtype::Int64 => mlx_sys::mlx_array_data_int64(self.as_ptr()) as *const c_void,
444 Dtype::Float16 => mlx_sys::mlx_array_data_float16(self.as_ptr()) as *const c_void,
445 Dtype::Float32 => mlx_sys::mlx_array_data_float32(self.as_ptr()) as *const c_void,
446 Dtype::Float64 => mlx_sys::mlx_array_data_float64(self.as_ptr()) as *const c_void,
447 Dtype::Bfloat16 => mlx_sys::mlx_array_data_bfloat16(self.as_ptr()) as *const c_void,
448 Dtype::Complex64 => {
449 mlx_sys::mlx_array_data_complex64(self.as_ptr()) as *const c_void
450 }
451 };
452
453 let new_c_array =
454 mlx_sys::mlx_array_new_data(data, shape.as_ptr(), shape.len() as i32, dtype.into());
455
456 Array::from_ptr(new_c_array)
457 }
458 }
459}
460
461impl Clone for Array {
462 fn clone(&self) -> Self {
463 Array::try_from_op(|res| unsafe { mlx_sys::mlx_array_set(res, self.as_ptr()) })
464 .expect("Failed to clone array")
466 }
467}
468
469impl Sum for Array {
470 fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
471 iter.fold(Array::from_int(0), |acc, x| acc.add(&x).unwrap())
472 }
473}
474
475#[default_device]
480pub fn stop_gradient_device(
481 a: impl AsRef<Array>,
482 stream: impl AsRef<Stream>,
483) -> crate::error::Result<Array> {
484 Array::try_from_op(|res| unsafe {
485 mlx_sys::mlx_stop_gradient(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
486 })
487}
488
489impl From<bool> for Array {
490 fn from(value: bool) -> Self {
491 Array::from_bool(value)
492 }
493}
494
495impl From<i32> for Array {
496 fn from(value: i32) -> Self {
497 Array::from_int(value)
498 }
499}
500
501impl From<f32> for Array {
502 fn from(value: f32) -> Self {
503 Array::from_f32(value)
504 }
505}
506
507impl From<complex64> for Array {
508 fn from(value: complex64) -> Self {
509 Array::from_complex(value)
510 }
511}
512
513impl<T> From<T> for Array
514where
515 Array: FromNested<T>,
516{
517 fn from(value: T) -> Self {
518 Array::from_nested(value)
519 }
520}
521
522impl AsRef<Array> for Array {
523 fn as_ref(&self) -> &Array {
524 self
525 }
526}
527
528pub trait FromScalar<T>
532where
533 T: ArrayElement,
534{
535 fn from_scalar(val: T) -> Array;
537}
538
539impl FromScalar<bool> for Array {
540 fn from_scalar(val: bool) -> Array {
541 Array::from_bool(val)
542 }
543}
544
545impl FromScalar<i32> for Array {
546 fn from_scalar(val: i32) -> Array {
547 Array::from_int(val)
548 }
549}
550
551impl FromScalar<f32> for Array {
552 fn from_scalar(val: f32) -> Array {
553 Array::from_f32(val)
554 }
555}
556
557impl FromScalar<complex64> for Array {
558 fn from_scalar(val: complex64) -> Array {
559 Array::from_complex(val)
560 }
561}
562
563pub trait FromNested<T> {
572 fn from_nested(data: T) -> Array;
574}
575
576impl<T: FromSliceElement> FromNested<&[T]> for Array {
577 fn from_nested(data: &[T]) -> Self {
578 Array::from_slice(data, &[data.len() as i32])
579 }
580}
581
582impl<T: FromSliceElement, const N: usize> FromNested<[T; N]> for Array {
583 fn from_nested(data: [T; N]) -> Self {
584 Array::from_slice(&data, &[N as i32])
585 }
586}
587
588impl<T: FromSliceElement, const N: usize> FromNested<&[T; N]> for Array {
589 fn from_nested(data: &[T; N]) -> Self {
590 Array::from_slice(data, &[N as i32])
591 }
592}
593
594impl<T: FromSliceElement + Copy> FromNested<&[&[T]]> for Array {
595 fn from_nested(data: &[&[T]]) -> Self {
596 let row_len = data[0].len();
598 assert!(
599 data.iter().all(|row| row.len() == row_len),
600 "Rows must have the same length"
601 );
602
603 let shape = [data.len() as i32, row_len as i32];
604 let data = data
605 .iter()
606 .flat_map(|x| x.iter())
607 .copied()
608 .collect::<Vec<T>>();
609 Array::from_slice(&data, &shape)
610 }
611}
612
613impl<T: FromSliceElement + Copy, const N: usize> FromNested<[&[T]; N]> for Array {
614 fn from_nested(data: [&[T]; N]) -> Self {
615 let row_len = data[0].len();
617 assert!(
618 data.iter().all(|row| row.len() == row_len),
619 "Rows must have the same length"
620 );
621
622 let shape = [N as i32, row_len as i32];
623 let data = data
624 .iter()
625 .flat_map(|x| x.iter())
626 .copied()
627 .collect::<Vec<T>>();
628 Array::from_slice(&data, &shape)
629 }
630}
631
632impl<T: FromSliceElement + Copy, const N: usize> FromNested<&[[T; N]]> for Array {
633 fn from_nested(data: &[[T; N]]) -> Self {
634 let shape = [data.len() as i32, N as i32];
635 let data = data
636 .iter()
637 .flat_map(|x| x.iter().copied())
638 .collect::<Vec<T>>();
639 Array::from_slice(&data, &shape)
640 }
641}
642
643impl<T: FromSliceElement + Copy, const N: usize> FromNested<&[&[T; N]]> for Array {
644 fn from_nested(data: &[&[T; N]]) -> Self {
645 let shape = [data.len() as i32, N as i32];
646 let data = data
647 .iter()
648 .flat_map(|x| x.iter().copied())
649 .collect::<Vec<T>>();
650 Array::from_slice(&data, &shape)
651 }
652}
653
654impl<T: FromSliceElement + Copy, const N: usize, const M: usize> FromNested<[[T; N]; M]> for Array {
655 fn from_nested(data: [[T; N]; M]) -> Self {
656 let shape = [M as i32, N as i32];
657 let data = data
658 .iter()
659 .flat_map(|x| x.iter().copied())
660 .collect::<Vec<T>>();
661 Array::from_slice(&data, &shape)
662 }
663}
664
665impl<T: FromSliceElement + Copy, const N: usize, const M: usize> FromNested<&[[T; N]; M]>
666 for Array
667{
668 fn from_nested(data: &[[T; N]; M]) -> Self {
669 let shape = [M as i32, N as i32];
670 let data = data
671 .iter()
672 .flat_map(|x| x.iter().copied())
673 .collect::<Vec<T>>();
674 Array::from_slice(&data, &shape)
675 }
676}
677
678impl<T: FromSliceElement + Copy, const N: usize, const M: usize> FromNested<&[&[T; N]; M]>
679 for Array
680{
681 fn from_nested(data: &[&[T; N]; M]) -> Self {
682 let shape = [M as i32, N as i32];
683 let data = data
684 .iter()
685 .flat_map(|x| x.iter().copied())
686 .collect::<Vec<T>>();
687 Array::from_slice(&data, &shape)
688 }
689}
690
691impl<T: FromSliceElement + Copy> FromNested<&[&[&[T]]]> for Array {
692 fn from_nested(data: &[&[&[T]]]) -> Self {
693 let len_2d = data[0].len();
695 assert!(
696 data.iter().all(|x| x.len() == len_2d),
697 "2nd dimension must have the same length"
698 );
699
700 let len_3d = data[0][0].len();
702 assert!(
703 data.iter().all(|x| x.iter().all(|y| y.len() == len_3d)),
704 "3rd dimension must have the same length"
705 );
706
707 let shape = [data.len() as i32, len_2d as i32, len_3d as i32];
708 let data = data
709 .iter()
710 .flat_map(|x| x.iter().flat_map(|y| y.iter().copied()))
711 .collect::<Vec<T>>();
712 Array::from_slice(&data, &shape)
713 }
714}
715
716impl<T: FromSliceElement + Copy, const N: usize> FromNested<[&[&[T]]; N]> for Array {
717 fn from_nested(data: [&[&[T]]; N]) -> Self {
718 let len_2d = data[0].len();
720 assert!(
721 data.iter().all(|x| x.len() == len_2d),
722 "2nd dimension must have the same length"
723 );
724
725 let len_3d = data[0][0].len();
727 assert!(
728 data.iter().all(|x| x.iter().all(|y| y.len() == len_3d)),
729 "3rd dimension must have the same length"
730 );
731
732 let shape = [N as i32, len_2d as i32, len_3d as i32];
733 let data = data
734 .iter()
735 .flat_map(|x| x.iter().flat_map(|y| y.iter().copied()))
736 .collect::<Vec<T>>();
737 Array::from_slice(&data, &shape)
738 }
739}
740
741impl<T: FromSliceElement + Copy, const N: usize> FromNested<&[[&[T]; N]]> for Array {
742 fn from_nested(data: &[[&[T]; N]]) -> Self {
743 let len_3d = data[0][0].len();
745 assert!(
746 data.iter().all(|x| x.iter().all(|y| y.len() == len_3d)),
747 "3rd dimension must have the same length"
748 );
749
750 let shape = [data.len() as i32, N as i32, len_3d as i32];
751 let data = data
752 .iter()
753 .flat_map(|x| x.iter().flat_map(|y| y.iter().copied()))
754 .collect::<Vec<T>>();
755 Array::from_slice(&data, &shape)
756 }
757}
758
759impl<T: FromSliceElement + Copy, const N: usize> FromNested<&[&[[T; N]]]> for Array {
760 fn from_nested(data: &[&[[T; N]]]) -> Self {
761 let len_2d = data[0].len();
763 assert!(
764 data.iter().all(|x| x.len() == len_2d),
765 "2nd dimension must have the same length"
766 );
767
768 let shape = [data.len() as i32, len_2d as i32, N as i32];
769 let data = data
770 .iter()
771 .flat_map(|x| x.iter().flat_map(|y| y.iter().copied()))
772 .collect::<Vec<T>>();
773 Array::from_slice(&data, &shape)
774 }
775}
776
777impl<T: FromSliceElement + Copy, const N: usize, const M: usize> FromNested<[[&[T]; N]; M]>
778 for Array
779{
780 fn from_nested(data: [[&[T]; N]; M]) -> Self {
781 let len_3d = data[0][0].len();
783 assert!(
784 data.iter().all(|x| x.iter().all(|y| y.len() == len_3d)),
785 "3rd dimension must have the same length"
786 );
787
788 let shape = [M as i32, N as i32, len_3d as i32];
789 let data = data
790 .iter()
791 .flat_map(|x| x.iter().flat_map(|y| y.iter().copied()))
792 .collect::<Vec<T>>();
793 Array::from_slice(&data, &shape)
794 }
795}
796
797impl<T: FromSliceElement + Copy, const N: usize, const M: usize> FromNested<&[[&[T]; N]; M]>
798 for Array
799{
800 fn from_nested(data: &[[&[T]; N]; M]) -> Self {
801 let len_3d = data[0][0].len();
803 assert!(
804 data.iter().all(|x| x.iter().all(|y| y.len() == len_3d)),
805 "3rd dimension must have the same length"
806 );
807
808 let shape = [M as i32, N as i32, len_3d as i32];
809 let data = data
810 .iter()
811 .flat_map(|x| x.iter().flat_map(|y| y.iter().copied()))
812 .collect::<Vec<T>>();
813 Array::from_slice(&data, &shape)
814 }
815}
816
817impl<T: FromSliceElement + Copy, const N: usize, const M: usize> FromNested<&[&[[T; N]]; M]>
818 for Array
819{
820 fn from_nested(data: &[&[[T; N]]; M]) -> Self {
821 let len_2d = data[0].len();
823 assert!(
824 data.iter().all(|x| x.len() == len_2d),
825 "2nd dimension must have the same length"
826 );
827
828 let shape = [M as i32, len_2d as i32, N as i32];
829 let data = data
830 .iter()
831 .flat_map(|x| x.iter().flat_map(|y| y.iter().copied()))
832 .collect::<Vec<T>>();
833 Array::from_slice(&data, &shape)
834 }
835}
836
837impl<T: FromSliceElement + Copy, const N: usize, const M: usize, const O: usize>
838 FromNested<[[[T; N]; M]; O]> for Array
839{
840 fn from_nested(data: [[[T; N]; M]; O]) -> Self {
841 let shape = [O as i32, M as i32, N as i32];
842 let data = data
843 .iter()
844 .flat_map(|x| x.iter().flat_map(|y| y.iter().copied()))
845 .collect::<Vec<T>>();
846 Array::from_slice(&data, &shape)
847 }
848}
849
850impl<T: FromSliceElement + Copy, const N: usize, const M: usize, const O: usize>
851 FromNested<&[[[T; N]; M]; O]> for Array
852{
853 fn from_nested(data: &[[[T; N]; M]; O]) -> Self {
854 let shape = [O as i32, M as i32, N as i32];
855 let data = data
856 .iter()
857 .flat_map(|x| x.iter().flat_map(|y| y.iter().copied()))
858 .collect::<Vec<T>>();
859 Array::from_slice(&data, &shape)
860 }
861}
862
863impl<T: FromSliceElement + Copy, const N: usize, const M: usize, const O: usize>
864 FromNested<&[&[[T; N]; M]; O]> for Array
865{
866 fn from_nested(data: &[&[[T; N]; M]; O]) -> Self {
867 let shape = [O as i32, M as i32, N as i32];
868 let data = data
869 .iter()
870 .flat_map(|x| x.iter().flat_map(|y| y.iter().copied()))
871 .collect::<Vec<T>>();
872 Array::from_slice(&data, &shape)
873 }
874}
875
876impl<T: FromSliceElement + Copy, const N: usize, const M: usize, const O: usize>
877 FromNested<&[[&[T; N]; M]; O]> for Array
878{
879 fn from_nested(data: &[[&[T; N]; M]; O]) -> Self {
880 let shape = [O as i32, M as i32, N as i32];
881 let data = data
882 .iter()
883 .flat_map(|x| x.iter().flat_map(|y| y.iter().copied()))
884 .collect::<Vec<T>>();
885 Array::from_slice(&data, &shape)
886 }
887}
888
889impl<T: FromSliceElement + Copy, const N: usize, const M: usize, const O: usize>
890 FromNested<&[&[&[T; N]; M]; O]> for Array
891{
892 fn from_nested(data: &[&[&[T; N]; M]; O]) -> Self {
893 let shape = [O as i32, M as i32, N as i32];
894 let data = data
895 .iter()
896 .flat_map(|x| x.iter().flat_map(|y| y.iter().copied()))
897 .collect::<Vec<T>>();
898 Array::from_slice(&data, &shape)
899 }
900}
901
902#[cfg(test)]
903mod tests {
904 use super::*;
905
906 #[test]
907 fn new_scalar_array_from_bool() {
908 let array = Array::from_bool(true);
909 assert!(array.item::<bool>());
910 assert_eq!(array.item_size(), 1);
911 assert_eq!(array.size(), 1);
912 assert!(array.strides().is_empty());
913 assert_eq!(array.nbytes(), 1);
914 assert_eq!(array.ndim(), 0);
915 assert!(array.shape().is_empty());
916 assert_eq!(array.dtype(), Dtype::Bool);
917 }
918
919 #[test]
920 fn new_scalar_array_from_int() {
921 let array = Array::from_int(42);
922 assert_eq!(array.item::<i32>(), 42);
923 assert_eq!(array.item_size(), 4);
924 assert_eq!(array.size(), 1);
925 assert!(array.strides().is_empty());
926 assert_eq!(array.nbytes(), 4);
927 assert_eq!(array.ndim(), 0);
928 assert!(array.shape().is_empty());
929 assert_eq!(array.dtype(), Dtype::Int32);
930 }
931
932 #[test]
933 fn new_scalar_array_from_f32() {
934 let array = Array::from_f32(3.14);
935 assert_eq!(array.item::<f32>(), 3.14);
936 assert_eq!(array.item_size(), 4);
937 assert_eq!(array.size(), 1);
938 assert!(array.strides().is_empty());
939 assert_eq!(array.nbytes(), 4);
940 assert_eq!(array.ndim(), 0);
941 assert!(array.shape().is_empty());
942 assert_eq!(array.dtype(), Dtype::Float32);
943 }
944
945 #[test]
946 fn new_scalar_array_from_f64() {
947 let array = Array::from_f64(3.14).as_dtype(Dtype::Float64).unwrap();
948 float_eq::assert_float_eq!(array.item::<f64>(), 3.14, abs <= 1e-5);
949 assert_eq!(array.item_size(), 8);
950 assert_eq!(array.size(), 1);
951 assert!(array.strides().is_empty());
952 assert_eq!(array.nbytes(), 8);
953 assert_eq!(array.ndim(), 0);
954 assert!(array.shape().is_empty());
955 assert_eq!(array.dtype(), Dtype::Float64);
956 }
957
958 #[test]
959 fn new_array_from_slice_f64() {
960 let array = Array::from_slice_f64(&[1.0, 2.0, 3.0], &[3]);
961 assert_eq!(array.item_size(), 8);
962 assert_eq!(array.size(), 3);
963 assert_eq!(array.strides(), &[1]);
964 assert_eq!(array.nbytes(), 24);
965 assert_eq!(array.ndim(), 1);
966 assert_eq!(array.dim(0), 3);
967 assert_eq!(array.shape(), &[3]);
968 assert_eq!(array.dtype(), Dtype::Float64);
969 }
970
971 #[test]
972 fn new_scalar_array_from_complex() {
973 let val = complex64::new(1.0, 2.0);
974 let array = Array::from_complex(val);
975 assert_eq!(array.item::<complex64>(), val);
976 assert_eq!(array.item_size(), 8);
977 assert_eq!(array.size(), 1);
978 assert!(array.strides().is_empty());
979 assert_eq!(array.nbytes(), 8);
980 assert_eq!(array.ndim(), 0);
981 assert!(array.shape().is_empty());
982 assert_eq!(array.dtype(), Dtype::Complex64);
983 }
984
985 #[test]
986 fn new_array_from_single_element_slice() {
987 let data = [1i32];
988 let array = Array::from_slice(&data, &[1]);
989 assert_eq!(array.as_slice::<i32>(), &data[..]);
990 assert_eq!(array.item::<i32>(), 1);
991 assert_eq!(array.item_size(), 4);
992 assert_eq!(array.size(), 1);
993 assert_eq!(array.strides(), &[1]);
994 assert_eq!(array.nbytes(), 4);
995 assert_eq!(array.ndim(), 1);
996 assert_eq!(array.dim(0), 1);
997 assert_eq!(array.shape(), &[1]);
998 assert_eq!(array.dtype(), Dtype::Int32);
999 }
1000
1001 #[test]
1002 fn new_array_from_multi_element_slice() {
1003 let data = [1i32, 2, 3, 4, 5];
1004 let array = Array::from_slice(&data, &[5]);
1005 assert_eq!(array.as_slice::<i32>(), &data[..]);
1006 assert_eq!(array.item_size(), 4);
1007 assert_eq!(array.size(), 5);
1008 assert_eq!(array.strides(), &[1]);
1009 assert_eq!(array.nbytes(), 20);
1010 assert_eq!(array.ndim(), 1);
1011 assert_eq!(array.dim(0), 5);
1012 assert_eq!(array.shape(), &[5]);
1013 assert_eq!(array.dtype(), Dtype::Int32);
1014 }
1015
1016 #[test]
1017 fn new_2d_array_from_slice() {
1018 let data = [1i32, 2, 3, 4, 5, 6];
1019 let array = Array::from_slice(&data, &[2, 3]);
1020 assert_eq!(array.as_slice::<i32>(), &data[..]);
1021 assert_eq!(array.item_size(), 4);
1022 assert_eq!(array.size(), 6);
1023 assert_eq!(array.strides(), &[3, 1]);
1024 assert_eq!(array.nbytes(), 24);
1025 assert_eq!(array.ndim(), 2);
1026 assert_eq!(array.dim(0), 2);
1027 assert_eq!(array.dim(1), 3);
1028 assert_eq!(array.dim(-1), 3); assert_eq!(array.dim(-2), 2); assert_eq!(array.shape(), &[2, 3]);
1031 assert_eq!(array.dtype(), Dtype::Int32);
1032 }
1033
1034 #[test]
1035 fn deep_cloned_array_has_different_ptr() {
1036 let data = [1i32, 2, 3, 4, 5];
1037 let orig = Array::from_slice(&data, &[5]);
1038 let clone = orig.deep_clone();
1039
1040 assert_eq!(orig.as_slice::<i32>(), clone.as_slice::<i32>());
1042
1043 assert_ne!(orig.as_ptr().ctx, clone.as_ptr().ctx);
1045
1046 assert_ne!(
1048 orig.as_slice::<i32>().as_ptr(),
1049 clone.as_slice::<i32>().as_ptr()
1050 );
1051 }
1052
1053 #[test]
1054 fn test_array_eq() {
1055 let data = [1i32, 2, 3, 4, 5];
1056 let array1 = Array::from_slice(&data, &[5]);
1057 let array2 = Array::from_slice(&data, &[5]);
1058 let array3 = Array::from_slice(&[1i32, 2, 3, 4, 6], &[5]);
1059
1060 assert_eq!(&array1, &array2);
1061 assert_ne!(&array1, &array3);
1062 }
1063
1064 #[test]
1065 fn test_array_item_non_scalar() {
1066 let data = [1i32, 2, 3, 4, 5];
1067 let array = Array::from_slice(&data, &[5]);
1068 assert!(array.try_item::<i32>().is_err());
1069 }
1070
1071 #[test]
1072 fn test_item_type_conversion() {
1073 let array = Array::from_f32(1.0);
1074 assert_eq!(array.item::<i32>(), 1);
1075 assert_eq!(array.item::<complex64>(), complex64::new(1.0, 0.0));
1076 assert_eq!(array.item::<u8>(), 1);
1077
1078 assert_eq!(array.as_slice::<f32>(), &[1.0]);
1079 }
1080}