mlx_rs/ops/
shapes.rs

1use std::borrow::Cow;
2
3use mlx_internal_macros::{default_device, generate_macro};
4use smallvec::SmallVec;
5
6use crate::{
7    constants::DEFAULT_STACK_VEC_LEN,
8    error::Result,
9    utils::{guard::Guarded, IntoOption, VectorArray},
10    Array, Stream, StreamOrDevice,
11};
12
13impl Array {
14    /// See [`expand_dims`].
15    #[default_device]
16    pub fn expand_dims_device(&self, axes: &[i32], stream: impl AsRef<Stream>) -> Result<Array> {
17        expand_dims_device(self, axes, stream)
18    }
19
20    /// See [`flatten`].
21    #[default_device]
22    pub fn flatten_device(
23        &self,
24        start_axis: impl Into<Option<i32>>,
25        end_axis: impl Into<Option<i32>>,
26        stream: impl AsRef<Stream>,
27    ) -> Result<Array> {
28        flatten_device(self, start_axis, end_axis, stream)
29    }
30
31    /// See [`reshape`].
32    #[default_device]
33    pub fn reshape_device(&self, shape: &[i32], stream: impl AsRef<Stream>) -> Result<Array> {
34        reshape_device(self, shape, stream)
35    }
36
37    /// See [`squeeze`].
38    #[default_device]
39    pub fn squeeze_device<'a>(
40        &'a self,
41        axes: impl IntoOption<&'a [i32]>,
42        stream: impl AsRef<Stream>,
43    ) -> Result<Array> {
44        squeeze_device(self, axes, stream)
45    }
46
47    /// See [`as_strided`]
48    #[default_device]
49    pub fn as_strided_device<'a>(
50        &'a self,
51        shape: impl IntoOption<&'a [i32]>,
52        strides: impl IntoOption<&'a [i64]>,
53        offset: impl Into<Option<usize>>,
54        stream: impl AsRef<Stream>,
55    ) -> Result<Array> {
56        as_strided_device(self, shape, strides, offset, stream)
57    }
58
59    /// See [`at_least_1d`]
60    #[default_device]
61    pub fn at_least_1d_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
62        at_least_1d_device(self, stream)
63    }
64
65    /// See [`at_least_2d`]
66    #[default_device]
67    pub fn at_least_2d_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
68        at_least_2d_device(self, stream)
69    }
70
71    /// See [`at_least_3d`]
72    #[default_device]
73    pub fn at_least_3d_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
74        at_least_3d_device(self, stream)
75    }
76
77    /// See [`move_axis`]
78    #[default_device]
79    pub fn move_axis_device(
80        &self,
81        src: i32,
82        dst: i32,
83        stream: impl AsRef<Stream>,
84    ) -> Result<Array> {
85        move_axis_device(self, src, dst, stream)
86    }
87
88    /// See [`split`]
89    #[default_device]
90    pub fn split_device(
91        &self,
92        indices: &[i32],
93        axis: impl Into<Option<i32>>,
94        stream: impl AsRef<Stream>,
95    ) -> Result<Vec<Array>> {
96        split_device(self, indices, axis, stream)
97    }
98
99    /// See [`split_equal`]
100    #[default_device]
101    pub fn split_equal_device(
102        &self,
103        num_parts: i32,
104        axis: impl Into<Option<i32>>,
105        stream: impl AsRef<Stream>,
106    ) -> Result<Vec<Array>> {
107        split_equal_device(self, num_parts, axis, stream)
108    }
109
110    /// See [`swap_axes`]
111    #[default_device]
112    pub fn swap_axes_device(
113        &self,
114        axis1: i32,
115        axis2: i32,
116        stream: impl AsRef<Stream>,
117    ) -> Result<Array> {
118        swap_axes_device(self, axis1, axis2, stream)
119    }
120
121    /// See [`transpose`]
122    #[default_device]
123    pub fn transpose_device(&self, axes: &[i32], stream: impl AsRef<Stream>) -> Result<Array> {
124        transpose_device(self, axes, stream)
125    }
126
127    /// See [`transpose_all`]
128    #[default_device]
129    pub fn transpose_all_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
130        transpose_all_device(self, stream)
131    }
132
133    /// [`transpose`] and unwrap the result.
134    pub fn t(&self) -> Array {
135        self.transpose_all().unwrap()
136    }
137}
138
139fn axes_or_default_to_all_size_one_axes<'a>(
140    axes: impl IntoOption<&'a [i32]>,
141    shape: &[i32],
142) -> Cow<'a, [i32]> {
143    match axes.into_option() {
144        Some(axes) => Cow::Borrowed(axes),
145        None => shape
146            .iter()
147            .enumerate()
148            .filter_map(|(i, &dim)| if dim == 1 { Some(i as i32) } else { None })
149            .collect(),
150    }
151}
152
153fn resolve_strides(
154    shape: &[i32],
155    strides: Option<&[i64]>,
156) -> SmallVec<[i64; DEFAULT_STACK_VEC_LEN]> {
157    match strides {
158        Some(strides) => SmallVec::from_slice(strides),
159        None => {
160            let result = shape
161                .iter()
162                .rev()
163                .scan(1, |acc, &dim| {
164                    let result = *acc;
165                    *acc *= dim as i64;
166                    Some(result)
167                })
168                .collect::<SmallVec<[i64; DEFAULT_STACK_VEC_LEN]>>();
169            result.into_iter().rev().collect()
170        }
171    }
172}
173
174/// Broadcast a vector of arrays against one another. Returns an error if the shapes are
175/// broadcastable.
176///
177/// # Params
178///
179/// - `arrays`: The arrays to broadcast.
180#[generate_macro]
181#[default_device]
182pub fn broadcast_arrays_device(
183    arrays: &[impl AsRef<Array>],
184    #[optional] stream: impl AsRef<Stream>,
185) -> Result<Vec<Array>> {
186    let c_vec = VectorArray::try_from_iter(arrays.iter())?;
187    Vec::<Array>::try_from_op(|res| unsafe {
188        mlx_sys::mlx_broadcast_arrays(res, c_vec.as_ptr(), stream.as_ref().as_ptr())
189    })
190}
191
192/// Create a view into the array with the given shape and strides.
193///
194/// # Example
195///
196/// ```rust
197/// use mlx_rs::{Array, ops::*};
198///
199/// let x = Array::from_iter(0..10, &[10]);
200/// let y = as_strided(&x, &[3, 3], &[1, 1], 0);
201/// ```
202#[generate_macro]
203#[default_device]
204pub fn as_strided_device<'a>(
205    a: impl AsRef<Array>,
206    #[optional] shape: impl IntoOption<&'a [i32]>,
207    #[optional] strides: impl IntoOption<&'a [i64]>,
208    #[optional] offset: impl Into<Option<usize>>,
209    #[optional] stream: impl AsRef<Stream>,
210) -> Result<Array> {
211    let a = a.as_ref();
212    let shape = shape.into_option().unwrap_or(a.shape());
213    let resolved_strides = resolve_strides(shape, strides.into_option());
214    let offset = offset.into().unwrap_or(0);
215
216    Array::try_from_op(|res| unsafe {
217        mlx_sys::mlx_as_strided(
218            res,
219            a.as_ptr(),
220            shape.as_ptr(),
221            shape.len(),
222            resolved_strides.as_ptr(),
223            resolved_strides.len(),
224            offset,
225            stream.as_ref().as_ptr(),
226        )
227    })
228}
229
230/// Broadcast an array to the given shape. Returns an error if the shapes are not broadcastable.
231///
232/// # Params
233///
234/// - `a`: The input array.
235/// - `shape`: The shape to broadcast to.
236///
237/// # Example
238///
239/// ```rust
240/// use mlx_rs::{Array, ops::*};
241///
242/// let x = Array::from_f32(2.3);
243/// let result = broadcast_to(&x, &[1, 1]);
244/// ```
245#[generate_macro]
246#[default_device]
247pub fn broadcast_to_device(
248    a: impl AsRef<Array>,
249    shape: &[i32],
250    #[optional] stream: impl AsRef<Stream>,
251) -> Result<Array> {
252    Array::try_from_op(|res| unsafe {
253        mlx_sys::mlx_broadcast_to(
254            res,
255            a.as_ref().as_ptr(),
256            shape.as_ptr(),
257            shape.len(),
258            stream.as_ref().as_ptr(),
259        )
260    })
261}
262
263/// Concatenate the arrays along the given axis. Returns an error if the shapes are invalid.
264///
265/// # Params
266///
267/// - `arrays`: The arrays to concatenate.
268/// - `axis`: The axis to concatenate along.
269///
270/// # Example
271///
272/// ```rust
273/// use mlx_rs::{Array, ops::*};
274///
275/// let x = Array::from_iter(0..4, &[2, 2]);
276/// let y = Array::from_iter(4..8, &[2, 2]);
277/// let result = concatenate(&[x, y], 0);
278/// ```
279#[generate_macro]
280#[default_device]
281pub fn concatenate_device(
282    arrays: &[impl AsRef<Array>],
283    #[optional] axis: impl Into<Option<i32>>,
284    #[optional] stream: impl AsRef<Stream>,
285) -> Result<Array> {
286    let axis = axis.into().unwrap_or(0);
287    let c_arrays = VectorArray::try_from_iter(arrays.iter())?;
288    Array::try_from_op(|res| unsafe {
289        mlx_sys::mlx_concatenate(res, c_arrays.as_ptr(), axis, stream.as_ref().as_ptr())
290    })
291}
292
293/// Add a size one dimension at the given axis, returns an error if the axes are invalid.
294///
295/// # Params
296///
297/// - `a`: The input array.
298/// - `axes`: The index of the inserted dimensions.
299///
300/// # Example
301///
302/// ```rust
303/// use mlx_rs::{Array, ops::*};
304///
305/// let x = Array::zeros::<i32>(&[2, 2]).unwrap();
306/// let result = expand_dims(&x, &[0]);
307/// ```
308#[generate_macro]
309#[default_device]
310pub fn expand_dims_device(
311    a: impl AsRef<Array>,
312    axes: &[i32],
313    #[optional] stream: impl AsRef<Stream>,
314) -> Result<Array> {
315    Array::try_from_op(|res| unsafe {
316        mlx_sys::mlx_expand_dims(
317            res,
318            a.as_ref().as_ptr(),
319            axes.as_ptr(),
320            axes.len(),
321            stream.as_ref().as_ptr(),
322        )
323    })
324}
325
326/// Flatten an array. Returns an error if the axes are invalid.
327///
328/// The axes flattened will be between `start_axis` and `end_axis`, inclusive. Negative axes are
329/// supported. After converting negative axis to positive, axes outside the valid range will be
330/// clamped to a valid value, `start_axis` to `0` and `end_axis` to `ndim - 1`.
331///
332/// # Params
333///
334/// - `a`: The input array.
335/// - `start_axis`: The first axis to flatten. Default is `0` if not provided.
336/// - `end_axis`: The last axis to flatten. Default is `-1` if not provided.
337///
338/// # Example
339///
340/// ```rust
341/// use mlx_rs::{Array, ops::*};
342///
343/// let x = Array::zeros::<i32>(&[2, 2, 2]).unwrap();
344/// let y = flatten(&x, None, None);
345/// ```
346#[generate_macro]
347#[default_device]
348pub fn flatten_device(
349    a: impl AsRef<Array>,
350    #[optional] start_axis: impl Into<Option<i32>>,
351    #[optional] end_axis: impl Into<Option<i32>>,
352    #[optional] stream: impl AsRef<Stream>,
353) -> Result<Array> {
354    let start_axis = start_axis.into().unwrap_or(0);
355    let end_axis = end_axis.into().unwrap_or(-1);
356
357    Array::try_from_op(|res| unsafe {
358        mlx_sys::mlx_flatten(
359            res,
360            a.as_ref().as_ptr(),
361            start_axis,
362            end_axis,
363            stream.as_ref().as_ptr(),
364        )
365    })
366}
367
368/// Unflatten an axis of an array to a shape.
369///
370/// # Params
371///
372/// - `a`: input array
373/// - `axis`: axis to unflatten
374/// - `shape`: shape to unflatten into
375#[generate_macro]
376#[default_device]
377pub fn unflatten_device(
378    a: impl AsRef<Array>,
379    axis: i32,
380    shape: &[i32],
381    #[optional] stream: impl AsRef<Stream>,
382) -> Result<Array> {
383    Array::try_from_op(|res| unsafe {
384        mlx_sys::mlx_unflatten(
385            res,
386            a.as_ref().as_ptr(),
387            axis,
388            shape.as_ptr(),
389            shape.len(),
390            stream.as_ref().as_ptr(),
391        )
392    })
393}
394
395/// Reshape an array while preserving the size. Returns an error if the new shape is invalid.
396///
397/// # Params
398///
399/// - `a`: The input array.
400/// - `shape`: New shape.
401///
402/// # Example
403///
404/// ```rust
405/// use mlx_rs::{Array, ops::*};
406///
407/// let x = Array::zeros::<i32>(&[2, 2]).unwrap();
408/// let result = reshape(&x, &[4]);
409/// ```
410#[generate_macro]
411#[default_device]
412pub fn reshape_device(
413    a: impl AsRef<Array>,
414    shape: &[i32],
415    #[optional] stream: impl AsRef<Stream>,
416) -> Result<Array> {
417    Array::try_from_op(|res| unsafe {
418        mlx_sys::mlx_reshape(
419            res,
420            a.as_ref().as_ptr(),
421            shape.as_ptr(),
422            shape.len(),
423            stream.as_ref().as_ptr(),
424        )
425    })
426}
427
428/// Remove length one axes from an array. Returns an error if the axes are invalid.
429///
430/// # Params
431///
432/// - `a`: The input array.
433/// - `axes`: Axes to remove. If `None`, all length one axes will be removed.
434///
435/// # Example
436///
437/// ```rust
438/// use mlx_rs::{Array, ops::*};
439///
440/// let x = Array::zeros::<i32>(&[1, 2, 1, 3]).unwrap();
441/// let result = squeeze(&x, None);
442/// ```
443#[generate_macro]
444#[default_device]
445pub fn squeeze_device<'a>(
446    a: impl AsRef<Array>,
447    #[optional] axes: impl IntoOption<&'a [i32]>,
448    #[optional] stream: impl AsRef<Stream>,
449) -> Result<Array> {
450    let a = a.as_ref();
451    let axes = axes_or_default_to_all_size_one_axes(axes, a.shape());
452    Array::try_from_op(|res| unsafe {
453        mlx_sys::mlx_squeeze(
454            res,
455            a.as_ptr(),
456            axes.as_ptr(),
457            axes.len(),
458            stream.as_ref().as_ptr(),
459        )
460    })
461}
462
463/// Convert array to have at least one dimension.
464///
465/// # Params
466///
467/// - `a`: The input array.
468///
469/// # Example
470///
471/// ```rust
472/// use mlx_rs::{Array, ops::*};
473///
474/// let x = Array::from_int(1);
475/// let out = at_least_1d(&x);
476/// ```
477#[generate_macro]
478#[default_device]
479pub fn at_least_1d_device(
480    a: impl AsRef<Array>,
481    #[optional] stream: impl AsRef<Stream>,
482) -> Result<Array> {
483    Array::try_from_op(|res| unsafe {
484        mlx_sys::mlx_atleast_1d(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
485    })
486}
487
488/// Convert array to have at least two dimensions.
489///
490/// # Params
491///
492/// - `a`: The input array.
493///
494/// # Example
495///
496/// ```rust
497/// use mlx_rs::{Array, ops::*};
498///
499/// let x = Array::from_int(1);
500/// let out = at_least_2d(&x);
501/// ```
502#[generate_macro]
503#[default_device]
504pub fn at_least_2d_device(
505    a: impl AsRef<Array>,
506    #[optional] stream: impl AsRef<Stream>,
507) -> Result<Array> {
508    Array::try_from_op(|res| unsafe {
509        mlx_sys::mlx_atleast_2d(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
510    })
511}
512
513/// Convert array to have at least three dimensions.
514///
515/// # Params
516///
517/// - `a`: The input array.
518///
519/// # Example
520///
521/// ```rust
522/// use mlx_rs::{Array, ops::*};
523///
524/// let x = Array::from_int(1);
525/// let out = at_least_3d(&x);
526/// ```
527#[generate_macro]
528#[default_device]
529pub fn at_least_3d_device(
530    a: impl AsRef<Array>,
531    #[optional] stream: impl AsRef<Stream>,
532) -> Result<Array> {
533    Array::try_from_op(|res| unsafe {
534        mlx_sys::mlx_atleast_3d(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
535    })
536}
537
538/// Move an axis to a new position. Returns an error if the axes are invalid.
539///
540/// # Params
541///
542/// - `a`: The input array.
543/// - `src`: Specifies the source axis.
544/// - `dst`: Specifies the destination axis.
545///
546/// # Example
547///
548/// ```rust
549/// use mlx_rs::{Array, ops::*};
550///
551/// let a = Array::zeros::<i32>(&[2, 3, 4]).unwrap();
552/// let result = move_axis(&a, 0, 2);
553/// ```
554#[generate_macro]
555#[default_device]
556pub fn move_axis_device(
557    a: impl AsRef<Array>,
558    src: i32,
559    dst: i32,
560    #[optional] stream: impl AsRef<Stream>,
561) -> Result<Array> {
562    Array::try_from_op(|res| unsafe {
563        mlx_sys::mlx_moveaxis(res, a.as_ref().as_ptr(), src, dst, stream.as_ref().as_ptr())
564    })
565}
566
567/// Split an array along a given axis. Returns an error if the indices are invalid.
568///
569/// # Params
570///
571/// - `a`: The input array.
572/// - `indices`: The indices to split at.
573/// - `axis`: The axis to split along. Default is `0` if not provided.
574///
575/// # Example
576///
577/// ```rust
578/// use mlx_rs::{Array, ops::*};
579///
580/// let a = Array::from_iter(0..10, &[10]);
581/// let result = split(&a, &[3, 7], 0);
582/// ```
583#[generate_macro]
584#[default_device]
585pub fn split_device(
586    a: impl AsRef<Array>,
587    indices: &[i32],
588    #[optional] axis: impl Into<Option<i32>>,
589    #[optional] stream: impl AsRef<Stream>,
590) -> Result<Vec<Array>> {
591    let axis = axis.into().unwrap_or(0);
592    Vec::<Array>::try_from_op(|res| unsafe {
593        mlx_sys::mlx_split(
594            res,
595            a.as_ref().as_ptr(),
596            indices.as_ptr(),
597            indices.len(),
598            axis,
599            stream.as_ref().as_ptr(),
600        )
601    })
602}
603
604/// Split an array into equal parts along a given axis. Returns an error if the array cannot be
605/// split into equal parts.
606///
607/// # Params
608///
609/// - `a`: The input array.
610/// - `num_parts`: The number of parts to split into.
611/// - `axis`: The axis to split along. Default is `0` if not provided.
612///
613/// # Example
614///
615/// ```rust
616/// use mlx_rs::{Array, ops::*};
617///
618/// let a = Array::from_iter(0..10, &[10]);
619/// let result = split_equal(&a, 2, 0);
620/// ```
621#[generate_macro]
622#[default_device]
623pub fn split_equal_device(
624    a: impl AsRef<Array>,
625    num_parts: i32,
626    #[optional] axis: impl Into<Option<i32>>,
627    #[optional] stream: impl AsRef<Stream>,
628) -> Result<Vec<Array>> {
629    let axis = axis.into().unwrap_or(0);
630    Vec::<Array>::try_from_op(|res| unsafe {
631        mlx_sys::mlx_split_equal_parts(
632            res,
633            a.as_ref().as_ptr(),
634            num_parts,
635            axis,
636            stream.as_ref().as_ptr(),
637        )
638    })
639}
640
641/// Number of padding values to add to the edges of each axis.
642#[derive(Debug)]
643pub enum PadWidth<'a> {
644    /// (before, after) values for all axes.
645    Same((i32, i32)),
646
647    /// List of (before, after) values for each axis.
648    Widths(&'a [(i32, i32)]),
649}
650
651impl From<i32> for PadWidth<'_> {
652    fn from(width: i32) -> Self {
653        PadWidth::Same((width, width))
654    }
655}
656
657impl From<(i32, i32)> for PadWidth<'_> {
658    fn from(width: (i32, i32)) -> Self {
659        PadWidth::Same(width)
660    }
661}
662
663impl<'a> From<&'a [(i32, i32)]> for PadWidth<'a> {
664    fn from(widths: &'a [(i32, i32)]) -> Self {
665        PadWidth::Widths(widths)
666    }
667}
668
669impl<'a, const N: usize> From<&'a [(i32, i32); N]> for PadWidth<'a> {
670    fn from(widths: &'a [(i32, i32); N]) -> Self {
671        PadWidth::Widths(widths)
672    }
673}
674
675impl PadWidth<'_> {
676    fn low_pads(&self, ndim: usize) -> SmallVec<[i32; DEFAULT_STACK_VEC_LEN]> {
677        match self {
678            PadWidth::Same((low, _high)) => (0..ndim).map(|_| *low).collect(),
679            PadWidth::Widths(widths) => widths.iter().map(|(low, _high)| *low).collect(),
680        }
681    }
682
683    fn high_pads(&self, ndim: usize) -> SmallVec<[i32; DEFAULT_STACK_VEC_LEN]> {
684        match self {
685            PadWidth::Same((_low, high)) => (0..ndim).map(|_| *high).collect(),
686            PadWidth::Widths(widths) => widths.iter().map(|(_low, high)| *high).collect(),
687        }
688    }
689}
690
691/// The padding mode.
692#[derive(Debug)]
693pub enum PadMode {
694    /// Pad with a constant value.
695    Constant,
696
697    /// Pad with the edge value.
698    Edge,
699}
700
701impl PadMode {
702    unsafe fn as_c_str(&self) -> *const i8 {
703        static CONSTANT: &[u8] = b"constant\0";
704        static EDGE: &[u8] = b"edge\0";
705
706        match self {
707            PadMode::Constant => CONSTANT.as_ptr() as *const _,
708            PadMode::Edge => EDGE.as_ptr() as *const _,
709        }
710    }
711}
712
713/// Pad an array with a constant value. Returns an error if the width is invalid.
714///
715/// # Params
716///
717/// - `a`: The input array.
718/// - `width`: Number of padded values to add to the edges of each axis:`((before_1, after_1),
719///   (before_2, after_2), ..., (before_N, after_N))`. If a single pair of integers is passed then
720///   `(before_i, after_i)` are all the same. If a single integer or tuple with a single integer is
721///   passed then all axes are extended by the same number on each side.
722/// - `value`: The value to pad the array with. Default is `0` if not provided.
723/// - `mode`: The padding mode. Default is `PadMode::Constant` if not provided.
724///
725/// # Example
726///
727/// ```rust
728/// use mlx_rs::{Array, ops::*};
729///
730/// let a = Array::from_iter(0..4, &[2, 2]);
731/// let result = pad(&a, 1, Array::from_int(0), None);
732/// ```
733#[generate_macro]
734#[default_device]
735pub fn pad_device<'a>(
736    a: impl AsRef<Array>,
737    #[optional] width: impl Into<PadWidth<'a>>,
738    #[optional] value: impl Into<Option<Array>>,
739    #[optional] mode: impl Into<Option<PadMode>>,
740    #[optional] stream: impl AsRef<Stream>,
741) -> Result<Array> {
742    let a = a.as_ref();
743    let width = width.into();
744    let ndim = a.ndim();
745    let axes: SmallVec<[i32; DEFAULT_STACK_VEC_LEN]> = (0..ndim).map(|i| i as i32).collect();
746    let low_pads = width.low_pads(ndim);
747    let high_pads = width.high_pads(ndim);
748    let value = value
749        .into()
750        .map(Ok)
751        .unwrap_or_else(|| Array::from_int(0).as_dtype(a.dtype()))?;
752    let mode = mode.into().unwrap_or(PadMode::Constant);
753
754    Array::try_from_op(|res| unsafe {
755        mlx_sys::mlx_pad(
756            res,
757            a.as_ptr(),
758            axes.as_ptr(),
759            axes.len(),
760            low_pads.as_ptr(),
761            low_pads.len(),
762            high_pads.as_ptr(),
763            high_pads.len(),
764            value.as_ptr(),
765            mode.as_c_str(),
766            stream.as_ref().as_ptr(),
767        )
768    })
769}
770
771/// Stacks the arrays along a new axis. Returns an error if the arguments are invalid.
772///
773/// # Params
774///
775/// - `arrays`: The input arrays.
776/// - `axis`: The axis in the result array along which the input arrays are stacked.
777///
778/// # Example
779///
780/// ```rust
781/// use mlx_rs::{Array, ops::*};
782///
783/// let a = Array::from_iter(0..4, &[2, 2]);
784/// let b = Array::from_iter(4..8, &[2, 2]);
785/// let result = stack(&[&a, &b], 0);
786/// ```
787#[generate_macro]
788#[default_device]
789pub fn stack_device(
790    arrays: &[impl AsRef<Array>],
791    axis: i32,
792    #[optional] stream: impl AsRef<Stream>,
793) -> Result<Array> {
794    let c_vec = VectorArray::try_from_iter(arrays.iter())?;
795    Array::try_from_op(|res| unsafe {
796        mlx_sys::mlx_stack(res, c_vec.as_ptr(), axis, stream.as_ref().as_ptr())
797    })
798}
799
800/// Stacks the arrays along a new axis. Returns an error if the arrays have different shapes.
801///
802/// # Params
803///
804/// - `arrays`: The input arrays.
805///
806/// # Example
807///
808/// ```rust
809/// use mlx_rs::{Array, ops::*};
810///
811/// let a = Array::from_iter(0..4, &[2, 2]);
812/// let b = Array::from_iter(4..8, &[2, 2]);
813/// let result = stack_all(&[&a, &b]);
814/// ```
815#[generate_macro]
816#[default_device]
817pub fn stack_all_device(
818    arrays: &[impl AsRef<Array>],
819    #[optional] stream: impl AsRef<Stream>,
820) -> Result<Array> {
821    let c_vec = VectorArray::try_from_iter(arrays.iter())?;
822    Array::try_from_op(|res| unsafe {
823        mlx_sys::mlx_stack_all(res, c_vec.as_ptr(), stream.as_ref().as_ptr())
824    })
825}
826
827/// Swap two axes of an array. Returns an error if the axes are invalid.
828///
829/// # Params
830///
831/// - `a`: The input array.
832/// - `axis1`: The first axis.
833/// - `axis2`: The second axis.
834///
835/// # Example
836///
837/// ```rust
838/// use mlx_rs::{Array, ops::*};
839///
840/// let a = Array::from_iter(0..6, &[2, 3]);
841/// let result = swap_axes(&a, 0, 1);
842/// ```
843#[generate_macro]
844#[default_device]
845pub fn swap_axes_device(
846    a: impl AsRef<Array>,
847    axis1: i32,
848    axis2: i32,
849    #[optional] stream: impl AsRef<Stream>,
850) -> Result<Array> {
851    Array::try_from_op(|res| unsafe {
852        mlx_sys::mlx_swapaxes(
853            res,
854            a.as_ref().as_ptr(),
855            axis1,
856            axis2,
857            stream.as_ref().as_ptr(),
858        )
859    })
860}
861
862/// Construct an array by repeating `a` the number of times given by `reps`.
863///
864/// # Params
865///
866/// - `a`: The input array.
867/// - `reps`: The number of repetitions along each axis.
868///
869/// # Example
870///
871/// ```rust
872/// use mlx_rs::{Array, ops::*};
873///
874/// let x = Array::from_slice(&[1, 2, 3], &[3]);
875/// let y = tile(&x, &[2]);
876/// ```
877#[generate_macro]
878#[default_device]
879pub fn tile_device(
880    a: impl AsRef<Array>,
881    reps: &[i32],
882    #[optional] stream: impl AsRef<Stream>,
883) -> Result<Array> {
884    Array::try_from_op(|res| unsafe {
885        mlx_sys::mlx_tile(
886            res,
887            a.as_ref().as_ptr(),
888            reps.as_ptr(),
889            reps.len(),
890            stream.as_ref().as_ptr(),
891        )
892    })
893}
894
895/// Transpose the dimensions of the array. Returns an error if the axes are invalid.
896///
897/// # Params
898///
899/// - `a`: The input array.
900/// - `axes`: Specifies the source axis for each axis in the new array. The default is to reverse
901///  the axes.
902///
903/// # Example
904///
905/// ```rust
906/// use mlx_rs::{Array, ops::*};
907///
908/// let x = Array::from_slice(&[1, 2, 3, 4, 5, 6], &[2, 3]);
909/// let y1 = transpose(&x, &[0, 1]).unwrap();
910/// let y2 = transpose_all(&x).unwrap();
911/// ```
912///
913/// # See also
914///
915/// - [`transpose_all`]
916#[generate_macro]
917#[default_device]
918pub fn transpose_device(
919    a: impl AsRef<Array>,
920    axes: &[i32],
921    #[optional] stream: impl AsRef<Stream>,
922) -> Result<Array> {
923    Array::try_from_op(|res| unsafe {
924        mlx_sys::mlx_transpose(
925            res,
926            a.as_ref().as_ptr(),
927            axes.as_ptr(),
928            axes.len(),
929            stream.as_ref().as_ptr(),
930        )
931    })
932}
933
934/// Transpose with all axes reversed
935#[generate_macro]
936#[default_device]
937pub fn transpose_all_device(
938    a: impl AsRef<Array>,
939    #[optional] stream: impl AsRef<Stream>,
940) -> Result<Array> {
941    Array::try_from_op(|res| unsafe {
942        mlx_sys::mlx_transpose_all(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
943    })
944}
945
946// The unit tests below are adapted from
947// https://github.com/ml-explore/mlx/blob/main/tests/ops_tests.cpp
948#[cfg(test)]
949mod tests {
950    use crate::{array, Array, Dtype};
951
952    use super::*;
953
954    #[test]
955    fn test_squeeze() {
956        let a = Array::zeros::<i32>(&[2, 1, 2, 1, 2, 1]).unwrap();
957        assert_eq!(squeeze(&a, &[1, 3, 5][..]).unwrap().shape(), &[2, 2, 2]);
958        assert_eq!(squeeze(&a, &[-1, -3, -5][..]).unwrap().shape(), &[2, 2, 2]);
959        assert_eq!(squeeze(&a, &[1][..]).unwrap().shape(), &[2, 2, 1, 2, 1]);
960        assert_eq!(squeeze(&a, &[-1][..]).unwrap().shape(), &[2, 1, 2, 1, 2]);
961
962        assert!(squeeze(&a, &[0][..]).is_err());
963        assert!(squeeze(&a, &[2][..]).is_err());
964        assert!(squeeze(&a, &[1, 3, 1][..]).is_err());
965        assert!(squeeze(&a, &[1, 3, -3][..]).is_err());
966    }
967
968    #[test]
969    fn test_expand_dims() {
970        let a = Array::zeros::<i32>(&[2, 2]).unwrap();
971        assert_eq!(expand_dims(&a, &[0][..]).unwrap().shape(), &[1, 2, 2]);
972        assert_eq!(expand_dims(&a, &[-1][..]).unwrap().shape(), &[2, 2, 1]);
973        assert_eq!(expand_dims(&a, &[1][..]).unwrap().shape(), &[2, 1, 2]);
974        assert_eq!(
975            expand_dims(&a, &[0, 1, 2]).unwrap().shape(),
976            &[1, 1, 1, 2, 2]
977        );
978        assert_eq!(
979            expand_dims(&a, &[0, 1, 2, 5, 6, 7]).unwrap().shape(),
980            &[1, 1, 1, 2, 2, 1, 1, 1]
981        );
982
983        assert!(expand_dims(&a, &[3]).is_err());
984        assert!(expand_dims(&a, &[0, 1, 0]).is_err());
985        assert!(expand_dims(&a, &[0, 1, -4]).is_err());
986    }
987
988    #[test]
989    fn test_flatten() {
990        let x = Array::zeros::<i32>(&[2, 3, 4]).unwrap();
991        assert_eq!(flatten(&x, None, None).unwrap().shape(), &[2 * 3 * 4]);
992
993        assert_eq!(flatten(&x, 1, 1).unwrap().shape(), &[2, 3, 4]);
994        assert_eq!(flatten(&x, 1, 2).unwrap().shape(), &[2, 3 * 4]);
995        assert_eq!(flatten(&x, 1, 3).unwrap().shape(), &[2, 3 * 4]);
996        assert_eq!(flatten(&x, 1, -1).unwrap().shape(), &[2, 3 * 4]);
997        assert_eq!(flatten(&x, -2, -1).unwrap().shape(), &[2, 3 * 4]);
998        assert_eq!(flatten(&x, -3, -1).unwrap().shape(), &[2 * 3 * 4]);
999        assert_eq!(flatten(&x, -4, -1).unwrap().shape(), &[2 * 3 * 4]);
1000
1001        assert!(flatten(&x, 2, 1).is_err());
1002
1003        assert!(flatten(&x, 5, 6).is_err());
1004
1005        assert!(flatten(&x, -5, -4).is_err());
1006
1007        let x = Array::from_int(1);
1008        assert_eq!(flatten(&x, -3, -1).unwrap().shape(), &[1]);
1009        assert_eq!(flatten(&x, 0, 0).unwrap().shape(), &[1]);
1010    }
1011
1012    #[test]
1013    fn test_unflatten() {
1014        let a = array!([1, 2, 3, 4]);
1015        let b = unflatten(&a, 0, &[2, -1]).unwrap();
1016        let expected = array!([[1, 2], [3, 4]]);
1017        assert_eq!(b, expected);
1018    }
1019
1020    #[test]
1021    fn test_reshape() {
1022        let x = Array::from_int(1);
1023        assert!(reshape(&x, &[]).unwrap().shape().is_empty());
1024        assert!(reshape(&x, &[2]).is_err());
1025        let y = reshape(&x, &[1, 1, 1]).unwrap();
1026        assert_eq!(y.shape(), &[1, 1, 1]);
1027        let y = reshape(&x, &[-1, 1, 1]).unwrap();
1028        assert_eq!(y.shape(), &[1, 1, 1]);
1029        let y = reshape(&x, &[1, 1, -1]).unwrap();
1030        assert_eq!(y.shape(), &[1, 1, 1]);
1031        assert!(reshape(&x, &[1, -1, -1]).is_err());
1032        assert!(reshape(&x, &[2, -1]).is_err());
1033
1034        let x = Array::zeros::<i32>(&[2, 2, 2]).unwrap();
1035        let y = reshape(&x, &[8]).unwrap();
1036        assert_eq!(y.shape(), &[8]);
1037        assert!(reshape(&x, &[7]).is_err());
1038        let y = reshape(&x, &[-1]).unwrap();
1039        assert_eq!(y.shape(), &[8]);
1040        let y = reshape(&x, &[-1, 2]).unwrap();
1041        assert_eq!(y.shape(), &[4, 2]);
1042        assert!(reshape(&x, &[-1, 7]).is_err());
1043
1044        let x = Array::from_slice::<i32>(&[], &[0]);
1045        let y = reshape(&x, &[0, 0, 0]).unwrap();
1046        assert_eq!(y.shape(), &[0, 0, 0]);
1047        y.eval().unwrap();
1048        assert_eq!(y.size(), 0);
1049        assert!(reshape(&x, &[]).is_err());
1050        assert!(reshape(&x, &[1]).is_err());
1051        let y = reshape(&x, &[1, 5, 0]).unwrap();
1052        assert_eq!(y.shape(), &[1, 5, 0]);
1053    }
1054
1055    #[test]
1056    fn test_as_strided() {
1057        let x = Array::from_iter(0..10, &[10]);
1058        let y = as_strided(&x, &[3, 3][..], &[1, 1][..], 0).unwrap();
1059        let expected = Array::from_slice(&[0, 1, 2, 1, 2, 3, 2, 3, 4], &[3, 3]);
1060        assert_eq!(y, expected);
1061
1062        let y = as_strided(&x, &[3, 3][..], &[0, 3][..], 0).unwrap();
1063        let expected = Array::from_slice(&[0, 3, 6, 0, 3, 6, 0, 3, 6], &[3, 3]);
1064        assert_eq!(y, expected);
1065
1066        let x = x.reshape(&[2, 5]).unwrap();
1067        let x = x.transpose(&[1, 0][..]).unwrap();
1068        let y = as_strided(&x, &[3, 3][..], &[2, 1][..], 1).unwrap();
1069        let expected = Array::from_slice(&[5, 1, 6, 6, 2, 7, 7, 3, 8], &[3, 3]);
1070        assert_eq!(y, expected);
1071    }
1072
1073    #[test]
1074    fn test_at_least_1d() {
1075        let x = Array::from_int(1);
1076        let out = at_least_1d(&x).unwrap();
1077        assert_eq!(out.ndim(), 1);
1078        assert_eq!(out.shape(), &[1]);
1079
1080        let x = Array::from_slice(&[1, 2, 3], &[3]);
1081        let out = at_least_1d(&x).unwrap();
1082        assert_eq!(out.ndim(), 1);
1083        assert_eq!(out.shape(), &[3]);
1084
1085        let x = Array::from_slice(&[1, 2, 3], &[3, 1]);
1086        let out = at_least_1d(&x).unwrap();
1087        assert_eq!(out.ndim(), 2);
1088        assert_eq!(out.shape(), &[3, 1]);
1089    }
1090
1091    #[test]
1092    fn test_at_least_2d() {
1093        let x = Array::from_int(1);
1094        let out = at_least_2d(&x).unwrap();
1095        assert_eq!(out.ndim(), 2);
1096        assert_eq!(out.shape(), &[1, 1]);
1097
1098        let x = Array::from_slice(&[1, 2, 3], &[3]);
1099        let out = at_least_2d(&x).unwrap();
1100        assert_eq!(out.ndim(), 2);
1101        assert_eq!(out.shape(), &[1, 3]);
1102
1103        let x = Array::from_slice(&[1, 2, 3], &[3, 1]);
1104        let out = at_least_2d(&x).unwrap();
1105        assert_eq!(out.ndim(), 2);
1106        assert_eq!(out.shape(), &[3, 1]);
1107    }
1108
1109    #[test]
1110    fn test_at_least_3d() {
1111        let x = Array::from_int(1);
1112        let out = at_least_3d(&x).unwrap();
1113        assert_eq!(out.ndim(), 3);
1114        assert_eq!(out.shape(), &[1, 1, 1]);
1115
1116        let x = Array::from_slice(&[1, 2, 3], &[3]);
1117        let out = at_least_3d(&x).unwrap();
1118        assert_eq!(out.ndim(), 3);
1119        assert_eq!(out.shape(), &[1, 3, 1]);
1120
1121        let x = Array::from_slice(&[1, 2, 3], &[3, 1]);
1122        let out = at_least_3d(&x).unwrap();
1123        assert_eq!(out.ndim(), 3);
1124        assert_eq!(out.shape(), &[3, 1, 1]);
1125    }
1126
1127    #[test]
1128    fn test_move_axis() {
1129        let a = Array::from_int(0);
1130        assert!(move_axis(&a, 0, 0).is_err());
1131
1132        let a = Array::zeros::<i32>(&[2]).unwrap();
1133        assert!(move_axis(&a, 0, 1).is_err());
1134        assert_eq!(move_axis(&a, 0, 0).unwrap().shape(), &[2]);
1135        assert_eq!(move_axis(&a, -1, -1).unwrap().shape(), &[2]);
1136
1137        let a = Array::zeros::<i32>(&[2, 3, 4]).unwrap();
1138        assert!(move_axis(&a, 0, -4).is_err());
1139        assert!(move_axis(&a, 0, 3).is_err());
1140        assert!(move_axis(&a, 3, 0).is_err());
1141        assert!(move_axis(&a, -4, 0).is_err());
1142        assert_eq!(move_axis(&a, 0, 2).unwrap().shape(), &[3, 4, 2]);
1143        assert_eq!(move_axis(&a, 0, 1).unwrap().shape(), &[3, 2, 4]);
1144        assert_eq!(move_axis(&a, 0, -1).unwrap().shape(), &[3, 4, 2]);
1145        assert_eq!(move_axis(&a, -2, 2).unwrap().shape(), &[2, 4, 3]);
1146    }
1147
1148    #[test]
1149    fn test_split_equal() {
1150        let x = Array::from_int(3);
1151        assert!(split_equal(&x, 0, 0).is_err());
1152
1153        let x = Array::from_slice(&[0, 1, 2], &[3]);
1154        assert!(split_equal(&x, 3, 1).is_err());
1155        assert!(split_equal(&x, -2, 1).is_err());
1156
1157        let out = split_equal(&x, 3, 0).unwrap();
1158        assert_eq!(out.len(), 3);
1159
1160        let mut out = split_equal(&x, 3, -1).unwrap();
1161        assert_eq!(out.len(), 3);
1162        for (i, a) in out.iter_mut().enumerate() {
1163            assert_eq!(a.shape(), &[1]);
1164            assert_eq!(a.dtype(), Dtype::Int32);
1165            assert_eq!(a.item::<i32>(), i as i32);
1166        }
1167
1168        let x = Array::from_slice(&[0, 1, 2, 3, 4, 5], &[2, 3]);
1169        let out = split_equal(&x, 2, None).unwrap();
1170        assert_eq!(out[0], Array::from_slice(&[0, 1, 2], &[1, 3]));
1171        assert_eq!(out[1], Array::from_slice(&[3, 4, 5], &[1, 3]));
1172
1173        let out = split_equal(&x, 3, 1).unwrap();
1174        assert_eq!(out[0], Array::from_slice(&[0, 3], &[2, 1]));
1175        assert_eq!(out[1], Array::from_slice(&[1, 4], &[2, 1]));
1176        assert_eq!(out[2], Array::from_slice(&[2, 5], &[2, 1]));
1177
1178        let x = Array::zeros::<i32>(&[8, 12]).unwrap();
1179        let out = split_equal(&x, 2, None).unwrap();
1180        assert_eq!(out.len(), 2);
1181        assert_eq!(out[0].shape(), &[4, 12]);
1182        assert_eq!(out[1].shape(), &[4, 12]);
1183
1184        let out = split_equal(&x, 3, 1).unwrap();
1185        assert_eq!(out.len(), 3);
1186        assert_eq!(out[0].shape(), &[8, 4]);
1187        assert_eq!(out[1].shape(), &[8, 4]);
1188        assert_eq!(out[2].shape(), &[8, 4]);
1189    }
1190
1191    #[test]
1192    fn test_split() {
1193        let x = Array::zeros::<i32>(&[8, 12]).unwrap();
1194
1195        let out = split(&x, &[], None).unwrap();
1196        assert_eq!(out.len(), 1);
1197        assert_eq!(out[0].shape(), x.shape());
1198
1199        let out = split(&x, &[3, 7], None).unwrap();
1200        assert_eq!(out.len(), 3);
1201        assert_eq!(out[0].shape(), &[3, 12]);
1202        assert_eq!(out[1].shape(), &[4, 12]);
1203        assert_eq!(out[2].shape(), &[1, 12]);
1204
1205        let out = split(&x, &[20], None).unwrap();
1206        assert_eq!(out.len(), 2);
1207        assert_eq!(out[0].shape(), &[8, 12]);
1208        assert_eq!(out[1].shape(), &[0, 12]);
1209
1210        let out = split(&x, &[-5], None).unwrap();
1211        assert_eq!(out[0].shape(), &[3, 12]);
1212        assert_eq!(out[1].shape(), &[5, 12]);
1213
1214        let out = split(&x, &[2, 8], Some(1)).unwrap();
1215        assert_eq!(out[0].shape(), &[8, 2]);
1216        assert_eq!(out[1].shape(), &[8, 6]);
1217        assert_eq!(out[2].shape(), &[8, 4]);
1218
1219        let x = Array::from_iter(0i32..5, &[5]);
1220        let out = split(&x, &[2, 1, 2], None).unwrap();
1221        assert_eq!(out[0], Array::from_slice(&[0, 1], &[2]));
1222        assert_eq!(out[1], Array::from_slice::<i32>(&[], &[0]));
1223        assert_eq!(out[2], Array::from_slice(&[1], &[1]));
1224        assert_eq!(out[3], Array::from_slice(&[2, 3, 4], &[3]));
1225    }
1226
1227    #[test]
1228    fn test_pad() {
1229        let x = Array::zeros::<f32>(&[1, 2, 3]).unwrap();
1230        assert_eq!(pad(&x, 1, None, None).unwrap().shape(), &[3, 4, 5]);
1231        assert_eq!(pad(&x, (0, 1), None, None).unwrap().shape(), &[2, 3, 4]);
1232        assert_eq!(
1233            pad(&x, &[(1, 1), (1, 2), (3, 1)], None, None)
1234                .unwrap()
1235                .shape(),
1236            &[3, 5, 7]
1237        );
1238    }
1239
1240    #[test]
1241    fn test_stack() {
1242        let x = Array::from_slice::<f32>(&[], &[0]);
1243        let x = vec![x];
1244        assert_eq!(stack(&x, 0).unwrap().shape(), &[1, 0]);
1245        assert_eq!(stack(&x, 1).unwrap().shape(), &[0, 1]);
1246
1247        let x = Array::from_slice(&[1, 2, 3], &[3]);
1248        let x = vec![x];
1249        assert_eq!(stack(&x, 0).unwrap().shape(), &[1, 3]);
1250        assert_eq!(stack(&x, 1).unwrap().shape(), &[3, 1]);
1251
1252        let y = Array::from_slice(&[4, 5, 6], &[3]);
1253        let mut z = x;
1254        z.push(y);
1255        assert_eq!(stack_all(&z).unwrap().shape(), &[2, 3]);
1256        assert_eq!(stack(&z, 1).unwrap().shape(), &[3, 2]);
1257        assert_eq!(stack(&z, -1).unwrap().shape(), &[3, 2]);
1258        assert_eq!(stack(&z, -2).unwrap().shape(), &[2, 3]);
1259
1260        let empty: Vec<Array> = Vec::new();
1261        assert!(stack(&empty, 0).is_err());
1262
1263        let x = Array::from_slice(&[1, 2, 3], &[3])
1264            .as_dtype(Dtype::Float16)
1265            .unwrap();
1266        let y = Array::from_slice(&[4, 5, 6], &[3])
1267            .as_dtype(Dtype::Int32)
1268            .unwrap();
1269        assert_eq!(stack(&[x, y], 0).unwrap().dtype(), Dtype::Float16);
1270
1271        let x = Array::from_slice(&[1, 2, 3], &[3])
1272            .as_dtype(Dtype::Int32)
1273            .unwrap();
1274        let y = Array::from_slice(&[4, 5, 6, 7], &[4])
1275            .as_dtype(Dtype::Int32)
1276            .unwrap();
1277        assert!(stack(&[x, y], 0).is_err());
1278    }
1279
1280    #[test]
1281    fn test_swap_axes() {
1282        let a = Array::from_int(0);
1283        assert!(swap_axes(&a, 0, 0).is_err());
1284
1285        let a = Array::zeros::<i32>(&[2]).unwrap();
1286        assert!(swap_axes(&a, 0, 1).is_err());
1287        assert_eq!(swap_axes(&a, 0, 0).unwrap().shape(), &[2]);
1288        assert_eq!(swap_axes(&a, -1, -1).unwrap().shape(), &[2]);
1289
1290        let a = Array::zeros::<i32>(&[2, 3, 4]).unwrap();
1291        assert!(swap_axes(&a, 0, -4).is_err());
1292        assert!(swap_axes(&a, 0, 3).is_err());
1293        assert!(swap_axes(&a, 3, 0).is_err());
1294        assert!(swap_axes(&a, -4, 0).is_err());
1295        assert_eq!(swap_axes(&a, 0, 2).unwrap().shape(), &[4, 3, 2]);
1296        assert_eq!(swap_axes(&a, 0, 1).unwrap().shape(), &[3, 2, 4]);
1297        assert_eq!(swap_axes(&a, 0, -1).unwrap().shape(), &[4, 3, 2]);
1298        assert_eq!(swap_axes(&a, -2, 2).unwrap().shape(), &[2, 4, 3]);
1299    }
1300
1301    #[test]
1302    fn test_tile() {
1303        let x = Array::from_slice(&[1, 2, 3], &[3]);
1304        let y = tile(&x, &[2]).unwrap();
1305        let expected = Array::from_slice(&[1, 2, 3, 1, 2, 3], &[6]);
1306        assert_eq!(y, expected);
1307
1308        let x = Array::from_slice(&[1, 2, 3, 4], &[2, 2]);
1309        let y = tile(&x, &[2]).unwrap();
1310        let expected = Array::from_slice(&[1, 2, 1, 2, 3, 4, 3, 4], &[2, 4]);
1311        assert_eq!(y, expected);
1312
1313        let x = Array::from_slice(&[1, 2, 3, 4], &[2, 2]);
1314        let y = tile(&x, &[4, 1]).unwrap();
1315        let expected =
1316            Array::from_slice(&[1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4], &[8, 2]);
1317        assert_eq!(y, expected);
1318
1319        let x = Array::from_slice(&[1, 2, 3, 4], &[2, 2]);
1320        let y = tile(&x, &[2, 2]).unwrap();
1321        let expected =
1322            Array::from_slice(&[1, 2, 1, 2, 3, 4, 3, 4, 1, 2, 1, 2, 3, 4, 3, 4], &[4, 4]);
1323        assert_eq!(y, expected);
1324
1325        let x = Array::from_slice(&[1, 2, 3], &[3]);
1326        let y = tile(&x, &[2, 2, 2]).unwrap();
1327        let expected = Array::from_slice(
1328            &[
1329                1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3,
1330            ],
1331            &[2, 2, 6],
1332        );
1333        assert_eq!(y, expected);
1334    }
1335
1336    #[test]
1337    fn test_transpose() {
1338        let x = Array::from_int(1);
1339        let y = transpose_all(&x).unwrap();
1340        assert!(y.shape().is_empty());
1341        assert_eq!(y.item::<i32>(), 1);
1342        assert!(transpose(&x, &[0][..]).is_err());
1343        assert!(transpose(&x, &[1][..]).is_err());
1344
1345        let x = Array::from_slice(&[1], &[1]);
1346        let y = transpose_all(&x).unwrap();
1347        assert_eq!(y.shape(), &[1]);
1348        assert_eq!(y.item::<i32>(), 1);
1349
1350        let y = transpose(&x, &[-1][..]).unwrap();
1351        assert_eq!(y.shape(), &[1]);
1352        assert_eq!(y.item::<i32>(), 1);
1353
1354        assert!(transpose(&x, &[1][..]).is_err());
1355        assert!(transpose(&x, &[0, 0][..]).is_err());
1356
1357        let x = Array::from_slice::<i32>(&[], &[0]);
1358        let y = transpose_all(&x).unwrap();
1359        assert_eq!(y.shape(), &[0]);
1360        y.eval().unwrap();
1361        assert_eq!(y.size(), 0);
1362
1363        let x = Array::from_slice(&[1, 2, 3, 4, 5, 6], &[2, 3]);
1364        let mut y = transpose_all(&x).unwrap();
1365        assert_eq!(y.shape(), &[3, 2]);
1366        y = transpose(&x, &[-1, 0][..]).unwrap();
1367        assert_eq!(y.shape(), &[3, 2]);
1368        y = transpose(&x, &[-1, -2][..]).unwrap();
1369        assert_eq!(y.shape(), &[3, 2]);
1370        y.eval().unwrap();
1371        assert_eq!(y, Array::from_slice(&[1, 4, 2, 5, 3, 6], &[3, 2]));
1372
1373        let y = transpose(&x, &[0, 1][..]).unwrap();
1374        assert_eq!(y.shape(), &[2, 3]);
1375        assert_eq!(y, x);
1376
1377        let y = transpose(&x, &[0, -1][..]).unwrap();
1378        assert_eq!(y.shape(), &[2, 3]);
1379        assert_eq!(y, x);
1380
1381        assert!(transpose(&x, &[][..]).is_err());
1382        assert!(transpose(&x, &[0][..]).is_err());
1383        assert!(transpose(&x, &[0, 0][..]).is_err());
1384        assert!(transpose(&x, &[0, 0, 0][..]).is_err());
1385        assert!(transpose(&x, &[0, 1, 1][..]).is_err());
1386
1387        let x = Array::from_slice(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], &[2, 3, 2]);
1388        let y = transpose_all(&x).unwrap();
1389        assert_eq!(y.shape(), &[2, 3, 2]);
1390        let expected = Array::from_slice(&[1, 7, 3, 9, 5, 11, 2, 8, 4, 10, 6, 12], &[2, 3, 2]);
1391        assert_eq!(y, expected);
1392
1393        let y = transpose(&x, &[0, 1, 2][..]).unwrap();
1394        assert_eq!(y.shape(), &[2, 3, 2]);
1395        assert_eq!(y, x);
1396
1397        let y = transpose(&x, &[1, 0, 2][..]).unwrap();
1398        assert_eq!(y.shape(), &[3, 2, 2]);
1399        let expected = Array::from_slice(&[1, 2, 7, 8, 3, 4, 9, 10, 5, 6, 11, 12], &[3, 2, 2]);
1400        assert_eq!(y, expected);
1401
1402        let y = transpose(&x, &[0, 2, 1][..]).unwrap();
1403        assert_eq!(y.shape(), &[2, 2, 3]);
1404        let expected = Array::from_slice(&[1, 3, 5, 2, 4, 6, 7, 9, 11, 8, 10, 12], &[2, 2, 3]);
1405        assert_eq!(y, expected);
1406
1407        let mut x = Array::from_slice(&[0, 1, 2, 3, 4, 5, 6, 7], &[4, 2]);
1408        x = reshape(transpose_all(&x).unwrap(), &[2, 2, 2]).unwrap();
1409        let expected = Array::from_slice(&[0, 2, 4, 6, 1, 3, 5, 7], &[2, 2, 2]);
1410        assert_eq!(x, expected);
1411
1412        let mut x = Array::from_slice(&[0, 1, 2, 3, 4, 5, 6, 7], &[1, 4, 1, 2]);
1413        // assert!(x.flags().row_contiguous);
1414        x = transpose(&x, &[2, 1, 0, 3][..]).unwrap();
1415        x.eval().unwrap();
1416        // assert!(x.flags().row_contiguous);
1417    }
1418}