mlx_rs/ops/
factory.rs

1use crate::array::ArrayElement;
2use crate::error::Result;
3use crate::utils::guard::Guarded;
4use crate::{array::Array, stream::StreamOrDevice};
5use crate::{Dtype, Stream};
6use mlx_internal_macros::{default_device, generate_macro};
7use num_traits::NumCast;
8
9impl Array {
10    /// Construct an array of zeros returning an error if shape is invalid.
11    ///
12    /// # Params
13    ///
14    /// - shape: Desired shape
15    ///
16    /// # Example
17    ///
18    /// ```rust
19    /// use mlx_rs::{Array, StreamOrDevice};
20    /// Array::zeros_device::<f32>(&[5, 10], StreamOrDevice::default()).unwrap();
21    /// ```
22    #[default_device]
23    pub fn zeros_device<T: ArrayElement>(
24        shape: &[i32],
25        stream: impl AsRef<Stream>,
26    ) -> Result<Array> {
27        let dtype = T::DTYPE;
28        zeros_dtype_device(shape, dtype, stream)
29    }
30
31    /// Construct an array of ones returning an error if shape is invalid.
32    ///
33    /// # Params
34    ///
35    /// - shape: Desired shape
36    ///
37    /// # Example
38    ///
39    /// ```rust
40    /// use mlx_rs::{Array, StreamOrDevice};
41    /// Array::ones_device::<f32>(&[5, 10], StreamOrDevice::default()).unwrap();
42    /// ```
43    #[default_device]
44    pub fn ones_device<T: ArrayElement>(
45        shape: &[i32],
46        stream: impl AsRef<Stream>,
47    ) -> Result<Array> {
48        let dtype = T::DTYPE;
49        ones_dtype_device(shape, dtype, stream)
50    }
51
52    /// Create an identity matrix or a general diagonal matrix returning an error if params are invalid.
53    ///
54    /// # Params
55    ///
56    /// - n: number of rows in the output
57    /// - m: number of columns in the output -- equal to `n` if not specified
58    /// - k: index of the diagonal - defaults to 0 if not specified
59    ///
60    /// # Example
61    ///
62    /// ```rust
63    /// use mlx_rs::{Array, StreamOrDevice};
64    /// //  create [10, 10] array with 1's on the diagonal.
65    /// let r = Array::eye_device::<f32>(10, None, None, StreamOrDevice::default()).unwrap();
66    /// ```
67    #[default_device]
68    pub fn eye_device<T: ArrayElement>(
69        n: i32,
70        m: Option<i32>,
71        k: Option<i32>,
72        stream: impl AsRef<Stream>,
73    ) -> Result<Array> {
74        Array::try_from_op(|res| unsafe {
75            mlx_sys::mlx_eye(
76                res,
77                n,
78                m.unwrap_or(n),
79                k.unwrap_or(0),
80                T::DTYPE.into(),
81                stream.as_ref().as_ptr(),
82            )
83        })
84    }
85
86    /// Construct an array with the given value returning an error if shape is invalid.
87    ///
88    /// Constructs an array of size `shape` filled with `values`. If `values`
89    /// is an [Array] it must be [broadcasting](https://swiftpackageindex.com/ml-explore/mlx-swift/main/documentation/mlx/broadcasting) to the given `shape`.
90    ///
91    /// # Params
92    ///
93    /// - shape: shape of the output array
94    /// - values: values to be broadcast into the array
95    ///
96    /// # Example
97    ///
98    /// ```rust
99    /// use mlx_rs::{Array, StreamOrDevice, array};
100    /// //  create [5, 4] array filled with 7
101    /// let r = Array::full_device::<f32>(&[5, 4], array!(7.0f32), StreamOrDevice::default()).unwrap();
102    /// ```
103    #[default_device]
104    pub fn full_device<T: ArrayElement>(
105        shape: &[i32],
106        values: impl AsRef<Array>,
107        stream: impl AsRef<Stream>,
108    ) -> Result<Array> {
109        Array::try_from_op(|res| unsafe {
110            mlx_sys::mlx_full(
111                res,
112                shape.as_ptr(),
113                shape.len(),
114                values.as_ref().as_ptr(),
115                T::DTYPE.into(),
116                stream.as_ref().as_ptr(),
117            )
118        })
119    }
120
121    /// Create a square identity matrix returning an error if params are invalid.
122    ///
123    /// # Params
124    ///
125    /// - n: number of rows and columns in the output
126    ///
127    /// # Example
128    ///
129    /// ```rust
130    /// use mlx_rs::{Array, StreamOrDevice};
131    /// //  create [10, 10] array with 1's on the diagonal.
132    /// let r = Array::identity_device::<f32>(10, StreamOrDevice::default()).unwrap();
133    /// ```
134    #[default_device]
135    pub fn identity_device<T: ArrayElement>(n: i32, stream: impl AsRef<Stream>) -> Result<Array> {
136        Array::try_from_op(|res| unsafe {
137            mlx_sys::mlx_identity(res, n, T::DTYPE.into(), stream.as_ref().as_ptr())
138        })
139    }
140
141    /// Generates ranges of numbers.
142    ///
143    /// Generate numbers in the half-open interval `[start, stop)` in increments of `step`.
144    ///
145    /// # Params
146    ///
147    /// - `start`: Starting value which defaults to `0`.
148    /// - `stop`: Stopping value.
149    /// - `step`: Increment which defaults to `1`.
150    ///
151    /// # Example
152    ///
153    /// ```rust
154    /// use mlx_rs::{Array, StreamOrDevice};
155    ///
156    /// // Create a 1-D array with values from 0 to 50
157    /// let r = Array::arange::<_, f32>(None, 50, None);
158    /// ```
159    #[default_device]
160    pub fn arange_device<U, T>(
161        start: impl Into<Option<U>>,
162        stop: U,
163        step: impl Into<Option<U>>,
164        stream: impl AsRef<Stream>,
165    ) -> Result<Array>
166    where
167        U: NumCast,
168        T: ArrayElement,
169    {
170        let start: f64 = start.into().and_then(NumCast::from).unwrap_or(0.0);
171        let stop: f64 = NumCast::from(stop).unwrap();
172        let step: f64 = step.into().and_then(NumCast::from).unwrap_or(1.0);
173
174        Array::try_from_op(|res| unsafe {
175            mlx_sys::mlx_arange(
176                res,
177                start,
178                stop,
179                step,
180                T::DTYPE.into(),
181                stream.as_ref().as_ptr(),
182            )
183        })
184    }
185
186    /// Generate `num` evenly spaced numbers over interval `[start, stop]` returning an error if params are invalid.
187    ///
188    /// # Params
189    ///
190    /// - start: start value
191    /// - stop: stop value
192    /// - count: number of samples -- defaults to 50 if not specified
193    ///
194    /// # Example
195    ///
196    /// ```rust
197    /// use mlx_rs::{Array, StreamOrDevice};
198    /// // Create a 50 element 1-D array with values from 0 to 50
199    /// let r = Array::linspace_device::<_, f32>(0, 50, None, StreamOrDevice::default()).unwrap();
200    /// ```
201    #[default_device]
202    pub fn linspace_device<U, T>(
203        start: U,
204        stop: U,
205        count: impl Into<Option<i32>>,
206        stream: impl AsRef<Stream>,
207    ) -> Result<Array>
208    where
209        U: NumCast,
210        T: ArrayElement,
211    {
212        let count = count.into().unwrap_or(50);
213        let start_f32 = NumCast::from(start).unwrap();
214        let stop_f32 = NumCast::from(stop).unwrap();
215
216        Array::try_from_op(|res| unsafe {
217            mlx_sys::mlx_linspace(
218                res,
219                start_f32,
220                stop_f32,
221                count,
222                T::DTYPE.into(),
223                stream.as_ref().as_ptr(),
224            )
225        })
226    }
227
228    /// Repeat an array along a specified axis returning an error if params are invalid.
229    ///
230    /// # Params
231    ///
232    /// - array: array to repeat
233    /// - count: number of times to repeat
234    /// - axis: axis to repeat along
235    ///
236    /// # Example
237    ///
238    /// ```rust
239    /// use mlx_rs::{Array, StreamOrDevice};
240    /// // repeat a [2, 2] array 4 times along axis 1
241    /// let source = Array::from_slice(&[0, 1, 2, 3], &[2, 2]);
242    /// let r = Array::repeat_device::<i32>(source, 4, 1, StreamOrDevice::default()).unwrap();
243    /// ```
244    #[default_device]
245    pub fn repeat_device<T: ArrayElement>(
246        array: Array,
247        count: i32,
248        axis: i32,
249        stream: impl AsRef<Stream>,
250    ) -> Result<Array> {
251        Array::try_from_op(|res| unsafe {
252            mlx_sys::mlx_repeat(res, array.as_ptr(), count, axis, stream.as_ref().as_ptr())
253        })
254    }
255
256    /// Repeat a flattened array along axis 0 returning an error if params are invalid.
257    ///
258    /// # Params
259    ///
260    /// - array: array to repeat
261    /// - count: number of times to repeat
262    ///
263    /// # Example
264    ///
265    /// ```rust
266    /// use mlx_rs::{Array, StreamOrDevice};
267    /// // repeat a 4 element array 4 times along axis 0
268    /// let source = Array::from_slice(&[0, 1, 2, 3], &[2, 2]);
269    /// let r = Array::repeat_all_device::<i32>(source, 4, StreamOrDevice::default()).unwrap();
270    /// ```
271    #[default_device]
272    pub fn repeat_all_device<T: ArrayElement>(
273        array: Array,
274        count: i32,
275        stream: impl AsRef<Stream>,
276    ) -> Result<Array> {
277        Array::try_from_op(|res| unsafe {
278            mlx_sys::mlx_repeat_all(res, array.as_ptr(), count, stream.as_ref().as_ptr())
279        })
280    }
281
282    /// An array with ones at and below the given diagonal and zeros elsewhere.
283    ///
284    /// # Params
285    ///
286    /// - n: number of rows in the output
287    /// - m: number of columns in the output -- equal to `n` if not specified
288    /// - k: index of the diagonal -- defaults to 0 if not specified
289    ///
290    /// # Example
291    ///
292    /// ```rust
293    /// use mlx_rs::{Array, StreamOrDevice};
294    /// // [5, 5] array with the lower triangle filled with 1s
295    /// let r = Array::tri_device::<f32>(5, None, None, StreamOrDevice::default());
296    /// ```
297    #[default_device]
298    pub fn tri_device<T: ArrayElement>(
299        n: i32,
300        m: Option<i32>,
301        k: Option<i32>,
302        stream: impl AsRef<Stream>,
303    ) -> Result<Array> {
304        Array::try_from_op(|res| unsafe {
305            mlx_sys::mlx_tri(
306                res,
307                n,
308                m.unwrap_or(n),
309                k.unwrap_or(0),
310                T::DTYPE.into(),
311                stream.as_ref().as_ptr(),
312            )
313        })
314    }
315}
316
317/// See [`Array::zeros`]
318#[generate_macro]
319#[default_device]
320pub fn zeros_device<T: ArrayElement>(
321    shape: &[i32],
322    #[optional] stream: impl AsRef<Stream>,
323) -> Result<Array> {
324    Array::zeros_device::<T>(shape, stream)
325}
326
327/// An array of zeros like the input.
328#[generate_macro]
329#[default_device]
330pub fn zeros_like_device(
331    input: impl AsRef<Array>,
332    #[optional] stream: impl AsRef<Stream>,
333) -> Result<Array> {
334    let a = input.as_ref();
335    let shape = a.shape();
336    let dtype = a.dtype();
337    zeros_dtype_device(shape, dtype, stream)
338}
339
340/// Similar to [`Array::zeros`] but with a specified dtype.
341#[generate_macro]
342#[default_device]
343pub fn zeros_dtype_device(
344    shape: &[i32],
345    dtype: Dtype,
346    #[optional] stream: impl AsRef<Stream>,
347) -> Result<Array> {
348    Array::try_from_op(|res| unsafe {
349        mlx_sys::mlx_zeros(
350            res,
351            shape.as_ptr(),
352            shape.len(),
353            dtype.into(),
354            stream.as_ref().as_ptr(),
355        )
356    })
357}
358
359/// See [`Array::ones`]
360#[generate_macro]
361#[default_device]
362pub fn ones_device<T: ArrayElement>(
363    shape: &[i32],
364    #[optional] stream: impl AsRef<Stream>,
365) -> Result<Array> {
366    Array::ones_device::<T>(shape, stream)
367}
368
369/// An array of ones like the input.
370#[generate_macro]
371#[default_device]
372pub fn ones_like_device(
373    input: impl AsRef<Array>,
374    #[optional] stream: impl AsRef<Stream>,
375) -> Result<Array> {
376    let a = input.as_ref();
377    let shape = a.shape();
378    let dtype = a.dtype();
379    ones_dtype_device(shape, dtype, stream)
380}
381
382/// Similar to [`Array::ones`] but with a specified dtype.
383#[generate_macro]
384#[default_device]
385pub fn ones_dtype_device(
386    shape: &[i32],
387    dtype: Dtype,
388    #[optional] stream: impl AsRef<Stream>,
389) -> Result<Array> {
390    Array::try_from_op(|res| unsafe {
391        mlx_sys::mlx_ones(
392            res,
393            shape.as_ptr(),
394            shape.len(),
395            dtype.into(),
396            stream.as_ref().as_ptr(),
397        )
398    })
399}
400
401/// See [`Array::eye`]
402#[generate_macro]
403#[default_device]
404pub fn eye_device<T: ArrayElement>(
405    n: i32,
406    #[optional] m: Option<i32>,
407    #[optional] k: Option<i32>,
408    #[optional] stream: impl AsRef<Stream>,
409) -> Result<Array> {
410    Array::eye_device::<T>(n, m, k, stream)
411}
412
413/// See [`Array::full`]
414#[generate_macro]
415#[default_device]
416pub fn full_device<T: ArrayElement>(
417    shape: &[i32],
418    values: impl AsRef<Array>,
419    #[optional] stream: impl AsRef<Stream>,
420) -> Result<Array> {
421    Array::full_device::<T>(shape, values, stream)
422}
423
424/// See [`Array::identity`]
425#[generate_macro]
426#[default_device]
427pub fn identity_device<T: ArrayElement>(
428    n: i32,
429    #[optional] stream: impl AsRef<Stream>,
430) -> Result<Array> {
431    Array::identity_device::<T>(n, stream)
432}
433
434/// See [`Array::arange`]
435#[generate_macro]
436#[default_device]
437pub fn arange_device<U, T>(
438    #[optional] start: impl Into<Option<U>>,
439    #[named] stop: U,
440    #[optional] step: impl Into<Option<U>>,
441    #[optional] stream: impl AsRef<Stream>,
442) -> Result<Array>
443where
444    U: NumCast,
445    T: ArrayElement,
446{
447    Array::arange_device::<U, T>(start, stop, step, stream)
448}
449
450/// See [`Array::linspace`]
451#[generate_macro]
452#[default_device]
453pub fn linspace_device<U, T>(
454    start: U,
455    stop: U,
456    #[optional] count: impl Into<Option<i32>>,
457    #[optional] stream: impl AsRef<Stream>,
458) -> Result<Array>
459where
460    U: NumCast,
461    T: ArrayElement,
462{
463    Array::linspace_device::<U, T>(start, stop, count, stream)
464}
465
466/// See [`Array::repeat`]
467#[generate_macro]
468#[default_device]
469pub fn repeat_device<T: ArrayElement>(
470    array: Array,
471    count: i32,
472    axis: i32,
473    #[optional] stream: impl AsRef<Stream>,
474) -> Result<Array> {
475    Array::repeat_device::<T>(array, count, axis, stream)
476}
477
478/// See [`Array::repeat_all`]
479#[generate_macro]
480#[default_device]
481pub fn repeat_all_device<T: ArrayElement>(
482    array: Array,
483    count: i32,
484    #[optional] stream: impl AsRef<Stream>,
485) -> Result<Array> {
486    Array::repeat_all_device::<T>(array, count, stream)
487}
488
489/// See [`Array::tri`]
490#[generate_macro]
491#[default_device]
492pub fn tri_device<T: ArrayElement>(
493    n: i32,
494    #[optional] m: Option<i32>,
495    #[optional] k: Option<i32>,
496    #[optional] stream: impl AsRef<Stream>,
497) -> Result<Array> {
498    Array::tri_device::<T>(n, m, k, stream)
499}
500
501/// Zeros the array above the given diagonal
502///
503/// # Params
504///
505/// - `a`: input array
506/// - `k`: diagonal of the 2D array. Default to `0`
507/// - `stream`: stream to execute on
508#[generate_macro]
509#[default_device]
510pub fn tril_device(
511    a: impl AsRef<Array>,
512    #[optional] k: impl Into<Option<i32>>,
513    #[optional] stream: impl AsRef<Stream>,
514) -> Result<Array> {
515    let a = a.as_ref();
516    let k = k.into().unwrap_or(0);
517    Array::try_from_op(|res| unsafe {
518        mlx_sys::mlx_tril(res, a.as_ptr(), k, stream.as_ref().as_ptr())
519    })
520}
521
522/// Zeros the array below the given diagonal
523///
524/// # Params
525///
526/// - `a`: input array
527/// - `k`: diagonal of the 2D array. Default to `0`
528#[generate_macro]
529#[default_device]
530pub fn triu_device(
531    a: impl AsRef<Array>,
532    #[optional] k: impl Into<Option<i32>>,
533    #[optional] stream: impl AsRef<Stream>,
534) -> Result<Array> {
535    let a = a.as_ref();
536    let k = k.into().unwrap_or(0);
537    Array::try_from_op(|res| unsafe {
538        mlx_sys::mlx_triu(res, a.as_ptr(), k, stream.as_ref().as_ptr())
539    })
540}
541
542#[cfg(test)]
543mod tests {
544    use super::*;
545    use crate::{array, dtype::Dtype};
546    use half::f16;
547
548    #[test]
549    fn test_zeros() {
550        let array = Array::zeros::<f32>(&[2, 3]).unwrap();
551        assert_eq!(array.shape(), &[2, 3]);
552        assert_eq!(array.dtype(), Dtype::Float32);
553
554        let data: &[f32] = array.as_slice();
555        assert_eq!(data, &[0.0; 6]);
556    }
557
558    #[test]
559    fn test_zeros_try() {
560        let array = Array::zeros::<f32>(&[2, 3]);
561        assert!(array.is_ok());
562
563        let array = Array::zeros::<f32>(&[-1, 3]);
564        assert!(array.is_err());
565    }
566
567    #[test]
568    fn test_ones() {
569        let array = Array::ones::<f16>(&[2, 3]).unwrap();
570        assert_eq!(array.shape(), &[2, 3]);
571        assert_eq!(array.dtype(), Dtype::Float16);
572
573        let data: &[f16] = array.as_slice();
574        assert_eq!(data, &[f16::from_f32(1.0); 6]);
575    }
576
577    #[test]
578    fn test_eye() {
579        let array = Array::eye::<f32>(3, None, None).unwrap();
580        assert_eq!(array.shape(), &[3, 3]);
581        assert_eq!(array.dtype(), Dtype::Float32);
582
583        let data: &[f32] = array.as_slice();
584        assert_eq!(data, &[1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0]);
585    }
586
587    #[test]
588    fn test_full_scalar() {
589        let array = Array::full::<f32>(&[2, 3], array!(7f32)).unwrap();
590        assert_eq!(array.shape(), &[2, 3]);
591        assert_eq!(array.dtype(), Dtype::Float32);
592
593        let data: &[f32] = array.as_slice();
594        assert_eq!(data, &[7.0; 6]);
595    }
596
597    #[test]
598    fn test_full_array() {
599        let source = Array::zeros_device::<f32>(&[1, 3], StreamOrDevice::cpu()).unwrap();
600        let array = Array::full::<f32>(&[2, 3], source).unwrap();
601        assert_eq!(array.shape(), &[2, 3]);
602        assert_eq!(array.dtype(), Dtype::Float32);
603
604        let data: &[f32] = array.as_slice();
605        float_eq::float_eq!(*data, [0.0; 6], abs <= [1e-6; 6]);
606    }
607
608    #[test]
609    fn test_full_try() {
610        let source = Array::zeros_device::<f32>(&[1, 3], StreamOrDevice::default()).unwrap();
611        let array = Array::full::<f32>(&[2, 3], source);
612        assert!(array.is_ok());
613
614        let source = Array::zeros_device::<f32>(&[1, 3], StreamOrDevice::default()).unwrap();
615        let array = Array::full::<f32>(&[-1, 3], source);
616        assert!(array.is_err());
617    }
618
619    #[test]
620    fn test_identity() {
621        let array = Array::identity::<f32>(3).unwrap();
622        assert_eq!(array.shape(), &[3, 3]);
623        assert_eq!(array.dtype(), Dtype::Float32);
624
625        let data: &[f32] = array.as_slice();
626        assert_eq!(data, &[1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0]);
627    }
628
629    #[test]
630    fn test_arange() {
631        let array = Array::arange::<_, f32>(None, 50, None).unwrap();
632        assert_eq!(array.shape(), &[50]);
633        assert_eq!(array.dtype(), Dtype::Float32);
634
635        let data: &[f32] = array.as_slice();
636        let expected: Vec<f32> = (0..50).map(|x| x as f32).collect();
637        assert_eq!(data, expected.as_slice());
638
639        let array = Array::arange::<_, i32>(0, 50, None).unwrap();
640        assert_eq!(array.shape(), &[50]);
641        assert_eq!(array.dtype(), Dtype::Int32);
642
643        let data: &[i32] = array.as_slice();
644        let expected: Vec<i32> = (0..50).collect();
645        assert_eq!(data, expected.as_slice());
646
647        let result = Array::arange::<_, bool>(None, 50, None);
648        assert!(result.is_err());
649
650        let result = Array::arange::<_, f32>(f64::NEG_INFINITY, 50.0, None);
651        assert!(result.is_err());
652
653        let result = Array::arange::<_, f32>(0.0, f64::INFINITY, None);
654        assert!(result.is_err());
655
656        let result = Array::arange::<_, f32>(0.0, 50.0, f32::NAN);
657        assert!(result.is_err());
658
659        let result = Array::arange::<_, f32>(f32::NAN, 50.0, None);
660        assert!(result.is_err());
661
662        let result = Array::arange::<_, f32>(0.0, f32::NAN, None);
663        assert!(result.is_err());
664
665        let result = Array::arange::<_, f32>(0, i32::MAX as i64 + 1, None);
666        assert!(result.is_err());
667    }
668
669    #[test]
670    fn test_linspace_int() {
671        let array = Array::linspace::<_, f32>(0, 50, None).unwrap();
672        assert_eq!(array.shape(), &[50]);
673        assert_eq!(array.dtype(), Dtype::Float32);
674
675        let data: &[f32] = array.as_slice();
676        let expected: Vec<f32> = (0..50).map(|x| x as f32 * (50.0 / 49.0)).collect();
677        assert_eq!(data, expected.as_slice());
678    }
679
680    #[test]
681    fn test_linspace_float() {
682        let array = Array::linspace::<_, f32>(0., 50., None).unwrap();
683        assert_eq!(array.shape(), &[50]);
684        assert_eq!(array.dtype(), Dtype::Float32);
685
686        let data: &[f32] = array.as_slice();
687        let expected: Vec<f32> = (0..50).map(|x| x as f32 * (50.0 / 49.0)).collect();
688        assert_eq!(data, expected.as_slice());
689    }
690
691    #[test]
692    fn test_linspace_try() {
693        let array = Array::linspace::<_, f32>(0, 50, None);
694        assert!(array.is_ok());
695
696        let array = Array::linspace::<_, f32>(0, 50, Some(-1));
697        assert!(array.is_err());
698    }
699
700    #[test]
701    fn test_repeat() {
702        let source = Array::from_slice(&[0, 1, 2, 3], &[2, 2]);
703        let array = Array::repeat::<i32>(source, 4, 1).unwrap();
704        assert_eq!(array.shape(), &[2, 8]);
705        assert_eq!(array.dtype(), Dtype::Int32);
706
707        let data: &[i32] = array.as_slice();
708        assert_eq!(data, [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3]);
709    }
710
711    #[test]
712    fn test_repeat_try() {
713        let source = Array::from_slice(&[0, 1, 2, 3], &[2, 2]);
714        let array = Array::repeat::<i32>(source, 4, 1);
715        assert!(array.is_ok());
716
717        let source = Array::from_slice(&[0, 1, 2, 3], &[2, 2]);
718        let array = Array::repeat::<i32>(source, -1, 1);
719        assert!(array.is_err());
720    }
721
722    #[test]
723    fn test_repeat_all() {
724        let source = Array::from_slice(&[0, 1, 2, 3], &[2, 2]);
725        let array = Array::repeat_all::<i32>(source, 4).unwrap();
726        assert_eq!(array.shape(), &[16]);
727        assert_eq!(array.dtype(), Dtype::Int32);
728
729        let data: &[i32] = array.as_slice();
730        assert_eq!(data, [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3]);
731    }
732
733    #[test]
734    fn test_repeat_all_try() {
735        let source = Array::from_slice(&[0, 1, 2, 3], &[2, 2]);
736        let array = Array::repeat_all::<i32>(source, 4);
737        assert!(array.is_ok());
738
739        let source = Array::from_slice(&[0, 1, 2, 3], &[2, 2]);
740        let array = Array::repeat_all::<i32>(source, -1);
741        assert!(array.is_err());
742    }
743
744    #[test]
745    fn test_tri() {
746        let array = Array::tri::<f32>(3, None, None).unwrap();
747        assert_eq!(array.shape(), &[3, 3]);
748        assert_eq!(array.dtype(), Dtype::Float32);
749
750        let data: &[f32] = array.as_slice();
751        assert_eq!(data, &[1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0]);
752    }
753}