1use crate::array::Array;
2use crate::error::Result;
3use crate::utils::axes_or_default_to_all;
4use crate::utils::guard::Guarded;
5use crate::Stream;
6use mlx_internal_macros::{default_device, generate_macro};
7
8impl Array {
9 #[default_device]
27 pub fn all_axes_device(
28 &self,
29 axes: &[i32],
30 keep_dims: impl Into<Option<bool>>,
31 stream: impl AsRef<Stream>,
32 ) -> Result<Array> {
33 Array::try_from_op(|res| unsafe {
34 mlx_sys::mlx_all_axes(
35 res,
36 self.as_ptr(),
37 axes.as_ptr(),
38 axes.len(),
39 keep_dims.into().unwrap_or(false),
40 stream.as_ref().as_ptr(),
41 )
42 })
43 }
44
45 #[default_device]
47 pub fn all_axis_device(
48 &self,
49 axis: i32,
50 keep_dims: impl Into<Option<bool>>,
51 stream: impl AsRef<Stream>,
52 ) -> Result<Array> {
53 Array::try_from_op(|res| unsafe {
54 mlx_sys::mlx_all_axis(
55 res,
56 self.as_ptr(),
57 axis,
58 keep_dims.into().unwrap_or(false),
59 stream.as_ref().as_ptr(),
60 )
61 })
62 }
63
64 #[default_device]
66 pub fn all_device(
67 &self,
68 keep_dims: impl Into<Option<bool>>,
69 stream: impl AsRef<Stream>,
70 ) -> Result<Array> {
71 Array::try_from_op(|res| unsafe {
72 mlx_sys::mlx_all(
73 res,
74 self.as_ptr(),
75 keep_dims.into().unwrap_or(false),
76 stream.as_ref().as_ptr(),
77 )
78 })
79 }
80
81 #[default_device]
98 pub fn prod_axes_device(
99 &self,
100 axes: &[i32],
101 keep_dims: impl Into<Option<bool>>,
102 stream: impl AsRef<Stream>,
103 ) -> Result<Array> {
104 Array::try_from_op(|res| unsafe {
105 mlx_sys::mlx_prod_axes(
106 res,
107 self.as_ptr(),
108 axes.as_ptr(),
109 axes.len(),
110 keep_dims.into().unwrap_or(false),
111 stream.as_ref().as_ptr(),
112 )
113 })
114 }
115
116 #[default_device]
118 pub fn prod_axis_device(
119 &self,
120 axis: i32,
121 keep_dims: impl Into<Option<bool>>,
122 stream: impl AsRef<Stream>,
123 ) -> Result<Array> {
124 Array::try_from_op(|res| unsafe {
125 mlx_sys::mlx_prod_axis(
126 res,
127 self.as_ptr(),
128 axis,
129 keep_dims.into().unwrap_or(false),
130 stream.as_ref().as_ptr(),
131 )
132 })
133 }
134
135 #[default_device]
137 pub fn prod_device(
138 &self,
139 keep_dims: impl Into<Option<bool>>,
140 stream: impl AsRef<Stream>,
141 ) -> Result<Array> {
142 Array::try_from_op(|res| unsafe {
143 mlx_sys::mlx_prod(
144 res,
145 self.as_ptr(),
146 keep_dims.into().unwrap_or(false),
147 stream.as_ref().as_ptr(),
148 )
149 })
150 }
151
152 #[default_device]
169 pub fn max_axes_device(
170 &self,
171 axes: &[i32],
172 keep_dims: impl Into<Option<bool>>,
173 stream: impl AsRef<Stream>,
174 ) -> Result<Array> {
175 Array::try_from_op(|res| unsafe {
176 mlx_sys::mlx_max_axes(
177 res,
178 self.as_ptr(),
179 axes.as_ptr(),
180 axes.len(),
181 keep_dims.into().unwrap_or(false),
182 stream.as_ref().as_ptr(),
183 )
184 })
185 }
186
187 #[default_device]
189 pub fn max_axis_device(
190 &self,
191 axis: i32,
192 keep_dims: impl Into<Option<bool>>,
193 stream: impl AsRef<Stream>,
194 ) -> Result<Array> {
195 Array::try_from_op(|res| unsafe {
196 mlx_sys::mlx_max_axis(
197 res,
198 self.as_ptr(),
199 axis,
200 keep_dims.into().unwrap_or(false),
201 stream.as_ref().as_ptr(),
202 )
203 })
204 }
205
206 #[default_device]
208 pub fn max_device(
209 &self,
210 keep_dims: impl Into<Option<bool>>,
211 stream: impl AsRef<Stream>,
212 ) -> Result<Array> {
213 Array::try_from_op(|res| unsafe {
214 mlx_sys::mlx_max(
215 res,
216 self.as_ptr(),
217 keep_dims.into().unwrap_or(false),
218 stream.as_ref().as_ptr(),
219 )
220 })
221 }
222
223 #[default_device]
240 pub fn sum_axes_device(
241 &self,
242 axes: &[i32],
243 keep_dims: impl Into<Option<bool>>,
244 stream: impl AsRef<Stream>,
245 ) -> Result<Array> {
246 Array::try_from_op(|res| unsafe {
247 mlx_sys::mlx_sum_axes(
248 res,
249 self.as_ptr(),
250 axes.as_ptr(),
251 axes.len(),
252 keep_dims.into().unwrap_or(false),
253 stream.as_ref().as_ptr(),
254 )
255 })
256 }
257
258 #[default_device]
260 pub fn sum_axis_device(
261 &self,
262 axis: i32,
263 keep_dims: impl Into<Option<bool>>,
264 stream: impl AsRef<Stream>,
265 ) -> Result<Array> {
266 Array::try_from_op(|res| unsafe {
267 mlx_sys::mlx_sum_axis(
268 res,
269 self.as_ptr(),
270 axis,
271 keep_dims.into().unwrap_or(false),
272 stream.as_ref().as_ptr(),
273 )
274 })
275 }
276
277 #[default_device]
279 pub fn sum_device(
280 &self,
281 keep_dims: impl Into<Option<bool>>,
282 stream: impl AsRef<Stream>,
283 ) -> Result<Array> {
284 Array::try_from_op(|res| unsafe {
285 mlx_sys::mlx_sum(
286 res,
287 self.as_ptr(),
288 keep_dims.into().unwrap_or(false),
289 stream.as_ref().as_ptr(),
290 )
291 })
292 }
293
294 #[default_device]
311 pub fn mean_axes_device(
312 &self,
313 axes: &[i32],
314 keep_dims: impl Into<Option<bool>>,
315 stream: impl AsRef<Stream>,
316 ) -> Result<Array> {
317 let axes = axes_or_default_to_all(axes, self.ndim() as i32);
318 Array::try_from_op(|res| unsafe {
319 mlx_sys::mlx_mean_axes(
320 res,
321 self.as_ptr(),
322 axes.as_ptr(),
323 axes.len(),
324 keep_dims.into().unwrap_or(false),
325 stream.as_ref().as_ptr(),
326 )
327 })
328 }
329
330 #[default_device]
332 pub fn mean_axis_device(
333 &self,
334 axis: i32,
335 keep_dims: impl Into<Option<bool>>,
336 stream: impl AsRef<Stream>,
337 ) -> Result<Array> {
338 Array::try_from_op(|res| unsafe {
339 mlx_sys::mlx_mean_axis(
340 res,
341 self.as_ptr(),
342 axis,
343 keep_dims.into().unwrap_or(false),
344 stream.as_ref().as_ptr(),
345 )
346 })
347 }
348
349 #[default_device]
351 pub fn mean_device(
352 &self,
353 keep_dims: impl Into<Option<bool>>,
354 stream: impl AsRef<Stream>,
355 ) -> Result<Array> {
356 Array::try_from_op(|res| unsafe {
357 mlx_sys::mlx_mean(
358 res,
359 self.as_ptr(),
360 keep_dims.into().unwrap_or(false),
361 stream.as_ref().as_ptr(),
362 )
363 })
364 }
365
366 #[default_device]
383 pub fn min_axes_device(
384 &self,
385 axes: &[i32],
386 keep_dims: impl Into<Option<bool>>,
387 stream: impl AsRef<Stream>,
388 ) -> Result<Array> {
389 Array::try_from_op(|res| unsafe {
390 mlx_sys::mlx_min_axes(
391 res,
392 self.as_ptr(),
393 axes.as_ptr(),
394 axes.len(),
395 keep_dims.into().unwrap_or(false),
396 stream.as_ref().as_ptr(),
397 )
398 })
399 }
400
401 #[default_device]
403 pub fn min_axis_device(
404 &self,
405 axis: i32,
406 keep_dims: impl Into<Option<bool>>,
407 stream: impl AsRef<Stream>,
408 ) -> Result<Array> {
409 Array::try_from_op(|res| unsafe {
410 mlx_sys::mlx_min_axis(
411 res,
412 self.as_ptr(),
413 axis,
414 keep_dims.into().unwrap_or(false),
415 stream.as_ref().as_ptr(),
416 )
417 })
418 }
419
420 #[default_device]
422 pub fn min_device(
423 &self,
424 keep_dims: impl Into<Option<bool>>,
425 stream: impl AsRef<Stream>,
426 ) -> Result<Array> {
427 Array::try_from_op(|res| unsafe {
428 mlx_sys::mlx_min(
429 res,
430 self.as_ptr(),
431 keep_dims.into().unwrap_or(false),
432 stream.as_ref().as_ptr(),
433 )
434 })
435 }
436
437 #[default_device]
445 pub fn var_axes_device(
446 &self,
447 axes: &[i32],
448 keep_dims: impl Into<Option<bool>>,
449 ddof: impl Into<Option<i32>>,
450 stream: impl AsRef<Stream>,
451 ) -> Result<Array> {
452 Array::try_from_op(|res| unsafe {
453 mlx_sys::mlx_var_axes(
454 res,
455 self.as_ptr(),
456 axes.as_ptr(),
457 axes.len(),
458 keep_dims.into().unwrap_or(false),
459 ddof.into().unwrap_or(0),
460 stream.as_ref().as_ptr(),
461 )
462 })
463 }
464
465 #[default_device]
467 pub fn var_axis_device(
468 &self,
469 axis: i32,
470 keep_dims: impl Into<Option<bool>>,
471 ddof: impl Into<Option<i32>>,
472 stream: impl AsRef<Stream>,
473 ) -> Result<Array> {
474 Array::try_from_op(|res| unsafe {
475 mlx_sys::mlx_var_axis(
476 res,
477 self.as_ptr(),
478 axis,
479 keep_dims.into().unwrap_or(false),
480 ddof.into().unwrap_or(0),
481 stream.as_ref().as_ptr(),
482 )
483 })
484 }
485
486 #[default_device]
488 pub fn var_device(
489 &self,
490 keep_dims: impl Into<Option<bool>>,
491 ddof: impl Into<Option<i32>>,
492 stream: impl AsRef<Stream>,
493 ) -> Result<Array> {
494 Array::try_from_op(|res| unsafe {
495 mlx_sys::mlx_var(
496 res,
497 self.as_ptr(),
498 keep_dims.into().unwrap_or(false),
499 ddof.into().unwrap_or(0),
500 stream.as_ref().as_ptr(),
501 )
502 })
503 }
504
505 #[default_device]
514 pub fn logsumexp_axes_device(
515 &self,
516 axes: &[i32],
517 keep_dims: impl Into<Option<bool>>,
518 stream: impl AsRef<Stream>,
519 ) -> Result<Array> {
520 Array::try_from_op(|res| unsafe {
521 mlx_sys::mlx_logsumexp_axes(
522 res,
523 self.as_ptr(),
524 axes.as_ptr(),
525 axes.len(),
526 keep_dims.into().unwrap_or(false),
527 stream.as_ref().as_ptr(),
528 )
529 })
530 }
531
532 #[default_device]
534 pub fn logsumexp_axis_device(
535 &self,
536 axis: i32,
537 keep_dims: impl Into<Option<bool>>,
538 stream: impl AsRef<Stream>,
539 ) -> Result<Array> {
540 Array::try_from_op(|res| unsafe {
541 mlx_sys::mlx_logsumexp_axis(
542 res,
543 self.as_ptr(),
544 axis,
545 keep_dims.into().unwrap_or(false),
546 stream.as_ref().as_ptr(),
547 )
548 })
549 }
550
551 #[default_device]
553 pub fn logsumexp_device(
554 &self,
555 keep_dims: impl Into<Option<bool>>,
556 stream: impl AsRef<Stream>,
557 ) -> Result<Array> {
558 Array::try_from_op(|res| unsafe {
559 mlx_sys::mlx_logsumexp(
560 res,
561 self.as_ptr(),
562 keep_dims.into().unwrap_or(false),
563 stream.as_ref().as_ptr(),
564 )
565 })
566 }
567}
568
569#[generate_macro]
571#[default_device]
572pub fn all_axes_device(
573 array: impl AsRef<Array>,
574 axes: &[i32],
575 #[optional] keep_dims: impl Into<Option<bool>>,
576 #[optional] stream: impl AsRef<Stream>,
577) -> Result<Array> {
578 array.as_ref().all_axes_device(axes, keep_dims, stream)
579}
580
581#[generate_macro]
583#[default_device]
584pub fn all_axis_device(
585 array: impl AsRef<Array>,
586 axis: i32,
587 #[optional] keep_dims: impl Into<Option<bool>>,
588 #[optional] stream: impl AsRef<Stream>,
589) -> Result<Array> {
590 array.as_ref().all_axis_device(axis, keep_dims, stream)
591}
592
593#[generate_macro]
595#[default_device]
596pub fn all_device(
597 array: impl AsRef<Array>,
598 #[optional] keep_dims: impl Into<Option<bool>>,
599 #[optional] stream: impl AsRef<Stream>,
600) -> Result<Array> {
601 array.as_ref().all_device(keep_dims, stream)
602}
603
604#[generate_macro]
606#[default_device]
607pub fn prod_axes_device(
608 array: impl AsRef<Array>,
609 axes: &[i32],
610 #[optional] keep_dims: impl Into<Option<bool>>,
611 #[optional] stream: impl AsRef<Stream>,
612) -> Result<Array> {
613 array.as_ref().prod_axes_device(axes, keep_dims, stream)
614}
615
616#[generate_macro]
618#[default_device]
619pub fn prod_axis_device(
620 array: impl AsRef<Array>,
621 axis: i32,
622 #[optional] keep_dims: impl Into<Option<bool>>,
623 #[optional] stream: impl AsRef<Stream>,
624) -> Result<Array> {
625 array.as_ref().prod_axis_device(axis, keep_dims, stream)
626}
627
628#[generate_macro]
630#[default_device]
631pub fn prod_device(
632 array: impl AsRef<Array>,
633 #[optional] keep_dims: impl Into<Option<bool>>,
634 #[optional] stream: impl AsRef<Stream>,
635) -> Result<Array> {
636 array.as_ref().prod_device(keep_dims, stream)
637}
638
639#[generate_macro]
641#[default_device]
642pub fn max_axes_device(
643 array: impl AsRef<Array>,
644 axes: &[i32],
645 #[optional] keep_dims: impl Into<Option<bool>>,
646 #[optional] stream: impl AsRef<Stream>,
647) -> Result<Array> {
648 array.as_ref().max_axes_device(axes, keep_dims, stream)
649}
650
651#[generate_macro]
653#[default_device]
654pub fn max_axis_device(
655 array: impl AsRef<Array>,
656 axis: i32,
657 #[optional] keep_dims: impl Into<Option<bool>>,
658 #[optional] stream: impl AsRef<Stream>,
659) -> Result<Array> {
660 array.as_ref().max_axis_device(axis, keep_dims, stream)
661}
662
663#[generate_macro]
665#[default_device]
666pub fn max_device(
667 array: impl AsRef<Array>,
668 #[optional] keep_dims: impl Into<Option<bool>>,
669 #[optional] stream: impl AsRef<Stream>,
670) -> Result<Array> {
671 array.as_ref().max_device(keep_dims, stream)
672}
673
674#[generate_macro]
684#[default_device]
685pub fn std_axes_device(
686 a: impl AsRef<Array>,
687 axes: &[i32],
688 #[optional] keep_dims: impl Into<Option<bool>>,
689 #[optional] ddof: impl Into<Option<i32>>,
690 #[optional] stream: impl AsRef<Stream>,
691) -> Result<Array> {
692 let a = a.as_ref();
693 let keep_dims = keep_dims.into().unwrap_or(false);
694 let ddof = ddof.into().unwrap_or(0);
695 Array::try_from_op(|res| unsafe {
696 mlx_sys::mlx_std_axes(
697 res,
698 a.as_ptr(),
699 axes.as_ptr(),
700 axes.len(),
701 keep_dims,
702 ddof,
703 stream.as_ref().as_ptr(),
704 )
705 })
706}
707
708#[generate_macro]
710#[default_device]
711pub fn std_axis_device(
712 a: impl AsRef<Array>,
713 axis: i32,
714 #[optional] keep_dims: impl Into<Option<bool>>,
715 #[optional] ddof: impl Into<Option<i32>>,
716 #[optional] stream: impl AsRef<Stream>,
717) -> Result<Array> {
718 let a = a.as_ref();
719 let keep_dims = keep_dims.into().unwrap_or(false);
720 let ddof = ddof.into().unwrap_or(0);
721 Array::try_from_op(|res| unsafe {
722 mlx_sys::mlx_std_axis(
723 res,
724 a.as_ptr(),
725 axis,
726 keep_dims,
727 ddof,
728 stream.as_ref().as_ptr(),
729 )
730 })
731}
732
733#[generate_macro]
735#[default_device]
736pub fn std_device(
737 a: impl AsRef<Array>,
738 #[optional] keep_dims: impl Into<Option<bool>>,
739 #[optional] ddof: impl Into<Option<i32>>,
740 #[optional] stream: impl AsRef<Stream>,
741) -> Result<Array> {
742 let a = a.as_ref();
743 let keep_dims = keep_dims.into().unwrap_or(false);
744 let ddof = ddof.into().unwrap_or(0);
745 Array::try_from_op(|res| unsafe {
746 mlx_sys::mlx_std(res, a.as_ptr(), keep_dims, ddof, stream.as_ref().as_ptr())
747 })
748}
749
750#[generate_macro]
752#[default_device]
753pub fn sum_axes_device(
754 array: impl AsRef<Array>,
755 axes: &[i32],
756 #[optional] keep_dims: impl Into<Option<bool>>,
757 #[optional] stream: impl AsRef<Stream>,
758) -> Result<Array> {
759 array.as_ref().sum_axes_device(axes, keep_dims, stream)
760}
761
762#[generate_macro]
764#[default_device]
765pub fn sum_axis_device(
766 array: impl AsRef<Array>,
767 axis: i32,
768 #[optional] keep_dims: impl Into<Option<bool>>,
769 #[optional] stream: impl AsRef<Stream>,
770) -> Result<Array> {
771 array.as_ref().sum_axis_device(axis, keep_dims, stream)
772}
773
774#[generate_macro]
776#[default_device]
777pub fn sum_device(
778 array: impl AsRef<Array>,
779 #[optional] keep_dims: impl Into<Option<bool>>,
780 #[optional] stream: impl AsRef<Stream>,
781) -> Result<Array> {
782 array.as_ref().sum_device(keep_dims, stream)
783}
784
785#[generate_macro]
787#[default_device]
788pub fn mean_axes_device(
789 array: impl AsRef<Array>,
790 axes: &[i32],
791 #[optional] keep_dims: impl Into<Option<bool>>,
792 #[optional] stream: impl AsRef<Stream>,
793) -> Result<Array> {
794 array.as_ref().mean_axes_device(axes, keep_dims, stream)
795}
796
797#[generate_macro]
799#[default_device]
800pub fn mean_axis_device(
801 array: impl AsRef<Array>,
802 axis: i32,
803 #[optional] keep_dims: impl Into<Option<bool>>,
804 #[optional] stream: impl AsRef<Stream>,
805) -> Result<Array> {
806 array.as_ref().mean_axis_device(axis, keep_dims, stream)
807}
808
809#[generate_macro]
811#[default_device]
812pub fn mean_device(
813 array: impl AsRef<Array>,
814 #[optional] keep_dims: impl Into<Option<bool>>,
815 #[optional] stream: impl AsRef<Stream>,
816) -> Result<Array> {
817 array.as_ref().mean_device(keep_dims, stream)
818}
819
820#[generate_macro]
822#[default_device]
823pub fn min_axes_device(
824 array: impl AsRef<Array>,
825 axes: &[i32],
826 #[optional] keep_dims: impl Into<Option<bool>>,
827 #[optional] stream: impl AsRef<Stream>,
828) -> Result<Array> {
829 array.as_ref().min_axes_device(axes, keep_dims, stream)
830}
831
832#[generate_macro]
834#[default_device]
835pub fn min_axis_device(
836 array: impl AsRef<Array>,
837 axis: i32,
838 #[optional] keep_dims: impl Into<Option<bool>>,
839 #[optional] stream: impl AsRef<Stream>,
840) -> Result<Array> {
841 array.as_ref().min_axis_device(axis, keep_dims, stream)
842}
843
844#[generate_macro]
846#[default_device]
847pub fn min_device(
848 array: impl AsRef<Array>,
849 #[optional] keep_dims: impl Into<Option<bool>>,
850 #[optional] stream: impl AsRef<Stream>,
851) -> Result<Array> {
852 array.as_ref().min_device(keep_dims, stream)
853}
854
855#[generate_macro]
857#[default_device]
858pub fn var_axes_device(
859 array: impl AsRef<Array>,
860 axes: &[i32],
861 #[optional] keep_dims: impl Into<Option<bool>>,
862 #[optional] ddof: impl Into<Option<i32>>,
863 #[optional] stream: impl AsRef<Stream>,
864) -> Result<Array> {
865 array
866 .as_ref()
867 .var_axes_device(axes, keep_dims, ddof, stream)
868}
869
870#[generate_macro]
872#[default_device]
873pub fn var_axis_device(
874 array: impl AsRef<Array>,
875 axis: i32,
876 #[optional] keep_dims: impl Into<Option<bool>>,
877 #[optional] ddof: impl Into<Option<i32>>,
878 #[optional] stream: impl AsRef<Stream>,
879) -> Result<Array> {
880 array
881 .as_ref()
882 .var_axis_device(axis, keep_dims, ddof, stream)
883}
884
885#[generate_macro]
887#[default_device]
888pub fn var_device(
889 array: impl AsRef<Array>,
890 #[optional] keep_dims: impl Into<Option<bool>>,
891 #[optional] ddof: impl Into<Option<i32>>,
892 #[optional] stream: impl AsRef<Stream>,
893) -> Result<Array> {
894 array.as_ref().var_device(keep_dims, ddof, stream)
895}
896
897#[generate_macro]
899#[default_device]
900pub fn logsumexp_axes_device(
901 array: impl AsRef<Array>,
902 axes: &[i32],
903 #[optional] keep_dims: impl Into<Option<bool>>,
904 #[optional] stream: impl AsRef<Stream>,
905) -> Result<Array> {
906 array
907 .as_ref()
908 .logsumexp_axes_device(axes, keep_dims, stream)
909}
910
911#[generate_macro]
913#[default_device]
914pub fn logsumexp_axis_device(
915 array: impl AsRef<Array>,
916 axis: i32,
917 #[optional] keep_dims: impl Into<Option<bool>>,
918 #[optional] stream: impl AsRef<Stream>,
919) -> Result<Array> {
920 array
921 .as_ref()
922 .logsumexp_axis_device(axis, keep_dims, stream)
923}
924
925#[generate_macro]
927#[default_device]
928pub fn logsumexp_device(
929 array: impl AsRef<Array>,
930 #[optional] keep_dims: impl Into<Option<bool>>,
931 #[optional] stream: impl AsRef<Stream>,
932) -> Result<Array> {
933 array.as_ref().logsumexp_device(keep_dims, stream)
934}
935
936#[cfg(test)]
937mod tests {
938 use super::*;
939 use pretty_assertions::assert_eq;
940
941 #[test]
942 fn test_all() {
943 let array = Array::from_slice(&[true, false, true, false], &[2, 2]);
944
945 assert_eq!(array.all(None).unwrap().item::<bool>(), false);
946 assert_eq!(array.all(true).unwrap().shape(), &[1, 1]);
947 assert_eq!(array.all_axes(&[0, 1], None).unwrap().item::<bool>(), false);
948
949 let result = array.all_axis(0, None).unwrap();
950 assert_eq!(result.as_slice::<bool>(), &[true, false]);
951
952 let result = array.all_axis(1, None).unwrap();
953 assert_eq!(result.as_slice::<bool>(), &[false, false]);
954 }
955
956 #[test]
957 fn test_all_empty_axes() {
958 let array = Array::from_slice(&[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], &[3, 4]);
959 let all = array.all_axes(&[], None).unwrap();
960
961 let results: &[bool] = all.as_slice();
962 assert_eq!(
963 results,
964 &[false, true, true, true, true, true, true, true, true, true, true, true]
965 );
966 }
967
968 #[test]
969 fn test_prod() {
970 let x = Array::from_slice(&[1, 2, 3, 3], &[2, 2]);
971 assert_eq!(x.prod(None).unwrap().item::<i32>(), 18);
972
973 let y = x.prod(true).unwrap();
974 assert_eq!(y.item::<i32>(), 18);
975 assert_eq!(y.shape(), &[1, 1]);
976
977 let result = x.prod_axis(0, None).unwrap();
978 assert_eq!(result.as_slice::<i32>(), &[3, 6]);
979
980 let result = x.prod_axis(1, None).unwrap();
981 assert_eq!(result.as_slice::<i32>(), &[2, 9])
982 }
983
984 #[test]
985 fn test_prod_empty_axes() {
986 let array = Array::from_slice(&[5, 8, 4, 9], &[2, 2]);
987 let result = array.prod_axes(&[], None).unwrap();
988
989 let results: &[i32] = result.as_slice();
990 assert_eq!(results, &[5, 8, 4, 9]);
991 }
992
993 #[test]
994 fn test_max() {
995 let x = Array::from_slice(&[1, 2, 3, 4], &[2, 2]);
996 assert_eq!(x.max(None).unwrap().item::<i32>(), 4);
997 let y = x.max(true).unwrap();
998 assert_eq!(y.item::<i32>(), 4);
999 assert_eq!(y.shape(), &[1, 1]);
1000
1001 let result = x.max_axis(0, None).unwrap();
1002 assert_eq!(result.as_slice::<i32>(), &[3, 4]);
1003
1004 let result = x.max_axis(1, None).unwrap();
1005 assert_eq!(result.as_slice::<i32>(), &[2, 4]);
1006 }
1007
1008 #[test]
1009 fn test_max_empty_axes() {
1010 let array = Array::from_slice(&[5, 8, 4, 9], &[2, 2]);
1011 let result = array.max_axes(&[], None).unwrap();
1012
1013 let results: &[i32] = result.as_slice();
1014 assert_eq!(results, &[5, 8, 4, 9]);
1015 }
1016
1017 #[test]
1018 fn test_sum() {
1019 let array = Array::from_slice(&[5, 8, 4, 9], &[2, 2]);
1020 let result = array.sum_axis(0, None).unwrap();
1021
1022 let results: &[i32] = result.as_slice();
1023 assert_eq!(results, &[9, 17]);
1024 }
1025
1026 #[test]
1027 fn test_sum_empty_axes() {
1028 let array = Array::from_slice(&[5, 8, 4, 9], &[2, 2]);
1029 let result = array.sum_axes(&[], None).unwrap();
1030
1031 let results: &[i32] = result.as_slice();
1032 assert_eq!(results, &[5, 8, 4, 9]);
1033 }
1034
1035 #[test]
1036 fn test_mean() {
1037 let x = Array::from_slice(&[1, 2, 3, 4], &[2, 2]);
1038 assert_eq!(x.mean(None).unwrap().item::<f32>(), 2.5);
1039 let y = x.mean(true).unwrap();
1040 assert_eq!(y.item::<f32>(), 2.5);
1041 assert_eq!(y.shape(), &[1, 1]);
1042
1043 let result = x.mean_axis(0, None).unwrap();
1044 assert_eq!(result.as_slice::<f32>(), &[2.0, 3.0]);
1045
1046 let result = x.mean_axis(1, None).unwrap();
1047 assert_eq!(result.as_slice::<f32>(), &[1.5, 3.5]);
1048 }
1049
1050 #[test]
1051 fn test_mean_empty_axes() {
1052 let array = Array::from_slice(&[5, 8, 4, 9], &[2, 2]);
1053 let result = array.mean_axes(&[], None).unwrap();
1054
1055 let results: &[f32] = result.as_slice();
1056 assert_eq!(results, &[5.0, 8.0, 4.0, 9.0]);
1057 }
1058
1059 #[test]
1060 fn test_mean_out_of_bounds() {
1061 let array = Array::from_slice(&[5, 8, 4, 9], &[2, 2]);
1062 let result = array.mean_axis(2, None);
1063 assert!(result.is_err());
1064 }
1065
1066 #[test]
1067 fn test_min() {
1068 let x = Array::from_slice(&[1, 2, 3, 4], &[2, 2]);
1069 assert_eq!(x.min(None).unwrap().item::<i32>(), 1);
1070 let y = x.min(true).unwrap();
1071 assert_eq!(y.item::<i32>(), 1);
1072 assert_eq!(y.shape(), &[1, 1]);
1073
1074 let result = x.min_axis(0, None).unwrap();
1075 assert_eq!(result.as_slice::<i32>(), &[1, 2]);
1076
1077 let result = x.min_axis(1, None).unwrap();
1078 assert_eq!(result.as_slice::<i32>(), &[1, 3]);
1079 }
1080
1081 #[test]
1082 fn test_min_empty_axes() {
1083 let array = Array::from_slice(&[5, 8, 4, 9], &[2, 2]);
1084 let result = array.min_axes(&[], None).unwrap();
1085
1086 let results: &[i32] = result.as_slice();
1087 assert_eq!(results, &[5, 8, 4, 9]);
1088 }
1089
1090 #[test]
1091 fn test_var() {
1092 let x = Array::from_slice(&[1, 2, 3, 4], &[2, 2]);
1093 assert_eq!(x.var(None, None).unwrap().item::<f32>(), 1.25);
1094 let y = x.var(true, None).unwrap();
1095 assert_eq!(y.item::<f32>(), 1.25);
1096 assert_eq!(y.shape(), &[1, 1]);
1097
1098 let result = x.var_axis(0, None, None).unwrap();
1099 assert_eq!(result.as_slice::<f32>(), &[1.0, 1.0]);
1100
1101 let result = x.var_axis(1, None, None).unwrap();
1102 assert_eq!(result.as_slice::<f32>(), &[0.25, 0.25]);
1103
1104 let x = Array::from_slice(&[1.0, 2.0], &[2]);
1105 let out = x.var(None, Some(3)).unwrap();
1106 assert_eq!(out.item::<f32>(), f32::INFINITY);
1107 }
1108
1109 #[test]
1110 fn test_var_empty_axes() {
1111 let array = Array::from_slice(&[5, 8, 4, 9], &[2, 2]);
1112 let result = array.var_axes(&[], None, 0).unwrap();
1113
1114 let results: &[f32] = result.as_slice();
1115 assert_eq!(results, &[0.0, 0.0, 0.0, 0.0]);
1116 }
1117
1118 #[test]
1119 fn test_log_sum_exp() {
1120 let array = Array::from_slice(&[5, 8, 4, 9], &[2, 2]);
1121 let result = array.logsumexp_axis(0, None).unwrap();
1122
1123 let results: &[f32] = result.as_slice();
1124 assert_eq!(results, &[5.3132615, 9.313262]);
1125 }
1126
1127 #[test]
1128 fn test_log_sum_exp_empty_axes() {
1129 let array = Array::from_slice(&[5, 8, 4, 9], &[2, 2]);
1130 let result = array.logsumexp_axes(&[], None).unwrap();
1131
1132 let results: &[f32] = result.as_slice();
1133 assert_eq!(results, &[5.0, 8.0, 4.0, 9.0]);
1134 }
1135}