1use crate::array::ArrayElement;
2use crate::error::Result;
3use crate::utils::guard::Guarded;
4use crate::{array::Array, stream::StreamOrDevice};
5use crate::{Dtype, Stream};
6use mlx_internal_macros::{default_device, generate_macro};
7use num_traits::NumCast;
8
9impl Array {
10 #[default_device]
23 pub fn zeros_device<T: ArrayElement>(
24 shape: &[i32],
25 stream: impl AsRef<Stream>,
26 ) -> Result<Array> {
27 let dtype = T::DTYPE;
28 zeros_dtype_device(shape, dtype, stream)
29 }
30
31 #[default_device]
44 pub fn ones_device<T: ArrayElement>(
45 shape: &[i32],
46 stream: impl AsRef<Stream>,
47 ) -> Result<Array> {
48 let dtype = T::DTYPE;
49 ones_dtype_device(shape, dtype, stream)
50 }
51
52 #[default_device]
68 pub fn eye_device<T: ArrayElement>(
69 n: i32,
70 m: Option<i32>,
71 k: Option<i32>,
72 stream: impl AsRef<Stream>,
73 ) -> Result<Array> {
74 Array::try_from_op(|res| unsafe {
75 mlx_sys::mlx_eye(
76 res,
77 n,
78 m.unwrap_or(n),
79 k.unwrap_or(0),
80 T::DTYPE.into(),
81 stream.as_ref().as_ptr(),
82 )
83 })
84 }
85
86 #[default_device]
104 pub fn full_device<T: ArrayElement>(
105 shape: &[i32],
106 values: impl AsRef<Array>,
107 stream: impl AsRef<Stream>,
108 ) -> Result<Array> {
109 Array::try_from_op(|res| unsafe {
110 mlx_sys::mlx_full(
111 res,
112 shape.as_ptr(),
113 shape.len(),
114 values.as_ref().as_ptr(),
115 T::DTYPE.into(),
116 stream.as_ref().as_ptr(),
117 )
118 })
119 }
120
121 #[default_device]
135 pub fn identity_device<T: ArrayElement>(n: i32, stream: impl AsRef<Stream>) -> Result<Array> {
136 Array::try_from_op(|res| unsafe {
137 mlx_sys::mlx_identity(res, n, T::DTYPE.into(), stream.as_ref().as_ptr())
138 })
139 }
140
141 #[default_device]
160 pub fn arange_device<U, T>(
161 start: impl Into<Option<U>>,
162 stop: U,
163 step: impl Into<Option<U>>,
164 stream: impl AsRef<Stream>,
165 ) -> Result<Array>
166 where
167 U: NumCast,
168 T: ArrayElement,
169 {
170 let start: f64 = start.into().and_then(NumCast::from).unwrap_or(0.0);
171 let stop: f64 = NumCast::from(stop).unwrap();
172 let step: f64 = step.into().and_then(NumCast::from).unwrap_or(1.0);
173
174 Array::try_from_op(|res| unsafe {
175 mlx_sys::mlx_arange(
176 res,
177 start,
178 stop,
179 step,
180 T::DTYPE.into(),
181 stream.as_ref().as_ptr(),
182 )
183 })
184 }
185
186 #[default_device]
202 pub fn linspace_device<U, T>(
203 start: U,
204 stop: U,
205 count: impl Into<Option<i32>>,
206 stream: impl AsRef<Stream>,
207 ) -> Result<Array>
208 where
209 U: NumCast,
210 T: ArrayElement,
211 {
212 let count = count.into().unwrap_or(50);
213 let start_f32 = NumCast::from(start).unwrap();
214 let stop_f32 = NumCast::from(stop).unwrap();
215
216 Array::try_from_op(|res| unsafe {
217 mlx_sys::mlx_linspace(
218 res,
219 start_f32,
220 stop_f32,
221 count,
222 T::DTYPE.into(),
223 stream.as_ref().as_ptr(),
224 )
225 })
226 }
227
228 #[default_device]
245 pub fn repeat_device<T: ArrayElement>(
246 array: Array,
247 count: i32,
248 axis: i32,
249 stream: impl AsRef<Stream>,
250 ) -> Result<Array> {
251 Array::try_from_op(|res| unsafe {
252 mlx_sys::mlx_repeat(res, array.as_ptr(), count, axis, stream.as_ref().as_ptr())
253 })
254 }
255
256 #[default_device]
272 pub fn repeat_all_device<T: ArrayElement>(
273 array: Array,
274 count: i32,
275 stream: impl AsRef<Stream>,
276 ) -> Result<Array> {
277 Array::try_from_op(|res| unsafe {
278 mlx_sys::mlx_repeat_all(res, array.as_ptr(), count, stream.as_ref().as_ptr())
279 })
280 }
281
282 #[default_device]
298 pub fn tri_device<T: ArrayElement>(
299 n: i32,
300 m: Option<i32>,
301 k: Option<i32>,
302 stream: impl AsRef<Stream>,
303 ) -> Result<Array> {
304 Array::try_from_op(|res| unsafe {
305 mlx_sys::mlx_tri(
306 res,
307 n,
308 m.unwrap_or(n),
309 k.unwrap_or(0),
310 T::DTYPE.into(),
311 stream.as_ref().as_ptr(),
312 )
313 })
314 }
315}
316
317#[generate_macro]
319#[default_device]
320pub fn zeros_device<T: ArrayElement>(
321 shape: &[i32],
322 #[optional] stream: impl AsRef<Stream>,
323) -> Result<Array> {
324 Array::zeros_device::<T>(shape, stream)
325}
326
327#[generate_macro]
329#[default_device]
330pub fn zeros_like_device(
331 input: impl AsRef<Array>,
332 #[optional] stream: impl AsRef<Stream>,
333) -> Result<Array> {
334 let a = input.as_ref();
335 let shape = a.shape();
336 let dtype = a.dtype();
337 zeros_dtype_device(shape, dtype, stream)
338}
339
340#[generate_macro]
342#[default_device]
343pub fn zeros_dtype_device(
344 shape: &[i32],
345 dtype: Dtype,
346 #[optional] stream: impl AsRef<Stream>,
347) -> Result<Array> {
348 Array::try_from_op(|res| unsafe {
349 mlx_sys::mlx_zeros(
350 res,
351 shape.as_ptr(),
352 shape.len(),
353 dtype.into(),
354 stream.as_ref().as_ptr(),
355 )
356 })
357}
358
359#[generate_macro]
361#[default_device]
362pub fn ones_device<T: ArrayElement>(
363 shape: &[i32],
364 #[optional] stream: impl AsRef<Stream>,
365) -> Result<Array> {
366 Array::ones_device::<T>(shape, stream)
367}
368
369#[generate_macro]
371#[default_device]
372pub fn ones_like_device(
373 input: impl AsRef<Array>,
374 #[optional] stream: impl AsRef<Stream>,
375) -> Result<Array> {
376 let a = input.as_ref();
377 let shape = a.shape();
378 let dtype = a.dtype();
379 ones_dtype_device(shape, dtype, stream)
380}
381
382#[generate_macro]
384#[default_device]
385pub fn ones_dtype_device(
386 shape: &[i32],
387 dtype: Dtype,
388 #[optional] stream: impl AsRef<Stream>,
389) -> Result<Array> {
390 Array::try_from_op(|res| unsafe {
391 mlx_sys::mlx_ones(
392 res,
393 shape.as_ptr(),
394 shape.len(),
395 dtype.into(),
396 stream.as_ref().as_ptr(),
397 )
398 })
399}
400
401#[generate_macro]
403#[default_device]
404pub fn eye_device<T: ArrayElement>(
405 n: i32,
406 #[optional] m: Option<i32>,
407 #[optional] k: Option<i32>,
408 #[optional] stream: impl AsRef<Stream>,
409) -> Result<Array> {
410 Array::eye_device::<T>(n, m, k, stream)
411}
412
413#[generate_macro]
415#[default_device]
416pub fn full_device<T: ArrayElement>(
417 shape: &[i32],
418 values: impl AsRef<Array>,
419 #[optional] stream: impl AsRef<Stream>,
420) -> Result<Array> {
421 Array::full_device::<T>(shape, values, stream)
422}
423
424#[generate_macro]
426#[default_device]
427pub fn identity_device<T: ArrayElement>(
428 n: i32,
429 #[optional] stream: impl AsRef<Stream>,
430) -> Result<Array> {
431 Array::identity_device::<T>(n, stream)
432}
433
434#[generate_macro]
436#[default_device]
437pub fn arange_device<U, T>(
438 #[optional] start: impl Into<Option<U>>,
439 #[named] stop: U,
440 #[optional] step: impl Into<Option<U>>,
441 #[optional] stream: impl AsRef<Stream>,
442) -> Result<Array>
443where
444 U: NumCast,
445 T: ArrayElement,
446{
447 Array::arange_device::<U, T>(start, stop, step, stream)
448}
449
450#[generate_macro]
452#[default_device]
453pub fn linspace_device<U, T>(
454 start: U,
455 stop: U,
456 #[optional] count: impl Into<Option<i32>>,
457 #[optional] stream: impl AsRef<Stream>,
458) -> Result<Array>
459where
460 U: NumCast,
461 T: ArrayElement,
462{
463 Array::linspace_device::<U, T>(start, stop, count, stream)
464}
465
466#[generate_macro]
468#[default_device]
469pub fn repeat_device<T: ArrayElement>(
470 array: Array,
471 count: i32,
472 axis: i32,
473 #[optional] stream: impl AsRef<Stream>,
474) -> Result<Array> {
475 Array::repeat_device::<T>(array, count, axis, stream)
476}
477
478#[generate_macro]
480#[default_device]
481pub fn repeat_all_device<T: ArrayElement>(
482 array: Array,
483 count: i32,
484 #[optional] stream: impl AsRef<Stream>,
485) -> Result<Array> {
486 Array::repeat_all_device::<T>(array, count, stream)
487}
488
489#[generate_macro]
491#[default_device]
492pub fn tri_device<T: ArrayElement>(
493 n: i32,
494 #[optional] m: Option<i32>,
495 #[optional] k: Option<i32>,
496 #[optional] stream: impl AsRef<Stream>,
497) -> Result<Array> {
498 Array::tri_device::<T>(n, m, k, stream)
499}
500
501#[generate_macro]
509#[default_device]
510pub fn tril_device(
511 a: impl AsRef<Array>,
512 #[optional] k: impl Into<Option<i32>>,
513 #[optional] stream: impl AsRef<Stream>,
514) -> Result<Array> {
515 let a = a.as_ref();
516 let k = k.into().unwrap_or(0);
517 Array::try_from_op(|res| unsafe {
518 mlx_sys::mlx_tril(res, a.as_ptr(), k, stream.as_ref().as_ptr())
519 })
520}
521
522#[generate_macro]
529#[default_device]
530pub fn triu_device(
531 a: impl AsRef<Array>,
532 #[optional] k: impl Into<Option<i32>>,
533 #[optional] stream: impl AsRef<Stream>,
534) -> Result<Array> {
535 let a = a.as_ref();
536 let k = k.into().unwrap_or(0);
537 Array::try_from_op(|res| unsafe {
538 mlx_sys::mlx_triu(res, a.as_ptr(), k, stream.as_ref().as_ptr())
539 })
540}
541
542#[cfg(test)]
543mod tests {
544 use super::*;
545 use crate::{array, dtype::Dtype};
546 use half::f16;
547
548 #[test]
549 fn test_zeros() {
550 let array = Array::zeros::<f32>(&[2, 3]).unwrap();
551 assert_eq!(array.shape(), &[2, 3]);
552 assert_eq!(array.dtype(), Dtype::Float32);
553
554 let data: &[f32] = array.as_slice();
555 assert_eq!(data, &[0.0; 6]);
556 }
557
558 #[test]
559 fn test_zeros_try() {
560 let array = Array::zeros::<f32>(&[2, 3]);
561 assert!(array.is_ok());
562
563 let array = Array::zeros::<f32>(&[-1, 3]);
564 assert!(array.is_err());
565 }
566
567 #[test]
568 fn test_ones() {
569 let array = Array::ones::<f16>(&[2, 3]).unwrap();
570 assert_eq!(array.shape(), &[2, 3]);
571 assert_eq!(array.dtype(), Dtype::Float16);
572
573 let data: &[f16] = array.as_slice();
574 assert_eq!(data, &[f16::from_f32(1.0); 6]);
575 }
576
577 #[test]
578 fn test_eye() {
579 let array = Array::eye::<f32>(3, None, None).unwrap();
580 assert_eq!(array.shape(), &[3, 3]);
581 assert_eq!(array.dtype(), Dtype::Float32);
582
583 let data: &[f32] = array.as_slice();
584 assert_eq!(data, &[1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0]);
585 }
586
587 #[test]
588 fn test_full_scalar() {
589 let array = Array::full::<f32>(&[2, 3], array!(7f32)).unwrap();
590 assert_eq!(array.shape(), &[2, 3]);
591 assert_eq!(array.dtype(), Dtype::Float32);
592
593 let data: &[f32] = array.as_slice();
594 assert_eq!(data, &[7.0; 6]);
595 }
596
597 #[test]
598 fn test_full_array() {
599 let source = Array::zeros_device::<f32>(&[1, 3], StreamOrDevice::cpu()).unwrap();
600 let array = Array::full::<f32>(&[2, 3], source).unwrap();
601 assert_eq!(array.shape(), &[2, 3]);
602 assert_eq!(array.dtype(), Dtype::Float32);
603
604 let data: &[f32] = array.as_slice();
605 float_eq::float_eq!(*data, [0.0; 6], abs <= [1e-6; 6]);
606 }
607
608 #[test]
609 fn test_full_try() {
610 let source = Array::zeros_device::<f32>(&[1, 3], StreamOrDevice::default()).unwrap();
611 let array = Array::full::<f32>(&[2, 3], source);
612 assert!(array.is_ok());
613
614 let source = Array::zeros_device::<f32>(&[1, 3], StreamOrDevice::default()).unwrap();
615 let array = Array::full::<f32>(&[-1, 3], source);
616 assert!(array.is_err());
617 }
618
619 #[test]
620 fn test_identity() {
621 let array = Array::identity::<f32>(3).unwrap();
622 assert_eq!(array.shape(), &[3, 3]);
623 assert_eq!(array.dtype(), Dtype::Float32);
624
625 let data: &[f32] = array.as_slice();
626 assert_eq!(data, &[1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0]);
627 }
628
629 #[test]
630 fn test_arange() {
631 let array = Array::arange::<_, f32>(None, 50, None).unwrap();
632 assert_eq!(array.shape(), &[50]);
633 assert_eq!(array.dtype(), Dtype::Float32);
634
635 let data: &[f32] = array.as_slice();
636 let expected: Vec<f32> = (0..50).map(|x| x as f32).collect();
637 assert_eq!(data, expected.as_slice());
638
639 let array = Array::arange::<_, i32>(0, 50, None).unwrap();
640 assert_eq!(array.shape(), &[50]);
641 assert_eq!(array.dtype(), Dtype::Int32);
642
643 let data: &[i32] = array.as_slice();
644 let expected: Vec<i32> = (0..50).collect();
645 assert_eq!(data, expected.as_slice());
646
647 let result = Array::arange::<_, bool>(None, 50, None);
648 assert!(result.is_err());
649
650 let result = Array::arange::<_, f32>(f64::NEG_INFINITY, 50.0, None);
651 assert!(result.is_err());
652
653 let result = Array::arange::<_, f32>(0.0, f64::INFINITY, None);
654 assert!(result.is_err());
655
656 let result = Array::arange::<_, f32>(0.0, 50.0, f32::NAN);
657 assert!(result.is_err());
658
659 let result = Array::arange::<_, f32>(f32::NAN, 50.0, None);
660 assert!(result.is_err());
661
662 let result = Array::arange::<_, f32>(0.0, f32::NAN, None);
663 assert!(result.is_err());
664
665 let result = Array::arange::<_, f32>(0, i32::MAX as i64 + 1, None);
666 assert!(result.is_err());
667 }
668
669 #[test]
670 fn test_linspace_int() {
671 let array = Array::linspace::<_, f32>(0, 50, None).unwrap();
672 assert_eq!(array.shape(), &[50]);
673 assert_eq!(array.dtype(), Dtype::Float32);
674
675 let data: &[f32] = array.as_slice();
676 let expected: Vec<f32> = (0..50).map(|x| x as f32 * (50.0 / 49.0)).collect();
677 assert_eq!(data, expected.as_slice());
678 }
679
680 #[test]
681 fn test_linspace_float() {
682 let array = Array::linspace::<_, f32>(0., 50., None).unwrap();
683 assert_eq!(array.shape(), &[50]);
684 assert_eq!(array.dtype(), Dtype::Float32);
685
686 let data: &[f32] = array.as_slice();
687 let expected: Vec<f32> = (0..50).map(|x| x as f32 * (50.0 / 49.0)).collect();
688 assert_eq!(data, expected.as_slice());
689 }
690
691 #[test]
692 fn test_linspace_try() {
693 let array = Array::linspace::<_, f32>(0, 50, None);
694 assert!(array.is_ok());
695
696 let array = Array::linspace::<_, f32>(0, 50, Some(-1));
697 assert!(array.is_err());
698 }
699
700 #[test]
701 fn test_repeat() {
702 let source = Array::from_slice(&[0, 1, 2, 3], &[2, 2]);
703 let array = Array::repeat::<i32>(source, 4, 1).unwrap();
704 assert_eq!(array.shape(), &[2, 8]);
705 assert_eq!(array.dtype(), Dtype::Int32);
706
707 let data: &[i32] = array.as_slice();
708 assert_eq!(data, [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3]);
709 }
710
711 #[test]
712 fn test_repeat_try() {
713 let source = Array::from_slice(&[0, 1, 2, 3], &[2, 2]);
714 let array = Array::repeat::<i32>(source, 4, 1);
715 assert!(array.is_ok());
716
717 let source = Array::from_slice(&[0, 1, 2, 3], &[2, 2]);
718 let array = Array::repeat::<i32>(source, -1, 1);
719 assert!(array.is_err());
720 }
721
722 #[test]
723 fn test_repeat_all() {
724 let source = Array::from_slice(&[0, 1, 2, 3], &[2, 2]);
725 let array = Array::repeat_all::<i32>(source, 4).unwrap();
726 assert_eq!(array.shape(), &[16]);
727 assert_eq!(array.dtype(), Dtype::Int32);
728
729 let data: &[i32] = array.as_slice();
730 assert_eq!(data, [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3]);
731 }
732
733 #[test]
734 fn test_repeat_all_try() {
735 let source = Array::from_slice(&[0, 1, 2, 3], &[2, 2]);
736 let array = Array::repeat_all::<i32>(source, 4);
737 assert!(array.is_ok());
738
739 let source = Array::from_slice(&[0, 1, 2, 3], &[2, 2]);
740 let array = Array::repeat_all::<i32>(source, -1);
741 assert!(array.is_err());
742 }
743
744 #[test]
745 fn test_tri() {
746 let array = Array::tri::<f32>(3, None, None).unwrap();
747 assert_eq!(array.shape(), &[3, 3]);
748 assert_eq!(array.dtype(), Dtype::Float32);
749
750 let data: &[f32] = array.as_slice();
751 assert_eq!(data, &[1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0]);
752 }
753}