1use mlx_internal_macros::{default_device, generate_macro};
2use smallvec::SmallVec;
3
4use crate::{
5 constants::DEFAULT_STACK_VEC_LEN,
6 error::Result,
7 utils::{guard::Guarded, IntoOption, VectorArray},
8 Array, Stream,
9};
10
11impl Array {
12 #[default_device]
14 pub fn expand_dims_device(&self, axis: i32, stream: impl AsRef<Stream>) -> Result<Array> {
15 expand_dims_device(self, axis, stream)
16 }
17
18 #[default_device]
20 pub fn expand_dims_axes_device(
21 &self,
22 axes: &[i32],
23 stream: impl AsRef<Stream>,
24 ) -> Result<Array> {
25 expand_dims_axes_device(self, axes, stream)
26 }
27
28 #[default_device]
30 pub fn flatten_device(
31 &self,
32 start_axis: impl Into<Option<i32>>,
33 end_axis: impl Into<Option<i32>>,
34 stream: impl AsRef<Stream>,
35 ) -> Result<Array> {
36 flatten_device(self, start_axis, end_axis, stream)
37 }
38
39 #[default_device]
41 pub fn reshape_device(&self, shape: &[i32], stream: impl AsRef<Stream>) -> Result<Array> {
42 reshape_device(self, shape, stream)
43 }
44
45 #[default_device]
47 pub fn squeeze_axes_device(&self, axes: &[i32], stream: impl AsRef<Stream>) -> Result<Array> {
48 squeeze_axes_device(self, axes, stream)
49 }
50
51 #[default_device]
53 pub fn squeeze_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
54 squeeze_device(self, stream)
55 }
56
57 #[default_device]
59 pub fn as_strided_device<'a>(
60 &'a self,
61 shape: impl IntoOption<&'a [i32]>,
62 strides: impl IntoOption<&'a [i64]>,
63 offset: impl Into<Option<usize>>,
64 stream: impl AsRef<Stream>,
65 ) -> Result<Array> {
66 as_strided_device(self, shape, strides, offset, stream)
67 }
68
69 #[default_device]
71 pub fn at_least_1d_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
72 at_least_1d_device(self, stream)
73 }
74
75 #[default_device]
77 pub fn at_least_2d_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
78 at_least_2d_device(self, stream)
79 }
80
81 #[default_device]
83 pub fn at_least_3d_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
84 at_least_3d_device(self, stream)
85 }
86
87 #[default_device]
89 pub fn move_axis_device(
90 &self,
91 src: i32,
92 dst: i32,
93 stream: impl AsRef<Stream>,
94 ) -> Result<Array> {
95 move_axis_device(self, src, dst, stream)
96 }
97
98 #[default_device]
100 pub fn split_axis_device(
101 &self,
102 indices: &[i32],
103 axis: impl Into<Option<i32>>,
104 stream: impl AsRef<Stream>,
105 ) -> Result<Vec<Array>> {
106 split_sections_device(self, indices, axis, stream)
107 }
108
109 #[default_device]
111 pub fn split_device(
112 &self,
113 num_parts: i32,
114 axis: impl Into<Option<i32>>,
115 stream: impl AsRef<Stream>,
116 ) -> Result<Vec<Array>> {
117 split_device(self, num_parts, axis, stream)
118 }
119
120 #[default_device]
122 pub fn swap_axes_device(
123 &self,
124 axis1: i32,
125 axis2: i32,
126 stream: impl AsRef<Stream>,
127 ) -> Result<Array> {
128 swap_axes_device(self, axis1, axis2, stream)
129 }
130
131 #[default_device]
133 pub fn transpose_axes_device(&self, axes: &[i32], stream: impl AsRef<Stream>) -> Result<Array> {
134 transpose_axes_device(self, axes, stream)
135 }
136
137 #[default_device]
139 pub fn transpose_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
140 transpose_device(self, stream)
141 }
142
143 pub fn t(&self) -> Array {
145 self.transpose().unwrap()
146 }
147}
148
149fn resolve_strides(
150 shape: &[i32],
151 strides: Option<&[i64]>,
152) -> SmallVec<[i64; DEFAULT_STACK_VEC_LEN]> {
153 match strides {
154 Some(strides) => SmallVec::from_slice(strides),
155 None => {
156 let result = shape
157 .iter()
158 .rev()
159 .scan(1, |acc, &dim| {
160 let result = *acc;
161 *acc *= dim as i64;
162 Some(result)
163 })
164 .collect::<SmallVec<[i64; DEFAULT_STACK_VEC_LEN]>>();
165 result.into_iter().rev().collect()
166 }
167 }
168}
169
170#[generate_macro]
177#[default_device]
178pub fn broadcast_arrays_device(
179 arrays: &[impl AsRef<Array>],
180 #[optional] stream: impl AsRef<Stream>,
181) -> Result<Vec<Array>> {
182 let c_vec = VectorArray::try_from_iter(arrays.iter())?;
183 Vec::<Array>::try_from_op(|res| unsafe {
184 mlx_sys::mlx_broadcast_arrays(res, c_vec.as_ptr(), stream.as_ref().as_ptr())
185 })
186}
187
188#[generate_macro]
199#[default_device]
200pub fn as_strided_device<'a>(
201 a: impl AsRef<Array>,
202 #[optional] shape: impl IntoOption<&'a [i32]>,
203 #[optional] strides: impl IntoOption<&'a [i64]>,
204 #[optional] offset: impl Into<Option<usize>>,
205 #[optional] stream: impl AsRef<Stream>,
206) -> Result<Array> {
207 let a = a.as_ref();
208 let shape = shape.into_option().unwrap_or(a.shape());
209 let resolved_strides = resolve_strides(shape, strides.into_option());
210 let offset = offset.into().unwrap_or(0);
211
212 Array::try_from_op(|res| unsafe {
213 mlx_sys::mlx_as_strided(
214 res,
215 a.as_ptr(),
216 shape.as_ptr(),
217 shape.len(),
218 resolved_strides.as_ptr(),
219 resolved_strides.len(),
220 offset,
221 stream.as_ref().as_ptr(),
222 )
223 })
224}
225
226#[generate_macro]
242#[default_device]
243pub fn broadcast_to_device(
244 a: impl AsRef<Array>,
245 shape: &[i32],
246 #[optional] stream: impl AsRef<Stream>,
247) -> Result<Array> {
248 Array::try_from_op(|res| unsafe {
249 mlx_sys::mlx_broadcast_to(
250 res,
251 a.as_ref().as_ptr(),
252 shape.as_ptr(),
253 shape.len(),
254 stream.as_ref().as_ptr(),
255 )
256 })
257}
258
259#[generate_macro]
276#[default_device]
277pub fn concatenate_axis_device(
278 arrays: &[impl AsRef<Array>],
279 axis: i32,
280 #[optional] stream: impl AsRef<Stream>,
281) -> Result<Array> {
282 let c_arrays = VectorArray::try_from_iter(arrays.iter())?;
283 Array::try_from_op(|res| unsafe {
284 mlx_sys::mlx_concatenate_axis(res, c_arrays.as_ptr(), axis, stream.as_ref().as_ptr())
285 })
286}
287
288#[generate_macro]
290#[default_device]
291pub fn concatenate_device(
292 arrays: &[impl AsRef<Array>],
293 #[optional] stream: impl AsRef<Stream>,
294) -> Result<Array> {
295 let c_arrays = VectorArray::try_from_iter(arrays.iter())?;
296 Array::try_from_op(|res| unsafe {
297 mlx_sys::mlx_concatenate(res, c_arrays.as_ptr(), stream.as_ref().as_ptr())
298 })
299}
300
301#[generate_macro]
317#[default_device]
318pub fn expand_dims_axes_device(
319 a: impl AsRef<Array>,
320 axes: &[i32],
321 #[optional] stream: impl AsRef<Stream>,
322) -> Result<Array> {
323 Array::try_from_op(|res| unsafe {
324 mlx_sys::mlx_expand_dims_axes(
325 res,
326 a.as_ref().as_ptr(),
327 axes.as_ptr(),
328 axes.len(),
329 stream.as_ref().as_ptr(),
330 )
331 })
332}
333
334#[generate_macro]
336#[default_device]
337pub fn expand_dims_device(
338 a: impl AsRef<Array>,
339 axis: i32,
340 #[optional] stream: impl AsRef<Stream>,
341) -> Result<Array> {
342 Array::try_from_op(|res| unsafe {
343 mlx_sys::mlx_expand_dims(res, a.as_ref().as_ptr(), axis, stream.as_ref().as_ptr())
344 })
345}
346
347#[generate_macro]
368#[default_device]
369pub fn flatten_device(
370 a: impl AsRef<Array>,
371 #[optional] start_axis: impl Into<Option<i32>>,
372 #[optional] end_axis: impl Into<Option<i32>>,
373 #[optional] stream: impl AsRef<Stream>,
374) -> Result<Array> {
375 let start_axis = start_axis.into().unwrap_or(0);
376 let end_axis = end_axis.into().unwrap_or(-1);
377
378 Array::try_from_op(|res| unsafe {
379 mlx_sys::mlx_flatten(
380 res,
381 a.as_ref().as_ptr(),
382 start_axis,
383 end_axis,
384 stream.as_ref().as_ptr(),
385 )
386 })
387}
388
389#[generate_macro]
397#[default_device]
398pub fn unflatten_device(
399 a: impl AsRef<Array>,
400 axis: i32,
401 shape: &[i32],
402 #[optional] stream: impl AsRef<Stream>,
403) -> Result<Array> {
404 Array::try_from_op(|res| unsafe {
405 mlx_sys::mlx_unflatten(
406 res,
407 a.as_ref().as_ptr(),
408 axis,
409 shape.as_ptr(),
410 shape.len(),
411 stream.as_ref().as_ptr(),
412 )
413 })
414}
415
416#[generate_macro]
432#[default_device]
433pub fn reshape_device(
434 a: impl AsRef<Array>,
435 shape: &[i32],
436 #[optional] stream: impl AsRef<Stream>,
437) -> Result<Array> {
438 Array::try_from_op(|res| unsafe {
439 mlx_sys::mlx_reshape(
440 res,
441 a.as_ref().as_ptr(),
442 shape.as_ptr(),
443 shape.len(),
444 stream.as_ref().as_ptr(),
445 )
446 })
447}
448
449#[generate_macro]
465#[default_device]
466pub fn squeeze_axes_device(
467 a: impl AsRef<Array>,
468 axes: &[i32],
469 #[optional] stream: impl AsRef<Stream>,
470) -> Result<Array> {
471 let a = a.as_ref();
472 Array::try_from_op(|res| unsafe {
473 mlx_sys::mlx_squeeze_axes(
474 res,
475 a.as_ptr(),
476 axes.as_ptr(),
477 axes.len(),
478 stream.as_ref().as_ptr(),
479 )
480 })
481}
482
483#[generate_macro]
485#[default_device]
486pub fn squeeze_device(
487 a: impl AsRef<Array>,
488 #[optional] stream: impl AsRef<Stream>,
489) -> Result<Array> {
490 let a = a.as_ref();
491 Array::try_from_op(|res| unsafe {
492 mlx_sys::mlx_squeeze(res, a.as_ptr(), stream.as_ref().as_ptr())
493 })
494}
495
496#[generate_macro]
511#[default_device]
512pub fn at_least_1d_device(
513 a: impl AsRef<Array>,
514 #[optional] stream: impl AsRef<Stream>,
515) -> Result<Array> {
516 Array::try_from_op(|res| unsafe {
517 mlx_sys::mlx_atleast_1d(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
518 })
519}
520
521#[generate_macro]
536#[default_device]
537pub fn at_least_2d_device(
538 a: impl AsRef<Array>,
539 #[optional] stream: impl AsRef<Stream>,
540) -> Result<Array> {
541 Array::try_from_op(|res| unsafe {
542 mlx_sys::mlx_atleast_2d(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
543 })
544}
545
546#[generate_macro]
561#[default_device]
562pub fn at_least_3d_device(
563 a: impl AsRef<Array>,
564 #[optional] stream: impl AsRef<Stream>,
565) -> Result<Array> {
566 Array::try_from_op(|res| unsafe {
567 mlx_sys::mlx_atleast_3d(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
568 })
569}
570
571#[generate_macro]
588#[default_device]
589pub fn move_axis_device(
590 a: impl AsRef<Array>,
591 src: i32,
592 dst: i32,
593 #[optional] stream: impl AsRef<Stream>,
594) -> Result<Array> {
595 Array::try_from_op(|res| unsafe {
596 mlx_sys::mlx_moveaxis(res, a.as_ref().as_ptr(), src, dst, stream.as_ref().as_ptr())
597 })
598}
599
600#[generate_macro]
617#[default_device]
618pub fn split_sections_device(
619 a: impl AsRef<Array>,
620 indices: &[i32],
621 #[optional] axis: impl Into<Option<i32>>,
622 #[optional] stream: impl AsRef<Stream>,
623) -> Result<Vec<Array>> {
624 let axis = axis.into().unwrap_or(0);
625 Vec::<Array>::try_from_op(|res| unsafe {
626 mlx_sys::mlx_split_sections(
627 res,
628 a.as_ref().as_ptr(),
629 indices.as_ptr(),
630 indices.len(),
631 axis,
632 stream.as_ref().as_ptr(),
633 )
634 })
635}
636
637#[generate_macro]
655#[default_device]
656pub fn split_device(
657 a: impl AsRef<Array>,
658 num_parts: i32,
659 #[optional] axis: impl Into<Option<i32>>,
660 #[optional] stream: impl AsRef<Stream>,
661) -> Result<Vec<Array>> {
662 let axis = axis.into().unwrap_or(0);
663 Vec::<Array>::try_from_op(|res| unsafe {
664 mlx_sys::mlx_split(
665 res,
666 a.as_ref().as_ptr(),
667 num_parts,
668 axis,
669 stream.as_ref().as_ptr(),
670 )
671 })
672}
673
674#[derive(Debug)]
676pub enum PadWidth<'a> {
677 Same((i32, i32)),
679
680 Widths(&'a [(i32, i32)]),
682}
683
684impl From<i32> for PadWidth<'_> {
685 fn from(width: i32) -> Self {
686 PadWidth::Same((width, width))
687 }
688}
689
690impl From<(i32, i32)> for PadWidth<'_> {
691 fn from(width: (i32, i32)) -> Self {
692 PadWidth::Same(width)
693 }
694}
695
696impl<'a> From<&'a [(i32, i32)]> for PadWidth<'a> {
697 fn from(widths: &'a [(i32, i32)]) -> Self {
698 PadWidth::Widths(widths)
699 }
700}
701
702impl<'a, const N: usize> From<&'a [(i32, i32); N]> for PadWidth<'a> {
703 fn from(widths: &'a [(i32, i32); N]) -> Self {
704 PadWidth::Widths(widths)
705 }
706}
707
708impl PadWidth<'_> {
709 fn low_pads(&self, ndim: usize) -> SmallVec<[i32; DEFAULT_STACK_VEC_LEN]> {
710 match self {
711 PadWidth::Same((low, _high)) => (0..ndim).map(|_| *low).collect(),
712 PadWidth::Widths(widths) => widths.iter().map(|(low, _high)| *low).collect(),
713 }
714 }
715
716 fn high_pads(&self, ndim: usize) -> SmallVec<[i32; DEFAULT_STACK_VEC_LEN]> {
717 match self {
718 PadWidth::Same((_low, high)) => (0..ndim).map(|_| *high).collect(),
719 PadWidth::Widths(widths) => widths.iter().map(|(_low, high)| *high).collect(),
720 }
721 }
722}
723
724#[derive(Debug)]
726pub enum PadMode {
727 Constant,
729
730 Edge,
732}
733
734impl PadMode {
735 unsafe fn as_c_str(&self) -> *const i8 {
736 static CONSTANT: &[u8] = b"constant\0";
737 static EDGE: &[u8] = b"edge\0";
738
739 match self {
740 PadMode::Constant => CONSTANT.as_ptr() as *const _,
741 PadMode::Edge => EDGE.as_ptr() as *const _,
742 }
743 }
744}
745
746#[generate_macro]
767#[default_device]
768pub fn pad_device<'a>(
769 a: impl AsRef<Array>,
770 #[optional] width: impl Into<PadWidth<'a>>,
771 #[optional] value: impl Into<Option<Array>>,
772 #[optional] mode: impl Into<Option<PadMode>>,
773 #[optional] stream: impl AsRef<Stream>,
774) -> Result<Array> {
775 let a = a.as_ref();
776 let width = width.into();
777 let ndim = a.ndim();
778 let axes: SmallVec<[i32; DEFAULT_STACK_VEC_LEN]> = (0..ndim).map(|i| i as i32).collect();
779 let low_pads = width.low_pads(ndim);
780 let high_pads = width.high_pads(ndim);
781 let value = value
782 .into()
783 .map(Ok)
784 .unwrap_or_else(|| Array::from_int(0).as_dtype(a.dtype()))?;
785 let mode = mode.into().unwrap_or(PadMode::Constant);
786
787 Array::try_from_op(|res| unsafe {
788 mlx_sys::mlx_pad(
789 res,
790 a.as_ptr(),
791 axes.as_ptr(),
792 axes.len(),
793 low_pads.as_ptr(),
794 low_pads.len(),
795 high_pads.as_ptr(),
796 high_pads.len(),
797 value.as_ptr(),
798 mode.as_c_str(),
799 stream.as_ref().as_ptr(),
800 )
801 })
802}
803
804#[generate_macro]
821#[default_device]
822pub fn stack_axis_device(
823 arrays: &[impl AsRef<Array>],
824 axis: i32,
825 #[optional] stream: impl AsRef<Stream>,
826) -> Result<Array> {
827 let c_vec = VectorArray::try_from_iter(arrays.iter())?;
828 Array::try_from_op(|res| unsafe {
829 mlx_sys::mlx_stack_axis(res, c_vec.as_ptr(), axis, stream.as_ref().as_ptr())
830 })
831}
832
833#[generate_macro]
849#[default_device]
850pub fn stack_device(
851 arrays: &[impl AsRef<Array>],
852 #[optional] stream: impl AsRef<Stream>,
853) -> Result<Array> {
854 let c_vec = VectorArray::try_from_iter(arrays.iter())?;
855 Array::try_from_op(|res| unsafe {
856 mlx_sys::mlx_stack(res, c_vec.as_ptr(), stream.as_ref().as_ptr())
857 })
858}
859
860#[generate_macro]
877#[default_device]
878pub fn swap_axes_device(
879 a: impl AsRef<Array>,
880 axis1: i32,
881 axis2: i32,
882 #[optional] stream: impl AsRef<Stream>,
883) -> Result<Array> {
884 Array::try_from_op(|res| unsafe {
885 mlx_sys::mlx_swapaxes(
886 res,
887 a.as_ref().as_ptr(),
888 axis1,
889 axis2,
890 stream.as_ref().as_ptr(),
891 )
892 })
893}
894
895#[generate_macro]
911#[default_device]
912pub fn tile_device(
913 a: impl AsRef<Array>,
914 reps: &[i32],
915 #[optional] stream: impl AsRef<Stream>,
916) -> Result<Array> {
917 Array::try_from_op(|res| unsafe {
918 mlx_sys::mlx_tile(
919 res,
920 a.as_ref().as_ptr(),
921 reps.as_ptr(),
922 reps.len(),
923 stream.as_ref().as_ptr(),
924 )
925 })
926}
927
928#[generate_macro]
950#[default_device]
951pub fn transpose_axes_device(
952 a: impl AsRef<Array>,
953 axes: &[i32],
954 #[optional] stream: impl AsRef<Stream>,
955) -> Result<Array> {
956 Array::try_from_op(|res| unsafe {
957 mlx_sys::mlx_transpose_axes(
958 res,
959 a.as_ref().as_ptr(),
960 axes.as_ptr(),
961 axes.len(),
962 stream.as_ref().as_ptr(),
963 )
964 })
965}
966
967#[generate_macro]
969#[default_device]
970pub fn transpose_device(
971 a: impl AsRef<Array>,
972 #[optional] stream: impl AsRef<Stream>,
973) -> Result<Array> {
974 Array::try_from_op(|res| unsafe {
975 mlx_sys::mlx_transpose(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
976 })
977}
978
979#[cfg(test)]
982mod tests {
983 use crate::{array, Array, Dtype};
984
985 use super::*;
986
987 #[test]
988 fn test_squeeze() {
989 let a = Array::zeros::<i32>(&[2, 1, 2, 1, 2, 1]).unwrap();
990 assert_eq!(
991 squeeze_axes(&a, &[1, 3, 5][..]).unwrap().shape(),
992 &[2, 2, 2]
993 );
994 assert_eq!(
995 squeeze_axes(&a, &[-1, -3, -5][..]).unwrap().shape(),
996 &[2, 2, 2]
997 );
998 assert_eq!(
999 squeeze_axes(&a, &[1][..]).unwrap().shape(),
1000 &[2, 2, 1, 2, 1]
1001 );
1002 assert_eq!(
1003 squeeze_axes(&a, &[-1][..]).unwrap().shape(),
1004 &[2, 1, 2, 1, 2]
1005 );
1006
1007 assert!(squeeze_axes(&a, &[0][..]).is_err());
1008 assert!(squeeze_axes(&a, &[2][..]).is_err());
1009 assert!(squeeze_axes(&a, &[1, 3, 1][..]).is_err());
1010 assert!(squeeze_axes(&a, &[1, 3, -3][..]).is_err());
1011 }
1012
1013 #[test]
1014 fn test_expand_dims() {
1015 let a = Array::zeros::<i32>(&[2, 2]).unwrap();
1016 assert_eq!(expand_dims_axes(&a, &[0][..]).unwrap().shape(), &[1, 2, 2]);
1017 assert_eq!(expand_dims_axes(&a, &[-1][..]).unwrap().shape(), &[2, 2, 1]);
1018 assert_eq!(expand_dims_axes(&a, &[1][..]).unwrap().shape(), &[2, 1, 2]);
1019 assert_eq!(
1020 expand_dims_axes(&a, &[0, 1, 2]).unwrap().shape(),
1021 &[1, 1, 1, 2, 2]
1022 );
1023 assert_eq!(
1024 expand_dims_axes(&a, &[0, 1, 2, 5, 6, 7]).unwrap().shape(),
1025 &[1, 1, 1, 2, 2, 1, 1, 1]
1026 );
1027
1028 assert!(expand_dims_axes(&a, &[3]).is_err());
1029 assert!(expand_dims_axes(&a, &[0, 1, 0]).is_err());
1030 assert!(expand_dims_axes(&a, &[0, 1, -4]).is_err());
1031 }
1032
1033 #[test]
1034 fn test_flatten() {
1035 let x = Array::zeros::<i32>(&[2, 3, 4]).unwrap();
1036 assert_eq!(flatten(&x, None, None).unwrap().shape(), &[2 * 3 * 4]);
1037
1038 assert_eq!(flatten(&x, 1, 1).unwrap().shape(), &[2, 3, 4]);
1039 assert_eq!(flatten(&x, 1, 2).unwrap().shape(), &[2, 3 * 4]);
1040 assert_eq!(flatten(&x, 1, 3).unwrap().shape(), &[2, 3 * 4]);
1041 assert_eq!(flatten(&x, 1, -1).unwrap().shape(), &[2, 3 * 4]);
1042 assert_eq!(flatten(&x, -2, -1).unwrap().shape(), &[2, 3 * 4]);
1043 assert_eq!(flatten(&x, -3, -1).unwrap().shape(), &[2 * 3 * 4]);
1044 assert_eq!(flatten(&x, -4, -1).unwrap().shape(), &[2 * 3 * 4]);
1045
1046 assert!(flatten(&x, 2, 1).is_err());
1047
1048 assert!(flatten(&x, 5, 6).is_err());
1049
1050 assert!(flatten(&x, -5, -4).is_err());
1051
1052 let x = Array::from_int(1);
1053 assert_eq!(flatten(&x, -3, -1).unwrap().shape(), &[1]);
1054 assert_eq!(flatten(&x, 0, 0).unwrap().shape(), &[1]);
1055 }
1056
1057 #[test]
1058 fn test_unflatten() {
1059 let a = array!([1, 2, 3, 4]);
1060 let b = unflatten(&a, 0, &[2, -1]).unwrap();
1061 let expected = array!([[1, 2], [3, 4]]);
1062 assert_eq!(b, expected);
1063 }
1064
1065 #[test]
1066 fn test_reshape() {
1067 let x = Array::from_int(1);
1068 assert!(reshape(&x, &[]).unwrap().shape().is_empty());
1069 assert!(reshape(&x, &[2]).is_err());
1070 let y = reshape(&x, &[1, 1, 1]).unwrap();
1071 assert_eq!(y.shape(), &[1, 1, 1]);
1072 let y = reshape(&x, &[-1, 1, 1]).unwrap();
1073 assert_eq!(y.shape(), &[1, 1, 1]);
1074 let y = reshape(&x, &[1, 1, -1]).unwrap();
1075 assert_eq!(y.shape(), &[1, 1, 1]);
1076 assert!(reshape(&x, &[1, -1, -1]).is_err());
1077 assert!(reshape(&x, &[2, -1]).is_err());
1078
1079 let x = Array::zeros::<i32>(&[2, 2, 2]).unwrap();
1080 let y = reshape(&x, &[8]).unwrap();
1081 assert_eq!(y.shape(), &[8]);
1082 assert!(reshape(&x, &[7]).is_err());
1083 let y = reshape(&x, &[-1]).unwrap();
1084 assert_eq!(y.shape(), &[8]);
1085 let y = reshape(&x, &[-1, 2]).unwrap();
1086 assert_eq!(y.shape(), &[4, 2]);
1087 assert!(reshape(&x, &[-1, 7]).is_err());
1088
1089 let x = Array::from_slice::<i32>(&[], &[0]);
1090 let y = reshape(&x, &[0, 0, 0]).unwrap();
1091 assert_eq!(y.shape(), &[0, 0, 0]);
1092 y.eval().unwrap();
1093 assert_eq!(y.size(), 0);
1094 assert!(reshape(&x, &[]).is_err());
1095 assert!(reshape(&x, &[1]).is_err());
1096 let y = reshape(&x, &[1, 5, 0]).unwrap();
1097 assert_eq!(y.shape(), &[1, 5, 0]);
1098 }
1099
1100 #[test]
1101 fn test_as_strided() {
1102 let x = Array::from_iter(0..10, &[10]);
1103 let y = as_strided(&x, &[3, 3][..], &[1, 1][..], 0).unwrap();
1104 let expected = Array::from_slice(&[0, 1, 2, 1, 2, 3, 2, 3, 4], &[3, 3]);
1105 assert_eq!(y, expected);
1106
1107 let y = as_strided(&x, &[3, 3][..], &[0, 3][..], 0).unwrap();
1108 let expected = Array::from_slice(&[0, 3, 6, 0, 3, 6, 0, 3, 6], &[3, 3]);
1109 assert_eq!(y, expected);
1110
1111 let x = x.reshape(&[2, 5]).unwrap();
1112 let x = x.transpose_axes(&[1, 0][..]).unwrap();
1113 let y = as_strided(&x, &[3, 3][..], &[2, 1][..], 1).unwrap();
1114 let expected = Array::from_slice(&[5, 1, 6, 6, 2, 7, 7, 3, 8], &[3, 3]);
1115 assert_eq!(y, expected);
1116 }
1117
1118 #[test]
1119 fn test_at_least_1d() {
1120 let x = Array::from_int(1);
1121 let out = at_least_1d(&x).unwrap();
1122 assert_eq!(out.ndim(), 1);
1123 assert_eq!(out.shape(), &[1]);
1124
1125 let x = Array::from_slice(&[1, 2, 3], &[3]);
1126 let out = at_least_1d(&x).unwrap();
1127 assert_eq!(out.ndim(), 1);
1128 assert_eq!(out.shape(), &[3]);
1129
1130 let x = Array::from_slice(&[1, 2, 3], &[3, 1]);
1131 let out = at_least_1d(&x).unwrap();
1132 assert_eq!(out.ndim(), 2);
1133 assert_eq!(out.shape(), &[3, 1]);
1134 }
1135
1136 #[test]
1137 fn test_at_least_2d() {
1138 let x = Array::from_int(1);
1139 let out = at_least_2d(&x).unwrap();
1140 assert_eq!(out.ndim(), 2);
1141 assert_eq!(out.shape(), &[1, 1]);
1142
1143 let x = Array::from_slice(&[1, 2, 3], &[3]);
1144 let out = at_least_2d(&x).unwrap();
1145 assert_eq!(out.ndim(), 2);
1146 assert_eq!(out.shape(), &[1, 3]);
1147
1148 let x = Array::from_slice(&[1, 2, 3], &[3, 1]);
1149 let out = at_least_2d(&x).unwrap();
1150 assert_eq!(out.ndim(), 2);
1151 assert_eq!(out.shape(), &[3, 1]);
1152 }
1153
1154 #[test]
1155 fn test_at_least_3d() {
1156 let x = Array::from_int(1);
1157 let out = at_least_3d(&x).unwrap();
1158 assert_eq!(out.ndim(), 3);
1159 assert_eq!(out.shape(), &[1, 1, 1]);
1160
1161 let x = Array::from_slice(&[1, 2, 3], &[3]);
1162 let out = at_least_3d(&x).unwrap();
1163 assert_eq!(out.ndim(), 3);
1164 assert_eq!(out.shape(), &[1, 3, 1]);
1165
1166 let x = Array::from_slice(&[1, 2, 3], &[3, 1]);
1167 let out = at_least_3d(&x).unwrap();
1168 assert_eq!(out.ndim(), 3);
1169 assert_eq!(out.shape(), &[3, 1, 1]);
1170 }
1171
1172 #[test]
1173 fn test_move_axis() {
1174 let a = Array::from_int(0);
1175 assert!(move_axis(&a, 0, 0).is_err());
1176
1177 let a = Array::zeros::<i32>(&[2]).unwrap();
1178 assert!(move_axis(&a, 0, 1).is_err());
1179 assert_eq!(move_axis(&a, 0, 0).unwrap().shape(), &[2]);
1180 assert_eq!(move_axis(&a, -1, -1).unwrap().shape(), &[2]);
1181
1182 let a = Array::zeros::<i32>(&[2, 3, 4]).unwrap();
1183 assert!(move_axis(&a, 0, -4).is_err());
1184 assert!(move_axis(&a, 0, 3).is_err());
1185 assert!(move_axis(&a, 3, 0).is_err());
1186 assert!(move_axis(&a, -4, 0).is_err());
1187 assert_eq!(move_axis(&a, 0, 2).unwrap().shape(), &[3, 4, 2]);
1188 assert_eq!(move_axis(&a, 0, 1).unwrap().shape(), &[3, 2, 4]);
1189 assert_eq!(move_axis(&a, 0, -1).unwrap().shape(), &[3, 4, 2]);
1190 assert_eq!(move_axis(&a, -2, 2).unwrap().shape(), &[2, 4, 3]);
1191 }
1192
1193 #[test]
1194 fn test_split_equal() {
1195 let x = Array::from_int(3);
1196 assert!(split(&x, 0, 0).is_err());
1197
1198 let x = Array::from_slice(&[0, 1, 2], &[3]);
1199 assert!(split(&x, 3, 1).is_err());
1200 assert!(split(&x, -2, 1).is_err());
1201
1202 let out = split(&x, 3, 0).unwrap();
1203 assert_eq!(out.len(), 3);
1204
1205 let mut out = split(&x, 3, -1).unwrap();
1206 assert_eq!(out.len(), 3);
1207 for (i, a) in out.iter_mut().enumerate() {
1208 assert_eq!(a.shape(), &[1]);
1209 assert_eq!(a.dtype(), Dtype::Int32);
1210 assert_eq!(a.item::<i32>(), i as i32);
1211 }
1212
1213 let x = Array::from_slice(&[0, 1, 2, 3, 4, 5], &[2, 3]);
1214 let out = split(&x, 2, None).unwrap();
1215 assert_eq!(out[0], Array::from_slice(&[0, 1, 2], &[1, 3]));
1216 assert_eq!(out[1], Array::from_slice(&[3, 4, 5], &[1, 3]));
1217
1218 let out = split(&x, 3, 1).unwrap();
1219 assert_eq!(out[0], Array::from_slice(&[0, 3], &[2, 1]));
1220 assert_eq!(out[1], Array::from_slice(&[1, 4], &[2, 1]));
1221 assert_eq!(out[2], Array::from_slice(&[2, 5], &[2, 1]));
1222
1223 let x = Array::zeros::<i32>(&[8, 12]).unwrap();
1224 let out = split(&x, 2, None).unwrap();
1225 assert_eq!(out.len(), 2);
1226 assert_eq!(out[0].shape(), &[4, 12]);
1227 assert_eq!(out[1].shape(), &[4, 12]);
1228
1229 let out = split(&x, 3, 1).unwrap();
1230 assert_eq!(out.len(), 3);
1231 assert_eq!(out[0].shape(), &[8, 4]);
1232 assert_eq!(out[1].shape(), &[8, 4]);
1233 assert_eq!(out[2].shape(), &[8, 4]);
1234 }
1235
1236 #[test]
1237 fn test_split() {
1238 let x = Array::zeros::<i32>(&[8, 12]).unwrap();
1239
1240 let out = split_sections(&x, &[], None).unwrap();
1241 assert_eq!(out.len(), 1);
1242 assert_eq!(out[0].shape(), x.shape());
1243
1244 let out = split_sections(&x, &[3, 7], None).unwrap();
1245 assert_eq!(out.len(), 3);
1246 assert_eq!(out[0].shape(), &[3, 12]);
1247 assert_eq!(out[1].shape(), &[4, 12]);
1248 assert_eq!(out[2].shape(), &[1, 12]);
1249
1250 let out = split_sections(&x, &[20], None).unwrap();
1251 assert_eq!(out.len(), 2);
1252 assert_eq!(out[0].shape(), &[8, 12]);
1253 assert_eq!(out[1].shape(), &[0, 12]);
1254
1255 let out = split_sections(&x, &[-5], None).unwrap();
1256 assert_eq!(out[0].shape(), &[3, 12]);
1257 assert_eq!(out[1].shape(), &[5, 12]);
1258
1259 let out = split_sections(&x, &[2, 8], Some(1)).unwrap();
1260 assert_eq!(out[0].shape(), &[8, 2]);
1261 assert_eq!(out[1].shape(), &[8, 6]);
1262 assert_eq!(out[2].shape(), &[8, 4]);
1263
1264 let x = Array::from_iter(0i32..5, &[5]);
1265 let out = split_sections(&x, &[2, 1, 2], None).unwrap();
1266 assert_eq!(out[0], Array::from_slice(&[0, 1], &[2]));
1267 assert_eq!(out[1], Array::from_slice::<i32>(&[], &[0]));
1268 assert_eq!(out[2], Array::from_slice(&[1], &[1]));
1269 assert_eq!(out[3], Array::from_slice(&[2, 3, 4], &[3]));
1270 }
1271
1272 #[test]
1273 fn test_pad() {
1274 let x = Array::zeros::<f32>(&[1, 2, 3]).unwrap();
1275 assert_eq!(pad(&x, 1, None, None).unwrap().shape(), &[3, 4, 5]);
1276 assert_eq!(pad(&x, (0, 1), None, None).unwrap().shape(), &[2, 3, 4]);
1277 assert_eq!(
1278 pad(&x, &[(1, 1), (1, 2), (3, 1)], None, None)
1279 .unwrap()
1280 .shape(),
1281 &[3, 5, 7]
1282 );
1283 }
1284
1285 #[test]
1286 fn test_stack() {
1287 let x = Array::from_slice::<f32>(&[], &[0]);
1288 let x = vec![x];
1289 assert_eq!(stack_axis(&x, 0).unwrap().shape(), &[1, 0]);
1290 assert_eq!(stack_axis(&x, 1).unwrap().shape(), &[0, 1]);
1291
1292 let x = Array::from_slice(&[1, 2, 3], &[3]);
1293 let x = vec![x];
1294 assert_eq!(stack_axis(&x, 0).unwrap().shape(), &[1, 3]);
1295 assert_eq!(stack_axis(&x, 1).unwrap().shape(), &[3, 1]);
1296
1297 let y = Array::from_slice(&[4, 5, 6], &[3]);
1298 let mut z = x;
1299 z.push(y);
1300 assert_eq!(stack(&z).unwrap().shape(), &[2, 3]);
1301 assert_eq!(stack_axis(&z, 1).unwrap().shape(), &[3, 2]);
1302 assert_eq!(stack_axis(&z, -1).unwrap().shape(), &[3, 2]);
1303 assert_eq!(stack_axis(&z, -2).unwrap().shape(), &[2, 3]);
1304
1305 let empty: Vec<Array> = Vec::new();
1306 assert!(stack_axis(&empty, 0).is_err());
1307
1308 let x = Array::from_slice(&[1, 2, 3], &[3])
1309 .as_dtype(Dtype::Float16)
1310 .unwrap();
1311 let y = Array::from_slice(&[4, 5, 6], &[3])
1312 .as_dtype(Dtype::Int32)
1313 .unwrap();
1314 assert_eq!(stack_axis(&[x, y], 0).unwrap().dtype(), Dtype::Float16);
1315
1316 let x = Array::from_slice(&[1, 2, 3], &[3])
1317 .as_dtype(Dtype::Int32)
1318 .unwrap();
1319 let y = Array::from_slice(&[4, 5, 6, 7], &[4])
1320 .as_dtype(Dtype::Int32)
1321 .unwrap();
1322 assert!(stack_axis(&[x, y], 0).is_err());
1323 }
1324
1325 #[test]
1326 fn test_swap_axes() {
1327 let a = Array::from_int(0);
1328 assert!(swap_axes(&a, 0, 0).is_err());
1329
1330 let a = Array::zeros::<i32>(&[2]).unwrap();
1331 assert!(swap_axes(&a, 0, 1).is_err());
1332 assert_eq!(swap_axes(&a, 0, 0).unwrap().shape(), &[2]);
1333 assert_eq!(swap_axes(&a, -1, -1).unwrap().shape(), &[2]);
1334
1335 let a = Array::zeros::<i32>(&[2, 3, 4]).unwrap();
1336 assert!(swap_axes(&a, 0, -4).is_err());
1337 assert!(swap_axes(&a, 0, 3).is_err());
1338 assert!(swap_axes(&a, 3, 0).is_err());
1339 assert!(swap_axes(&a, -4, 0).is_err());
1340 assert_eq!(swap_axes(&a, 0, 2).unwrap().shape(), &[4, 3, 2]);
1341 assert_eq!(swap_axes(&a, 0, 1).unwrap().shape(), &[3, 2, 4]);
1342 assert_eq!(swap_axes(&a, 0, -1).unwrap().shape(), &[4, 3, 2]);
1343 assert_eq!(swap_axes(&a, -2, 2).unwrap().shape(), &[2, 4, 3]);
1344 }
1345
1346 #[test]
1347 fn test_tile() {
1348 let x = Array::from_slice(&[1, 2, 3], &[3]);
1349 let y = tile(&x, &[2]).unwrap();
1350 let expected = Array::from_slice(&[1, 2, 3, 1, 2, 3], &[6]);
1351 assert_eq!(y, expected);
1352
1353 let x = Array::from_slice(&[1, 2, 3, 4], &[2, 2]);
1354 let y = tile(&x, &[2]).unwrap();
1355 let expected = Array::from_slice(&[1, 2, 1, 2, 3, 4, 3, 4], &[2, 4]);
1356 assert_eq!(y, expected);
1357
1358 let x = Array::from_slice(&[1, 2, 3, 4], &[2, 2]);
1359 let y = tile(&x, &[4, 1]).unwrap();
1360 let expected =
1361 Array::from_slice(&[1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4], &[8, 2]);
1362 assert_eq!(y, expected);
1363
1364 let x = Array::from_slice(&[1, 2, 3, 4], &[2, 2]);
1365 let y = tile(&x, &[2, 2]).unwrap();
1366 let expected =
1367 Array::from_slice(&[1, 2, 1, 2, 3, 4, 3, 4, 1, 2, 1, 2, 3, 4, 3, 4], &[4, 4]);
1368 assert_eq!(y, expected);
1369
1370 let x = Array::from_slice(&[1, 2, 3], &[3]);
1371 let y = tile(&x, &[2, 2, 2]).unwrap();
1372 let expected = Array::from_slice(
1373 &[
1374 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3,
1375 ],
1376 &[2, 2, 6],
1377 );
1378 assert_eq!(y, expected);
1379 }
1380
1381 #[test]
1382 fn test_transpose() {
1383 let x = Array::from_int(1);
1384 let y = transpose(&x).unwrap();
1385 assert!(y.shape().is_empty());
1386 assert_eq!(y.item::<i32>(), 1);
1387 assert!(transpose_axes(&x, &[0][..]).is_err());
1388 assert!(transpose_axes(&x, &[1][..]).is_err());
1389
1390 let x = Array::from_slice(&[1], &[1]);
1391 let y = transpose(&x).unwrap();
1392 assert_eq!(y.shape(), &[1]);
1393 assert_eq!(y.item::<i32>(), 1);
1394
1395 let y = transpose_axes(&x, &[-1][..]).unwrap();
1396 assert_eq!(y.shape(), &[1]);
1397 assert_eq!(y.item::<i32>(), 1);
1398
1399 assert!(transpose_axes(&x, &[1][..]).is_err());
1400 assert!(transpose_axes(&x, &[0, 0][..]).is_err());
1401
1402 let x = Array::from_slice::<i32>(&[], &[0]);
1403 let y = transpose(&x).unwrap();
1404 assert_eq!(y.shape(), &[0]);
1405 y.eval().unwrap();
1406 assert_eq!(y.size(), 0);
1407
1408 let x = Array::from_slice(&[1, 2, 3, 4, 5, 6], &[2, 3]);
1409 let mut y = transpose(&x).unwrap();
1410 assert_eq!(y.shape(), &[3, 2]);
1411 y = transpose_axes(&x, &[-1, 0][..]).unwrap();
1412 assert_eq!(y.shape(), &[3, 2]);
1413 y = transpose_axes(&x, &[-1, -2][..]).unwrap();
1414 assert_eq!(y.shape(), &[3, 2]);
1415 y.eval().unwrap();
1416 assert_eq!(y, Array::from_slice(&[1, 4, 2, 5, 3, 6], &[3, 2]));
1417
1418 let y = transpose_axes(&x, &[0, 1][..]).unwrap();
1419 assert_eq!(y.shape(), &[2, 3]);
1420 assert_eq!(y, x);
1421
1422 let y = transpose_axes(&x, &[0, -1][..]).unwrap();
1423 assert_eq!(y.shape(), &[2, 3]);
1424 assert_eq!(y, x);
1425
1426 assert!(transpose_axes(&x, &[][..]).is_err());
1427 assert!(transpose_axes(&x, &[0][..]).is_err());
1428 assert!(transpose_axes(&x, &[0, 0][..]).is_err());
1429 assert!(transpose_axes(&x, &[0, 0, 0][..]).is_err());
1430 assert!(transpose_axes(&x, &[0, 1, 1][..]).is_err());
1431
1432 let x = Array::from_slice(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], &[2, 3, 2]);
1433 let y = transpose(&x).unwrap();
1434 assert_eq!(y.shape(), &[2, 3, 2]);
1435 let expected = Array::from_slice(&[1, 7, 3, 9, 5, 11, 2, 8, 4, 10, 6, 12], &[2, 3, 2]);
1436 assert_eq!(y, expected);
1437
1438 let y = transpose_axes(&x, &[0, 1, 2][..]).unwrap();
1439 assert_eq!(y.shape(), &[2, 3, 2]);
1440 assert_eq!(y, x);
1441
1442 let y = transpose_axes(&x, &[1, 0, 2][..]).unwrap();
1443 assert_eq!(y.shape(), &[3, 2, 2]);
1444 let expected = Array::from_slice(&[1, 2, 7, 8, 3, 4, 9, 10, 5, 6, 11, 12], &[3, 2, 2]);
1445 assert_eq!(y, expected);
1446
1447 let y = transpose_axes(&x, &[0, 2, 1][..]).unwrap();
1448 assert_eq!(y.shape(), &[2, 2, 3]);
1449 let expected = Array::from_slice(&[1, 3, 5, 2, 4, 6, 7, 9, 11, 8, 10, 12], &[2, 2, 3]);
1450 assert_eq!(y, expected);
1451
1452 let mut x = Array::from_slice(&[0, 1, 2, 3, 4, 5, 6, 7], &[4, 2]);
1453 x = reshape(transpose(&x).unwrap(), &[2, 2, 2]).unwrap();
1454 let expected = Array::from_slice(&[0, 2, 4, 6, 1, 3, 5, 7], &[2, 2, 2]);
1455 assert_eq!(x, expected);
1456
1457 let mut x = Array::from_slice(&[0, 1, 2, 3, 4, 5, 6, 7], &[1, 4, 1, 2]);
1458 x = transpose_axes(&x, &[2, 1, 0, 3][..]).unwrap();
1460 x.eval().unwrap();
1461 }
1463}