1use crate::{
2 dtype::Dtype,
3 error::AsSliceError,
4 sealed::Sealed,
5 utils::{guard::Guarded, SUCCESS},
6 Stream, StreamOrDevice,
7};
8use element::FromSliceElement;
9use mlx_internal_macros::default_device;
10use mlx_sys::mlx_array;
11use num_complex::Complex;
12use std::{
13 ffi::{c_void, CStr},
14 iter::Sum,
15};
16
17mod element;
18mod operators;
19
20cfg_safetensors! {
21 mod safetensors;
22}
23
24pub use element::ArrayElement;
25
26#[allow(non_camel_case_types)]
30pub type complex64 = Complex<f32>;
31
32#[repr(transparent)]
34pub struct Array {
35 c_array: mlx_array,
36}
37
38impl Sealed for Array {}
39
40impl Sealed for &Array {}
41
42impl std::fmt::Debug for Array {
43 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44 write!(f, "{}", self)
45 }
46}
47
48impl std::fmt::Display for Array {
49 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
50 unsafe {
51 let mut mlx_str = mlx_sys::mlx_string_new();
52 let status = mlx_sys::mlx_array_tostring(&mut mlx_str as *mut _, self.as_ptr());
53 if status != SUCCESS {
54 return Err(std::fmt::Error);
55 }
56 let ptr = mlx_sys::mlx_string_data(mlx_str);
57 let c_str = CStr::from_ptr(ptr);
58 write!(f, "{:?}", c_str)?;
59 mlx_sys::mlx_string_free(mlx_str);
60 Ok(())
61 }
62 }
63}
64
65impl Drop for Array {
66 fn drop(&mut self) {
67 unsafe { mlx_sys::mlx_array_free(self.as_ptr()) };
71 }
72}
73
74unsafe impl Send for Array {}
75
76impl PartialEq for Array {
77 fn eq(&self, other: &Self) -> bool {
85 self.array_eq(other, None).unwrap().item()
86 }
87}
88
89impl Array {
90 pub unsafe fn from_ptr(c_array: mlx_array) -> Array {
97 Self { c_array }
98 }
99
100 pub fn as_ptr(&self) -> mlx_array {
102 self.c_array
103 }
104
105 pub fn from_bool(val: bool) -> Array {
107 let c_array = unsafe { mlx_sys::mlx_array_new_bool(val) };
108 Array { c_array }
109 }
110
111 pub fn from_int(val: i32) -> Array {
113 let c_array = unsafe { mlx_sys::mlx_array_new_int(val) };
114 Array { c_array }
115 }
116
117 pub fn from_f32(val: f32) -> Array {
119 let c_array = unsafe { mlx_sys::mlx_array_new_float32(val) };
120 Array { c_array }
121 }
122
123 pub fn from_complex(val: complex64) -> Array {
132 let c_array = unsafe { mlx_sys::mlx_array_new_complex(val.re, val.im) };
133 Array { c_array }
134 }
135
136 pub fn from_slice<T: FromSliceElement>(data: &[T], shape: &[i32]) -> Self {
152 assert_eq!(data.len(), shape.iter().product::<i32>() as usize);
154
155 unsafe { Self::from_raw_data(data.as_ptr() as *const c_void, shape, T::DTYPE) }
156 }
157
158 pub fn from_slice_f64(data: &[f64], shape: &[i32]) -> Self {
163 assert_eq!(data.len(), shape.iter().product::<i32>() as usize);
165
166 unsafe { Self::from_raw_data(data.as_ptr() as *const c_void, shape, Dtype::Float64) }
167 }
168
169 #[inline]
178 pub unsafe fn from_raw_data(data: *const c_void, shape: &[i32], dtype: Dtype) -> Self {
179 let dim = if shape.len() > i32::MAX as usize {
180 panic!("Shape is too large")
181 } else {
182 shape.len() as i32
183 };
184
185 let c_array = mlx_sys::mlx_array_new_data(data, shape.as_ptr(), dim, dtype.into());
186 Array { c_array }
187 }
188
189 pub fn from_iter<I: IntoIterator<Item = T>, T: FromSliceElement>(
211 iter: I,
212 shape: &[i32],
213 ) -> Self {
214 let data: Vec<T> = iter.into_iter().collect();
215 Self::from_slice(&data, shape)
216 }
217
218 pub fn from_iter_f64<I: IntoIterator<Item = f64>>(iter: I, shape: &[i32]) -> Self {
223 let data: Vec<f64> = iter.into_iter().collect();
224 Self::from_slice_f64(&data, shape)
225 }
226
227 pub fn item_size(&self) -> usize {
229 unsafe { mlx_sys::mlx_array_itemsize(self.as_ptr()) }
230 }
231
232 pub fn size(&self) -> usize {
234 unsafe { mlx_sys::mlx_array_size(self.as_ptr()) }
235 }
236
237 pub fn strides(&self) -> &[usize] {
239 let ndim = self.ndim();
240 if ndim == 0 {
241 return &[];
243 }
244
245 unsafe {
246 let data = mlx_sys::mlx_array_strides(self.as_ptr());
247 std::slice::from_raw_parts(data, ndim)
248 }
249 }
250
251 pub fn nbytes(&self) -> usize {
253 unsafe { mlx_sys::mlx_array_nbytes(self.as_ptr()) }
254 }
255
256 pub fn ndim(&self) -> usize {
258 unsafe { mlx_sys::mlx_array_ndim(self.as_ptr()) }
259 }
260
261 pub fn shape(&self) -> &[i32] {
265 let ndim = self.ndim();
266 if ndim == 0 {
267 return &[];
269 }
270
271 unsafe {
272 let data = mlx_sys::mlx_array_shape(self.as_ptr());
273 std::slice::from_raw_parts(data, ndim)
274 }
275 }
276
277 pub fn dim(&self, dim: i32) -> i32 {
285 let dim = if dim.is_negative() {
286 (self.ndim() as i32).checked_add(dim).unwrap()
287 } else {
288 dim
289 };
290
291 unsafe { mlx_sys::mlx_array_dim(self.as_ptr(), dim) }
293 }
294
295 pub fn dtype(&self) -> Dtype {
297 let dtype = unsafe { mlx_sys::mlx_array_dtype(self.as_ptr()) };
298 Dtype::try_from(dtype).unwrap()
299 }
300
301 pub fn eval(&self) -> crate::error::Result<()> {
304 <() as Guarded>::try_from_op(|_| unsafe { mlx_sys::mlx_array_eval(self.as_ptr()) })
305 }
306
307 pub fn item<T: ArrayElement>(&self) -> T {
312 self.try_item().unwrap()
313 }
314
315 pub fn try_item<T: ArrayElement>(&self) -> crate::error::Result<T> {
320 self.eval()?;
321
322 self.eval()?;
324
325 if self.dtype() != T::DTYPE {
328 let new_array = Array::try_from_op(|res| unsafe {
329 mlx_sys::mlx_astype(
330 res,
331 self.as_ptr(),
332 T::DTYPE.into(),
333 Stream::default().as_ptr(),
334 )
335 })?;
336 new_array.eval()?;
337 return T::array_item(&new_array);
338 }
339
340 T::array_item(self)
341 }
342
343 pub unsafe fn as_slice_unchecked<T: ArrayElement>(&self) -> &[T] {
364 self.eval().unwrap();
365
366 unsafe {
367 let data = T::array_data(self);
368 let size = self.size();
369 std::slice::from_raw_parts(data, size)
370 }
371 }
372
373 pub fn try_as_slice<T: ArrayElement>(&self) -> Result<&[T], AsSliceError> {
387 if self.dtype() != T::DTYPE {
388 return Err(AsSliceError::DtypeMismatch {
389 expecting: T::DTYPE,
390 found: self.dtype(),
391 });
392 }
393
394 self.eval()?;
395
396 unsafe {
397 let size = self.size();
398 let data = T::array_data(self);
399 if data.is_null() || size == 0 {
400 return Err(AsSliceError::Null);
401 }
402
403 Ok(std::slice::from_raw_parts(data, size))
404 }
405 }
406
407 pub fn as_slice<T: ArrayElement>(&self) -> &[T] {
426 self.try_as_slice().unwrap()
427 }
428
429 pub fn deep_clone(&self) -> Self {
433 unsafe {
434 let dtype = self.dtype();
435 let shape = self.shape();
436 let data = match dtype {
437 Dtype::Bool => mlx_sys::mlx_array_data_bool(self.as_ptr()) as *const c_void,
438 Dtype::Uint8 => mlx_sys::mlx_array_data_uint8(self.as_ptr()) as *const c_void,
439 Dtype::Uint16 => mlx_sys::mlx_array_data_uint16(self.as_ptr()) as *const c_void,
440 Dtype::Uint32 => mlx_sys::mlx_array_data_uint32(self.as_ptr()) as *const c_void,
441 Dtype::Uint64 => mlx_sys::mlx_array_data_uint64(self.as_ptr()) as *const c_void,
442 Dtype::Int8 => mlx_sys::mlx_array_data_int8(self.as_ptr()) as *const c_void,
443 Dtype::Int16 => mlx_sys::mlx_array_data_int16(self.as_ptr()) as *const c_void,
444 Dtype::Int32 => mlx_sys::mlx_array_data_int32(self.as_ptr()) as *const c_void,
445 Dtype::Int64 => mlx_sys::mlx_array_data_int64(self.as_ptr()) as *const c_void,
446 Dtype::Float16 => mlx_sys::mlx_array_data_float16(self.as_ptr()) as *const c_void,
447 Dtype::Float32 => mlx_sys::mlx_array_data_float32(self.as_ptr()) as *const c_void,
448 Dtype::Float64 => mlx_sys::mlx_array_data_float64(self.as_ptr()) as *const c_void,
449 Dtype::Bfloat16 => mlx_sys::mlx_array_data_bfloat16(self.as_ptr()) as *const c_void,
450 Dtype::Complex64 => {
451 mlx_sys::mlx_array_data_complex64(self.as_ptr()) as *const c_void
452 }
453 };
454
455 let new_c_array =
456 mlx_sys::mlx_array_new_data(data, shape.as_ptr(), shape.len() as i32, dtype.into());
457
458 Array::from_ptr(new_c_array)
459 }
460 }
461}
462
463impl Clone for Array {
464 fn clone(&self) -> Self {
465 Array::try_from_op(|res| unsafe { mlx_sys::mlx_array_set(res, self.as_ptr()) })
466 .expect("Failed to clone array")
468 }
469}
470
471impl Sum for Array {
472 fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
473 iter.fold(Array::from_int(0), |acc, x| acc.add(&x).unwrap())
474 }
475}
476
477#[default_device]
482pub fn stop_gradient_device(
483 a: impl AsRef<Array>,
484 stream: impl AsRef<Stream>,
485) -> crate::error::Result<Array> {
486 Array::try_from_op(|res| unsafe {
487 mlx_sys::mlx_stop_gradient(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
488 })
489}
490
491impl From<bool> for Array {
492 fn from(value: bool) -> Self {
493 Array::from_bool(value)
494 }
495}
496
497impl From<i32> for Array {
498 fn from(value: i32) -> Self {
499 Array::from_int(value)
500 }
501}
502
503impl From<f32> for Array {
504 fn from(value: f32) -> Self {
505 Array::from_f32(value)
506 }
507}
508
509impl From<complex64> for Array {
510 fn from(value: complex64) -> Self {
511 Array::from_complex(value)
512 }
513}
514
515impl<T> From<T> for Array
516where
517 Array: FromNested<T>,
518{
519 fn from(value: T) -> Self {
520 Array::from_nested(value)
521 }
522}
523
524impl AsRef<Array> for Array {
525 fn as_ref(&self) -> &Array {
526 self
527 }
528}
529
530pub trait FromScalar<T>
534where
535 T: ArrayElement,
536{
537 fn from_scalar(val: T) -> Array;
539}
540
541impl FromScalar<bool> for Array {
542 fn from_scalar(val: bool) -> Array {
543 Array::from_bool(val)
544 }
545}
546
547impl FromScalar<i32> for Array {
548 fn from_scalar(val: i32) -> Array {
549 Array::from_int(val)
550 }
551}
552
553impl FromScalar<f32> for Array {
554 fn from_scalar(val: f32) -> Array {
555 Array::from_f32(val)
556 }
557}
558
559impl FromScalar<complex64> for Array {
560 fn from_scalar(val: complex64) -> Array {
561 Array::from_complex(val)
562 }
563}
564
565pub trait FromNested<T> {
574 fn from_nested(data: T) -> Array;
576}
577
578impl<T: FromSliceElement> FromNested<&[T]> for Array {
579 fn from_nested(data: &[T]) -> Self {
580 Array::from_slice(data, &[data.len() as i32])
581 }
582}
583
584impl<T: FromSliceElement, const N: usize> FromNested<[T; N]> for Array {
585 fn from_nested(data: [T; N]) -> Self {
586 Array::from_slice(&data, &[N as i32])
587 }
588}
589
590impl<T: FromSliceElement, const N: usize> FromNested<&[T; N]> for Array {
591 fn from_nested(data: &[T; N]) -> Self {
592 Array::from_slice(data, &[N as i32])
593 }
594}
595
596impl<T: FromSliceElement + Copy> FromNested<&[&[T]]> for Array {
597 fn from_nested(data: &[&[T]]) -> Self {
598 let row_len = data[0].len();
600 assert!(
601 data.iter().all(|row| row.len() == row_len),
602 "Rows must have the same length"
603 );
604
605 let shape = [data.len() as i32, row_len as i32];
606 let data = data
607 .iter()
608 .flat_map(|x| x.iter())
609 .copied()
610 .collect::<Vec<T>>();
611 Array::from_slice(&data, &shape)
612 }
613}
614
615impl<T: FromSliceElement + Copy, const N: usize> FromNested<[&[T]; N]> for Array {
616 fn from_nested(data: [&[T]; N]) -> Self {
617 let row_len = data[0].len();
619 assert!(
620 data.iter().all(|row| row.len() == row_len),
621 "Rows must have the same length"
622 );
623
624 let shape = [N as i32, row_len as i32];
625 let data = data
626 .iter()
627 .flat_map(|x| x.iter())
628 .copied()
629 .collect::<Vec<T>>();
630 Array::from_slice(&data, &shape)
631 }
632}
633
634impl<T: FromSliceElement + Copy, const N: usize> FromNested<&[[T; N]]> for Array {
635 fn from_nested(data: &[[T; N]]) -> Self {
636 let shape = [data.len() as i32, N as i32];
637 let data = data
638 .iter()
639 .flat_map(|x| x.iter().copied())
640 .collect::<Vec<T>>();
641 Array::from_slice(&data, &shape)
642 }
643}
644
645impl<T: FromSliceElement + Copy, const N: usize> FromNested<&[&[T; N]]> for Array {
646 fn from_nested(data: &[&[T; N]]) -> Self {
647 let shape = [data.len() as i32, N as i32];
648 let data = data
649 .iter()
650 .flat_map(|x| x.iter().copied())
651 .collect::<Vec<T>>();
652 Array::from_slice(&data, &shape)
653 }
654}
655
656impl<T: FromSliceElement + Copy, const N: usize, const M: usize> FromNested<[[T; N]; M]> for Array {
657 fn from_nested(data: [[T; N]; M]) -> Self {
658 let shape = [M as i32, N as i32];
659 let data = data
660 .iter()
661 .flat_map(|x| x.iter().copied())
662 .collect::<Vec<T>>();
663 Array::from_slice(&data, &shape)
664 }
665}
666
667impl<T: FromSliceElement + Copy, const N: usize, const M: usize> FromNested<&[[T; N]; M]>
668 for Array
669{
670 fn from_nested(data: &[[T; N]; M]) -> Self {
671 let shape = [M as i32, N as i32];
672 let data = data
673 .iter()
674 .flat_map(|x| x.iter().copied())
675 .collect::<Vec<T>>();
676 Array::from_slice(&data, &shape)
677 }
678}
679
680impl<T: FromSliceElement + Copy, const N: usize, const M: usize> FromNested<&[&[T; N]; M]>
681 for Array
682{
683 fn from_nested(data: &[&[T; N]; M]) -> Self {
684 let shape = [M as i32, N as i32];
685 let data = data
686 .iter()
687 .flat_map(|x| x.iter().copied())
688 .collect::<Vec<T>>();
689 Array::from_slice(&data, &shape)
690 }
691}
692
693impl<T: FromSliceElement + Copy> FromNested<&[&[&[T]]]> for Array {
694 fn from_nested(data: &[&[&[T]]]) -> Self {
695 let len_2d = data[0].len();
697 assert!(
698 data.iter().all(|x| x.len() == len_2d),
699 "2nd dimension must have the same length"
700 );
701
702 let len_3d = data[0][0].len();
704 assert!(
705 data.iter().all(|x| x.iter().all(|y| y.len() == len_3d)),
706 "3rd dimension must have the same length"
707 );
708
709 let shape = [data.len() as i32, len_2d as i32, len_3d as i32];
710 let data = data
711 .iter()
712 .flat_map(|x| x.iter().flat_map(|y| y.iter().copied()))
713 .collect::<Vec<T>>();
714 Array::from_slice(&data, &shape)
715 }
716}
717
718impl<T: FromSliceElement + Copy, const N: usize> FromNested<[&[&[T]]; N]> for Array {
719 fn from_nested(data: [&[&[T]]; N]) -> Self {
720 let len_2d = data[0].len();
722 assert!(
723 data.iter().all(|x| x.len() == len_2d),
724 "2nd dimension must have the same length"
725 );
726
727 let len_3d = data[0][0].len();
729 assert!(
730 data.iter().all(|x| x.iter().all(|y| y.len() == len_3d)),
731 "3rd dimension must have the same length"
732 );
733
734 let shape = [N as i32, len_2d as i32, len_3d as i32];
735 let data = data
736 .iter()
737 .flat_map(|x| x.iter().flat_map(|y| y.iter().copied()))
738 .collect::<Vec<T>>();
739 Array::from_slice(&data, &shape)
740 }
741}
742
743impl<T: FromSliceElement + Copy, const N: usize> FromNested<&[[&[T]; N]]> for Array {
744 fn from_nested(data: &[[&[T]; N]]) -> Self {
745 let len_3d = data[0][0].len();
747 assert!(
748 data.iter().all(|x| x.iter().all(|y| y.len() == len_3d)),
749 "3rd dimension must have the same length"
750 );
751
752 let shape = [data.len() as i32, N as i32, len_3d as i32];
753 let data = data
754 .iter()
755 .flat_map(|x| x.iter().flat_map(|y| y.iter().copied()))
756 .collect::<Vec<T>>();
757 Array::from_slice(&data, &shape)
758 }
759}
760
761impl<T: FromSliceElement + Copy, const N: usize> FromNested<&[&[[T; N]]]> for Array {
762 fn from_nested(data: &[&[[T; N]]]) -> Self {
763 let len_2d = data[0].len();
765 assert!(
766 data.iter().all(|x| x.len() == len_2d),
767 "2nd dimension must have the same length"
768 );
769
770 let shape = [data.len() as i32, len_2d as i32, N as i32];
771 let data = data
772 .iter()
773 .flat_map(|x| x.iter().flat_map(|y| y.iter().copied()))
774 .collect::<Vec<T>>();
775 Array::from_slice(&data, &shape)
776 }
777}
778
779impl<T: FromSliceElement + Copy, const N: usize, const M: usize> FromNested<[[&[T]; N]; M]>
780 for Array
781{
782 fn from_nested(data: [[&[T]; N]; M]) -> Self {
783 let len_3d = data[0][0].len();
785 assert!(
786 data.iter().all(|x| x.iter().all(|y| y.len() == len_3d)),
787 "3rd dimension must have the same length"
788 );
789
790 let shape = [M as i32, N as i32, len_3d as i32];
791 let data = data
792 .iter()
793 .flat_map(|x| x.iter().flat_map(|y| y.iter().copied()))
794 .collect::<Vec<T>>();
795 Array::from_slice(&data, &shape)
796 }
797}
798
799impl<T: FromSliceElement + Copy, const N: usize, const M: usize> FromNested<&[[&[T]; N]; M]>
800 for Array
801{
802 fn from_nested(data: &[[&[T]; N]; M]) -> Self {
803 let len_3d = data[0][0].len();
805 assert!(
806 data.iter().all(|x| x.iter().all(|y| y.len() == len_3d)),
807 "3rd dimension must have the same length"
808 );
809
810 let shape = [M as i32, N as i32, len_3d as i32];
811 let data = data
812 .iter()
813 .flat_map(|x| x.iter().flat_map(|y| y.iter().copied()))
814 .collect::<Vec<T>>();
815 Array::from_slice(&data, &shape)
816 }
817}
818
819impl<T: FromSliceElement + Copy, const N: usize, const M: usize> FromNested<&[&[[T; N]]; M]>
820 for Array
821{
822 fn from_nested(data: &[&[[T; N]]; M]) -> Self {
823 let len_2d = data[0].len();
825 assert!(
826 data.iter().all(|x| x.len() == len_2d),
827 "2nd dimension must have the same length"
828 );
829
830 let shape = [M as i32, len_2d as i32, N as i32];
831 let data = data
832 .iter()
833 .flat_map(|x| x.iter().flat_map(|y| y.iter().copied()))
834 .collect::<Vec<T>>();
835 Array::from_slice(&data, &shape)
836 }
837}
838
839impl<T: FromSliceElement + Copy, const N: usize, const M: usize, const O: usize>
840 FromNested<[[[T; N]; M]; O]> for Array
841{
842 fn from_nested(data: [[[T; N]; M]; O]) -> Self {
843 let shape = [O as i32, M as i32, N as i32];
844 let data = data
845 .iter()
846 .flat_map(|x| x.iter().flat_map(|y| y.iter().copied()))
847 .collect::<Vec<T>>();
848 Array::from_slice(&data, &shape)
849 }
850}
851
852impl<T: FromSliceElement + Copy, const N: usize, const M: usize, const O: usize>
853 FromNested<&[[[T; N]; M]; O]> for Array
854{
855 fn from_nested(data: &[[[T; N]; M]; O]) -> Self {
856 let shape = [O as i32, M as i32, N as i32];
857 let data = data
858 .iter()
859 .flat_map(|x| x.iter().flat_map(|y| y.iter().copied()))
860 .collect::<Vec<T>>();
861 Array::from_slice(&data, &shape)
862 }
863}
864
865impl<T: FromSliceElement + Copy, const N: usize, const M: usize, const O: usize>
866 FromNested<&[&[[T; N]; M]; O]> for Array
867{
868 fn from_nested(data: &[&[[T; N]; M]; O]) -> Self {
869 let shape = [O as i32, M as i32, N as i32];
870 let data = data
871 .iter()
872 .flat_map(|x| x.iter().flat_map(|y| y.iter().copied()))
873 .collect::<Vec<T>>();
874 Array::from_slice(&data, &shape)
875 }
876}
877
878impl<T: FromSliceElement + Copy, const N: usize, const M: usize, const O: usize>
879 FromNested<&[[&[T; N]; M]; O]> for Array
880{
881 fn from_nested(data: &[[&[T; N]; M]; O]) -> Self {
882 let shape = [O as i32, M as i32, N as i32];
883 let data = data
884 .iter()
885 .flat_map(|x| x.iter().flat_map(|y| y.iter().copied()))
886 .collect::<Vec<T>>();
887 Array::from_slice(&data, &shape)
888 }
889}
890
891impl<T: FromSliceElement + Copy, const N: usize, const M: usize, const O: usize>
892 FromNested<&[&[&[T; N]; M]; O]> for Array
893{
894 fn from_nested(data: &[&[&[T; N]; M]; O]) -> Self {
895 let shape = [O as i32, M as i32, N as i32];
896 let data = data
897 .iter()
898 .flat_map(|x| x.iter().flat_map(|y| y.iter().copied()))
899 .collect::<Vec<T>>();
900 Array::from_slice(&data, &shape)
901 }
902}
903
904#[cfg(test)]
905mod tests {
906 use super::*;
907
908 #[test]
909 fn new_scalar_array_from_bool() {
910 let array = Array::from_bool(true);
911 assert!(array.item::<bool>());
912 assert_eq!(array.item_size(), 1);
913 assert_eq!(array.size(), 1);
914 assert!(array.strides().is_empty());
915 assert_eq!(array.nbytes(), 1);
916 assert_eq!(array.ndim(), 0);
917 assert!(array.shape().is_empty());
918 assert_eq!(array.dtype(), Dtype::Bool);
919 }
920
921 #[test]
922 fn new_scalar_array_from_int() {
923 let array = Array::from_int(42);
924 assert_eq!(array.item::<i32>(), 42);
925 assert_eq!(array.item_size(), 4);
926 assert_eq!(array.size(), 1);
927 assert!(array.strides().is_empty());
928 assert_eq!(array.nbytes(), 4);
929 assert_eq!(array.ndim(), 0);
930 assert!(array.shape().is_empty());
931 assert_eq!(array.dtype(), Dtype::Int32);
932 }
933
934 #[test]
935 fn new_scalar_array_from_f32() {
936 let array = Array::from_f32(3.14);
937 assert_eq!(array.item::<f32>(), 3.14);
938 assert_eq!(array.item_size(), 4);
939 assert_eq!(array.size(), 1);
940 assert!(array.strides().is_empty());
941 assert_eq!(array.nbytes(), 4);
942 assert_eq!(array.ndim(), 0);
943 assert!(array.shape().is_empty());
944 assert_eq!(array.dtype(), Dtype::Float32);
945 }
946
947 #[test]
962 fn new_array_from_slice_f64() {
963 let array = Array::from_slice_f64(&[1.0, 2.0, 3.0], &[3]);
964 assert_eq!(array.item_size(), 8);
965 assert_eq!(array.size(), 3);
966 assert_eq!(array.strides(), &[1]);
967 assert_eq!(array.nbytes(), 24);
968 assert_eq!(array.ndim(), 1);
969 assert_eq!(array.dim(0), 3);
970 assert_eq!(array.shape(), &[3]);
971 assert_eq!(array.dtype(), Dtype::Float64);
972 }
973
974 #[test]
975 fn new_scalar_array_from_complex() {
976 let val = complex64::new(1.0, 2.0);
977 let array = Array::from_complex(val);
978 assert_eq!(array.item::<complex64>(), val);
979 assert_eq!(array.item_size(), 8);
980 assert_eq!(array.size(), 1);
981 assert!(array.strides().is_empty());
982 assert_eq!(array.nbytes(), 8);
983 assert_eq!(array.ndim(), 0);
984 assert!(array.shape().is_empty());
985 assert_eq!(array.dtype(), Dtype::Complex64);
986 }
987
988 #[test]
989 fn new_array_from_single_element_slice() {
990 let data = [1i32];
991 let array = Array::from_slice(&data, &[1]);
992 assert_eq!(array.as_slice::<i32>(), &data[..]);
993 assert_eq!(array.item::<i32>(), 1);
994 assert_eq!(array.item_size(), 4);
995 assert_eq!(array.size(), 1);
996 assert_eq!(array.strides(), &[1]);
997 assert_eq!(array.nbytes(), 4);
998 assert_eq!(array.ndim(), 1);
999 assert_eq!(array.dim(0), 1);
1000 assert_eq!(array.shape(), &[1]);
1001 assert_eq!(array.dtype(), Dtype::Int32);
1002 }
1003
1004 #[test]
1005 fn new_array_from_multi_element_slice() {
1006 let data = [1i32, 2, 3, 4, 5];
1007 let array = Array::from_slice(&data, &[5]);
1008 assert_eq!(array.as_slice::<i32>(), &data[..]);
1009 assert_eq!(array.item_size(), 4);
1010 assert_eq!(array.size(), 5);
1011 assert_eq!(array.strides(), &[1]);
1012 assert_eq!(array.nbytes(), 20);
1013 assert_eq!(array.ndim(), 1);
1014 assert_eq!(array.dim(0), 5);
1015 assert_eq!(array.shape(), &[5]);
1016 assert_eq!(array.dtype(), Dtype::Int32);
1017 }
1018
1019 #[test]
1020 fn new_2d_array_from_slice() {
1021 let data = [1i32, 2, 3, 4, 5, 6];
1022 let array = Array::from_slice(&data, &[2, 3]);
1023 assert_eq!(array.as_slice::<i32>(), &data[..]);
1024 assert_eq!(array.item_size(), 4);
1025 assert_eq!(array.size(), 6);
1026 assert_eq!(array.strides(), &[3, 1]);
1027 assert_eq!(array.nbytes(), 24);
1028 assert_eq!(array.ndim(), 2);
1029 assert_eq!(array.dim(0), 2);
1030 assert_eq!(array.dim(1), 3);
1031 assert_eq!(array.dim(-1), 3); assert_eq!(array.dim(-2), 2); assert_eq!(array.shape(), &[2, 3]);
1034 assert_eq!(array.dtype(), Dtype::Int32);
1035 }
1036
1037 #[test]
1038 fn deep_cloned_array_has_different_ptr() {
1039 let data = [1i32, 2, 3, 4, 5];
1040 let orig = Array::from_slice(&data, &[5]);
1041 let clone = orig.deep_clone();
1042
1043 assert_eq!(orig.as_slice::<i32>(), clone.as_slice::<i32>());
1045
1046 assert_ne!(orig.as_ptr().ctx, clone.as_ptr().ctx);
1048
1049 assert_ne!(
1051 orig.as_slice::<i32>().as_ptr(),
1052 clone.as_slice::<i32>().as_ptr()
1053 );
1054 }
1055
1056 #[test]
1057 fn test_array_eq() {
1058 let data = [1i32, 2, 3, 4, 5];
1059 let array1 = Array::from_slice(&data, &[5]);
1060 let array2 = Array::from_slice(&data, &[5]);
1061 let array3 = Array::from_slice(&[1i32, 2, 3, 4, 6], &[5]);
1062
1063 assert_eq!(&array1, &array2);
1064 assert_ne!(&array1, &array3);
1065 }
1066
1067 #[test]
1068 fn test_array_item_non_scalar() {
1069 let data = [1i32, 2, 3, 4, 5];
1070 let array = Array::from_slice(&data, &[5]);
1071 assert!(array.try_item::<i32>().is_err());
1072 }
1073
1074 #[test]
1075 fn test_item_type_conversion() {
1076 let array = Array::from_f32(1.0);
1077 assert_eq!(array.item::<i32>(), 1);
1078 assert_eq!(array.item::<complex64>(), complex64::new(1.0, 0.0));
1079 assert_eq!(array.item::<u8>(), 1);
1080
1081 assert_eq!(array.as_slice::<f32>(), &[1.0]);
1082 }
1083}