1use std::borrow::Cow;
2
3use mlx_internal_macros::{default_device, generate_macro};
4use smallvec::SmallVec;
5
6use crate::{
7 constants::DEFAULT_STACK_VEC_LEN,
8 error::Result,
9 utils::{guard::Guarded, IntoOption, VectorArray},
10 Array, Stream, StreamOrDevice,
11};
12
13impl Array {
14 #[default_device]
16 pub fn expand_dims_device(&self, axes: &[i32], stream: impl AsRef<Stream>) -> Result<Array> {
17 expand_dims_device(self, axes, stream)
18 }
19
20 #[default_device]
22 pub fn flatten_device(
23 &self,
24 start_axis: impl Into<Option<i32>>,
25 end_axis: impl Into<Option<i32>>,
26 stream: impl AsRef<Stream>,
27 ) -> Result<Array> {
28 flatten_device(self, start_axis, end_axis, stream)
29 }
30
31 #[default_device]
33 pub fn reshape_device(&self, shape: &[i32], stream: impl AsRef<Stream>) -> Result<Array> {
34 reshape_device(self, shape, stream)
35 }
36
37 #[default_device]
39 pub fn squeeze_device<'a>(
40 &'a self,
41 axes: impl IntoOption<&'a [i32]>,
42 stream: impl AsRef<Stream>,
43 ) -> Result<Array> {
44 squeeze_device(self, axes, stream)
45 }
46
47 #[default_device]
49 pub fn as_strided_device<'a>(
50 &'a self,
51 shape: impl IntoOption<&'a [i32]>,
52 strides: impl IntoOption<&'a [i64]>,
53 offset: impl Into<Option<usize>>,
54 stream: impl AsRef<Stream>,
55 ) -> Result<Array> {
56 as_strided_device(self, shape, strides, offset, stream)
57 }
58
59 #[default_device]
61 pub fn at_least_1d_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
62 at_least_1d_device(self, stream)
63 }
64
65 #[default_device]
67 pub fn at_least_2d_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
68 at_least_2d_device(self, stream)
69 }
70
71 #[default_device]
73 pub fn at_least_3d_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
74 at_least_3d_device(self, stream)
75 }
76
77 #[default_device]
79 pub fn move_axis_device(
80 &self,
81 src: i32,
82 dst: i32,
83 stream: impl AsRef<Stream>,
84 ) -> Result<Array> {
85 move_axis_device(self, src, dst, stream)
86 }
87
88 #[default_device]
90 pub fn split_device(
91 &self,
92 indices: &[i32],
93 axis: impl Into<Option<i32>>,
94 stream: impl AsRef<Stream>,
95 ) -> Result<Vec<Array>> {
96 split_device(self, indices, axis, stream)
97 }
98
99 #[default_device]
101 pub fn split_equal_device(
102 &self,
103 num_parts: i32,
104 axis: impl Into<Option<i32>>,
105 stream: impl AsRef<Stream>,
106 ) -> Result<Vec<Array>> {
107 split_equal_device(self, num_parts, axis, stream)
108 }
109
110 #[default_device]
112 pub fn swap_axes_device(
113 &self,
114 axis1: i32,
115 axis2: i32,
116 stream: impl AsRef<Stream>,
117 ) -> Result<Array> {
118 swap_axes_device(self, axis1, axis2, stream)
119 }
120
121 #[default_device]
123 pub fn transpose_device(&self, axes: &[i32], stream: impl AsRef<Stream>) -> Result<Array> {
124 transpose_device(self, axes, stream)
125 }
126
127 #[default_device]
129 pub fn transpose_all_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
130 transpose_all_device(self, stream)
131 }
132
133 pub fn t(&self) -> Array {
135 self.transpose_all().unwrap()
136 }
137}
138
139fn axes_or_default_to_all_size_one_axes<'a>(
140 axes: impl IntoOption<&'a [i32]>,
141 shape: &[i32],
142) -> Cow<'a, [i32]> {
143 match axes.into_option() {
144 Some(axes) => Cow::Borrowed(axes),
145 None => shape
146 .iter()
147 .enumerate()
148 .filter_map(|(i, &dim)| if dim == 1 { Some(i as i32) } else { None })
149 .collect(),
150 }
151}
152
153fn resolve_strides(
154 shape: &[i32],
155 strides: Option<&[i64]>,
156) -> SmallVec<[i64; DEFAULT_STACK_VEC_LEN]> {
157 match strides {
158 Some(strides) => SmallVec::from_slice(strides),
159 None => {
160 let result = shape
161 .iter()
162 .rev()
163 .scan(1, |acc, &dim| {
164 let result = *acc;
165 *acc *= dim as i64;
166 Some(result)
167 })
168 .collect::<SmallVec<[i64; DEFAULT_STACK_VEC_LEN]>>();
169 result.into_iter().rev().collect()
170 }
171 }
172}
173
174#[generate_macro]
181#[default_device]
182pub fn broadcast_arrays_device(
183 arrays: &[impl AsRef<Array>],
184 #[optional] stream: impl AsRef<Stream>,
185) -> Result<Vec<Array>> {
186 let c_vec = VectorArray::try_from_iter(arrays.iter())?;
187 Vec::<Array>::try_from_op(|res| unsafe {
188 mlx_sys::mlx_broadcast_arrays(res, c_vec.as_ptr(), stream.as_ref().as_ptr())
189 })
190}
191
192#[generate_macro]
203#[default_device]
204pub fn as_strided_device<'a>(
205 a: impl AsRef<Array>,
206 #[optional] shape: impl IntoOption<&'a [i32]>,
207 #[optional] strides: impl IntoOption<&'a [i64]>,
208 #[optional] offset: impl Into<Option<usize>>,
209 #[optional] stream: impl AsRef<Stream>,
210) -> Result<Array> {
211 let a = a.as_ref();
212 let shape = shape.into_option().unwrap_or(a.shape());
213 let resolved_strides = resolve_strides(shape, strides.into_option());
214 let offset = offset.into().unwrap_or(0);
215
216 Array::try_from_op(|res| unsafe {
217 mlx_sys::mlx_as_strided(
218 res,
219 a.as_ptr(),
220 shape.as_ptr(),
221 shape.len(),
222 resolved_strides.as_ptr(),
223 resolved_strides.len(),
224 offset,
225 stream.as_ref().as_ptr(),
226 )
227 })
228}
229
230#[generate_macro]
246#[default_device]
247pub fn broadcast_to_device(
248 a: impl AsRef<Array>,
249 shape: &[i32],
250 #[optional] stream: impl AsRef<Stream>,
251) -> Result<Array> {
252 Array::try_from_op(|res| unsafe {
253 mlx_sys::mlx_broadcast_to(
254 res,
255 a.as_ref().as_ptr(),
256 shape.as_ptr(),
257 shape.len(),
258 stream.as_ref().as_ptr(),
259 )
260 })
261}
262
263#[generate_macro]
280#[default_device]
281pub fn concatenate_device(
282 arrays: &[impl AsRef<Array>],
283 #[optional] axis: impl Into<Option<i32>>,
284 #[optional] stream: impl AsRef<Stream>,
285) -> Result<Array> {
286 let axis = axis.into().unwrap_or(0);
287 let c_arrays = VectorArray::try_from_iter(arrays.iter())?;
288 Array::try_from_op(|res| unsafe {
289 mlx_sys::mlx_concatenate(res, c_arrays.as_ptr(), axis, stream.as_ref().as_ptr())
290 })
291}
292
293#[generate_macro]
309#[default_device]
310pub fn expand_dims_device(
311 a: impl AsRef<Array>,
312 axes: &[i32],
313 #[optional] stream: impl AsRef<Stream>,
314) -> Result<Array> {
315 Array::try_from_op(|res| unsafe {
316 mlx_sys::mlx_expand_dims(
317 res,
318 a.as_ref().as_ptr(),
319 axes.as_ptr(),
320 axes.len(),
321 stream.as_ref().as_ptr(),
322 )
323 })
324}
325
326#[generate_macro]
347#[default_device]
348pub fn flatten_device(
349 a: impl AsRef<Array>,
350 #[optional] start_axis: impl Into<Option<i32>>,
351 #[optional] end_axis: impl Into<Option<i32>>,
352 #[optional] stream: impl AsRef<Stream>,
353) -> Result<Array> {
354 let start_axis = start_axis.into().unwrap_or(0);
355 let end_axis = end_axis.into().unwrap_or(-1);
356
357 Array::try_from_op(|res| unsafe {
358 mlx_sys::mlx_flatten(
359 res,
360 a.as_ref().as_ptr(),
361 start_axis,
362 end_axis,
363 stream.as_ref().as_ptr(),
364 )
365 })
366}
367
368#[generate_macro]
376#[default_device]
377pub fn unflatten_device(
378 a: impl AsRef<Array>,
379 axis: i32,
380 shape: &[i32],
381 #[optional] stream: impl AsRef<Stream>,
382) -> Result<Array> {
383 Array::try_from_op(|res| unsafe {
384 mlx_sys::mlx_unflatten(
385 res,
386 a.as_ref().as_ptr(),
387 axis,
388 shape.as_ptr(),
389 shape.len(),
390 stream.as_ref().as_ptr(),
391 )
392 })
393}
394
395#[generate_macro]
411#[default_device]
412pub fn reshape_device(
413 a: impl AsRef<Array>,
414 shape: &[i32],
415 #[optional] stream: impl AsRef<Stream>,
416) -> Result<Array> {
417 Array::try_from_op(|res| unsafe {
418 mlx_sys::mlx_reshape(
419 res,
420 a.as_ref().as_ptr(),
421 shape.as_ptr(),
422 shape.len(),
423 stream.as_ref().as_ptr(),
424 )
425 })
426}
427
428#[generate_macro]
444#[default_device]
445pub fn squeeze_device<'a>(
446 a: impl AsRef<Array>,
447 #[optional] axes: impl IntoOption<&'a [i32]>,
448 #[optional] stream: impl AsRef<Stream>,
449) -> Result<Array> {
450 let a = a.as_ref();
451 let axes = axes_or_default_to_all_size_one_axes(axes, a.shape());
452 Array::try_from_op(|res| unsafe {
453 mlx_sys::mlx_squeeze(
454 res,
455 a.as_ptr(),
456 axes.as_ptr(),
457 axes.len(),
458 stream.as_ref().as_ptr(),
459 )
460 })
461}
462
463#[generate_macro]
478#[default_device]
479pub fn at_least_1d_device(
480 a: impl AsRef<Array>,
481 #[optional] stream: impl AsRef<Stream>,
482) -> Result<Array> {
483 Array::try_from_op(|res| unsafe {
484 mlx_sys::mlx_atleast_1d(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
485 })
486}
487
488#[generate_macro]
503#[default_device]
504pub fn at_least_2d_device(
505 a: impl AsRef<Array>,
506 #[optional] stream: impl AsRef<Stream>,
507) -> Result<Array> {
508 Array::try_from_op(|res| unsafe {
509 mlx_sys::mlx_atleast_2d(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
510 })
511}
512
513#[generate_macro]
528#[default_device]
529pub fn at_least_3d_device(
530 a: impl AsRef<Array>,
531 #[optional] stream: impl AsRef<Stream>,
532) -> Result<Array> {
533 Array::try_from_op(|res| unsafe {
534 mlx_sys::mlx_atleast_3d(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
535 })
536}
537
538#[generate_macro]
555#[default_device]
556pub fn move_axis_device(
557 a: impl AsRef<Array>,
558 src: i32,
559 dst: i32,
560 #[optional] stream: impl AsRef<Stream>,
561) -> Result<Array> {
562 Array::try_from_op(|res| unsafe {
563 mlx_sys::mlx_moveaxis(res, a.as_ref().as_ptr(), src, dst, stream.as_ref().as_ptr())
564 })
565}
566
567#[generate_macro]
584#[default_device]
585pub fn split_device(
586 a: impl AsRef<Array>,
587 indices: &[i32],
588 #[optional] axis: impl Into<Option<i32>>,
589 #[optional] stream: impl AsRef<Stream>,
590) -> Result<Vec<Array>> {
591 let axis = axis.into().unwrap_or(0);
592 Vec::<Array>::try_from_op(|res| unsafe {
593 mlx_sys::mlx_split(
594 res,
595 a.as_ref().as_ptr(),
596 indices.as_ptr(),
597 indices.len(),
598 axis,
599 stream.as_ref().as_ptr(),
600 )
601 })
602}
603
604#[generate_macro]
622#[default_device]
623pub fn split_equal_device(
624 a: impl AsRef<Array>,
625 num_parts: i32,
626 #[optional] axis: impl Into<Option<i32>>,
627 #[optional] stream: impl AsRef<Stream>,
628) -> Result<Vec<Array>> {
629 let axis = axis.into().unwrap_or(0);
630 Vec::<Array>::try_from_op(|res| unsafe {
631 mlx_sys::mlx_split_equal_parts(
632 res,
633 a.as_ref().as_ptr(),
634 num_parts,
635 axis,
636 stream.as_ref().as_ptr(),
637 )
638 })
639}
640
641#[derive(Debug)]
643pub enum PadWidth<'a> {
644 Same((i32, i32)),
646
647 Widths(&'a [(i32, i32)]),
649}
650
651impl From<i32> for PadWidth<'_> {
652 fn from(width: i32) -> Self {
653 PadWidth::Same((width, width))
654 }
655}
656
657impl From<(i32, i32)> for PadWidth<'_> {
658 fn from(width: (i32, i32)) -> Self {
659 PadWidth::Same(width)
660 }
661}
662
663impl<'a> From<&'a [(i32, i32)]> for PadWidth<'a> {
664 fn from(widths: &'a [(i32, i32)]) -> Self {
665 PadWidth::Widths(widths)
666 }
667}
668
669impl<'a, const N: usize> From<&'a [(i32, i32); N]> for PadWidth<'a> {
670 fn from(widths: &'a [(i32, i32); N]) -> Self {
671 PadWidth::Widths(widths)
672 }
673}
674
675impl PadWidth<'_> {
676 fn low_pads(&self, ndim: usize) -> SmallVec<[i32; DEFAULT_STACK_VEC_LEN]> {
677 match self {
678 PadWidth::Same((low, _high)) => (0..ndim).map(|_| *low).collect(),
679 PadWidth::Widths(widths) => widths.iter().map(|(low, _high)| *low).collect(),
680 }
681 }
682
683 fn high_pads(&self, ndim: usize) -> SmallVec<[i32; DEFAULT_STACK_VEC_LEN]> {
684 match self {
685 PadWidth::Same((_low, high)) => (0..ndim).map(|_| *high).collect(),
686 PadWidth::Widths(widths) => widths.iter().map(|(_low, high)| *high).collect(),
687 }
688 }
689}
690
691#[derive(Debug)]
693pub enum PadMode {
694 Constant,
696
697 Edge,
699}
700
701impl PadMode {
702 unsafe fn as_c_str(&self) -> *const i8 {
703 static CONSTANT: &[u8] = b"constant\0";
704 static EDGE: &[u8] = b"edge\0";
705
706 match self {
707 PadMode::Constant => CONSTANT.as_ptr() as *const _,
708 PadMode::Edge => EDGE.as_ptr() as *const _,
709 }
710 }
711}
712
713#[generate_macro]
734#[default_device]
735pub fn pad_device<'a>(
736 a: impl AsRef<Array>,
737 #[optional] width: impl Into<PadWidth<'a>>,
738 #[optional] value: impl Into<Option<Array>>,
739 #[optional] mode: impl Into<Option<PadMode>>,
740 #[optional] stream: impl AsRef<Stream>,
741) -> Result<Array> {
742 let a = a.as_ref();
743 let width = width.into();
744 let ndim = a.ndim();
745 let axes: SmallVec<[i32; DEFAULT_STACK_VEC_LEN]> = (0..ndim).map(|i| i as i32).collect();
746 let low_pads = width.low_pads(ndim);
747 let high_pads = width.high_pads(ndim);
748 let value = value
749 .into()
750 .map(Ok)
751 .unwrap_or_else(|| Array::from_int(0).as_dtype(a.dtype()))?;
752 let mode = mode.into().unwrap_or(PadMode::Constant);
753
754 Array::try_from_op(|res| unsafe {
755 mlx_sys::mlx_pad(
756 res,
757 a.as_ptr(),
758 axes.as_ptr(),
759 axes.len(),
760 low_pads.as_ptr(),
761 low_pads.len(),
762 high_pads.as_ptr(),
763 high_pads.len(),
764 value.as_ptr(),
765 mode.as_c_str(),
766 stream.as_ref().as_ptr(),
767 )
768 })
769}
770
771#[generate_macro]
788#[default_device]
789pub fn stack_device(
790 arrays: &[impl AsRef<Array>],
791 axis: i32,
792 #[optional] stream: impl AsRef<Stream>,
793) -> Result<Array> {
794 let c_vec = VectorArray::try_from_iter(arrays.iter())?;
795 Array::try_from_op(|res| unsafe {
796 mlx_sys::mlx_stack(res, c_vec.as_ptr(), axis, stream.as_ref().as_ptr())
797 })
798}
799
800#[generate_macro]
816#[default_device]
817pub fn stack_all_device(
818 arrays: &[impl AsRef<Array>],
819 #[optional] stream: impl AsRef<Stream>,
820) -> Result<Array> {
821 let c_vec = VectorArray::try_from_iter(arrays.iter())?;
822 Array::try_from_op(|res| unsafe {
823 mlx_sys::mlx_stack_all(res, c_vec.as_ptr(), stream.as_ref().as_ptr())
824 })
825}
826
827#[generate_macro]
844#[default_device]
845pub fn swap_axes_device(
846 a: impl AsRef<Array>,
847 axis1: i32,
848 axis2: i32,
849 #[optional] stream: impl AsRef<Stream>,
850) -> Result<Array> {
851 Array::try_from_op(|res| unsafe {
852 mlx_sys::mlx_swapaxes(
853 res,
854 a.as_ref().as_ptr(),
855 axis1,
856 axis2,
857 stream.as_ref().as_ptr(),
858 )
859 })
860}
861
862#[generate_macro]
878#[default_device]
879pub fn tile_device(
880 a: impl AsRef<Array>,
881 reps: &[i32],
882 #[optional] stream: impl AsRef<Stream>,
883) -> Result<Array> {
884 Array::try_from_op(|res| unsafe {
885 mlx_sys::mlx_tile(
886 res,
887 a.as_ref().as_ptr(),
888 reps.as_ptr(),
889 reps.len(),
890 stream.as_ref().as_ptr(),
891 )
892 })
893}
894
895#[generate_macro]
917#[default_device]
918pub fn transpose_device(
919 a: impl AsRef<Array>,
920 axes: &[i32],
921 #[optional] stream: impl AsRef<Stream>,
922) -> Result<Array> {
923 Array::try_from_op(|res| unsafe {
924 mlx_sys::mlx_transpose(
925 res,
926 a.as_ref().as_ptr(),
927 axes.as_ptr(),
928 axes.len(),
929 stream.as_ref().as_ptr(),
930 )
931 })
932}
933
934#[generate_macro]
936#[default_device]
937pub fn transpose_all_device(
938 a: impl AsRef<Array>,
939 #[optional] stream: impl AsRef<Stream>,
940) -> Result<Array> {
941 Array::try_from_op(|res| unsafe {
942 mlx_sys::mlx_transpose_all(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
943 })
944}
945
946#[cfg(test)]
949mod tests {
950 use crate::{array, Array, Dtype};
951
952 use super::*;
953
954 #[test]
955 fn test_squeeze() {
956 let a = Array::zeros::<i32>(&[2, 1, 2, 1, 2, 1]).unwrap();
957 assert_eq!(squeeze(&a, &[1, 3, 5][..]).unwrap().shape(), &[2, 2, 2]);
958 assert_eq!(squeeze(&a, &[-1, -3, -5][..]).unwrap().shape(), &[2, 2, 2]);
959 assert_eq!(squeeze(&a, &[1][..]).unwrap().shape(), &[2, 2, 1, 2, 1]);
960 assert_eq!(squeeze(&a, &[-1][..]).unwrap().shape(), &[2, 1, 2, 1, 2]);
961
962 assert!(squeeze(&a, &[0][..]).is_err());
963 assert!(squeeze(&a, &[2][..]).is_err());
964 assert!(squeeze(&a, &[1, 3, 1][..]).is_err());
965 assert!(squeeze(&a, &[1, 3, -3][..]).is_err());
966 }
967
968 #[test]
969 fn test_expand_dims() {
970 let a = Array::zeros::<i32>(&[2, 2]).unwrap();
971 assert_eq!(expand_dims(&a, &[0][..]).unwrap().shape(), &[1, 2, 2]);
972 assert_eq!(expand_dims(&a, &[-1][..]).unwrap().shape(), &[2, 2, 1]);
973 assert_eq!(expand_dims(&a, &[1][..]).unwrap().shape(), &[2, 1, 2]);
974 assert_eq!(
975 expand_dims(&a, &[0, 1, 2]).unwrap().shape(),
976 &[1, 1, 1, 2, 2]
977 );
978 assert_eq!(
979 expand_dims(&a, &[0, 1, 2, 5, 6, 7]).unwrap().shape(),
980 &[1, 1, 1, 2, 2, 1, 1, 1]
981 );
982
983 assert!(expand_dims(&a, &[3]).is_err());
984 assert!(expand_dims(&a, &[0, 1, 0]).is_err());
985 assert!(expand_dims(&a, &[0, 1, -4]).is_err());
986 }
987
988 #[test]
989 fn test_flatten() {
990 let x = Array::zeros::<i32>(&[2, 3, 4]).unwrap();
991 assert_eq!(flatten(&x, None, None).unwrap().shape(), &[2 * 3 * 4]);
992
993 assert_eq!(flatten(&x, 1, 1).unwrap().shape(), &[2, 3, 4]);
994 assert_eq!(flatten(&x, 1, 2).unwrap().shape(), &[2, 3 * 4]);
995 assert_eq!(flatten(&x, 1, 3).unwrap().shape(), &[2, 3 * 4]);
996 assert_eq!(flatten(&x, 1, -1).unwrap().shape(), &[2, 3 * 4]);
997 assert_eq!(flatten(&x, -2, -1).unwrap().shape(), &[2, 3 * 4]);
998 assert_eq!(flatten(&x, -3, -1).unwrap().shape(), &[2 * 3 * 4]);
999 assert_eq!(flatten(&x, -4, -1).unwrap().shape(), &[2 * 3 * 4]);
1000
1001 assert!(flatten(&x, 2, 1).is_err());
1002
1003 assert!(flatten(&x, 5, 6).is_err());
1004
1005 assert!(flatten(&x, -5, -4).is_err());
1006
1007 let x = Array::from_int(1);
1008 assert_eq!(flatten(&x, -3, -1).unwrap().shape(), &[1]);
1009 assert_eq!(flatten(&x, 0, 0).unwrap().shape(), &[1]);
1010 }
1011
1012 #[test]
1013 fn test_unflatten() {
1014 let a = array!([1, 2, 3, 4]);
1015 let b = unflatten(&a, 0, &[2, -1]).unwrap();
1016 let expected = array!([[1, 2], [3, 4]]);
1017 assert_eq!(b, expected);
1018 }
1019
1020 #[test]
1021 fn test_reshape() {
1022 let x = Array::from_int(1);
1023 assert!(reshape(&x, &[]).unwrap().shape().is_empty());
1024 assert!(reshape(&x, &[2]).is_err());
1025 let y = reshape(&x, &[1, 1, 1]).unwrap();
1026 assert_eq!(y.shape(), &[1, 1, 1]);
1027 let y = reshape(&x, &[-1, 1, 1]).unwrap();
1028 assert_eq!(y.shape(), &[1, 1, 1]);
1029 let y = reshape(&x, &[1, 1, -1]).unwrap();
1030 assert_eq!(y.shape(), &[1, 1, 1]);
1031 assert!(reshape(&x, &[1, -1, -1]).is_err());
1032 assert!(reshape(&x, &[2, -1]).is_err());
1033
1034 let x = Array::zeros::<i32>(&[2, 2, 2]).unwrap();
1035 let y = reshape(&x, &[8]).unwrap();
1036 assert_eq!(y.shape(), &[8]);
1037 assert!(reshape(&x, &[7]).is_err());
1038 let y = reshape(&x, &[-1]).unwrap();
1039 assert_eq!(y.shape(), &[8]);
1040 let y = reshape(&x, &[-1, 2]).unwrap();
1041 assert_eq!(y.shape(), &[4, 2]);
1042 assert!(reshape(&x, &[-1, 7]).is_err());
1043
1044 let x = Array::from_slice::<i32>(&[], &[0]);
1045 let y = reshape(&x, &[0, 0, 0]).unwrap();
1046 assert_eq!(y.shape(), &[0, 0, 0]);
1047 y.eval().unwrap();
1048 assert_eq!(y.size(), 0);
1049 assert!(reshape(&x, &[]).is_err());
1050 assert!(reshape(&x, &[1]).is_err());
1051 let y = reshape(&x, &[1, 5, 0]).unwrap();
1052 assert_eq!(y.shape(), &[1, 5, 0]);
1053 }
1054
1055 #[test]
1056 fn test_as_strided() {
1057 let x = Array::from_iter(0..10, &[10]);
1058 let y = as_strided(&x, &[3, 3][..], &[1, 1][..], 0).unwrap();
1059 let expected = Array::from_slice(&[0, 1, 2, 1, 2, 3, 2, 3, 4], &[3, 3]);
1060 assert_eq!(y, expected);
1061
1062 let y = as_strided(&x, &[3, 3][..], &[0, 3][..], 0).unwrap();
1063 let expected = Array::from_slice(&[0, 3, 6, 0, 3, 6, 0, 3, 6], &[3, 3]);
1064 assert_eq!(y, expected);
1065
1066 let x = x.reshape(&[2, 5]).unwrap();
1067 let x = x.transpose(&[1, 0][..]).unwrap();
1068 let y = as_strided(&x, &[3, 3][..], &[2, 1][..], 1).unwrap();
1069 let expected = Array::from_slice(&[5, 1, 6, 6, 2, 7, 7, 3, 8], &[3, 3]);
1070 assert_eq!(y, expected);
1071 }
1072
1073 #[test]
1074 fn test_at_least_1d() {
1075 let x = Array::from_int(1);
1076 let out = at_least_1d(&x).unwrap();
1077 assert_eq!(out.ndim(), 1);
1078 assert_eq!(out.shape(), &[1]);
1079
1080 let x = Array::from_slice(&[1, 2, 3], &[3]);
1081 let out = at_least_1d(&x).unwrap();
1082 assert_eq!(out.ndim(), 1);
1083 assert_eq!(out.shape(), &[3]);
1084
1085 let x = Array::from_slice(&[1, 2, 3], &[3, 1]);
1086 let out = at_least_1d(&x).unwrap();
1087 assert_eq!(out.ndim(), 2);
1088 assert_eq!(out.shape(), &[3, 1]);
1089 }
1090
1091 #[test]
1092 fn test_at_least_2d() {
1093 let x = Array::from_int(1);
1094 let out = at_least_2d(&x).unwrap();
1095 assert_eq!(out.ndim(), 2);
1096 assert_eq!(out.shape(), &[1, 1]);
1097
1098 let x = Array::from_slice(&[1, 2, 3], &[3]);
1099 let out = at_least_2d(&x).unwrap();
1100 assert_eq!(out.ndim(), 2);
1101 assert_eq!(out.shape(), &[1, 3]);
1102
1103 let x = Array::from_slice(&[1, 2, 3], &[3, 1]);
1104 let out = at_least_2d(&x).unwrap();
1105 assert_eq!(out.ndim(), 2);
1106 assert_eq!(out.shape(), &[3, 1]);
1107 }
1108
1109 #[test]
1110 fn test_at_least_3d() {
1111 let x = Array::from_int(1);
1112 let out = at_least_3d(&x).unwrap();
1113 assert_eq!(out.ndim(), 3);
1114 assert_eq!(out.shape(), &[1, 1, 1]);
1115
1116 let x = Array::from_slice(&[1, 2, 3], &[3]);
1117 let out = at_least_3d(&x).unwrap();
1118 assert_eq!(out.ndim(), 3);
1119 assert_eq!(out.shape(), &[1, 3, 1]);
1120
1121 let x = Array::from_slice(&[1, 2, 3], &[3, 1]);
1122 let out = at_least_3d(&x).unwrap();
1123 assert_eq!(out.ndim(), 3);
1124 assert_eq!(out.shape(), &[3, 1, 1]);
1125 }
1126
1127 #[test]
1128 fn test_move_axis() {
1129 let a = Array::from_int(0);
1130 assert!(move_axis(&a, 0, 0).is_err());
1131
1132 let a = Array::zeros::<i32>(&[2]).unwrap();
1133 assert!(move_axis(&a, 0, 1).is_err());
1134 assert_eq!(move_axis(&a, 0, 0).unwrap().shape(), &[2]);
1135 assert_eq!(move_axis(&a, -1, -1).unwrap().shape(), &[2]);
1136
1137 let a = Array::zeros::<i32>(&[2, 3, 4]).unwrap();
1138 assert!(move_axis(&a, 0, -4).is_err());
1139 assert!(move_axis(&a, 0, 3).is_err());
1140 assert!(move_axis(&a, 3, 0).is_err());
1141 assert!(move_axis(&a, -4, 0).is_err());
1142 assert_eq!(move_axis(&a, 0, 2).unwrap().shape(), &[3, 4, 2]);
1143 assert_eq!(move_axis(&a, 0, 1).unwrap().shape(), &[3, 2, 4]);
1144 assert_eq!(move_axis(&a, 0, -1).unwrap().shape(), &[3, 4, 2]);
1145 assert_eq!(move_axis(&a, -2, 2).unwrap().shape(), &[2, 4, 3]);
1146 }
1147
1148 #[test]
1149 fn test_split_equal() {
1150 let x = Array::from_int(3);
1151 assert!(split_equal(&x, 0, 0).is_err());
1152
1153 let x = Array::from_slice(&[0, 1, 2], &[3]);
1154 assert!(split_equal(&x, 3, 1).is_err());
1155 assert!(split_equal(&x, -2, 1).is_err());
1156
1157 let out = split_equal(&x, 3, 0).unwrap();
1158 assert_eq!(out.len(), 3);
1159
1160 let mut out = split_equal(&x, 3, -1).unwrap();
1161 assert_eq!(out.len(), 3);
1162 for (i, a) in out.iter_mut().enumerate() {
1163 assert_eq!(a.shape(), &[1]);
1164 assert_eq!(a.dtype(), Dtype::Int32);
1165 assert_eq!(a.item::<i32>(), i as i32);
1166 }
1167
1168 let x = Array::from_slice(&[0, 1, 2, 3, 4, 5], &[2, 3]);
1169 let out = split_equal(&x, 2, None).unwrap();
1170 assert_eq!(out[0], Array::from_slice(&[0, 1, 2], &[1, 3]));
1171 assert_eq!(out[1], Array::from_slice(&[3, 4, 5], &[1, 3]));
1172
1173 let out = split_equal(&x, 3, 1).unwrap();
1174 assert_eq!(out[0], Array::from_slice(&[0, 3], &[2, 1]));
1175 assert_eq!(out[1], Array::from_slice(&[1, 4], &[2, 1]));
1176 assert_eq!(out[2], Array::from_slice(&[2, 5], &[2, 1]));
1177
1178 let x = Array::zeros::<i32>(&[8, 12]).unwrap();
1179 let out = split_equal(&x, 2, None).unwrap();
1180 assert_eq!(out.len(), 2);
1181 assert_eq!(out[0].shape(), &[4, 12]);
1182 assert_eq!(out[1].shape(), &[4, 12]);
1183
1184 let out = split_equal(&x, 3, 1).unwrap();
1185 assert_eq!(out.len(), 3);
1186 assert_eq!(out[0].shape(), &[8, 4]);
1187 assert_eq!(out[1].shape(), &[8, 4]);
1188 assert_eq!(out[2].shape(), &[8, 4]);
1189 }
1190
1191 #[test]
1192 fn test_split() {
1193 let x = Array::zeros::<i32>(&[8, 12]).unwrap();
1194
1195 let out = split(&x, &[], None).unwrap();
1196 assert_eq!(out.len(), 1);
1197 assert_eq!(out[0].shape(), x.shape());
1198
1199 let out = split(&x, &[3, 7], None).unwrap();
1200 assert_eq!(out.len(), 3);
1201 assert_eq!(out[0].shape(), &[3, 12]);
1202 assert_eq!(out[1].shape(), &[4, 12]);
1203 assert_eq!(out[2].shape(), &[1, 12]);
1204
1205 let out = split(&x, &[20], None).unwrap();
1206 assert_eq!(out.len(), 2);
1207 assert_eq!(out[0].shape(), &[8, 12]);
1208 assert_eq!(out[1].shape(), &[0, 12]);
1209
1210 let out = split(&x, &[-5], None).unwrap();
1211 assert_eq!(out[0].shape(), &[3, 12]);
1212 assert_eq!(out[1].shape(), &[5, 12]);
1213
1214 let out = split(&x, &[2, 8], Some(1)).unwrap();
1215 assert_eq!(out[0].shape(), &[8, 2]);
1216 assert_eq!(out[1].shape(), &[8, 6]);
1217 assert_eq!(out[2].shape(), &[8, 4]);
1218
1219 let x = Array::from_iter(0i32..5, &[5]);
1220 let out = split(&x, &[2, 1, 2], None).unwrap();
1221 assert_eq!(out[0], Array::from_slice(&[0, 1], &[2]));
1222 assert_eq!(out[1], Array::from_slice::<i32>(&[], &[0]));
1223 assert_eq!(out[2], Array::from_slice(&[1], &[1]));
1224 assert_eq!(out[3], Array::from_slice(&[2, 3, 4], &[3]));
1225 }
1226
1227 #[test]
1228 fn test_pad() {
1229 let x = Array::zeros::<f32>(&[1, 2, 3]).unwrap();
1230 assert_eq!(pad(&x, 1, None, None).unwrap().shape(), &[3, 4, 5]);
1231 assert_eq!(pad(&x, (0, 1), None, None).unwrap().shape(), &[2, 3, 4]);
1232 assert_eq!(
1233 pad(&x, &[(1, 1), (1, 2), (3, 1)], None, None)
1234 .unwrap()
1235 .shape(),
1236 &[3, 5, 7]
1237 );
1238 }
1239
1240 #[test]
1241 fn test_stack() {
1242 let x = Array::from_slice::<f32>(&[], &[0]);
1243 let x = vec![x];
1244 assert_eq!(stack(&x, 0).unwrap().shape(), &[1, 0]);
1245 assert_eq!(stack(&x, 1).unwrap().shape(), &[0, 1]);
1246
1247 let x = Array::from_slice(&[1, 2, 3], &[3]);
1248 let x = vec![x];
1249 assert_eq!(stack(&x, 0).unwrap().shape(), &[1, 3]);
1250 assert_eq!(stack(&x, 1).unwrap().shape(), &[3, 1]);
1251
1252 let y = Array::from_slice(&[4, 5, 6], &[3]);
1253 let mut z = x;
1254 z.push(y);
1255 assert_eq!(stack_all(&z).unwrap().shape(), &[2, 3]);
1256 assert_eq!(stack(&z, 1).unwrap().shape(), &[3, 2]);
1257 assert_eq!(stack(&z, -1).unwrap().shape(), &[3, 2]);
1258 assert_eq!(stack(&z, -2).unwrap().shape(), &[2, 3]);
1259
1260 let empty: Vec<Array> = Vec::new();
1261 assert!(stack(&empty, 0).is_err());
1262
1263 let x = Array::from_slice(&[1, 2, 3], &[3])
1264 .as_dtype(Dtype::Float16)
1265 .unwrap();
1266 let y = Array::from_slice(&[4, 5, 6], &[3])
1267 .as_dtype(Dtype::Int32)
1268 .unwrap();
1269 assert_eq!(stack(&[x, y], 0).unwrap().dtype(), Dtype::Float16);
1270
1271 let x = Array::from_slice(&[1, 2, 3], &[3])
1272 .as_dtype(Dtype::Int32)
1273 .unwrap();
1274 let y = Array::from_slice(&[4, 5, 6, 7], &[4])
1275 .as_dtype(Dtype::Int32)
1276 .unwrap();
1277 assert!(stack(&[x, y], 0).is_err());
1278 }
1279
1280 #[test]
1281 fn test_swap_axes() {
1282 let a = Array::from_int(0);
1283 assert!(swap_axes(&a, 0, 0).is_err());
1284
1285 let a = Array::zeros::<i32>(&[2]).unwrap();
1286 assert!(swap_axes(&a, 0, 1).is_err());
1287 assert_eq!(swap_axes(&a, 0, 0).unwrap().shape(), &[2]);
1288 assert_eq!(swap_axes(&a, -1, -1).unwrap().shape(), &[2]);
1289
1290 let a = Array::zeros::<i32>(&[2, 3, 4]).unwrap();
1291 assert!(swap_axes(&a, 0, -4).is_err());
1292 assert!(swap_axes(&a, 0, 3).is_err());
1293 assert!(swap_axes(&a, 3, 0).is_err());
1294 assert!(swap_axes(&a, -4, 0).is_err());
1295 assert_eq!(swap_axes(&a, 0, 2).unwrap().shape(), &[4, 3, 2]);
1296 assert_eq!(swap_axes(&a, 0, 1).unwrap().shape(), &[3, 2, 4]);
1297 assert_eq!(swap_axes(&a, 0, -1).unwrap().shape(), &[4, 3, 2]);
1298 assert_eq!(swap_axes(&a, -2, 2).unwrap().shape(), &[2, 4, 3]);
1299 }
1300
1301 #[test]
1302 fn test_tile() {
1303 let x = Array::from_slice(&[1, 2, 3], &[3]);
1304 let y = tile(&x, &[2]).unwrap();
1305 let expected = Array::from_slice(&[1, 2, 3, 1, 2, 3], &[6]);
1306 assert_eq!(y, expected);
1307
1308 let x = Array::from_slice(&[1, 2, 3, 4], &[2, 2]);
1309 let y = tile(&x, &[2]).unwrap();
1310 let expected = Array::from_slice(&[1, 2, 1, 2, 3, 4, 3, 4], &[2, 4]);
1311 assert_eq!(y, expected);
1312
1313 let x = Array::from_slice(&[1, 2, 3, 4], &[2, 2]);
1314 let y = tile(&x, &[4, 1]).unwrap();
1315 let expected =
1316 Array::from_slice(&[1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4], &[8, 2]);
1317 assert_eq!(y, expected);
1318
1319 let x = Array::from_slice(&[1, 2, 3, 4], &[2, 2]);
1320 let y = tile(&x, &[2, 2]).unwrap();
1321 let expected =
1322 Array::from_slice(&[1, 2, 1, 2, 3, 4, 3, 4, 1, 2, 1, 2, 3, 4, 3, 4], &[4, 4]);
1323 assert_eq!(y, expected);
1324
1325 let x = Array::from_slice(&[1, 2, 3], &[3]);
1326 let y = tile(&x, &[2, 2, 2]).unwrap();
1327 let expected = Array::from_slice(
1328 &[
1329 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3,
1330 ],
1331 &[2, 2, 6],
1332 );
1333 assert_eq!(y, expected);
1334 }
1335
1336 #[test]
1337 fn test_transpose() {
1338 let x = Array::from_int(1);
1339 let y = transpose_all(&x).unwrap();
1340 assert!(y.shape().is_empty());
1341 assert_eq!(y.item::<i32>(), 1);
1342 assert!(transpose(&x, &[0][..]).is_err());
1343 assert!(transpose(&x, &[1][..]).is_err());
1344
1345 let x = Array::from_slice(&[1], &[1]);
1346 let y = transpose_all(&x).unwrap();
1347 assert_eq!(y.shape(), &[1]);
1348 assert_eq!(y.item::<i32>(), 1);
1349
1350 let y = transpose(&x, &[-1][..]).unwrap();
1351 assert_eq!(y.shape(), &[1]);
1352 assert_eq!(y.item::<i32>(), 1);
1353
1354 assert!(transpose(&x, &[1][..]).is_err());
1355 assert!(transpose(&x, &[0, 0][..]).is_err());
1356
1357 let x = Array::from_slice::<i32>(&[], &[0]);
1358 let y = transpose_all(&x).unwrap();
1359 assert_eq!(y.shape(), &[0]);
1360 y.eval().unwrap();
1361 assert_eq!(y.size(), 0);
1362
1363 let x = Array::from_slice(&[1, 2, 3, 4, 5, 6], &[2, 3]);
1364 let mut y = transpose_all(&x).unwrap();
1365 assert_eq!(y.shape(), &[3, 2]);
1366 y = transpose(&x, &[-1, 0][..]).unwrap();
1367 assert_eq!(y.shape(), &[3, 2]);
1368 y = transpose(&x, &[-1, -2][..]).unwrap();
1369 assert_eq!(y.shape(), &[3, 2]);
1370 y.eval().unwrap();
1371 assert_eq!(y, Array::from_slice(&[1, 4, 2, 5, 3, 6], &[3, 2]));
1372
1373 let y = transpose(&x, &[0, 1][..]).unwrap();
1374 assert_eq!(y.shape(), &[2, 3]);
1375 assert_eq!(y, x);
1376
1377 let y = transpose(&x, &[0, -1][..]).unwrap();
1378 assert_eq!(y.shape(), &[2, 3]);
1379 assert_eq!(y, x);
1380
1381 assert!(transpose(&x, &[][..]).is_err());
1382 assert!(transpose(&x, &[0][..]).is_err());
1383 assert!(transpose(&x, &[0, 0][..]).is_err());
1384 assert!(transpose(&x, &[0, 0, 0][..]).is_err());
1385 assert!(transpose(&x, &[0, 1, 1][..]).is_err());
1386
1387 let x = Array::from_slice(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], &[2, 3, 2]);
1388 let y = transpose_all(&x).unwrap();
1389 assert_eq!(y.shape(), &[2, 3, 2]);
1390 let expected = Array::from_slice(&[1, 7, 3, 9, 5, 11, 2, 8, 4, 10, 6, 12], &[2, 3, 2]);
1391 assert_eq!(y, expected);
1392
1393 let y = transpose(&x, &[0, 1, 2][..]).unwrap();
1394 assert_eq!(y.shape(), &[2, 3, 2]);
1395 assert_eq!(y, x);
1396
1397 let y = transpose(&x, &[1, 0, 2][..]).unwrap();
1398 assert_eq!(y.shape(), &[3, 2, 2]);
1399 let expected = Array::from_slice(&[1, 2, 7, 8, 3, 4, 9, 10, 5, 6, 11, 12], &[3, 2, 2]);
1400 assert_eq!(y, expected);
1401
1402 let y = transpose(&x, &[0, 2, 1][..]).unwrap();
1403 assert_eq!(y.shape(), &[2, 2, 3]);
1404 let expected = Array::from_slice(&[1, 3, 5, 2, 4, 6, 7, 9, 11, 8, 10, 12], &[2, 2, 3]);
1405 assert_eq!(y, expected);
1406
1407 let mut x = Array::from_slice(&[0, 1, 2, 3, 4, 5, 6, 7], &[4, 2]);
1408 x = reshape(transpose_all(&x).unwrap(), &[2, 2, 2]).unwrap();
1409 let expected = Array::from_slice(&[0, 2, 4, 6, 1, 3, 5, 7], &[2, 2, 2]);
1410 assert_eq!(x, expected);
1411
1412 let mut x = Array::from_slice(&[0, 1, 2, 3, 4, 5, 6, 7], &[1, 4, 1, 2]);
1413 x = transpose(&x, &[2, 1, 0, 3][..]).unwrap();
1415 x.eval().unwrap();
1416 }
1418}