1use crate::array::Array;
2use crate::error::Result;
3use crate::utils::guard::Guarded;
4use crate::Stream;
5use mlx_internal_macros::{default_device, generate_macro};
6
7impl Array {
8 #[default_device]
29 pub fn eq_device(&self, other: impl AsRef<Array>, stream: impl AsRef<Stream>) -> Result<Array> {
30 Array::try_from_op(|res| unsafe {
31 mlx_sys::mlx_equal(
32 res,
33 self.as_ptr(),
34 other.as_ref().as_ptr(),
35 stream.as_ref().as_ptr(),
36 )
37 })
38 }
39
40 #[default_device]
61 pub fn le_device(&self, other: impl AsRef<Array>, stream: impl AsRef<Stream>) -> Result<Array> {
62 Array::try_from_op(|res| unsafe {
63 mlx_sys::mlx_less_equal(
64 res,
65 self.as_ptr(),
66 other.as_ref().as_ptr(),
67 stream.as_ref().as_ptr(),
68 )
69 })
70 }
71
72 #[default_device]
93 pub fn ge_device(&self, other: impl AsRef<Array>, stream: impl AsRef<Stream>) -> Result<Array> {
94 Array::try_from_op(|res| unsafe {
95 mlx_sys::mlx_greater_equal(
96 res,
97 self.as_ptr(),
98 other.as_ref().as_ptr(),
99 stream.as_ref().as_ptr(),
100 )
101 })
102 }
103
104 #[default_device]
125 pub fn ne_device(&self, other: impl AsRef<Array>, stream: impl AsRef<Stream>) -> Result<Array> {
126 Array::try_from_op(|res| unsafe {
127 mlx_sys::mlx_not_equal(
128 res,
129 self.as_ptr(),
130 other.as_ref().as_ptr(),
131 stream.as_ref().as_ptr(),
132 )
133 })
134 }
135
136 #[default_device]
156 pub fn lt_device(&self, other: impl AsRef<Array>, stream: impl AsRef<Stream>) -> Result<Array> {
157 Array::try_from_op(|res| unsafe {
158 mlx_sys::mlx_less(
159 res,
160 self.as_ptr(),
161 other.as_ref().as_ptr(),
162 stream.as_ref().as_ptr(),
163 )
164 })
165 }
166
167 #[default_device]
187 pub fn gt_device(&self, other: impl AsRef<Array>, stream: impl AsRef<Stream>) -> Result<Array> {
188 Array::try_from_op(|res| unsafe {
189 mlx_sys::mlx_greater(
190 res,
191 self.as_ptr(),
192 other.as_ref().as_ptr(),
193 stream.as_ref().as_ptr(),
194 )
195 })
196 }
197
198 #[default_device]
218 pub fn logical_and_device(
219 &self,
220 other: impl AsRef<Array>,
221 stream: impl AsRef<Stream>,
222 ) -> Result<Array> {
223 Array::try_from_op(|res| unsafe {
224 mlx_sys::mlx_logical_and(
225 res,
226 self.as_ptr(),
227 other.as_ref().as_ptr(),
228 stream.as_ref().as_ptr(),
229 )
230 })
231 }
232
233 #[default_device]
253 pub fn logical_or_device(
254 &self,
255 other: impl AsRef<Array>,
256 stream: impl AsRef<Stream>,
257 ) -> Result<Array> {
258 Array::try_from_op(|res| unsafe {
259 mlx_sys::mlx_logical_or(
260 res,
261 self.as_ptr(),
262 other.as_ref().as_ptr(),
263 stream.as_ref().as_ptr(),
264 )
265 })
266 }
267
268 #[default_device]
281 pub fn logical_not_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
282 Array::try_from_op(|res| unsafe {
283 mlx_sys::mlx_logical_not(res, self.as_ptr(), stream.as_ref().as_ptr())
284 })
285 }
286
287 #[default_device]
315 pub fn all_close_device(
316 &self,
317 other: impl AsRef<Array>,
318 rtol: impl Into<Option<f64>>,
319 atol: impl Into<Option<f64>>,
320 equal_nan: impl Into<Option<bool>>,
321 stream: impl AsRef<Stream>,
322 ) -> Result<Array> {
323 Array::try_from_op(|res| unsafe {
324 mlx_sys::mlx_allclose(
325 res,
326 self.as_ptr(),
327 other.as_ref().as_ptr(),
328 rtol.into().unwrap_or(1e-5),
329 atol.into().unwrap_or(1e-8),
330 equal_nan.into().unwrap_or(false),
331 stream.as_ref().as_ptr(),
332 )
333 })
334 }
335
336 #[default_device]
349 pub fn is_close_device(
350 &self,
351 other: impl AsRef<Array>,
352 rtol: impl Into<Option<f64>>,
353 atol: impl Into<Option<f64>>,
354 equal_nan: impl Into<Option<bool>>,
355 stream: impl AsRef<Stream>,
356 ) -> Result<Array> {
357 Array::try_from_op(|res| unsafe {
358 mlx_sys::mlx_isclose(
359 res,
360 self.as_ptr(),
361 other.as_ref().as_ptr(),
362 rtol.into().unwrap_or(1e-5),
363 atol.into().unwrap_or(1e-8),
364 equal_nan.into().unwrap_or(false),
365 stream.as_ref().as_ptr(),
366 )
367 })
368 }
369
370 #[default_device]
392 pub fn array_eq_device(
393 &self,
394 other: impl AsRef<Array>,
395 equal_nan: impl Into<Option<bool>>,
396 stream: impl AsRef<Stream>,
397 ) -> Result<Array> {
398 Array::try_from_op(|res| unsafe {
399 mlx_sys::mlx_array_equal(
400 res,
401 self.as_ptr(),
402 other.as_ref().as_ptr(),
403 equal_nan.into().unwrap_or(false),
404 stream.as_ref().as_ptr(),
405 )
406 })
407 }
408
409 #[default_device]
430 pub fn any_axes_device(
431 &self,
432 axes: &[i32],
433 keep_dims: impl Into<Option<bool>>,
434 stream: impl AsRef<Stream>,
435 ) -> Result<Array> {
436 Array::try_from_op(|res| unsafe {
437 mlx_sys::mlx_any_axes(
438 res,
439 self.as_ptr(),
440 axes.as_ptr(),
441 axes.len(),
442 keep_dims.into().unwrap_or(false),
443 stream.as_ref().as_ptr(),
444 )
445 })
446 }
447
448 #[default_device]
450 pub fn any_axis_device(
451 &self,
452 axis: i32,
453 keep_dims: impl Into<Option<bool>>,
454 stream: impl AsRef<Stream>,
455 ) -> Result<Array> {
456 Array::try_from_op(|res| unsafe {
457 mlx_sys::mlx_any_axis(
458 res,
459 self.as_ptr(),
460 axis,
461 keep_dims.into().unwrap_or(false),
462 stream.as_ref().as_ptr(),
463 )
464 })
465 }
466
467 #[default_device]
469 pub fn any_device(
470 &self,
471 keep_dims: impl Into<Option<bool>>,
472 stream: impl AsRef<Stream>,
473 ) -> Result<Array> {
474 Array::try_from_op(|res| unsafe {
475 mlx_sys::mlx_any(
476 res,
477 self.as_ptr(),
478 keep_dims.into().unwrap_or(false),
479 stream.as_ref().as_ptr(),
480 )
481 })
482 }
483}
484
485#[generate_macro]
487#[default_device]
488pub fn any_axes_device(
489 array: impl AsRef<Array>,
490 axes: &[i32],
491 #[optional] keep_dims: impl Into<Option<bool>>,
492 #[optional] stream: impl AsRef<Stream>,
493) -> Result<Array> {
494 array.as_ref().any_axes_device(axes, keep_dims, stream)
495}
496
497#[generate_macro]
499#[default_device]
500pub fn any_axis_device(
501 array: impl AsRef<Array>,
502 axis: i32,
503 #[optional] keep_dims: impl Into<Option<bool>>,
504 #[optional] stream: impl AsRef<Stream>,
505) -> Result<Array> {
506 array.as_ref().any_axis_device(axis, keep_dims, stream)
507}
508
509#[generate_macro]
511#[default_device]
512pub fn any_device(
513 array: impl AsRef<Array>,
514 #[optional] keep_dims: impl Into<Option<bool>>,
515 #[optional] stream: impl AsRef<Stream>,
516) -> Result<Array> {
517 array.as_ref().any_device(keep_dims, stream)
518}
519
520#[generate_macro]
522#[default_device]
523pub fn logical_and_device(
524 a: impl AsRef<Array>,
525 b: impl AsRef<Array>,
526 #[optional] stream: impl AsRef<Stream>,
527) -> Result<Array> {
528 a.as_ref().logical_and_device(b, stream)
529}
530
531#[generate_macro]
533#[default_device]
534pub fn logical_or_device(
535 a: impl AsRef<Array>,
536 b: impl AsRef<Array>,
537 #[optional] stream: impl AsRef<Stream>,
538) -> Result<Array> {
539 a.as_ref().logical_or_device(b, stream)
540}
541
542#[generate_macro]
544#[default_device]
545pub fn logical_not_device(
546 a: impl AsRef<Array>,
547 #[optional] stream: impl AsRef<Stream>,
548) -> Result<Array> {
549 a.as_ref().logical_not_device(stream)
550}
551
552#[generate_macro]
554#[default_device]
555pub fn all_close_device(
556 a: impl AsRef<Array>,
557 b: impl AsRef<Array>,
558 #[optional] rtol: impl Into<Option<f64>>,
559 #[optional] atol: impl Into<Option<f64>>,
560 #[optional] equal_nan: impl Into<Option<bool>>,
561 #[optional] stream: impl AsRef<Stream>,
562) -> Result<Array> {
563 a.as_ref()
564 .all_close_device(b, rtol, atol, equal_nan, stream)
565}
566
567#[generate_macro]
569#[default_device]
570pub fn is_close_device(
571 a: impl AsRef<Array>,
572 b: impl AsRef<Array>,
573 #[optional] rtol: impl Into<Option<f64>>,
574 #[optional] atol: impl Into<Option<f64>>,
575 #[optional] equal_nan: impl Into<Option<bool>>,
576 #[optional] stream: impl AsRef<Stream>,
577) -> Result<Array> {
578 a.as_ref().is_close_device(b, rtol, atol, equal_nan, stream)
579}
580
581#[generate_macro]
583#[default_device]
584pub fn array_eq_device(
585 a: impl AsRef<Array>,
586 b: impl AsRef<Array>,
587 #[optional] equal_nan: impl Into<Option<bool>>,
588 #[optional] stream: impl AsRef<Stream>,
589) -> Result<Array> {
590 a.as_ref().array_eq_device(b, equal_nan, stream)
591}
592
593#[generate_macro]
595#[default_device]
596pub fn eq_device(
597 a: impl AsRef<Array>,
598 b: impl AsRef<Array>,
599 #[optional] stream: impl AsRef<Stream>,
600) -> Result<Array> {
601 a.as_ref().eq_device(b, stream)
602}
603
604#[generate_macro]
606#[default_device]
607pub fn le_device(
608 a: impl AsRef<Array>,
609 b: impl AsRef<Array>,
610 #[optional] stream: impl AsRef<Stream>,
611) -> Result<Array> {
612 a.as_ref().le_device(b, stream)
613}
614
615#[generate_macro]
617#[default_device]
618pub fn ge_device(
619 a: impl AsRef<Array>,
620 b: impl AsRef<Array>,
621 #[optional] stream: impl AsRef<Stream>,
622) -> Result<Array> {
623 a.as_ref().ge_device(b, stream)
624}
625
626#[generate_macro]
628#[default_device]
629pub fn ne_device(
630 a: impl AsRef<Array>,
631 b: impl AsRef<Array>,
632 #[optional] stream: impl AsRef<Stream>,
633) -> Result<Array> {
634 a.as_ref().ne_device(b, stream)
635}
636
637#[generate_macro]
639#[default_device]
640pub fn lt_device(
641 a: impl AsRef<Array>,
642 b: impl AsRef<Array>,
643 #[optional] stream: impl AsRef<Stream>,
644) -> Result<Array> {
645 a.as_ref().lt_device(b, stream)
646}
647
648#[generate_macro]
650#[default_device]
651pub fn gt_device(
652 a: impl AsRef<Array>,
653 b: impl AsRef<Array>,
654 #[optional] stream: impl AsRef<Stream>,
655) -> Result<Array> {
656 a.as_ref().gt_device(b, stream)
657}
658
659#[generate_macro]
663#[default_device]
664pub fn is_nan_device(
665 array: impl AsRef<Array>,
666 #[optional] stream: impl AsRef<Stream>,
667) -> Result<Array> {
668 Array::try_from_op(|res| unsafe {
669 mlx_sys::mlx_isnan(res, array.as_ref().as_ptr(), stream.as_ref().as_ptr())
670 })
671}
672
673#[generate_macro]
675#[default_device]
676pub fn is_inf_device(
677 array: impl AsRef<Array>,
678 #[optional] stream: impl AsRef<Stream>,
679) -> Result<Array> {
680 Array::try_from_op(|res| unsafe {
681 mlx_sys::mlx_isinf(res, array.as_ref().as_ptr(), stream.as_ref().as_ptr())
682 })
683}
684
685#[generate_macro]
687#[default_device]
688pub fn is_pos_inf_device(
689 array: impl AsRef<Array>,
690 #[optional] stream: impl AsRef<Stream>,
691) -> Result<Array> {
692 Array::try_from_op(|res| unsafe {
693 mlx_sys::mlx_isposinf(res, array.as_ref().as_ptr(), stream.as_ref().as_ptr())
694 })
695}
696
697#[generate_macro]
699#[default_device]
700pub fn is_neg_inf_device(
701 array: impl AsRef<Array>,
702 #[optional] stream: impl AsRef<Stream>,
703) -> Result<Array> {
704 Array::try_from_op(|res| unsafe {
705 mlx_sys::mlx_isneginf(res, array.as_ref().as_ptr(), stream.as_ref().as_ptr())
706 })
707}
708
709#[default_device]
722pub fn r#where_device(
723 condition: impl AsRef<Array>,
724 a: impl AsRef<Array>,
725 b: impl AsRef<Array>,
726 stream: impl AsRef<Stream>,
727) -> Result<Array> {
728 Array::try_from_op(|res| unsafe {
729 mlx_sys::mlx_where(
730 res,
731 condition.as_ref().as_ptr(),
732 a.as_ref().as_ptr(),
733 b.as_ref().as_ptr(),
734 stream.as_ref().as_ptr(),
735 )
736 })
737}
738
739#[generate_macro]
741#[default_device]
742pub fn which_device(
743 condition: impl AsRef<Array>,
744 a: impl AsRef<Array>,
745 b: impl AsRef<Array>,
746 #[optional] stream: impl AsRef<Stream>,
747) -> Result<Array> {
748 r#where_device(condition, a, b, stream)
749}
750
751#[cfg(test)]
752mod tests {
753 use crate::{array, Dtype};
754
755 use super::*;
756
757 #[test]
758 fn test_eq() {
759 let a = Array::from_slice(&[1, 2, 3], &[3]);
760 let b = Array::from_slice(&[1, 2, 3], &[3]);
761 let c = a.eq(&b).unwrap();
762
763 let c_data: &[bool] = c.as_slice();
764 assert_eq!(c_data, [true, true, true]);
765
766 let a_data: &[i32] = a.as_slice();
768 assert_eq!(a_data, [1, 2, 3]);
769
770 let b_data: &[i32] = b.as_slice();
771 assert_eq!(b_data, [1, 2, 3]);
772 }
773
774 #[test]
775 fn test_eq_invalid_broadcast() {
776 let a = Array::from_slice(&[1, 2, 3], &[3]);
777 let b = Array::from_slice(&[1, 2, 3, 4], &[4]);
778 let c = a.eq(&b);
779 assert!(c.is_err());
780 }
781
782 #[test]
783 fn test_le() {
784 let a = Array::from_slice(&[1, 2, 3], &[3]);
785 let b = Array::from_slice(&[1, 2, 3], &[3]);
786 let c = a.le(&b).unwrap();
787
788 let c_data: &[bool] = c.as_slice();
789 assert_eq!(c_data, [true, true, true]);
790
791 let a_data: &[i32] = a.as_slice();
793 assert_eq!(a_data, [1, 2, 3]);
794
795 let b_data: &[i32] = b.as_slice();
796 assert_eq!(b_data, [1, 2, 3]);
797 }
798
799 #[test]
800 fn test_le_invalid_broadcast() {
801 let a = Array::from_slice(&[1, 2, 3], &[3]);
802 let b = Array::from_slice(&[1, 2, 3, 4], &[4]);
803 let c = a.le(&b);
804 assert!(c.is_err());
805 }
806
807 #[test]
808 fn test_ge() {
809 let a = Array::from_slice(&[1, 2, 3], &[3]);
810 let b = Array::from_slice(&[1, 2, 3], &[3]);
811 let c = a.ge(&b).unwrap();
812
813 let c_data: &[bool] = c.as_slice();
814 assert_eq!(c_data, [true, true, true]);
815
816 let a_data: &[i32] = a.as_slice();
818 assert_eq!(a_data, [1, 2, 3]);
819
820 let b_data: &[i32] = b.as_slice();
821 assert_eq!(b_data, [1, 2, 3]);
822 }
823
824 #[test]
825 fn test_ge_invalid_broadcast() {
826 let a = Array::from_slice(&[1, 2, 3], &[3]);
827 let b = Array::from_slice(&[1, 2, 3, 4], &[4]);
828 let c = a.ge(&b);
829 assert!(c.is_err());
830 }
831
832 #[test]
833 fn test_ne() {
834 let a = Array::from_slice(&[1, 2, 3], &[3]);
835 let b = Array::from_slice(&[1, 2, 3], &[3]);
836 let c = a.ne(&b).unwrap();
837
838 let c_data: &[bool] = c.as_slice();
839 assert_eq!(c_data, [false, false, false]);
840
841 let a_data: &[i32] = a.as_slice();
843 assert_eq!(a_data, [1, 2, 3]);
844
845 let b_data: &[i32] = b.as_slice();
846 assert_eq!(b_data, [1, 2, 3]);
847 }
848
849 #[test]
850 fn test_ne_invalid_broadcast() {
851 let a = Array::from_slice(&[1, 2, 3], &[3]);
852 let b = Array::from_slice(&[1, 2, 3, 4], &[4]);
853 let c = a.ne(&b);
854 assert!(c.is_err());
855 }
856
857 #[test]
858 fn test_lt() {
859 let a = Array::from_slice(&[1, 0, 3], &[3]);
860 let b = Array::from_slice(&[1, 2, 3], &[3]);
861 let c = a.lt(&b).unwrap();
862
863 let c_data: &[bool] = c.as_slice();
864 assert_eq!(c_data, [false, true, false]);
865
866 let a_data: &[i32] = a.as_slice();
868 assert_eq!(a_data, [1, 0, 3]);
869
870 let b_data: &[i32] = b.as_slice();
871 assert_eq!(b_data, [1, 2, 3]);
872 }
873
874 #[test]
875 fn test_lt_invalid_broadcast() {
876 let a = Array::from_slice(&[1, 2, 3], &[3]);
877 let b = Array::from_slice(&[1, 2, 3, 4], &[4]);
878 let c = a.lt(&b);
879 assert!(c.is_err());
880 }
881
882 #[test]
883 fn test_gt() {
884 let a = Array::from_slice(&[1, 4, 3], &[3]);
885 let b = Array::from_slice(&[1, 2, 3], &[3]);
886 let c = a.gt(&b).unwrap();
887
888 let c_data: &[bool] = c.as_slice();
889 assert_eq!(c_data, [false, true, false]);
890
891 let a_data: &[i32] = a.as_slice();
893 assert_eq!(a_data, [1, 4, 3]);
894
895 let b_data: &[i32] = b.as_slice();
896 assert_eq!(b_data, [1, 2, 3]);
897 }
898
899 #[test]
900 fn test_gt_invalid_broadcast() {
901 let a = Array::from_slice(&[1, 2, 3], &[3]);
902 let b = Array::from_slice(&[1, 2, 3, 4], &[4]);
903 let c = a.gt(&b);
904 assert!(c.is_err());
905 }
906
907 #[test]
908 fn test_logical_and() {
909 let a = Array::from_slice(&[true, false, true], &[3]);
910 let b = Array::from_slice(&[true, true, false], &[3]);
911 let c = a.logical_and(&b).unwrap();
912
913 let c_data: &[bool] = c.as_slice();
914 assert_eq!(c_data, [true, false, false]);
915
916 let a_data: &[bool] = a.as_slice();
918 assert_eq!(a_data, [true, false, true]);
919
920 let b_data: &[bool] = b.as_slice();
921 assert_eq!(b_data, [true, true, false]);
922 }
923
924 #[test]
925 fn test_logical_and_invalid_broadcast() {
926 let a = Array::from_slice(&[true, false, true], &[3]);
927 let b = Array::from_slice(&[true, true, false, true], &[4]);
928 let c = a.logical_and(&b);
929 assert!(c.is_err());
930 }
931
932 #[test]
933 fn test_logical_or() {
934 let a = Array::from_slice(&[true, false, true], &[3]);
935 let b = Array::from_slice(&[true, true, false], &[3]);
936 let c = a.logical_or(&b).unwrap();
937
938 let c_data: &[bool] = c.as_slice();
939 assert_eq!(c_data, [true, true, true]);
940
941 let a_data: &[bool] = a.as_slice();
943 assert_eq!(a_data, [true, false, true]);
944
945 let b_data: &[bool] = b.as_slice();
946 assert_eq!(b_data, [true, true, false]);
947 }
948
949 #[test]
950 fn test_logical_or_invalid_broadcast() {
951 let a = Array::from_slice(&[true, false, true], &[3]);
952 let b = Array::from_slice(&[true, true, false, true], &[4]);
953 let c = a.logical_or(&b);
954 assert!(c.is_err());
955 }
956
957 #[test]
958 fn test_all_close() {
959 let a = Array::from_slice(&[0., 1., 2., 3.], &[4]).sqrt().unwrap();
960 let b = Array::from_slice(&[0., 1., 2., 3.], &[4])
961 .power(array!(0.5))
962 .unwrap();
963 let c = a.all_close(&b, 1e-5, None, None).unwrap();
964
965 let c_data: &[bool] = c.as_slice();
966 assert_eq!(c_data, [true]);
967 }
968
969 #[test]
970 fn test_all_close_invalid_broadcast() {
971 let a = Array::from_slice(&[0., 1., 2., 3.], &[4]);
972 let b = Array::from_slice(&[0., 1., 2., 3., 4.], &[5]);
973 let c = a.all_close(&b, 1e-5, None, None);
974 assert!(c.is_err());
975 }
976
977 #[test]
978 fn test_is_close_false() {
979 let a = Array::from_slice(&[1., 2., 3.], &[3]);
980 let b = Array::from_slice(&[1.1, 2.2, 3.3], &[3]);
981 let c = a.is_close(&b, None, None, false).unwrap();
982
983 let c_data: &[bool] = c.as_slice();
984 assert_eq!(c_data, [false, false, false]);
985 }
986
987 #[test]
988 fn test_is_close_true() {
989 let a = Array::from_slice(&[1., 2., 3.], &[3]);
990 let b = Array::from_slice(&[1.1, 2.2, 3.3], &[3]);
991 let c = a.is_close(&b, 0.1, 0.2, true).unwrap();
992
993 let c_data: &[bool] = c.as_slice();
994 assert_eq!(c_data, [true, true, true]);
995 }
996
997 #[test]
998 fn test_is_close_invalid_broadcast() {
999 let a = Array::from_slice(&[1., 2., 3.], &[3]);
1000 let b = Array::from_slice(&[1.1, 2.2, 3.3, 4.4], &[4]);
1001 let c = a.is_close(&b, None, None, false);
1002 assert!(c.is_err());
1003 }
1004
1005 #[test]
1006 fn test_array_eq() {
1007 let a = Array::from_slice(&[0, 1, 2, 3], &[4]);
1008 let b = Array::from_slice(&[0., 1., 2., 3.], &[4]);
1009 let c = a.array_eq(&b, None).unwrap();
1010
1011 let c_data: &[bool] = c.as_slice();
1012 assert_eq!(c_data, [true]);
1013 }
1014
1015 #[test]
1016 fn test_any() {
1017 let array = Array::from_slice(&[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], &[3, 4]);
1018 let all = array.any_axes(&[0][..], None).unwrap();
1019
1020 let results: &[bool] = all.as_slice();
1021 assert_eq!(results, &[true, true, true, true]);
1022 }
1023
1024 #[test]
1025 fn test_any_empty_axes() {
1026 let array = Array::from_slice(&[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], &[3, 4]);
1027 let all = array.any_axes(&[][..], None).unwrap();
1028
1029 let results: &[bool] = all.as_slice();
1030 assert_eq!(
1031 results,
1032 &[false, true, true, true, true, true, true, true, true, true, true, true]
1033 );
1034 }
1035
1036 #[test]
1037 fn test_any_out_of_bounds() {
1038 let array = Array::from_slice(&[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], &[12]);
1039 let result = array.any_axes(&[1][..], None);
1040 assert!(result.is_err());
1041 }
1042
1043 #[test]
1044 fn test_any_duplicate_axes() {
1045 let array = Array::from_slice(&[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], &[3, 4]);
1046 let result = array.any_axes(&[0, 0][..], None);
1047 assert!(result.is_err());
1048 }
1049
1050 #[test]
1051 fn test_which() {
1052 let condition = Array::from_slice(&[true, false, true], &[3]);
1053 let a = Array::from_slice(&[1, 2, 3], &[3]);
1054 let b = Array::from_slice(&[4, 5, 6], &[3]);
1055 let c = which(&condition, &a, &b).unwrap();
1056
1057 let c_data: &[i32] = c.as_slice();
1058 assert_eq!(c_data, [1, 5, 3]);
1059 }
1060
1061 #[test]
1062 fn test_which_invalid_broadcast() {
1063 let condition = Array::from_slice(&[true, false, true], &[3]);
1064 let a = Array::from_slice(&[1, 2, 3], &[3]);
1065 let b = Array::from_slice(&[4, 5, 6, 7], &[4]);
1066 let c = which(&condition, &a, &b);
1067 assert!(c.is_err());
1068 }
1069
1070 #[test]
1073 fn test_unary_logical_not() {
1074 let x = array!(false);
1075 assert!(logical_not(&x).unwrap().item::<bool>());
1076
1077 let x = array!(1.0);
1078 let y = logical_not(&x).unwrap();
1079 assert_eq!(y.dtype(), Dtype::Bool);
1080 assert!(!y.item::<bool>());
1081
1082 let x = array!(0);
1083 let y = logical_not(&x).unwrap();
1084 assert_eq!(y.dtype(), Dtype::Bool);
1085 assert!(y.item::<bool>());
1086 }
1087
1088 #[test]
1089 fn test_unary_logical_and() {
1090 let x = array!(true);
1091 let y = array!(true);
1092 assert!(logical_and(&x, &y).unwrap().item::<bool>());
1093
1094 let x = array!(1.0);
1095 let y = array!(1.0);
1096 let z = logical_and(&x, &y).unwrap();
1097 assert_eq!(z.dtype(), Dtype::Bool);
1098 assert!(z.item::<bool>());
1099
1100 let x = array!(0);
1101 let y = array!(1.0);
1102 let z = logical_and(&x, &y).unwrap();
1103 assert_eq!(z.dtype(), Dtype::Bool);
1104 assert!(!z.item::<bool>());
1105 }
1106
1107 #[test]
1108 fn test_unary_logical_or() {
1109 let a = array!(false);
1110 let b = array!(false);
1111 assert!(!logical_or(&a, &b).unwrap().item::<bool>());
1112
1113 let a = array!(1.0);
1114 let b = array!(1.0);
1115 let c = logical_or(&a, &b).unwrap();
1116 assert_eq!(c.dtype(), Dtype::Bool);
1117 assert!(c.item::<bool>());
1118
1119 let a = array!(0);
1120 let b = array!(1.0);
1121 let c = logical_or(&a, &b).unwrap();
1122 assert_eq!(c.dtype(), Dtype::Bool);
1123 assert!(c.item::<bool>());
1124 }
1125}