mlx_rs/ops/
shapes.rs

1use mlx_internal_macros::{default_device, generate_macro};
2use smallvec::SmallVec;
3
4use crate::{
5    constants::DEFAULT_STACK_VEC_LEN,
6    error::Result,
7    utils::{guard::Guarded, IntoOption, VectorArray},
8    Array, Stream,
9};
10
11impl Array {
12    /// See [`expand_dims()`].
13    #[default_device]
14    pub fn expand_dims_device(&self, axis: i32, stream: impl AsRef<Stream>) -> Result<Array> {
15        expand_dims_device(self, axis, stream)
16    }
17
18    /// See [`expand_dims_axes()`].
19    #[default_device]
20    pub fn expand_dims_axes_device(
21        &self,
22        axes: &[i32],
23        stream: impl AsRef<Stream>,
24    ) -> Result<Array> {
25        expand_dims_axes_device(self, axes, stream)
26    }
27
28    /// See [`flatten`].
29    #[default_device]
30    pub fn flatten_device(
31        &self,
32        start_axis: impl Into<Option<i32>>,
33        end_axis: impl Into<Option<i32>>,
34        stream: impl AsRef<Stream>,
35    ) -> Result<Array> {
36        flatten_device(self, start_axis, end_axis, stream)
37    }
38
39    /// See [`reshape`].
40    #[default_device]
41    pub fn reshape_device(&self, shape: &[i32], stream: impl AsRef<Stream>) -> Result<Array> {
42        reshape_device(self, shape, stream)
43    }
44
45    /// See [`squeeze_axes()`].
46    #[default_device]
47    pub fn squeeze_axes_device(&self, axes: &[i32], stream: impl AsRef<Stream>) -> Result<Array> {
48        squeeze_axes_device(self, axes, stream)
49    }
50
51    /// See [`squeeze()`].
52    #[default_device]
53    pub fn squeeze_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
54        squeeze_device(self, stream)
55    }
56
57    /// See [`as_strided`]
58    #[default_device]
59    pub fn as_strided_device<'a>(
60        &'a self,
61        shape: impl IntoOption<&'a [i32]>,
62        strides: impl IntoOption<&'a [i64]>,
63        offset: impl Into<Option<usize>>,
64        stream: impl AsRef<Stream>,
65    ) -> Result<Array> {
66        as_strided_device(self, shape, strides, offset, stream)
67    }
68
69    /// See [`at_least_1d`]
70    #[default_device]
71    pub fn at_least_1d_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
72        at_least_1d_device(self, stream)
73    }
74
75    /// See [`at_least_2d`]
76    #[default_device]
77    pub fn at_least_2d_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
78        at_least_2d_device(self, stream)
79    }
80
81    /// See [`at_least_3d`]
82    #[default_device]
83    pub fn at_least_3d_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
84        at_least_3d_device(self, stream)
85    }
86
87    /// See [`move_axis`]
88    #[default_device]
89    pub fn move_axis_device(
90        &self,
91        src: i32,
92        dst: i32,
93        stream: impl AsRef<Stream>,
94    ) -> Result<Array> {
95        move_axis_device(self, src, dst, stream)
96    }
97
98    /// See [`split`]
99    #[default_device]
100    pub fn split_axis_device(
101        &self,
102        indices: &[i32],
103        axis: impl Into<Option<i32>>,
104        stream: impl AsRef<Stream>,
105    ) -> Result<Vec<Array>> {
106        split_sections_device(self, indices, axis, stream)
107    }
108
109    /// See [`split`]
110    #[default_device]
111    pub fn split_device(
112        &self,
113        num_parts: i32,
114        axis: impl Into<Option<i32>>,
115        stream: impl AsRef<Stream>,
116    ) -> Result<Vec<Array>> {
117        split_device(self, num_parts, axis, stream)
118    }
119
120    /// See [`swap_axes`]
121    #[default_device]
122    pub fn swap_axes_device(
123        &self,
124        axis1: i32,
125        axis2: i32,
126        stream: impl AsRef<Stream>,
127    ) -> Result<Array> {
128        swap_axes_device(self, axis1, axis2, stream)
129    }
130
131    /// See [`transpose_axes`]
132    #[default_device]
133    pub fn transpose_axes_device(&self, axes: &[i32], stream: impl AsRef<Stream>) -> Result<Array> {
134        transpose_axes_device(self, axes, stream)
135    }
136
137    /// See [`transpose`]
138    #[default_device]
139    pub fn transpose_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
140        transpose_device(self, stream)
141    }
142
143    /// [`transpose_axes`] and unwrap the result.
144    pub fn t(&self) -> Array {
145        self.transpose().unwrap()
146    }
147}
148
149fn resolve_strides(
150    shape: &[i32],
151    strides: Option<&[i64]>,
152) -> SmallVec<[i64; DEFAULT_STACK_VEC_LEN]> {
153    match strides {
154        Some(strides) => SmallVec::from_slice(strides),
155        None => {
156            let result = shape
157                .iter()
158                .rev()
159                .scan(1, |acc, &dim| {
160                    let result = *acc;
161                    *acc *= dim as i64;
162                    Some(result)
163                })
164                .collect::<SmallVec<[i64; DEFAULT_STACK_VEC_LEN]>>();
165            result.into_iter().rev().collect()
166        }
167    }
168}
169
170/// Broadcast a vector of arrays against one another. Returns an error if the shapes are
171/// broadcastable.
172///
173/// # Params
174///
175/// - `arrays`: The arrays to broadcast.
176#[generate_macro]
177#[default_device]
178pub fn broadcast_arrays_device(
179    arrays: &[impl AsRef<Array>],
180    #[optional] stream: impl AsRef<Stream>,
181) -> Result<Vec<Array>> {
182    let c_vec = VectorArray::try_from_iter(arrays.iter())?;
183    Vec::<Array>::try_from_op(|res| unsafe {
184        mlx_sys::mlx_broadcast_arrays(res, c_vec.as_ptr(), stream.as_ref().as_ptr())
185    })
186}
187
188/// Create a view into the array with the given shape and strides.
189///
190/// # Example
191///
192/// ```rust
193/// use mlx_rs::{Array, ops::*};
194///
195/// let x = Array::from_iter(0..10, &[10]);
196/// let y = as_strided(&x, &[3, 3], &[1, 1], 0);
197/// ```
198#[generate_macro]
199#[default_device]
200pub fn as_strided_device<'a>(
201    a: impl AsRef<Array>,
202    #[optional] shape: impl IntoOption<&'a [i32]>,
203    #[optional] strides: impl IntoOption<&'a [i64]>,
204    #[optional] offset: impl Into<Option<usize>>,
205    #[optional] stream: impl AsRef<Stream>,
206) -> Result<Array> {
207    let a = a.as_ref();
208    let shape = shape.into_option().unwrap_or(a.shape());
209    let resolved_strides = resolve_strides(shape, strides.into_option());
210    let offset = offset.into().unwrap_or(0);
211
212    Array::try_from_op(|res| unsafe {
213        mlx_sys::mlx_as_strided(
214            res,
215            a.as_ptr(),
216            shape.as_ptr(),
217            shape.len(),
218            resolved_strides.as_ptr(),
219            resolved_strides.len(),
220            offset,
221            stream.as_ref().as_ptr(),
222        )
223    })
224}
225
226/// Broadcast an array to the given shape. Returns an error if the shapes are not broadcastable.
227///
228/// # Params
229///
230/// - `a`: The input array.
231/// - `shape`: The shape to broadcast to.
232///
233/// # Example
234///
235/// ```rust
236/// use mlx_rs::{Array, ops::*};
237///
238/// let x = Array::from_f32(2.3);
239/// let result = broadcast_to(&x, &[1, 1]);
240/// ```
241#[generate_macro]
242#[default_device]
243pub fn broadcast_to_device(
244    a: impl AsRef<Array>,
245    shape: &[i32],
246    #[optional] stream: impl AsRef<Stream>,
247) -> Result<Array> {
248    Array::try_from_op(|res| unsafe {
249        mlx_sys::mlx_broadcast_to(
250            res,
251            a.as_ref().as_ptr(),
252            shape.as_ptr(),
253            shape.len(),
254            stream.as_ref().as_ptr(),
255        )
256    })
257}
258
259/// Concatenate the arrays along the given axis. Returns an error if the shapes are invalid.
260///
261/// # Params
262///
263/// - `arrays`: The arrays to concatenate.
264/// - `axis`: The axis to concatenate along.
265///
266/// # Example
267///
268/// ```rust
269/// use mlx_rs::{Array, ops::*};
270///
271/// let x = Array::from_iter(0..4, &[2, 2]);
272/// let y = Array::from_iter(4..8, &[2, 2]);
273/// let result = concatenate_axis(&[x, y], 0);
274/// ```
275#[generate_macro]
276#[default_device]
277pub fn concatenate_axis_device(
278    arrays: &[impl AsRef<Array>],
279    axis: i32,
280    #[optional] stream: impl AsRef<Stream>,
281) -> Result<Array> {
282    let c_arrays = VectorArray::try_from_iter(arrays.iter())?;
283    Array::try_from_op(|res| unsafe {
284        mlx_sys::mlx_concatenate_axis(res, c_arrays.as_ptr(), axis, stream.as_ref().as_ptr())
285    })
286}
287
288/// Concatenate the arrays along the first axis. Returns an error if the shapes are invalid.
289#[generate_macro]
290#[default_device]
291pub fn concatenate_device(
292    arrays: &[impl AsRef<Array>],
293    #[optional] stream: impl AsRef<Stream>,
294) -> Result<Array> {
295    let c_arrays = VectorArray::try_from_iter(arrays.iter())?;
296    Array::try_from_op(|res| unsafe {
297        mlx_sys::mlx_concatenate(res, c_arrays.as_ptr(), stream.as_ref().as_ptr())
298    })
299}
300
301/// Add a size one dimension at the given axis, returns an error if the axes are invalid.
302///
303/// # Params
304///
305/// - `a`: The input array.
306/// - `axes`: The index of the inserted dimensions.
307///
308/// # Example
309///
310/// ```rust
311/// use mlx_rs::{Array, ops::*};
312///
313/// let x = Array::zeros::<i32>(&[2, 2]).unwrap();
314/// let result = expand_dims_axes(&x, &[0]);
315/// ```
316#[generate_macro]
317#[default_device]
318pub fn expand_dims_axes_device(
319    a: impl AsRef<Array>,
320    axes: &[i32],
321    #[optional] stream: impl AsRef<Stream>,
322) -> Result<Array> {
323    Array::try_from_op(|res| unsafe {
324        mlx_sys::mlx_expand_dims_axes(
325            res,
326            a.as_ref().as_ptr(),
327            axes.as_ptr(),
328            axes.len(),
329            stream.as_ref().as_ptr(),
330        )
331    })
332}
333
334/// Similar to [`expand_dims_axes`], but only takes a single axis.
335#[generate_macro]
336#[default_device]
337pub fn expand_dims_device(
338    a: impl AsRef<Array>,
339    axis: i32,
340    #[optional] stream: impl AsRef<Stream>,
341) -> Result<Array> {
342    Array::try_from_op(|res| unsafe {
343        mlx_sys::mlx_expand_dims(res, a.as_ref().as_ptr(), axis, stream.as_ref().as_ptr())
344    })
345}
346
347/// Flatten an array. Returns an error if the axes are invalid.
348///
349/// The axes flattened will be between `start_axis` and `end_axis`, inclusive. Negative axes are
350/// supported. After converting negative axis to positive, axes outside the valid range will be
351/// clamped to a valid value, `start_axis` to `0` and `end_axis` to `ndim - 1`.
352///
353/// # Params
354///
355/// - `a`: The input array.
356/// - `start_axis`: The first axis to flatten. Default is `0` if not provided.
357/// - `end_axis`: The last axis to flatten. Default is `-1` if not provided.
358///
359/// # Example
360///
361/// ```rust
362/// use mlx_rs::{Array, ops::*};
363///
364/// let x = Array::zeros::<i32>(&[2, 2, 2]).unwrap();
365/// let y = flatten(&x, None, None);
366/// ```
367#[generate_macro]
368#[default_device]
369pub fn flatten_device(
370    a: impl AsRef<Array>,
371    #[optional] start_axis: impl Into<Option<i32>>,
372    #[optional] end_axis: impl Into<Option<i32>>,
373    #[optional] stream: impl AsRef<Stream>,
374) -> Result<Array> {
375    let start_axis = start_axis.into().unwrap_or(0);
376    let end_axis = end_axis.into().unwrap_or(-1);
377
378    Array::try_from_op(|res| unsafe {
379        mlx_sys::mlx_flatten(
380            res,
381            a.as_ref().as_ptr(),
382            start_axis,
383            end_axis,
384            stream.as_ref().as_ptr(),
385        )
386    })
387}
388
389/// Unflatten an axis of an array to a shape.
390///
391/// # Params
392///
393/// - `a`: input array
394/// - `axis`: axis to unflatten
395/// - `shape`: shape to unflatten into
396#[generate_macro]
397#[default_device]
398pub fn unflatten_device(
399    a: impl AsRef<Array>,
400    axis: i32,
401    shape: &[i32],
402    #[optional] stream: impl AsRef<Stream>,
403) -> Result<Array> {
404    Array::try_from_op(|res| unsafe {
405        mlx_sys::mlx_unflatten(
406            res,
407            a.as_ref().as_ptr(),
408            axis,
409            shape.as_ptr(),
410            shape.len(),
411            stream.as_ref().as_ptr(),
412        )
413    })
414}
415
416/// Reshape an array while preserving the size. Returns an error if the new shape is invalid.
417///
418/// # Params
419///
420/// - `a`: The input array.
421/// - `shape`: New shape.
422///
423/// # Example
424///
425/// ```rust
426/// use mlx_rs::{Array, ops::*};
427///
428/// let x = Array::zeros::<i32>(&[2, 2]).unwrap();
429/// let result = reshape(&x, &[4]);
430/// ```
431#[generate_macro]
432#[default_device]
433pub fn reshape_device(
434    a: impl AsRef<Array>,
435    shape: &[i32],
436    #[optional] stream: impl AsRef<Stream>,
437) -> Result<Array> {
438    Array::try_from_op(|res| unsafe {
439        mlx_sys::mlx_reshape(
440            res,
441            a.as_ref().as_ptr(),
442            shape.as_ptr(),
443            shape.len(),
444            stream.as_ref().as_ptr(),
445        )
446    })
447}
448
449/// Remove length one axes from an array. Returns an error if the axes are invalid.
450///
451/// # Params
452///
453/// - `a`: The input array.
454/// - `axes`: Axes to remove. If `None`, all length one axes will be removed.
455///
456/// # Example
457///
458/// ```rust
459/// use mlx_rs::{Array, ops::*};
460///
461/// let x = Array::zeros::<i32>(&[1, 2, 1, 3]).unwrap();
462/// let result = squeeze(&x);
463/// ```
464#[generate_macro]
465#[default_device]
466pub fn squeeze_axes_device(
467    a: impl AsRef<Array>,
468    axes: &[i32],
469    #[optional] stream: impl AsRef<Stream>,
470) -> Result<Array> {
471    let a = a.as_ref();
472    Array::try_from_op(|res| unsafe {
473        mlx_sys::mlx_squeeze_axes(
474            res,
475            a.as_ptr(),
476            axes.as_ptr(),
477            axes.len(),
478            stream.as_ref().as_ptr(),
479        )
480    })
481}
482
483/// Similar to [`squeeze_axes`], but removes all length one axes.
484#[generate_macro]
485#[default_device]
486pub fn squeeze_device(
487    a: impl AsRef<Array>,
488    #[optional] stream: impl AsRef<Stream>,
489) -> Result<Array> {
490    let a = a.as_ref();
491    Array::try_from_op(|res| unsafe {
492        mlx_sys::mlx_squeeze(res, a.as_ptr(), stream.as_ref().as_ptr())
493    })
494}
495
496/// Convert array to have at least one dimension.
497///
498/// # Params
499///
500/// - `a`: The input array.
501///
502/// # Example
503///
504/// ```rust
505/// use mlx_rs::{Array, ops::*};
506///
507/// let x = Array::from_int(1);
508/// let out = at_least_1d(&x);
509/// ```
510#[generate_macro]
511#[default_device]
512pub fn at_least_1d_device(
513    a: impl AsRef<Array>,
514    #[optional] stream: impl AsRef<Stream>,
515) -> Result<Array> {
516    Array::try_from_op(|res| unsafe {
517        mlx_sys::mlx_atleast_1d(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
518    })
519}
520
521/// Convert array to have at least two dimensions.
522///
523/// # Params
524///
525/// - `a`: The input array.
526///
527/// # Example
528///
529/// ```rust
530/// use mlx_rs::{Array, ops::*};
531///
532/// let x = Array::from_int(1);
533/// let out = at_least_2d(&x);
534/// ```
535#[generate_macro]
536#[default_device]
537pub fn at_least_2d_device(
538    a: impl AsRef<Array>,
539    #[optional] stream: impl AsRef<Stream>,
540) -> Result<Array> {
541    Array::try_from_op(|res| unsafe {
542        mlx_sys::mlx_atleast_2d(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
543    })
544}
545
546/// Convert array to have at least three dimensions.
547///
548/// # Params
549///
550/// - `a`: The input array.
551///
552/// # Example
553///
554/// ```rust
555/// use mlx_rs::{Array, ops::*};
556///
557/// let x = Array::from_int(1);
558/// let out = at_least_3d(&x);
559/// ```
560#[generate_macro]
561#[default_device]
562pub fn at_least_3d_device(
563    a: impl AsRef<Array>,
564    #[optional] stream: impl AsRef<Stream>,
565) -> Result<Array> {
566    Array::try_from_op(|res| unsafe {
567        mlx_sys::mlx_atleast_3d(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
568    })
569}
570
571/// Move an axis to a new position. Returns an error if the axes are invalid.
572///
573/// # Params
574///
575/// - `a`: The input array.
576/// - `src`: Specifies the source axis.
577/// - `dst`: Specifies the destination axis.
578///
579/// # Example
580///
581/// ```rust
582/// use mlx_rs::{Array, ops::*};
583///
584/// let a = Array::zeros::<i32>(&[2, 3, 4]).unwrap();
585/// let result = move_axis(&a, 0, 2);
586/// ```
587#[generate_macro]
588#[default_device]
589pub fn move_axis_device(
590    a: impl AsRef<Array>,
591    src: i32,
592    dst: i32,
593    #[optional] stream: impl AsRef<Stream>,
594) -> Result<Array> {
595    Array::try_from_op(|res| unsafe {
596        mlx_sys::mlx_moveaxis(res, a.as_ref().as_ptr(), src, dst, stream.as_ref().as_ptr())
597    })
598}
599
600/// Split an array along a given axis. Returns an error if the indices are invalid.
601///
602/// # Params
603///
604/// - `a`: The input array.
605/// - `indices`: The indices to split at.
606/// - `axis`: The axis to split along. Default is `0` if not provided.
607///
608/// # Example
609///
610/// ```rust
611/// use mlx_rs::{Array, ops::*};
612///
613/// let a = Array::from_iter(0..10, &[10]);
614/// let result = split_sections(&a, &[3, 7], 0);
615/// ```
616#[generate_macro]
617#[default_device]
618pub fn split_sections_device(
619    a: impl AsRef<Array>,
620    indices: &[i32],
621    #[optional] axis: impl Into<Option<i32>>,
622    #[optional] stream: impl AsRef<Stream>,
623) -> Result<Vec<Array>> {
624    let axis = axis.into().unwrap_or(0);
625    Vec::<Array>::try_from_op(|res| unsafe {
626        mlx_sys::mlx_split_sections(
627            res,
628            a.as_ref().as_ptr(),
629            indices.as_ptr(),
630            indices.len(),
631            axis,
632            stream.as_ref().as_ptr(),
633        )
634    })
635}
636
637/// Split an array into equal parts along a given axis. Returns an error if the array cannot be
638/// split into equal parts.
639///
640/// # Params
641///
642/// - `a`: The input array.
643/// - `num_parts`: The number of parts to split into.
644/// - `axis`: The axis to split along. Default is `0` if not provided.
645///
646/// # Example
647///
648/// ```rust
649/// use mlx_rs::{Array, ops::*};
650///
651/// let a = Array::from_iter(0..10, &[10]);
652/// let result = split(&a, 2, 0);
653/// ```
654#[generate_macro]
655#[default_device]
656pub fn split_device(
657    a: impl AsRef<Array>,
658    num_parts: i32,
659    #[optional] axis: impl Into<Option<i32>>,
660    #[optional] stream: impl AsRef<Stream>,
661) -> Result<Vec<Array>> {
662    let axis = axis.into().unwrap_or(0);
663    Vec::<Array>::try_from_op(|res| unsafe {
664        mlx_sys::mlx_split(
665            res,
666            a.as_ref().as_ptr(),
667            num_parts,
668            axis,
669            stream.as_ref().as_ptr(),
670        )
671    })
672}
673
674/// Number of padding values to add to the edges of each axis.
675#[derive(Debug)]
676pub enum PadWidth<'a> {
677    /// (before, after) values for all axes.
678    Same((i32, i32)),
679
680    /// List of (before, after) values for each axis.
681    Widths(&'a [(i32, i32)]),
682}
683
684impl From<i32> for PadWidth<'_> {
685    fn from(width: i32) -> Self {
686        PadWidth::Same((width, width))
687    }
688}
689
690impl From<(i32, i32)> for PadWidth<'_> {
691    fn from(width: (i32, i32)) -> Self {
692        PadWidth::Same(width)
693    }
694}
695
696impl<'a> From<&'a [(i32, i32)]> for PadWidth<'a> {
697    fn from(widths: &'a [(i32, i32)]) -> Self {
698        PadWidth::Widths(widths)
699    }
700}
701
702impl<'a, const N: usize> From<&'a [(i32, i32); N]> for PadWidth<'a> {
703    fn from(widths: &'a [(i32, i32); N]) -> Self {
704        PadWidth::Widths(widths)
705    }
706}
707
708impl PadWidth<'_> {
709    fn low_pads(&self, ndim: usize) -> SmallVec<[i32; DEFAULT_STACK_VEC_LEN]> {
710        match self {
711            PadWidth::Same((low, _high)) => (0..ndim).map(|_| *low).collect(),
712            PadWidth::Widths(widths) => widths.iter().map(|(low, _high)| *low).collect(),
713        }
714    }
715
716    fn high_pads(&self, ndim: usize) -> SmallVec<[i32; DEFAULT_STACK_VEC_LEN]> {
717        match self {
718            PadWidth::Same((_low, high)) => (0..ndim).map(|_| *high).collect(),
719            PadWidth::Widths(widths) => widths.iter().map(|(_low, high)| *high).collect(),
720        }
721    }
722}
723
724/// The padding mode.
725#[derive(Debug)]
726pub enum PadMode {
727    /// Pad with a constant value.
728    Constant,
729
730    /// Pad with the edge value.
731    Edge,
732}
733
734impl PadMode {
735    unsafe fn as_c_str(&self) -> *const i8 {
736        static CONSTANT: &[u8] = b"constant\0";
737        static EDGE: &[u8] = b"edge\0";
738
739        match self {
740            PadMode::Constant => CONSTANT.as_ptr() as *const _,
741            PadMode::Edge => EDGE.as_ptr() as *const _,
742        }
743    }
744}
745
746/// Pad an array with a constant value. Returns an error if the width is invalid.
747///
748/// # Params
749///
750/// - `a`: The input array.
751/// - `width`: Number of padded values to add to the edges of each axis:`((before_1, after_1),
752///   (before_2, after_2), ..., (before_N, after_N))`. If a single pair of integers is passed then
753///   `(before_i, after_i)` are all the same. If a single integer or tuple with a single integer is
754///   passed then all axes are extended by the same number on each side.
755/// - `value`: The value to pad the array with. Default is `0` if not provided.
756/// - `mode`: The padding mode. Default is `PadMode::Constant` if not provided.
757///
758/// # Example
759///
760/// ```rust
761/// use mlx_rs::{Array, ops::*};
762///
763/// let a = Array::from_iter(0..4, &[2, 2]);
764/// let result = pad(&a, 1, Array::from_int(0), None);
765/// ```
766#[generate_macro]
767#[default_device]
768pub fn pad_device<'a>(
769    a: impl AsRef<Array>,
770    #[optional] width: impl Into<PadWidth<'a>>,
771    #[optional] value: impl Into<Option<Array>>,
772    #[optional] mode: impl Into<Option<PadMode>>,
773    #[optional] stream: impl AsRef<Stream>,
774) -> Result<Array> {
775    let a = a.as_ref();
776    let width = width.into();
777    let ndim = a.ndim();
778    let axes: SmallVec<[i32; DEFAULT_STACK_VEC_LEN]> = (0..ndim).map(|i| i as i32).collect();
779    let low_pads = width.low_pads(ndim);
780    let high_pads = width.high_pads(ndim);
781    let value = value
782        .into()
783        .map(Ok)
784        .unwrap_or_else(|| Array::from_int(0).as_dtype(a.dtype()))?;
785    let mode = mode.into().unwrap_or(PadMode::Constant);
786
787    Array::try_from_op(|res| unsafe {
788        mlx_sys::mlx_pad(
789            res,
790            a.as_ptr(),
791            axes.as_ptr(),
792            axes.len(),
793            low_pads.as_ptr(),
794            low_pads.len(),
795            high_pads.as_ptr(),
796            high_pads.len(),
797            value.as_ptr(),
798            mode.as_c_str(),
799            stream.as_ref().as_ptr(),
800        )
801    })
802}
803
804/// Stacks the arrays along a new axis. Returns an error if the arguments are invalid.
805///
806/// # Params
807///
808/// - `arrays`: The input arrays.
809/// - `axis`: The axis in the result array along which the input arrays are stacked.
810///
811/// # Example
812///
813/// ```rust
814/// use mlx_rs::{Array, ops::*};
815///
816/// let a = Array::from_iter(0..4, &[2, 2]);
817/// let b = Array::from_iter(4..8, &[2, 2]);
818/// let result = stack_axis(&[&a, &b], 0);
819/// ```
820#[generate_macro]
821#[default_device]
822pub fn stack_axis_device(
823    arrays: &[impl AsRef<Array>],
824    axis: i32,
825    #[optional] stream: impl AsRef<Stream>,
826) -> Result<Array> {
827    let c_vec = VectorArray::try_from_iter(arrays.iter())?;
828    Array::try_from_op(|res| unsafe {
829        mlx_sys::mlx_stack_axis(res, c_vec.as_ptr(), axis, stream.as_ref().as_ptr())
830    })
831}
832
833/// Stacks the arrays along a new axis. Returns an error if the arrays have different shapes.
834///
835/// # Params
836///
837/// - `arrays`: The input arrays.
838///
839/// # Example
840///
841/// ```rust
842/// use mlx_rs::{Array, ops::*};
843///
844/// let a = Array::from_iter(0..4, &[2, 2]);
845/// let b = Array::from_iter(4..8, &[2, 2]);
846/// let result = stack(&[&a, &b]);
847/// ```
848#[generate_macro]
849#[default_device]
850pub fn stack_device(
851    arrays: &[impl AsRef<Array>],
852    #[optional] stream: impl AsRef<Stream>,
853) -> Result<Array> {
854    let c_vec = VectorArray::try_from_iter(arrays.iter())?;
855    Array::try_from_op(|res| unsafe {
856        mlx_sys::mlx_stack(res, c_vec.as_ptr(), stream.as_ref().as_ptr())
857    })
858}
859
860/// Swap two axes of an array. Returns an error if the axes are invalid.
861///
862/// # Params
863///
864/// - `a`: The input array.
865/// - `axis1`: The first axis.
866/// - `axis2`: The second axis.
867///
868/// # Example
869///
870/// ```rust
871/// use mlx_rs::{Array, ops::*};
872///
873/// let a = Array::from_iter(0..6, &[2, 3]);
874/// let result = swap_axes(&a, 0, 1);
875/// ```
876#[generate_macro]
877#[default_device]
878pub fn swap_axes_device(
879    a: impl AsRef<Array>,
880    axis1: i32,
881    axis2: i32,
882    #[optional] stream: impl AsRef<Stream>,
883) -> Result<Array> {
884    Array::try_from_op(|res| unsafe {
885        mlx_sys::mlx_swapaxes(
886            res,
887            a.as_ref().as_ptr(),
888            axis1,
889            axis2,
890            stream.as_ref().as_ptr(),
891        )
892    })
893}
894
895/// Construct an array by repeating `a` the number of times given by `reps`.
896///
897/// # Params
898///
899/// - `a`: The input array.
900/// - `reps`: The number of repetitions along each axis.
901///
902/// # Example
903///
904/// ```rust
905/// use mlx_rs::{Array, ops::*};
906///
907/// let x = Array::from_slice(&[1, 2, 3], &[3]);
908/// let y = tile(&x, &[2]);
909/// ```
910#[generate_macro]
911#[default_device]
912pub fn tile_device(
913    a: impl AsRef<Array>,
914    reps: &[i32],
915    #[optional] stream: impl AsRef<Stream>,
916) -> Result<Array> {
917    Array::try_from_op(|res| unsafe {
918        mlx_sys::mlx_tile(
919            res,
920            a.as_ref().as_ptr(),
921            reps.as_ptr(),
922            reps.len(),
923            stream.as_ref().as_ptr(),
924        )
925    })
926}
927
928/// Transpose the dimensions of the array. Returns an error if the axes are invalid.
929///
930/// # Params
931///
932/// - `a`: The input array.
933/// - `axes`: Specifies the source axis for each axis in the new array. The default is to reverse
934///  the axes.
935///
936/// # Example
937///
938/// ```rust
939/// use mlx_rs::{Array, ops::*};
940///
941/// let x = Array::from_slice(&[1, 2, 3, 4, 5, 6], &[2, 3]);
942/// let y1 = transpose_axes(&x, &[0, 1]).unwrap();
943/// let y2 = transpose(&x).unwrap();
944/// ```
945///
946/// # See also
947///
948/// - [`transpose`]
949#[generate_macro]
950#[default_device]
951pub fn transpose_axes_device(
952    a: impl AsRef<Array>,
953    axes: &[i32],
954    #[optional] stream: impl AsRef<Stream>,
955) -> Result<Array> {
956    Array::try_from_op(|res| unsafe {
957        mlx_sys::mlx_transpose_axes(
958            res,
959            a.as_ref().as_ptr(),
960            axes.as_ptr(),
961            axes.len(),
962            stream.as_ref().as_ptr(),
963        )
964    })
965}
966
967/// Transpose with all axes reversed
968#[generate_macro]
969#[default_device]
970pub fn transpose_device(
971    a: impl AsRef<Array>,
972    #[optional] stream: impl AsRef<Stream>,
973) -> Result<Array> {
974    Array::try_from_op(|res| unsafe {
975        mlx_sys::mlx_transpose(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
976    })
977}
978
979// The unit tests below are adapted from
980// https://github.com/ml-explore/mlx/blob/main/tests/ops_tests.cpp
981#[cfg(test)]
982mod tests {
983    use crate::{array, Array, Dtype};
984
985    use super::*;
986
987    #[test]
988    fn test_squeeze() {
989        let a = Array::zeros::<i32>(&[2, 1, 2, 1, 2, 1]).unwrap();
990        assert_eq!(
991            squeeze_axes(&a, &[1, 3, 5][..]).unwrap().shape(),
992            &[2, 2, 2]
993        );
994        assert_eq!(
995            squeeze_axes(&a, &[-1, -3, -5][..]).unwrap().shape(),
996            &[2, 2, 2]
997        );
998        assert_eq!(
999            squeeze_axes(&a, &[1][..]).unwrap().shape(),
1000            &[2, 2, 1, 2, 1]
1001        );
1002        assert_eq!(
1003            squeeze_axes(&a, &[-1][..]).unwrap().shape(),
1004            &[2, 1, 2, 1, 2]
1005        );
1006
1007        assert!(squeeze_axes(&a, &[0][..]).is_err());
1008        assert!(squeeze_axes(&a, &[2][..]).is_err());
1009        assert!(squeeze_axes(&a, &[1, 3, 1][..]).is_err());
1010        assert!(squeeze_axes(&a, &[1, 3, -3][..]).is_err());
1011    }
1012
1013    #[test]
1014    fn test_expand_dims() {
1015        let a = Array::zeros::<i32>(&[2, 2]).unwrap();
1016        assert_eq!(expand_dims_axes(&a, &[0][..]).unwrap().shape(), &[1, 2, 2]);
1017        assert_eq!(expand_dims_axes(&a, &[-1][..]).unwrap().shape(), &[2, 2, 1]);
1018        assert_eq!(expand_dims_axes(&a, &[1][..]).unwrap().shape(), &[2, 1, 2]);
1019        assert_eq!(
1020            expand_dims_axes(&a, &[0, 1, 2]).unwrap().shape(),
1021            &[1, 1, 1, 2, 2]
1022        );
1023        assert_eq!(
1024            expand_dims_axes(&a, &[0, 1, 2, 5, 6, 7]).unwrap().shape(),
1025            &[1, 1, 1, 2, 2, 1, 1, 1]
1026        );
1027
1028        assert!(expand_dims_axes(&a, &[3]).is_err());
1029        assert!(expand_dims_axes(&a, &[0, 1, 0]).is_err());
1030        assert!(expand_dims_axes(&a, &[0, 1, -4]).is_err());
1031    }
1032
1033    #[test]
1034    fn test_flatten() {
1035        let x = Array::zeros::<i32>(&[2, 3, 4]).unwrap();
1036        assert_eq!(flatten(&x, None, None).unwrap().shape(), &[2 * 3 * 4]);
1037
1038        assert_eq!(flatten(&x, 1, 1).unwrap().shape(), &[2, 3, 4]);
1039        assert_eq!(flatten(&x, 1, 2).unwrap().shape(), &[2, 3 * 4]);
1040        assert_eq!(flatten(&x, 1, 3).unwrap().shape(), &[2, 3 * 4]);
1041        assert_eq!(flatten(&x, 1, -1).unwrap().shape(), &[2, 3 * 4]);
1042        assert_eq!(flatten(&x, -2, -1).unwrap().shape(), &[2, 3 * 4]);
1043        assert_eq!(flatten(&x, -3, -1).unwrap().shape(), &[2 * 3 * 4]);
1044        assert_eq!(flatten(&x, -4, -1).unwrap().shape(), &[2 * 3 * 4]);
1045
1046        assert!(flatten(&x, 2, 1).is_err());
1047
1048        assert!(flatten(&x, 5, 6).is_err());
1049
1050        assert!(flatten(&x, -5, -4).is_err());
1051
1052        let x = Array::from_int(1);
1053        assert_eq!(flatten(&x, -3, -1).unwrap().shape(), &[1]);
1054        assert_eq!(flatten(&x, 0, 0).unwrap().shape(), &[1]);
1055    }
1056
1057    #[test]
1058    fn test_unflatten() {
1059        let a = array!([1, 2, 3, 4]);
1060        let b = unflatten(&a, 0, &[2, -1]).unwrap();
1061        let expected = array!([[1, 2], [3, 4]]);
1062        assert_eq!(b, expected);
1063    }
1064
1065    #[test]
1066    fn test_reshape() {
1067        let x = Array::from_int(1);
1068        assert!(reshape(&x, &[]).unwrap().shape().is_empty());
1069        assert!(reshape(&x, &[2]).is_err());
1070        let y = reshape(&x, &[1, 1, 1]).unwrap();
1071        assert_eq!(y.shape(), &[1, 1, 1]);
1072        let y = reshape(&x, &[-1, 1, 1]).unwrap();
1073        assert_eq!(y.shape(), &[1, 1, 1]);
1074        let y = reshape(&x, &[1, 1, -1]).unwrap();
1075        assert_eq!(y.shape(), &[1, 1, 1]);
1076        assert!(reshape(&x, &[1, -1, -1]).is_err());
1077        assert!(reshape(&x, &[2, -1]).is_err());
1078
1079        let x = Array::zeros::<i32>(&[2, 2, 2]).unwrap();
1080        let y = reshape(&x, &[8]).unwrap();
1081        assert_eq!(y.shape(), &[8]);
1082        assert!(reshape(&x, &[7]).is_err());
1083        let y = reshape(&x, &[-1]).unwrap();
1084        assert_eq!(y.shape(), &[8]);
1085        let y = reshape(&x, &[-1, 2]).unwrap();
1086        assert_eq!(y.shape(), &[4, 2]);
1087        assert!(reshape(&x, &[-1, 7]).is_err());
1088
1089        let x = Array::from_slice::<i32>(&[], &[0]);
1090        let y = reshape(&x, &[0, 0, 0]).unwrap();
1091        assert_eq!(y.shape(), &[0, 0, 0]);
1092        y.eval().unwrap();
1093        assert_eq!(y.size(), 0);
1094        assert!(reshape(&x, &[]).is_err());
1095        assert!(reshape(&x, &[1]).is_err());
1096        let y = reshape(&x, &[1, 5, 0]).unwrap();
1097        assert_eq!(y.shape(), &[1, 5, 0]);
1098    }
1099
1100    #[test]
1101    fn test_as_strided() {
1102        let x = Array::from_iter(0..10, &[10]);
1103        let y = as_strided(&x, &[3, 3][..], &[1, 1][..], 0).unwrap();
1104        let expected = Array::from_slice(&[0, 1, 2, 1, 2, 3, 2, 3, 4], &[3, 3]);
1105        assert_eq!(y, expected);
1106
1107        let y = as_strided(&x, &[3, 3][..], &[0, 3][..], 0).unwrap();
1108        let expected = Array::from_slice(&[0, 3, 6, 0, 3, 6, 0, 3, 6], &[3, 3]);
1109        assert_eq!(y, expected);
1110
1111        let x = x.reshape(&[2, 5]).unwrap();
1112        let x = x.transpose_axes(&[1, 0][..]).unwrap();
1113        let y = as_strided(&x, &[3, 3][..], &[2, 1][..], 1).unwrap();
1114        let expected = Array::from_slice(&[5, 1, 6, 6, 2, 7, 7, 3, 8], &[3, 3]);
1115        assert_eq!(y, expected);
1116    }
1117
1118    #[test]
1119    fn test_at_least_1d() {
1120        let x = Array::from_int(1);
1121        let out = at_least_1d(&x).unwrap();
1122        assert_eq!(out.ndim(), 1);
1123        assert_eq!(out.shape(), &[1]);
1124
1125        let x = Array::from_slice(&[1, 2, 3], &[3]);
1126        let out = at_least_1d(&x).unwrap();
1127        assert_eq!(out.ndim(), 1);
1128        assert_eq!(out.shape(), &[3]);
1129
1130        let x = Array::from_slice(&[1, 2, 3], &[3, 1]);
1131        let out = at_least_1d(&x).unwrap();
1132        assert_eq!(out.ndim(), 2);
1133        assert_eq!(out.shape(), &[3, 1]);
1134    }
1135
1136    #[test]
1137    fn test_at_least_2d() {
1138        let x = Array::from_int(1);
1139        let out = at_least_2d(&x).unwrap();
1140        assert_eq!(out.ndim(), 2);
1141        assert_eq!(out.shape(), &[1, 1]);
1142
1143        let x = Array::from_slice(&[1, 2, 3], &[3]);
1144        let out = at_least_2d(&x).unwrap();
1145        assert_eq!(out.ndim(), 2);
1146        assert_eq!(out.shape(), &[1, 3]);
1147
1148        let x = Array::from_slice(&[1, 2, 3], &[3, 1]);
1149        let out = at_least_2d(&x).unwrap();
1150        assert_eq!(out.ndim(), 2);
1151        assert_eq!(out.shape(), &[3, 1]);
1152    }
1153
1154    #[test]
1155    fn test_at_least_3d() {
1156        let x = Array::from_int(1);
1157        let out = at_least_3d(&x).unwrap();
1158        assert_eq!(out.ndim(), 3);
1159        assert_eq!(out.shape(), &[1, 1, 1]);
1160
1161        let x = Array::from_slice(&[1, 2, 3], &[3]);
1162        let out = at_least_3d(&x).unwrap();
1163        assert_eq!(out.ndim(), 3);
1164        assert_eq!(out.shape(), &[1, 3, 1]);
1165
1166        let x = Array::from_slice(&[1, 2, 3], &[3, 1]);
1167        let out = at_least_3d(&x).unwrap();
1168        assert_eq!(out.ndim(), 3);
1169        assert_eq!(out.shape(), &[3, 1, 1]);
1170    }
1171
1172    #[test]
1173    fn test_move_axis() {
1174        let a = Array::from_int(0);
1175        assert!(move_axis(&a, 0, 0).is_err());
1176
1177        let a = Array::zeros::<i32>(&[2]).unwrap();
1178        assert!(move_axis(&a, 0, 1).is_err());
1179        assert_eq!(move_axis(&a, 0, 0).unwrap().shape(), &[2]);
1180        assert_eq!(move_axis(&a, -1, -1).unwrap().shape(), &[2]);
1181
1182        let a = Array::zeros::<i32>(&[2, 3, 4]).unwrap();
1183        assert!(move_axis(&a, 0, -4).is_err());
1184        assert!(move_axis(&a, 0, 3).is_err());
1185        assert!(move_axis(&a, 3, 0).is_err());
1186        assert!(move_axis(&a, -4, 0).is_err());
1187        assert_eq!(move_axis(&a, 0, 2).unwrap().shape(), &[3, 4, 2]);
1188        assert_eq!(move_axis(&a, 0, 1).unwrap().shape(), &[3, 2, 4]);
1189        assert_eq!(move_axis(&a, 0, -1).unwrap().shape(), &[3, 4, 2]);
1190        assert_eq!(move_axis(&a, -2, 2).unwrap().shape(), &[2, 4, 3]);
1191    }
1192
1193    #[test]
1194    fn test_split_equal() {
1195        let x = Array::from_int(3);
1196        assert!(split(&x, 0, 0).is_err());
1197
1198        let x = Array::from_slice(&[0, 1, 2], &[3]);
1199        assert!(split(&x, 3, 1).is_err());
1200        assert!(split(&x, -2, 1).is_err());
1201
1202        let out = split(&x, 3, 0).unwrap();
1203        assert_eq!(out.len(), 3);
1204
1205        let mut out = split(&x, 3, -1).unwrap();
1206        assert_eq!(out.len(), 3);
1207        for (i, a) in out.iter_mut().enumerate() {
1208            assert_eq!(a.shape(), &[1]);
1209            assert_eq!(a.dtype(), Dtype::Int32);
1210            assert_eq!(a.item::<i32>(), i as i32);
1211        }
1212
1213        let x = Array::from_slice(&[0, 1, 2, 3, 4, 5], &[2, 3]);
1214        let out = split(&x, 2, None).unwrap();
1215        assert_eq!(out[0], Array::from_slice(&[0, 1, 2], &[1, 3]));
1216        assert_eq!(out[1], Array::from_slice(&[3, 4, 5], &[1, 3]));
1217
1218        let out = split(&x, 3, 1).unwrap();
1219        assert_eq!(out[0], Array::from_slice(&[0, 3], &[2, 1]));
1220        assert_eq!(out[1], Array::from_slice(&[1, 4], &[2, 1]));
1221        assert_eq!(out[2], Array::from_slice(&[2, 5], &[2, 1]));
1222
1223        let x = Array::zeros::<i32>(&[8, 12]).unwrap();
1224        let out = split(&x, 2, None).unwrap();
1225        assert_eq!(out.len(), 2);
1226        assert_eq!(out[0].shape(), &[4, 12]);
1227        assert_eq!(out[1].shape(), &[4, 12]);
1228
1229        let out = split(&x, 3, 1).unwrap();
1230        assert_eq!(out.len(), 3);
1231        assert_eq!(out[0].shape(), &[8, 4]);
1232        assert_eq!(out[1].shape(), &[8, 4]);
1233        assert_eq!(out[2].shape(), &[8, 4]);
1234    }
1235
1236    #[test]
1237    fn test_split() {
1238        let x = Array::zeros::<i32>(&[8, 12]).unwrap();
1239
1240        let out = split_sections(&x, &[], None).unwrap();
1241        assert_eq!(out.len(), 1);
1242        assert_eq!(out[0].shape(), x.shape());
1243
1244        let out = split_sections(&x, &[3, 7], None).unwrap();
1245        assert_eq!(out.len(), 3);
1246        assert_eq!(out[0].shape(), &[3, 12]);
1247        assert_eq!(out[1].shape(), &[4, 12]);
1248        assert_eq!(out[2].shape(), &[1, 12]);
1249
1250        let out = split_sections(&x, &[20], None).unwrap();
1251        assert_eq!(out.len(), 2);
1252        assert_eq!(out[0].shape(), &[8, 12]);
1253        assert_eq!(out[1].shape(), &[0, 12]);
1254
1255        let out = split_sections(&x, &[-5], None).unwrap();
1256        assert_eq!(out[0].shape(), &[3, 12]);
1257        assert_eq!(out[1].shape(), &[5, 12]);
1258
1259        let out = split_sections(&x, &[2, 8], Some(1)).unwrap();
1260        assert_eq!(out[0].shape(), &[8, 2]);
1261        assert_eq!(out[1].shape(), &[8, 6]);
1262        assert_eq!(out[2].shape(), &[8, 4]);
1263
1264        let x = Array::from_iter(0i32..5, &[5]);
1265        let out = split_sections(&x, &[2, 1, 2], None).unwrap();
1266        assert_eq!(out[0], Array::from_slice(&[0, 1], &[2]));
1267        assert_eq!(out[1], Array::from_slice::<i32>(&[], &[0]));
1268        assert_eq!(out[2], Array::from_slice(&[1], &[1]));
1269        assert_eq!(out[3], Array::from_slice(&[2, 3, 4], &[3]));
1270    }
1271
1272    #[test]
1273    fn test_pad() {
1274        let x = Array::zeros::<f32>(&[1, 2, 3]).unwrap();
1275        assert_eq!(pad(&x, 1, None, None).unwrap().shape(), &[3, 4, 5]);
1276        assert_eq!(pad(&x, (0, 1), None, None).unwrap().shape(), &[2, 3, 4]);
1277        assert_eq!(
1278            pad(&x, &[(1, 1), (1, 2), (3, 1)], None, None)
1279                .unwrap()
1280                .shape(),
1281            &[3, 5, 7]
1282        );
1283    }
1284
1285    #[test]
1286    fn test_stack() {
1287        let x = Array::from_slice::<f32>(&[], &[0]);
1288        let x = vec![x];
1289        assert_eq!(stack_axis(&x, 0).unwrap().shape(), &[1, 0]);
1290        assert_eq!(stack_axis(&x, 1).unwrap().shape(), &[0, 1]);
1291
1292        let x = Array::from_slice(&[1, 2, 3], &[3]);
1293        let x = vec![x];
1294        assert_eq!(stack_axis(&x, 0).unwrap().shape(), &[1, 3]);
1295        assert_eq!(stack_axis(&x, 1).unwrap().shape(), &[3, 1]);
1296
1297        let y = Array::from_slice(&[4, 5, 6], &[3]);
1298        let mut z = x;
1299        z.push(y);
1300        assert_eq!(stack(&z).unwrap().shape(), &[2, 3]);
1301        assert_eq!(stack_axis(&z, 1).unwrap().shape(), &[3, 2]);
1302        assert_eq!(stack_axis(&z, -1).unwrap().shape(), &[3, 2]);
1303        assert_eq!(stack_axis(&z, -2).unwrap().shape(), &[2, 3]);
1304
1305        let empty: Vec<Array> = Vec::new();
1306        assert!(stack_axis(&empty, 0).is_err());
1307
1308        let x = Array::from_slice(&[1, 2, 3], &[3])
1309            .as_dtype(Dtype::Float16)
1310            .unwrap();
1311        let y = Array::from_slice(&[4, 5, 6], &[3])
1312            .as_dtype(Dtype::Int32)
1313            .unwrap();
1314        assert_eq!(stack_axis(&[x, y], 0).unwrap().dtype(), Dtype::Float16);
1315
1316        let x = Array::from_slice(&[1, 2, 3], &[3])
1317            .as_dtype(Dtype::Int32)
1318            .unwrap();
1319        let y = Array::from_slice(&[4, 5, 6, 7], &[4])
1320            .as_dtype(Dtype::Int32)
1321            .unwrap();
1322        assert!(stack_axis(&[x, y], 0).is_err());
1323    }
1324
1325    #[test]
1326    fn test_swap_axes() {
1327        let a = Array::from_int(0);
1328        assert!(swap_axes(&a, 0, 0).is_err());
1329
1330        let a = Array::zeros::<i32>(&[2]).unwrap();
1331        assert!(swap_axes(&a, 0, 1).is_err());
1332        assert_eq!(swap_axes(&a, 0, 0).unwrap().shape(), &[2]);
1333        assert_eq!(swap_axes(&a, -1, -1).unwrap().shape(), &[2]);
1334
1335        let a = Array::zeros::<i32>(&[2, 3, 4]).unwrap();
1336        assert!(swap_axes(&a, 0, -4).is_err());
1337        assert!(swap_axes(&a, 0, 3).is_err());
1338        assert!(swap_axes(&a, 3, 0).is_err());
1339        assert!(swap_axes(&a, -4, 0).is_err());
1340        assert_eq!(swap_axes(&a, 0, 2).unwrap().shape(), &[4, 3, 2]);
1341        assert_eq!(swap_axes(&a, 0, 1).unwrap().shape(), &[3, 2, 4]);
1342        assert_eq!(swap_axes(&a, 0, -1).unwrap().shape(), &[4, 3, 2]);
1343        assert_eq!(swap_axes(&a, -2, 2).unwrap().shape(), &[2, 4, 3]);
1344    }
1345
1346    #[test]
1347    fn test_tile() {
1348        let x = Array::from_slice(&[1, 2, 3], &[3]);
1349        let y = tile(&x, &[2]).unwrap();
1350        let expected = Array::from_slice(&[1, 2, 3, 1, 2, 3], &[6]);
1351        assert_eq!(y, expected);
1352
1353        let x = Array::from_slice(&[1, 2, 3, 4], &[2, 2]);
1354        let y = tile(&x, &[2]).unwrap();
1355        let expected = Array::from_slice(&[1, 2, 1, 2, 3, 4, 3, 4], &[2, 4]);
1356        assert_eq!(y, expected);
1357
1358        let x = Array::from_slice(&[1, 2, 3, 4], &[2, 2]);
1359        let y = tile(&x, &[4, 1]).unwrap();
1360        let expected =
1361            Array::from_slice(&[1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4], &[8, 2]);
1362        assert_eq!(y, expected);
1363
1364        let x = Array::from_slice(&[1, 2, 3, 4], &[2, 2]);
1365        let y = tile(&x, &[2, 2]).unwrap();
1366        let expected =
1367            Array::from_slice(&[1, 2, 1, 2, 3, 4, 3, 4, 1, 2, 1, 2, 3, 4, 3, 4], &[4, 4]);
1368        assert_eq!(y, expected);
1369
1370        let x = Array::from_slice(&[1, 2, 3], &[3]);
1371        let y = tile(&x, &[2, 2, 2]).unwrap();
1372        let expected = Array::from_slice(
1373            &[
1374                1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3,
1375            ],
1376            &[2, 2, 6],
1377        );
1378        assert_eq!(y, expected);
1379    }
1380
1381    #[test]
1382    fn test_transpose() {
1383        let x = Array::from_int(1);
1384        let y = transpose(&x).unwrap();
1385        assert!(y.shape().is_empty());
1386        assert_eq!(y.item::<i32>(), 1);
1387        assert!(transpose_axes(&x, &[0][..]).is_err());
1388        assert!(transpose_axes(&x, &[1][..]).is_err());
1389
1390        let x = Array::from_slice(&[1], &[1]);
1391        let y = transpose(&x).unwrap();
1392        assert_eq!(y.shape(), &[1]);
1393        assert_eq!(y.item::<i32>(), 1);
1394
1395        let y = transpose_axes(&x, &[-1][..]).unwrap();
1396        assert_eq!(y.shape(), &[1]);
1397        assert_eq!(y.item::<i32>(), 1);
1398
1399        assert!(transpose_axes(&x, &[1][..]).is_err());
1400        assert!(transpose_axes(&x, &[0, 0][..]).is_err());
1401
1402        let x = Array::from_slice::<i32>(&[], &[0]);
1403        let y = transpose(&x).unwrap();
1404        assert_eq!(y.shape(), &[0]);
1405        y.eval().unwrap();
1406        assert_eq!(y.size(), 0);
1407
1408        let x = Array::from_slice(&[1, 2, 3, 4, 5, 6], &[2, 3]);
1409        let mut y = transpose(&x).unwrap();
1410        assert_eq!(y.shape(), &[3, 2]);
1411        y = transpose_axes(&x, &[-1, 0][..]).unwrap();
1412        assert_eq!(y.shape(), &[3, 2]);
1413        y = transpose_axes(&x, &[-1, -2][..]).unwrap();
1414        assert_eq!(y.shape(), &[3, 2]);
1415        y.eval().unwrap();
1416        assert_eq!(y, Array::from_slice(&[1, 4, 2, 5, 3, 6], &[3, 2]));
1417
1418        let y = transpose_axes(&x, &[0, 1][..]).unwrap();
1419        assert_eq!(y.shape(), &[2, 3]);
1420        assert_eq!(y, x);
1421
1422        let y = transpose_axes(&x, &[0, -1][..]).unwrap();
1423        assert_eq!(y.shape(), &[2, 3]);
1424        assert_eq!(y, x);
1425
1426        assert!(transpose_axes(&x, &[][..]).is_err());
1427        assert!(transpose_axes(&x, &[0][..]).is_err());
1428        assert!(transpose_axes(&x, &[0, 0][..]).is_err());
1429        assert!(transpose_axes(&x, &[0, 0, 0][..]).is_err());
1430        assert!(transpose_axes(&x, &[0, 1, 1][..]).is_err());
1431
1432        let x = Array::from_slice(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], &[2, 3, 2]);
1433        let y = transpose(&x).unwrap();
1434        assert_eq!(y.shape(), &[2, 3, 2]);
1435        let expected = Array::from_slice(&[1, 7, 3, 9, 5, 11, 2, 8, 4, 10, 6, 12], &[2, 3, 2]);
1436        assert_eq!(y, expected);
1437
1438        let y = transpose_axes(&x, &[0, 1, 2][..]).unwrap();
1439        assert_eq!(y.shape(), &[2, 3, 2]);
1440        assert_eq!(y, x);
1441
1442        let y = transpose_axes(&x, &[1, 0, 2][..]).unwrap();
1443        assert_eq!(y.shape(), &[3, 2, 2]);
1444        let expected = Array::from_slice(&[1, 2, 7, 8, 3, 4, 9, 10, 5, 6, 11, 12], &[3, 2, 2]);
1445        assert_eq!(y, expected);
1446
1447        let y = transpose_axes(&x, &[0, 2, 1][..]).unwrap();
1448        assert_eq!(y.shape(), &[2, 2, 3]);
1449        let expected = Array::from_slice(&[1, 3, 5, 2, 4, 6, 7, 9, 11, 8, 10, 12], &[2, 2, 3]);
1450        assert_eq!(y, expected);
1451
1452        let mut x = Array::from_slice(&[0, 1, 2, 3, 4, 5, 6, 7], &[4, 2]);
1453        x = reshape(transpose(&x).unwrap(), &[2, 2, 2]).unwrap();
1454        let expected = Array::from_slice(&[0, 2, 4, 6, 1, 3, 5, 7], &[2, 2, 2]);
1455        assert_eq!(x, expected);
1456
1457        let mut x = Array::from_slice(&[0, 1, 2, 3, 4, 5, 6, 7], &[1, 4, 1, 2]);
1458        // assert!(x.flags().row_contiguous);
1459        x = transpose_axes(&x, &[2, 1, 0, 3][..]).unwrap();
1460        x.eval().unwrap();
1461        // assert!(x.flags().row_contiguous);
1462    }
1463}