1use crate::array::Array;
2use crate::error::Result;
3use crate::sealed::Sealed;
4use crate::stream::StreamOrDevice;
5
6use crate::utils::guard::Guarded;
7use crate::utils::{IntoOption, ScalarOrArray, VectorArray};
8use crate::Stream;
9use mlx_internal_macros::{default_device, generate_macro};
10use smallvec::SmallVec;
11
12impl Array {
13 #[default_device]
26 pub fn abs_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
27 Array::try_from_op(|res| unsafe {
28 mlx_sys::mlx_abs(res, self.as_ptr(), stream.as_ref().as_ptr())
29 })
30 }
31
32 #[default_device]
52 pub fn add_device(
53 &self,
54 other: impl AsRef<Array>,
55 stream: impl AsRef<Stream>,
56 ) -> Result<Array> {
57 Array::try_from_op(|res| unsafe {
58 mlx_sys::mlx_add(
59 res,
60 self.as_ptr(),
61 other.as_ref().as_ptr(),
62 stream.as_ref().as_ptr(),
63 )
64 })
65 }
66
67 #[default_device]
87 pub fn subtract_device(
88 &self,
89 other: impl AsRef<Array>,
90 stream: impl AsRef<Stream>,
91 ) -> Result<Array> {
92 Array::try_from_op(|res| unsafe {
93 mlx_sys::mlx_subtract(
94 res,
95 self.as_ptr(),
96 other.as_ref().as_ptr(),
97 stream.as_ref().as_ptr(),
98 )
99 })
100 }
101
102 #[default_device]
117 pub fn negative_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
118 Array::try_from_op(|res| unsafe {
119 mlx_sys::mlx_negative(res, self.as_ptr(), stream.as_ref().as_ptr())
120 })
121 }
122
123 #[default_device]
139 pub fn multiply_device(
140 &self,
141 other: impl AsRef<Array>,
142 stream: impl AsRef<Stream>,
143 ) -> Result<Array> {
144 Array::try_from_op(|res| unsafe {
145 mlx_sys::mlx_multiply(
146 res,
147 self.as_ptr(),
148 other.as_ref().as_ptr(),
149 stream.as_ref().as_ptr(),
150 )
151 })
152 }
153
154 #[default_device]
164 pub fn nan_to_num_device(
165 &self,
166 nan: impl IntoOption<f32>,
167 pos_inf: impl IntoOption<f32>,
168 neg_inf: impl IntoOption<f32>,
169 stream: impl AsRef<Stream>,
170 ) -> Result<Array> {
171 let pos_inf = pos_inf.into_option();
172 let neg_inf = neg_inf.into_option();
173
174 let pos_inf = mlx_sys::mlx_optional_float {
175 value: pos_inf.unwrap_or(0.0),
176 has_value: pos_inf.is_some(),
177 };
178 let neg_inf = mlx_sys::mlx_optional_float {
179 value: neg_inf.unwrap_or(0.0),
180 has_value: neg_inf.is_some(),
181 };
182
183 Array::try_from_op(|res| unsafe {
184 mlx_sys::mlx_nan_to_num(
185 res,
186 self.as_ptr(),
187 nan.into_option().unwrap_or(0.),
188 pos_inf,
189 neg_inf,
190 stream.as_ref().as_ptr(),
191 )
192 })
193 }
194
195 #[default_device]
215 pub fn divide_device(
216 &self,
217 other: impl AsRef<Array>,
218 stream: impl AsRef<Stream>,
219 ) -> Result<Array> {
220 Array::try_from_op(|res| unsafe {
221 mlx_sys::mlx_divide(
222 res,
223 self.as_ptr(),
224 other.as_ref().as_ptr(),
225 stream.as_ref().as_ptr(),
226 )
227 })
228 }
229
230 #[default_device]
250 pub fn power_device(
251 &self,
252 other: impl AsRef<Array>,
253 stream: impl AsRef<Stream>,
254 ) -> Result<Array> {
255 Array::try_from_op(|res| unsafe {
256 mlx_sys::mlx_power(
257 res,
258 self.as_ptr(),
259 other.as_ref().as_ptr(),
260 stream.as_ref().as_ptr(),
261 )
262 })
263 }
264
265 #[default_device]
285 pub fn remainder_device(
286 &self,
287 other: impl AsRef<Array>,
288 stream: impl AsRef<Stream>,
289 ) -> Result<Array> {
290 Array::try_from_op(|res| unsafe {
291 mlx_sys::mlx_remainder(
292 res,
293 self.as_ptr(),
294 other.as_ref().as_ptr(),
295 stream.as_ref().as_ptr(),
296 )
297 })
298 }
299
300 #[default_device]
313 pub fn sqrt_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
314 Array::try_from_op(|res| unsafe {
315 mlx_sys::mlx_sqrt(res, self.as_ptr(), stream.as_ref().as_ptr())
316 })
317 }
318
319 #[default_device]
332 pub fn cos_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
333 Array::try_from_op(|res| unsafe {
334 mlx_sys::mlx_cos(res, self.as_ptr(), stream.as_ref().as_ptr())
335 })
336 }
337
338 #[default_device]
353 pub fn exp_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
354 Array::try_from_op(|res| unsafe {
355 mlx_sys::mlx_exp(res, self.as_ptr(), stream.as_ref().as_ptr())
356 })
357 }
358
359 #[default_device]
372 pub fn floor_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
373 Array::try_from_op(|res| unsafe {
374 mlx_sys::mlx_floor(res, self.as_ptr(), stream.as_ref().as_ptr())
375 })
376 }
377
378 #[default_device]
402 pub fn floor_divide_device(
403 &self,
404 other: impl AsRef<Array>,
405 stream: impl AsRef<Stream>,
406 ) -> Result<Array> {
407 Array::try_from_op(|res| unsafe {
408 mlx_sys::mlx_floor_divide(
409 res,
410 self.as_ptr(),
411 other.as_ref().as_ptr(),
412 stream.as_ref().as_ptr(),
413 )
414 })
415 }
416
417 #[default_device]
422 pub fn is_nan_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
423 Array::try_from_op(|res| unsafe {
424 mlx_sys::mlx_isnan(res, self.as_ptr(), stream.as_ref().as_ptr())
425 })
426 }
427
428 #[default_device]
433 pub fn is_inf_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
434 Array::try_from_op(|res| unsafe {
435 mlx_sys::mlx_isinf(res, self.as_ptr(), stream.as_ref().as_ptr())
436 })
437 }
438
439 #[default_device]
444 pub fn is_finite_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
445 Array::try_from_op(|res| unsafe {
446 mlx_sys::mlx_isfinite(res, self.as_ptr(), stream.as_ref().as_ptr())
447 })
448 }
449
450 #[default_device]
455 pub fn is_neg_inf_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
456 Array::try_from_op(|res| unsafe {
457 mlx_sys::mlx_isneginf(res, self.as_ptr(), stream.as_ref().as_ptr())
458 })
459 }
460
461 #[default_device]
466 pub fn is_pos_inf_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
467 Array::try_from_op(|res| unsafe {
468 mlx_sys::mlx_isposinf(res, self.as_ptr(), stream.as_ref().as_ptr())
469 })
470 }
471
472 #[default_device]
485 pub fn log_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
486 Array::try_from_op(|res| unsafe {
487 mlx_sys::mlx_log(res, self.as_ptr(), stream.as_ref().as_ptr())
488 })
489 }
490
491 #[default_device]
504 pub fn log2_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
505 Array::try_from_op(|res| unsafe {
506 mlx_sys::mlx_log2(res, self.as_ptr(), stream.as_ref().as_ptr())
507 })
508 }
509
510 #[default_device]
523 pub fn log10_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
524 Array::try_from_op(|res| unsafe {
525 mlx_sys::mlx_log10(res, self.as_ptr(), stream.as_ref().as_ptr())
526 })
527 }
528
529 #[default_device]
542 pub fn log1p_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
543 Array::try_from_op(|res| unsafe {
544 mlx_sys::mlx_log1p(res, self.as_ptr(), stream.as_ref().as_ptr())
545 })
546 }
547
548 #[default_device]
578 pub fn matmul_device(
579 &self,
580 other: impl AsRef<Array>,
581 stream: impl AsRef<Stream>,
582 ) -> Result<Array> {
583 Array::try_from_op(|res| unsafe {
584 mlx_sys::mlx_matmul(
585 res,
586 self.as_ptr(),
587 other.as_ref().as_ptr(),
588 stream.as_ref().as_ptr(),
589 )
590 })
591 }
592
593 #[default_device]
606 pub fn reciprocal_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
607 Array::try_from_op(|res| unsafe {
608 mlx_sys::mlx_reciprocal(res, self.as_ptr(), stream.as_ref().as_ptr())
609 })
610 }
611
612 #[default_device]
618 pub fn round_device(
619 &self,
620 decimals: impl Into<Option<i32>>,
621 stream: impl AsRef<Stream>,
622 ) -> Result<Array> {
623 Array::try_from_op(|res| unsafe {
624 mlx_sys::mlx_round(
625 res,
626 self.as_ptr(),
627 decimals.into().unwrap_or(0),
628 stream.as_ref().as_ptr(),
629 )
630 })
631 }
632
633 #[default_device]
635 pub fn rsqrt_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
636 Array::try_from_op(|res| unsafe {
637 mlx_sys::mlx_rsqrt(res, self.as_ptr(), stream.as_ref().as_ptr())
638 })
639 }
640
641 #[default_device]
643 pub fn sin_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
644 Array::try_from_op(|res| unsafe {
645 mlx_sys::mlx_sin(res, self.as_ptr(), stream.as_ref().as_ptr())
646 })
647 }
648
649 #[default_device]
651 pub fn square_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
652 Array::try_from_op(|res| unsafe {
653 mlx_sys::mlx_square(res, self.as_ptr(), stream.as_ref().as_ptr())
654 })
655 }
656}
657
658#[generate_macro]
669#[default_device]
670pub fn abs_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
671 a.as_ref().abs_device(stream)
672}
673
674#[generate_macro]
676#[default_device]
677pub fn acos_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
678 Array::try_from_op(|res| unsafe {
679 mlx_sys::mlx_arccos(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
680 })
681}
682
683#[generate_macro]
685#[default_device]
686pub fn acosh_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
687 Array::try_from_op(|res| unsafe {
688 mlx_sys::mlx_arccosh(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
689 })
690}
691
692#[generate_macro]
694#[default_device]
695pub fn add_device(
696 lhs: impl AsRef<Array>,
697 rhs: impl AsRef<Array>,
698 #[optional] stream: impl AsRef<Stream>,
699) -> Result<Array> {
700 lhs.as_ref().add_device(rhs, stream)
701}
702
703#[generate_macro]
705#[default_device]
706pub fn asin_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
707 Array::try_from_op(|res| unsafe {
708 mlx_sys::mlx_arcsin(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
709 })
710}
711
712#[generate_macro]
714#[default_device]
715pub fn asinh_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
716 Array::try_from_op(|res| unsafe {
717 mlx_sys::mlx_arcsinh(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
718 })
719}
720
721#[generate_macro]
723#[default_device]
724pub fn atan_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
725 Array::try_from_op(|res| unsafe {
726 mlx_sys::mlx_arctan(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
727 })
728}
729
730#[generate_macro]
732#[default_device]
733pub fn atanh_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
734 Array::try_from_op(|res| unsafe {
735 mlx_sys::mlx_arctanh(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
736 })
737}
738
739#[generate_macro]
741#[default_device]
742pub fn ceil_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
743 Array::try_from_op(|res| unsafe {
744 mlx_sys::mlx_ceil(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
745 })
746}
747
748pub trait ClipBound<'min, 'max>: Sealed {
753 fn into_min_max(
755 self,
756 ) -> (
757 Option<impl ScalarOrArray<'min>>,
758 Option<impl ScalarOrArray<'max>>,
759 );
760}
761
762impl<'min, Min> ClipBound<'min, 'min> for (Min, ())
763where
764 Min: ScalarOrArray<'min> + Sealed,
765{
766 fn into_min_max(
767 self,
768 ) -> (
769 Option<impl ScalarOrArray<'min>>,
770 Option<impl ScalarOrArray<'min>>,
771 ) {
772 (Some(self.0), Option::<Min>::None)
773 }
774}
775
776impl<'max, Max> ClipBound<'max, 'max> for ((), Max)
777where
778 Max: ScalarOrArray<'max> + Sealed,
779{
780 fn into_min_max(
781 self,
782 ) -> (
783 Option<impl ScalarOrArray<'max>>,
784 Option<impl ScalarOrArray<'max>>,
785 ) {
786 (Option::<Max>::None, Some(self.1))
787 }
788}
789
790impl<'min, 'max, Min, Max> ClipBound<'min, 'max> for (Min, Max)
791where
792 Min: ScalarOrArray<'min> + Sealed,
793 Max: ScalarOrArray<'max> + Sealed,
794{
795 fn into_min_max(
796 self,
797 ) -> (
798 Option<impl ScalarOrArray<'min>>,
799 Option<impl ScalarOrArray<'max>>,
800 ) {
801 (Some(self.0), Some(self.1))
802 }
803}
804
805#[generate_macro]
827#[default_device]
828pub fn clip_device<'min, 'max>(
829 a: impl AsRef<Array>,
830 bound: impl ClipBound<'min, 'max>,
831 #[optional] stream: impl AsRef<Stream>,
832) -> Result<Array> {
833 let (a_min, a_max) = bound.into_min_max();
834
835 let a_min = a_min.map(|min| min.into_owned_or_ref_array());
837 let a_max = a_max.map(|max| max.into_owned_or_ref_array());
838
839 unsafe {
840 let min_ptr = match &a_min {
841 Some(a_min) => a_min.as_ref().as_ptr(),
842 None => mlx_sys::mlx_array_new(),
843 };
844 let max_ptr = match &a_max {
845 Some(a_max) => a_max.as_ref().as_ptr(),
846 None => mlx_sys::mlx_array_new(),
847 };
848
849 Array::try_from_op(|res| {
850 mlx_sys::mlx_clip(
851 res,
852 a.as_ref().as_ptr(),
853 min_ptr,
854 max_ptr,
855 stream.as_ref().as_ptr(),
856 )
857 })
858 }
859}
860
861#[generate_macro]
863#[default_device]
864pub fn cos_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
865 a.as_ref().cos_device(stream)
866}
867
868#[generate_macro]
870#[default_device]
871pub fn cosh_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
872 Array::try_from_op(|res| unsafe {
873 mlx_sys::mlx_cosh(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
874 })
875}
876
877#[generate_macro]
879#[default_device]
880pub fn degrees_device(
881 a: impl AsRef<Array>,
882 #[optional] stream: impl AsRef<Stream>,
883) -> Result<Array> {
884 Array::try_from_op(|res| unsafe {
885 mlx_sys::mlx_degrees(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
886 })
887}
888
889#[generate_macro]
891#[default_device]
892pub fn divide_device(
893 a: impl AsRef<Array>,
894 b: impl AsRef<Array>,
895 #[optional] stream: impl AsRef<Stream>,
896) -> Result<Array> {
897 a.as_ref().divide_device(b, stream)
898}
899
900#[generate_macro]
907#[default_device]
908pub fn divmod_device(
909 a: impl AsRef<Array>,
910 b: impl AsRef<Array>,
911 #[optional] stream: impl AsRef<Stream>,
912) -> Result<(Array, Array)> {
913 let a_ptr = a.as_ref().as_ptr();
914 let b_ptr = b.as_ref().as_ptr();
915
916 let vec = VectorArray::try_from_op(|res| unsafe {
917 mlx_sys::mlx_divmod(res, a_ptr, b_ptr, stream.as_ref().as_ptr())
918 })?;
919
920 let vals: SmallVec<[_; 2]> = vec.try_into_values()?;
921 let mut iter = vals.into_iter();
922 let quotient = iter.next().unwrap();
923 let remainder = iter.next().unwrap();
924
925 Ok((quotient, remainder))
926}
927
928#[generate_macro]
930#[default_device]
931pub fn erf_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
932 Array::try_from_op(|res| unsafe {
933 mlx_sys::mlx_erf(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
934 })
935}
936
937#[generate_macro]
939#[default_device]
940pub fn erfinv_device(
941 a: impl AsRef<Array>,
942 #[optional] stream: impl AsRef<Stream>,
943) -> Result<Array> {
944 Array::try_from_op(|res| unsafe {
945 mlx_sys::mlx_erfinv(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
946 })
947}
948
949#[generate_macro]
951#[default_device]
952pub fn exp_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
953 a.as_ref().exp_device(stream)
954}
955
956#[generate_macro]
958#[default_device]
959pub fn expm1_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
960 Array::try_from_op(|res| unsafe {
961 mlx_sys::mlx_expm1(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
962 })
963}
964
965#[generate_macro]
967#[default_device]
968pub fn floor_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
969 a.as_ref().floor_device(stream)
970}
971
972#[generate_macro]
974#[default_device]
975pub fn floor_divide_device(
976 a: impl AsRef<Array>,
977 other: impl AsRef<Array>,
978 #[optional] stream: impl AsRef<Stream>,
979) -> Result<Array> {
980 a.as_ref().floor_divide_device(other, stream)
981}
982
983#[generate_macro]
985#[default_device]
986pub fn log_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
987 a.as_ref().log_device(stream)
988}
989
990#[generate_macro]
992#[default_device]
993pub fn log10_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
994 a.as_ref().log10_device(stream)
995}
996
997#[generate_macro]
999#[default_device]
1000pub fn log1p_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
1001 a.as_ref().log1p_device(stream)
1002}
1003
1004#[generate_macro]
1006#[default_device]
1007pub fn log2_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
1008 a.as_ref().log2_device(stream)
1009}
1010
1011#[generate_macro]
1018#[default_device]
1019pub fn log_add_exp_device(
1020 a: impl AsRef<Array>,
1021 b: impl AsRef<Array>,
1022 #[optional] stream: impl AsRef<Stream>,
1023) -> Result<Array> {
1024 let a_ptr = a.as_ref().as_ptr();
1025 let b_ptr = b.as_ref().as_ptr();
1026
1027 Array::try_from_op(|res| unsafe {
1028 mlx_sys::mlx_logaddexp(res, a_ptr, b_ptr, stream.as_ref().as_ptr())
1029 })
1030}
1031
1032#[generate_macro]
1034#[default_device]
1035pub fn matmul_device(
1036 a: impl AsRef<Array>,
1037 b: impl AsRef<Array>,
1038 #[optional] stream: impl AsRef<Stream>,
1039) -> Result<Array> {
1040 a.as_ref().matmul_device(b, stream)
1041}
1042
1043#[generate_macro]
1048#[default_device]
1049pub fn maximum_device(
1050 a: impl AsRef<Array>,
1051 b: impl AsRef<Array>,
1052 #[optional] stream: impl AsRef<Stream>,
1053) -> Result<Array> {
1054 let a_ptr = a.as_ref().as_ptr();
1055 let b_ptr = b.as_ref().as_ptr();
1056
1057 Array::try_from_op(|res| unsafe {
1058 mlx_sys::mlx_maximum(res, a_ptr, b_ptr, stream.as_ref().as_ptr())
1059 })
1060}
1061
1062#[generate_macro]
1067#[default_device]
1068pub fn minimum_device(
1069 a: impl AsRef<Array>,
1070 b: impl AsRef<Array>,
1071 #[optional] stream: impl AsRef<Stream>,
1072) -> Result<Array> {
1073 let a_ptr = a.as_ref().as_ptr();
1074 let b_ptr = b.as_ref().as_ptr();
1075
1076 Array::try_from_op(|res| unsafe {
1077 mlx_sys::mlx_minimum(res, a_ptr, b_ptr, stream.as_ref().as_ptr())
1078 })
1079}
1080
1081#[generate_macro]
1083#[default_device]
1084pub fn multiply_device(
1085 a: impl AsRef<Array>,
1086 b: impl AsRef<Array>,
1087 #[optional] stream: impl AsRef<Stream>,
1088) -> Result<Array> {
1089 a.as_ref().multiply_device(b, stream)
1090}
1091
1092#[generate_macro]
1094#[default_device]
1095pub fn negative_device(
1096 a: impl AsRef<Array>,
1097 #[optional] stream: impl AsRef<Stream>,
1098) -> Result<Array> {
1099 a.as_ref().negative_device(stream)
1100}
1101
1102#[generate_macro]
1104#[default_device]
1105pub fn power_device(
1106 a: impl AsRef<Array>,
1107 b: impl AsRef<Array>,
1108 #[optional] stream: impl AsRef<Stream>,
1109) -> Result<Array> {
1110 a.as_ref().power_device(b, stream)
1111}
1112
1113#[generate_macro]
1115#[default_device]
1116pub fn radians_device(
1117 a: impl AsRef<Array>,
1118 #[optional] stream: impl AsRef<Stream>,
1119) -> Result<Array> {
1120 Array::try_from_op(|res| unsafe {
1121 mlx_sys::mlx_radians(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
1122 })
1123}
1124
1125#[generate_macro]
1127#[default_device]
1128pub fn reciprocal_device(
1129 a: impl AsRef<Array>,
1130 #[optional] stream: impl AsRef<Stream>,
1131) -> Result<Array> {
1132 a.as_ref().reciprocal_device(stream)
1133}
1134
1135#[generate_macro]
1137#[default_device]
1138pub fn remainder_device(
1139 a: impl AsRef<Array>,
1140 b: impl AsRef<Array>,
1141 #[optional] stream: impl AsRef<Stream>,
1142) -> Result<Array> {
1143 a.as_ref().remainder_device(b, stream)
1144}
1145
1146#[generate_macro]
1148#[default_device]
1149pub fn round_device(
1150 a: impl AsRef<Array>,
1151 decimals: impl Into<Option<i32>>,
1152 #[optional] stream: impl AsRef<Stream>,
1153) -> Result<Array> {
1154 a.as_ref().round_device(decimals, stream)
1155}
1156
1157#[generate_macro]
1159#[default_device]
1160pub fn rsqrt_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
1161 a.as_ref().rsqrt_device(stream)
1162}
1163
1164#[generate_macro]
1170#[default_device]
1171pub fn sigmoid_device(
1172 a: impl AsRef<Array>,
1173 #[optional] stream: impl AsRef<Stream>,
1174) -> Result<Array> {
1175 Array::try_from_op(|res| unsafe {
1176 mlx_sys::mlx_sigmoid(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
1177 })
1178}
1179
1180#[generate_macro]
1182#[default_device]
1183pub fn sign_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
1184 Array::try_from_op(|res| unsafe {
1185 mlx_sys::mlx_sign(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
1186 })
1187}
1188
1189#[generate_macro]
1191#[default_device]
1192pub fn sin_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
1193 a.as_ref().sin_device(stream)
1194}
1195
1196#[generate_macro]
1198#[default_device]
1199pub fn sinh_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
1200 Array::try_from_op(|res| unsafe {
1201 mlx_sys::mlx_sinh(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
1202 })
1203}
1204
1205#[generate_macro]
1211#[default_device]
1212pub fn softmax_device(
1213 a: impl AsRef<Array>,
1214 axes: &[i32],
1215 precise: impl Into<Option<bool>>,
1216 #[optional] stream: impl AsRef<Stream>,
1217) -> Result<Array> {
1218 let precise = precise.into().unwrap_or(false);
1219 let s = stream.as_ref().as_ptr();
1220
1221 Array::try_from_op(|res| unsafe {
1222 mlx_sys::mlx_softmax(
1223 res,
1224 a.as_ref().as_ptr(),
1225 axes.as_ptr(),
1226 axes.len(),
1227 precise,
1228 s,
1229 )
1230 })
1231}
1232
1233#[generate_macro]
1235#[default_device]
1236pub fn softmax_all_device(
1237 a: impl AsRef<Array>,
1238 #[optional] precise: impl Into<Option<bool>>,
1239 #[optional] stream: impl AsRef<Stream>,
1240) -> Result<Array> {
1241 let precise = precise.into().unwrap_or(false);
1242 let s = stream.as_ref().as_ptr();
1243
1244 Array::try_from_op(|res| unsafe {
1245 mlx_sys::mlx_softmax_all(res, a.as_ref().as_ptr(), precise, s)
1246 })
1247}
1248
1249#[generate_macro]
1251#[default_device]
1252pub fn sqrt_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
1253 a.as_ref().sqrt_device(stream)
1254}
1255
1256#[generate_macro]
1258#[default_device]
1259pub fn square_device(
1260 a: impl AsRef<Array>,
1261 #[optional] stream: impl AsRef<Stream>,
1262) -> Result<Array> {
1263 a.as_ref().square_device(stream)
1264}
1265
1266#[generate_macro]
1268#[default_device]
1269pub fn subtract_device(
1270 a: impl AsRef<Array>,
1271 b: impl AsRef<Array>,
1272 #[optional] stream: impl AsRef<Stream>,
1273) -> Result<Array> {
1274 a.as_ref().subtract_device(b, stream)
1275}
1276
1277#[generate_macro]
1279#[default_device]
1280pub fn tan_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
1281 Array::try_from_op(|res| unsafe {
1282 mlx_sys::mlx_tan(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
1283 })
1284}
1285
1286#[generate_macro]
1288#[default_device]
1289pub fn tanh_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
1290 Array::try_from_op(|res| unsafe {
1291 mlx_sys::mlx_tanh(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
1292 })
1293}
1294
1295#[generate_macro]
1301#[default_device]
1302pub fn block_masked_mm_device<'mo, 'lhs, 'rhs>(
1303 a: impl AsRef<Array>,
1304 b: impl AsRef<Array>,
1305 #[optional] block_size: impl Into<Option<i32>>,
1306 #[optional] mask_out: impl Into<Option<&'mo Array>>,
1307 #[optional] mask_lhs: impl Into<Option<&'lhs Array>>,
1308 #[optional] mask_rhs: impl Into<Option<&'rhs Array>>,
1309 #[optional] stream: impl AsRef<Stream>,
1310) -> Result<Array> {
1311 let a_ptr = a.as_ref().as_ptr();
1312 let b_ptr = b.as_ref().as_ptr();
1313 unsafe {
1314 let mask_out_ptr = mask_out
1315 .into()
1316 .map(|m| m.as_ptr())
1317 .unwrap_or(mlx_sys::mlx_array_new());
1318 let mask_lhs_ptr = mask_lhs
1319 .into()
1320 .map(|m| m.as_ptr())
1321 .unwrap_or(mlx_sys::mlx_array_new());
1322 let mask_rhs_ptr = mask_rhs
1323 .into()
1324 .map(|m| m.as_ptr())
1325 .unwrap_or(mlx_sys::mlx_array_new());
1326
1327 Array::try_from_op(|res| {
1328 mlx_sys::mlx_block_masked_mm(
1329 res,
1330 a_ptr,
1331 b_ptr,
1332 block_size.into().unwrap_or(32),
1333 mask_out_ptr,
1334 mask_lhs_ptr,
1335 mask_rhs_ptr,
1336 stream.as_ref().as_ptr(),
1337 )
1338 })
1339 }
1340}
1341
1342#[generate_macro]
1355#[default_device]
1356pub fn addmm_device(
1357 c: impl AsRef<Array>,
1358 a: impl AsRef<Array>,
1359 b: impl AsRef<Array>,
1360 #[optional] alpha: impl Into<Option<f32>>,
1361 #[optional] beta: impl Into<Option<f32>>,
1362 #[optional] stream: impl AsRef<Stream>,
1363) -> Result<Array> {
1364 let c_ptr = c.as_ref().as_ptr();
1365 let a_ptr = a.as_ref().as_ptr();
1366 let b_ptr = b.as_ref().as_ptr();
1367 let alpha = alpha.into().unwrap_or(1.0);
1368 let beta = beta.into().unwrap_or(1.0);
1369
1370 Array::try_from_op(|res| unsafe {
1371 mlx_sys::mlx_addmm(
1372 res,
1373 c_ptr,
1374 a_ptr,
1375 b_ptr,
1376 alpha,
1377 beta,
1378 stream.as_ref().as_ptr(),
1379 )
1380 })
1381}
1382
1383#[generate_macro]
1386#[default_device]
1387pub fn inner_device(
1388 a: impl AsRef<Array>,
1389 b: impl AsRef<Array>,
1390 #[optional] stream: impl AsRef<Stream>,
1391) -> Result<Array> {
1392 let a = a.as_ref();
1393 let b = b.as_ref();
1394 Array::try_from_op(|res| unsafe {
1395 mlx_sys::mlx_inner(res, a.as_ptr(), b.as_ptr(), stream.as_ref().as_ptr())
1396 })
1397}
1398
1399#[generate_macro]
1402#[default_device]
1403pub fn outer_device(
1404 a: impl AsRef<Array>,
1405 b: impl AsRef<Array>,
1406 #[optional] stream: impl AsRef<Stream>,
1407) -> Result<Array> {
1408 let a = a.as_ref();
1409 let b = b.as_ref();
1410 Array::try_from_op(|res| unsafe {
1411 mlx_sys::mlx_outer(res, a.as_ptr(), b.as_ptr(), stream.as_ref().as_ptr())
1412 })
1413}
1414
1415#[derive(Debug)]
1417pub enum TensorDotDims<'a> {
1418 Int(i32),
1420
1421 List((&'a [i32], &'a [i32])),
1423}
1424
1425impl From<i32> for TensorDotDims<'_> {
1426 fn from(i: i32) -> Self {
1427 TensorDotDims::Int(i)
1428 }
1429}
1430
1431impl<'a> From<(&'a [i32], &'a [i32])> for TensorDotDims<'a> {
1432 fn from((lhs, rhs): (&'a [i32], &'a [i32])) -> Self {
1433 TensorDotDims::List((lhs, rhs))
1434 }
1435}
1436
1437impl<'a, const M: usize, const N: usize> From<(&'a [i32; M], &'a [i32; N])> for TensorDotDims<'a> {
1438 fn from((lhs, rhs): (&'a [i32; M], &'a [i32; N])) -> Self {
1439 TensorDotDims::List((lhs, rhs))
1440 }
1441}
1442
1443#[generate_macro]
1453#[default_device]
1454pub fn tensordot_device<'a>(
1455 a: impl AsRef<Array>,
1456 b: impl AsRef<Array>,
1457 #[optional] axes: impl Into<TensorDotDims<'a>>,
1458 #[optional] stream: impl AsRef<Stream>,
1459) -> Result<Array> {
1460 let a = a.as_ref();
1461 let b = b.as_ref();
1462 match axes.into() {
1463 TensorDotDims::Int(dim) => Array::try_from_op(|res| unsafe {
1464 mlx_sys::mlx_tensordot_along_axis(
1465 res,
1466 a.as_ptr(),
1467 b.as_ptr(),
1468 dim,
1469 stream.as_ref().as_ptr(),
1470 )
1471 }),
1472 TensorDotDims::List((lhs, rhs)) => Array::try_from_op(|res| unsafe {
1473 mlx_sys::mlx_tensordot(
1474 res,
1475 a.as_ptr(),
1476 b.as_ptr(),
1477 lhs.as_ptr(),
1478 lhs.len(),
1479 rhs.as_ptr(),
1480 rhs.len(),
1481 stream.as_ref().as_ptr(),
1482 )
1483 }),
1484 }
1485}
1486
1487#[cfg(test)]
1488mod tests {
1489 use std::f32::consts::PI;
1490
1491 use super::*;
1492 use crate::{
1493 array, complex64,
1494 ops::{all_close, arange, broadcast_to, eye, full, linspace, ones, reshape, split_equal},
1495 transforms::eval,
1496 Dtype,
1497 };
1498 use float_eq::assert_float_eq;
1499 use pretty_assertions::assert_eq;
1500
1501 #[test]
1502 fn test_abs() {
1503 let data = [1i32, 2, -3, -4, -5];
1504 let array = Array::from_slice(&data, &[5]);
1505 let result = array.abs().unwrap();
1506
1507 let data: &[i32] = result.as_slice();
1508 assert_eq!(data, [1, 2, 3, 4, 5]);
1509
1510 let data: &[i32] = array.as_slice();
1512 assert_eq!(data, [1, 2, -3, -4, -5]);
1513 }
1514
1515 #[test]
1516 fn test_add() {
1517 let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1518 let b = Array::from_slice(&[4.0, 5.0, 6.0], &[3]);
1519
1520 let c = &a + &b;
1521
1522 let c_data: &[f32] = c.as_slice();
1523 assert_eq!(c_data, &[5.0, 7.0, 9.0]);
1524
1525 let a_data: &[f32] = a.as_slice();
1527 assert_eq!(a_data, &[1.0, 2.0, 3.0]);
1528
1529 let b_data: &[f32] = b.as_slice();
1530 assert_eq!(b_data, &[4.0, 5.0, 6.0]);
1531 }
1532
1533 #[test]
1534 fn test_add_invalid_broadcast() {
1535 let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1536 let b = Array::from_slice(&[4.0, 5.0], &[2]);
1537
1538 let c = a.add(&b);
1539 assert!(c.is_err());
1540 }
1541
1542 #[test]
1543 fn test_sub() {
1544 let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1545 let b = Array::from_slice(&[4.0, 5.0, 6.0], &[3]);
1546
1547 let c = &a - &b;
1548
1549 let c_data: &[f32] = c.as_slice();
1550 assert_eq!(c_data, &[-3.0, -3.0, -3.0]);
1551
1552 let a_data: &[f32] = a.as_slice();
1554 assert_eq!(a_data, &[1.0, 2.0, 3.0]);
1555
1556 let b_data: &[f32] = b.as_slice();
1557 assert_eq!(b_data, &[4.0, 5.0, 6.0]);
1558 }
1559
1560 #[test]
1561 fn test_sub_invalid_broadcast() {
1562 let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1563 let b = Array::from_slice(&[4.0, 5.0], &[2]);
1564 let c = a.subtract(&b);
1565 assert!(c.is_err());
1566 }
1567
1568 #[test]
1569 fn test_neg() {
1570 let a = Array::from_slice::<f32>(&[1.0, 2.0, 3.0], &[3]);
1571 let b = a.negative().unwrap();
1572
1573 let b_data: &[f32] = b.as_slice();
1574 assert_eq!(b_data, &[-1.0, -2.0, -3.0]);
1575
1576 let a_data: &[f32] = a.as_slice();
1578 assert_eq!(a_data, &[1.0, 2.0, 3.0]);
1579 }
1580
1581 #[test]
1582 fn test_neg_bool() {
1583 let a = Array::from_slice(&[true, false, true], &[3]);
1584 let b = a.negative();
1585 assert!(b.is_err());
1586 }
1587
1588 #[test]
1589 fn test_logical_not() {
1590 let a: Array = false.into();
1591 let b = a.logical_not().unwrap();
1592
1593 let b_data: &[bool] = b.as_slice();
1594 assert_eq!(b_data, [true]);
1595 }
1596
1597 #[test]
1598 fn test_mul() {
1599 let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1600 let b = Array::from_slice(&[4.0, 5.0, 6.0], &[3]);
1601
1602 let c = &a * &b;
1603
1604 let c_data: &[f32] = c.as_slice();
1605 assert_eq!(c_data, &[4.0, 10.0, 18.0]);
1606
1607 let a_data: &[f32] = a.as_slice();
1609 assert_eq!(a_data, &[1.0, 2.0, 3.0]);
1610
1611 let b_data: &[f32] = b.as_slice();
1612 assert_eq!(b_data, &[4.0, 5.0, 6.0]);
1613 }
1614
1615 #[test]
1616 fn test_mul_invalid_broadcast() {
1617 let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1618 let b = Array::from_slice(&[4.0, 5.0], &[2]);
1619 let c = a.multiply(&b);
1620 assert!(c.is_err());
1621 }
1622
1623 #[test]
1624 fn test_nan_to_num() {
1625 let a = array!([1.0, 2.0, f32::NAN, 4.0, 5.0]);
1626 let b = a.nan_to_num(0.0, 1.0, 0.0).unwrap();
1627
1628 let b_data: &[f32] = b.as_slice();
1629 assert_eq!(b_data, &[1.0, 2.0, 0.0, 4.0, 5.0]);
1630 }
1631
1632 #[test]
1633 fn test_div() {
1634 let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1635 let b = Array::from_slice(&[4.0, 5.0, 6.0], &[3]);
1636
1637 let c = &a / &b;
1638
1639 let c_data: &[f32] = c.as_slice();
1640 assert_eq!(c_data, &[0.25, 0.4, 0.5]);
1641
1642 let a_data: &[f32] = a.as_slice();
1644 assert_eq!(a_data, &[1.0, 2.0, 3.0]);
1645
1646 let b_data: &[f32] = b.as_slice();
1647 assert_eq!(b_data, &[4.0, 5.0, 6.0]);
1648 }
1649
1650 #[test]
1651 fn test_div_invalid_broadcast() {
1652 let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1653 let b = Array::from_slice(&[4.0, 5.0], &[2]);
1654 let c = a.divide(&b);
1655 assert!(c.is_err());
1656 }
1657
1658 #[test]
1659 fn test_pow() {
1660 let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1661 let b = Array::from_slice(&[2.0, 3.0, 4.0], &[3]);
1662
1663 let c = a.power(&b).unwrap();
1664
1665 let c_data: &[f32] = c.as_slice();
1666 assert_eq!(c_data, &[1.0, 8.0, 81.0]);
1667
1668 let a_data: &[f32] = a.as_slice();
1670 assert_eq!(a_data, &[1.0, 2.0, 3.0]);
1671
1672 let b_data: &[f32] = b.as_slice();
1673 assert_eq!(b_data, &[2.0, 3.0, 4.0]);
1674 }
1675
1676 #[test]
1677 fn test_pow_invalid_broadcast() {
1678 let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1679 let b = Array::from_slice(&[2.0, 3.0], &[2]);
1680 let c = a.power(&b);
1681 assert!(c.is_err());
1682 }
1683
1684 #[test]
1685 fn test_rem() {
1686 let a = Array::from_slice(&[10.0, 11.0, 12.0], &[3]);
1687 let b = Array::from_slice(&[3.0, 4.0, 5.0], &[3]);
1688
1689 let c = &a % &b;
1690
1691 let c_data: &[f32] = c.as_slice();
1692 assert_eq!(c_data, &[1.0, 3.0, 2.0]);
1693
1694 let a_data: &[f32] = a.as_slice();
1696 assert_eq!(a_data, &[10.0, 11.0, 12.0]);
1697
1698 let b_data: &[f32] = b.as_slice();
1699 assert_eq!(b_data, &[3.0, 4.0, 5.0]);
1700 }
1701
1702 #[test]
1703 fn test_rem_invalid_broadcast() {
1704 let a = Array::from_slice(&[10.0, 11.0, 12.0], &[3]);
1705 let b = Array::from_slice(&[3.0, 4.0], &[2]);
1706 let c = a.remainder(&b);
1707 assert!(c.is_err());
1708 }
1709
1710 #[test]
1711 fn test_sqrt() {
1712 let a = Array::from_slice(&[1.0, 4.0, 9.0], &[3]);
1713 let b = a.sqrt().unwrap();
1714
1715 let b_data: &[f32] = b.as_slice();
1716 assert_eq!(b_data, &[1.0, 2.0, 3.0]);
1717
1718 let a_data: &[f32] = a.as_slice();
1720 assert_eq!(a_data, &[1.0, 4.0, 9.0]);
1721 }
1722
1723 #[test]
1724 fn test_cos() {
1725 let a = Array::from_slice(&[0.0, 1.0, 2.0], &[3]);
1726 let b = a.cos().unwrap();
1727
1728 let b_data: &[f32] = b.as_slice();
1729 assert_eq!(b_data, &[1.0, 0.54030234, -0.41614687]);
1730
1731 let a_data: &[f32] = a.as_slice();
1733 assert_eq!(a_data, &[0.0, 1.0, 2.0]);
1734 }
1735
1736 #[test]
1737 fn test_exp() {
1738 let a = Array::from_slice(&[0.0, 1.0, 2.0], &[3]);
1739 let b = a.exp().unwrap();
1740
1741 let b_data: &[f32] = b.as_slice();
1742 assert_eq!(b_data, &[1.0, 2.7182817, 7.389056]);
1743
1744 let a_data: &[f32] = a.as_slice();
1746 assert_eq!(a_data, &[0.0, 1.0, 2.0]);
1747 }
1748
1749 #[test]
1750 fn test_floor() {
1751 let a = Array::from_slice(&[0.1, 1.9, 2.5], &[3]);
1752 let b = a.floor().unwrap();
1753
1754 let b_data: &[f32] = b.as_slice();
1755 assert_eq!(b_data, &[0.0, 1.0, 2.0]);
1756
1757 let a_data: &[f32] = a.as_slice();
1759 assert_eq!(a_data, &[0.1, 1.9, 2.5]);
1760 }
1761
1762 #[test]
1763 fn test_floor_complex64() {
1764 let val = complex64::new(1.0, 2.0);
1765 let a = Array::from_complex(val);
1766 let b = a.floor_device(StreamOrDevice::default());
1767 assert!(b.is_err());
1768 }
1769
1770 #[test]
1771 fn test_floor_divide() {
1772 let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1773 let b = Array::from_slice(&[4.0, 5.0, 6.0], &[3]);
1774
1775 let c = a.floor_divide(&b).unwrap();
1776
1777 let c_data: &[f32] = c.as_slice();
1778 assert_eq!(c_data, &[0.0, 0.0, 0.0]);
1779
1780 let a_data: &[f32] = a.as_slice();
1782 assert_eq!(a_data, &[1.0, 2.0, 3.0]);
1783
1784 let b_data: &[f32] = b.as_slice();
1785 assert_eq!(b_data, &[4.0, 5.0, 6.0]);
1786 }
1787
1788 #[test]
1789 fn test_floor_divide_complex64() {
1790 let val = complex64::new(1.0, 2.0);
1791 let a = Array::from_complex(val);
1792 let b = Array::from_slice(&[4.0, 5.0, 6.0], &[3]);
1793 let c = a.floor_divide_device(&b, StreamOrDevice::default());
1794 assert!(c.is_err());
1795 }
1796
1797 #[test]
1798 fn test_floor_divide_invalid_broadcast() {
1799 let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1800 let b = Array::from_slice(&[4.0, 5.0], &[2]);
1801 let c = a.floor_divide_device(&b, StreamOrDevice::default());
1802 assert!(c.is_err());
1803 }
1804
1805 #[test]
1806 fn test_is_nan() {
1807 let a = Array::from_slice(&[1.0, f32::NAN, 3.0], &[3]);
1808 let b = a.is_nan().unwrap();
1809
1810 let b_data: &[bool] = b.as_slice();
1811 assert_eq!(b_data, &[false, true, false]);
1812 }
1813
1814 #[test]
1815 fn test_is_inf() {
1816 let a = Array::from_slice(&[1.0, f32::INFINITY, 3.0], &[3]);
1817 let b = a.is_inf().unwrap();
1818
1819 let b_data: &[bool] = b.as_slice();
1820 assert_eq!(b_data, &[false, true, false]);
1821 }
1822
1823 #[test]
1824 fn test_is_finite() {
1825 let a = Array::from_slice(&[1.0, f32::INFINITY, 3.0], &[3]);
1826 let b = a.is_finite().unwrap();
1827
1828 let b_data: &[bool] = b.as_slice();
1829 assert_eq!(b_data, &[true, false, true]);
1830 }
1831
1832 #[test]
1833 fn test_is_neg_inf() {
1834 let a = Array::from_slice(&[1.0, f32::NEG_INFINITY, 3.0], &[3]);
1835 let b = a.is_neg_inf().unwrap();
1836
1837 let b_data: &[bool] = b.as_slice();
1838 assert_eq!(b_data, &[false, true, false]);
1839 }
1840
1841 #[test]
1842 fn test_is_pos_inf() {
1843 let a = Array::from_slice(&[1.0, f32::INFINITY, 3.0], &[3]);
1844 let b = a.is_pos_inf().unwrap();
1845
1846 let b_data: &[bool] = b.as_slice();
1847 assert_eq!(b_data, &[false, true, false]);
1848 }
1849
1850 #[test]
1851 fn test_log() {
1852 let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1853 let b = a.log().unwrap();
1854
1855 let b_data: &[f32] = b.as_slice();
1856 assert_eq!(b_data, &[0.0, 0.6931472, 1.0986123]);
1857
1858 let a_data: &[f32] = a.as_slice();
1860 assert_eq!(a_data, &[1.0, 2.0, 3.0]);
1861 }
1862
1863 #[test]
1864 fn test_log2() {
1865 let a = Array::from_slice(&[1.0, 2.0, 4.0, 8.0], &[4]);
1866 let b = a.log2().unwrap();
1867
1868 let b_data: &[f32] = b.as_slice();
1869 assert_eq!(b_data, &[0.0, 1.0, 2.0, 3.0]);
1870
1871 let a_data: &[f32] = a.as_slice();
1873 assert_eq!(a_data, &[1.0, 2.0, 4.0, 8.0]);
1874 }
1875
1876 #[test]
1877 fn test_log10() {
1878 let a = Array::from_slice(&[1.0, 10.0, 100.0], &[3]);
1879 let b = a.log10().unwrap();
1880
1881 let b_data: &[f32] = b.as_slice();
1882 assert_eq!(b_data, &[0.0, 1.0, 2.0]);
1883
1884 let a_data: &[f32] = a.as_slice();
1886 assert_eq!(a_data, &[1.0, 10.0, 100.0]);
1887 }
1888
1889 #[test]
1890 fn test_log1p() {
1891 let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1892 let b = a.log1p().unwrap();
1893
1894 let b_data: &[f32] = b.as_slice();
1895 assert_eq!(b_data, &[0.6931472, 1.0986123, 1.3862944]);
1896
1897 let a_data: &[f32] = a.as_slice();
1899 assert_eq!(a_data, &[1.0, 2.0, 3.0]);
1900 }
1901
1902 #[test]
1903 fn test_matmul() {
1904 let a = Array::from_slice(&[1, 2, 3, 4], &[2, 2]);
1905 let b = Array::from_slice(&[-5.0, 37.5, 4., 7., 1., 0.], &[2, 3]);
1906
1907 let c = a.matmul(&b).unwrap();
1908
1909 assert_eq!(c.shape(), &[2, 3]);
1910 let c_data: &[f32] = c.as_slice();
1911 assert_eq!(c_data, &[9.0, 39.5, 4.0, 13.0, 116.5, 12.0]);
1912
1913 let a_data: &[i32] = a.as_slice();
1915 assert_eq!(a_data, &[1, 2, 3, 4]);
1916
1917 let b_data: &[f32] = b.as_slice();
1918 assert_eq!(b_data, &[-5.0, 37.5, 4., 7., 1., 0.]);
1919 }
1920
1921 #[test]
1922 fn test_matmul_ndim_zero() {
1923 let a: Array = 1.0.into();
1924 let b = Array::from_slice::<i32>(&[1], &[1]);
1925 let c = a.matmul(&b);
1926 assert!(c.is_err());
1927 }
1928
1929 #[test]
1930 fn test_matmul_ndim_one() {
1931 let a = Array::from_slice(&[1.0, 2.0, 3.0, 4.0], &[4]);
1932 let b = Array::from_slice(&[1.0, 2.0, 3.0, 4.0], &[4]);
1933 let c = a.matmul(&b);
1934 assert!(c.is_ok());
1935 }
1936
1937 #[test]
1938 fn test_matmul_dim_mismatch() {
1939 let a = Array::from_slice(&[1, 2, 3, 4, 5, 6], &[2, 3]);
1940 let b = Array::from_slice(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10], &[2, 5]);
1941 let c = a.matmul(&b);
1942 assert!(c.is_err());
1943 }
1944
1945 #[test]
1946 fn test_matmul_non_float_output_type() {
1947 let a = Array::from_slice(&[1, 2, 3, 4], &[2, 2]);
1948 let b = Array::from_slice(&[5, 37, 4, 7, 1, 0], &[2, 3]);
1949
1950 let c = a.matmul(&b);
1951 assert!(c.is_err());
1952 }
1953
1954 #[test]
1955 fn test_reciprocal() {
1956 let a = Array::from_slice(&[1.0, 2.0, 4.0], &[3]);
1957 let b = a.reciprocal().unwrap();
1958
1959 let b_data: &[f32] = b.as_slice();
1960 assert_eq!(b_data, &[1.0, 0.5, 0.25]);
1961
1962 let a_data: &[f32] = a.as_slice();
1964 assert_eq!(a_data, &[1.0, 2.0, 4.0]);
1965 }
1966
1967 #[test]
1968 fn test_round() {
1969 let a = Array::from_slice(&[1.1, 2.9, 3.5], &[3]);
1970 let b = a.round(None).unwrap();
1971
1972 let b_data: &[f32] = b.as_slice();
1973 assert_eq!(b_data, &[1.0, 3.0, 4.0]);
1974
1975 let a_data: &[f32] = a.as_slice();
1977 assert_eq!(a_data, &[1.1, 2.9, 3.5]);
1978 }
1979
1980 #[test]
1981 fn test_rsqrt() {
1982 let a = Array::from_slice(&[1.0, 2.0, 4.0], &[3]);
1983 let b = a.rsqrt().unwrap();
1984
1985 let b_data: &[f32] = b.as_slice();
1986 assert_eq!(b_data, &[1.0, 0.70710677, 0.5]);
1987
1988 let a_data: &[f32] = a.as_slice();
1990 assert_eq!(a_data, &[1.0, 2.0, 4.0]);
1991 }
1992
1993 #[test]
1994 fn test_sin() {
1995 let a = Array::from_slice(&[0.0, 1.0, 2.0], &[3]);
1996 let b = a.sin().unwrap();
1997
1998 let b_data: &[f32] = b.as_slice();
1999 assert_eq!(b_data, &[0.0, 0.841471, 0.9092974]);
2000
2001 let a_data: &[f32] = a.as_slice();
2003 assert_eq!(a_data, &[0.0, 1.0, 2.0]);
2004 }
2005
2006 #[test]
2007 fn test_square() {
2008 let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
2009 let b = a.square().unwrap();
2010
2011 let b_data: &[f32] = b.as_slice();
2012 assert_eq!(b_data, &[1.0, 4.0, 9.0]);
2013
2014 let a_data: &[f32] = a.as_slice();
2016 assert_eq!(a_data, &[1.0, 2.0, 3.0]);
2017 }
2018
2019 #[test]
2022 fn test_unary_neg() {
2023 let x = array!(1.0);
2024 assert_eq!(negative(&x).unwrap().item::<f32>(), -1.0);
2025 assert_eq!((-x).item::<f32>(), -1.0);
2026
2027 assert_eq!(-array!(), array!());
2029
2030 let x = array!(true);
2032 assert!(negative(&x).is_err());
2033 }
2034
2035 #[test]
2036 fn test_unary_abs() {
2037 let x = array!([-1.0, 0.0, 1.0]);
2038 assert_eq!(abs(&x).unwrap(), array!([1.0, 0.0, 1.0]));
2039
2040 assert_eq!(abs(array!()).unwrap(), array!());
2042
2043 let x = array!([-1, 0, 1]);
2045 assert_eq!(abs(&x).unwrap(), array!([1, 0, 1]));
2046
2047 let x = array!([1u32, 0, 1]);
2049 assert_eq!(abs(&x).unwrap(), array!([1u32, 0, 1]));
2050
2051 let x = array!([false, true]);
2053 assert_eq!(abs(&x).unwrap(), array!([false, true]));
2054 }
2055
2056 #[test]
2057 fn test_unary_sign() {
2058 let x = array!([-1.0, 0.0, 1.0]);
2059 assert_eq!(sign(&x).unwrap(), x);
2060
2061 assert_eq!(sign(array!()).unwrap(), array!());
2063
2064 let x = array!([-1, 0, 1]);
2066 assert_eq!(sign(&x).unwrap(), x);
2067
2068 let x = array!([1u32, 0, 1]);
2070 assert_eq!(sign(&x).unwrap(), x);
2071
2072 let x = array!([false, true]);
2074 assert_eq!(sign(&x).unwrap(), x);
2075 }
2076
2077 const NEG_INF: f32 = f32::NEG_INFINITY;
2078
2079 #[test]
2080 fn test_unary_floor_ceil() {
2081 let x = array![1.0];
2082 assert_eq!(floor(&x).unwrap().item::<f32>(), 1.0);
2083 assert_eq!(ceil(&x).unwrap().item::<f32>(), 1.0);
2084
2085 let x = array![1.5];
2086 assert_eq!(floor(&x).unwrap().item::<f32>(), 1.0);
2087 assert_eq!(ceil(&x).unwrap().item::<f32>(), 2.0);
2088
2089 let x = array![-1.5];
2090 assert_eq!(floor(&x).unwrap().item::<f32>(), -2.0);
2091 assert_eq!(ceil(&x).unwrap().item::<f32>(), -1.0);
2092
2093 let x = array![NEG_INF];
2094 assert_eq!(floor(&x).unwrap().item::<f32>(), NEG_INF);
2095 assert_eq!(ceil(&x).unwrap().item::<f32>(), NEG_INF);
2096
2097 let x = array!([1.0, 1.0]).as_type::<complex64>().unwrap();
2098 assert!(floor(&x).is_err());
2099 assert!(ceil(&x).is_err());
2100 }
2101
2102 #[test]
2103 fn test_unary_round() {
2104 let x = array!([0.5, -0.5, 1.5, -1.5, 2.3, 2.6]);
2105 assert_eq!(round(&x, None).unwrap(), array!([0, 0, 2, -2, 2, 3]));
2106
2107 let x = array!([11, 222, 32]);
2108 assert_eq!(round(&x, -1).unwrap(), array!([10, 220, 30]));
2109 }
2110
2111 #[test]
2112 fn test_unary_exp() {
2113 let x = array![0.0];
2114 assert_eq!(exp(&x).unwrap().item::<f32>(), 1.0);
2115
2116 let x = array![2.0];
2117 assert_float_eq! {
2118 exp(&x).unwrap().item::<f32>(),
2119 2.0f32.exp(),
2120 abs <= 1e-5
2121 };
2122
2123 assert_eq!(exp(array!()).unwrap(), array!());
2124
2125 let x = array![NEG_INF];
2126 assert_eq!(exp(&x).unwrap().item::<f32>(), 0.0);
2127
2128 let x = array![2];
2130 assert_eq!(x.dtype(), Dtype::Int32);
2131 assert_float_eq! {
2132 exp(&x).unwrap().item::<f32>(),
2133 2.0f32.exp(),
2134 abs <= 1e-5
2135 };
2136
2137 let x = broadcast_to(&array!(1.0), &[2, 2, 2]).unwrap();
2139 let res = exp(&x).unwrap();
2140 let expected = Array::full::<f32>(&[2, 2, 2], array!(1.0f32.exp())).unwrap();
2141 assert!(all_close(&res, &expected, None, None, None)
2142 .unwrap()
2143 .item::<bool>());
2144
2145 let data = Array::from_slice(&[0.0, 1.0, 2.0, 3.0], &[2, 2]);
2146 let x = split_equal(&data, 2, 1).unwrap();
2147 let expected = Array::from_slice(&[0.0f32.exp(), 2.0f32.exp()], &[2, 1]);
2148 assert!(all_close(exp(&x[0]).unwrap(), &expected, None, None, None)
2149 .unwrap()
2150 .item::<bool>());
2151 }
2152
2153 #[test]
2154 fn test_unary_expm1() {
2155 let x = array![-1.0];
2156 assert_float_eq! {
2157 expm1(&x).unwrap().item::<f32>(),
2158 (-1.0f32).exp_m1(),
2159 abs <= 1e-5
2160 };
2161
2162 let x = array![1.0];
2163 assert_float_eq! {
2164 expm1(&x).unwrap().item::<f32>(),
2165 1.0f32.exp_m1(),
2166 abs <= 1e-5
2167 };
2168
2169 let x = array![1];
2171 assert_eq!(expm1(&x).unwrap().dtype(), Dtype::Float32);
2172 assert_float_eq! {
2173 expm1(&x).unwrap().item::<f32>(),
2174 1.0f32.exp_m1(),
2175 abs <= 1e-5
2176 };
2177 }
2178
2179 #[test]
2180 fn test_unary_sin() {
2181 let x = array![0.0];
2182 assert_eq!(sin(&x).unwrap().item::<f32>(), 0.0);
2183
2184 let x = array![std::f32::consts::PI / 2.0];
2185 assert_float_eq! {
2186 sin(&x).unwrap().item::<f32>(),
2187 (std::f32::consts::PI / 2.0f32).sin(),
2188 abs <= 1e-5
2189 };
2190
2191 assert_eq!(sin(array!()).unwrap(), array!());
2192
2193 let x = array![0];
2195 assert_eq!(x.dtype(), Dtype::Int32);
2196 assert_float_eq! {
2197 sin(&x).unwrap().item::<f32>(),
2198 0.0f32.sin(),
2199 abs <= 1e-5
2200 };
2201
2202 let x = broadcast_to(&array!(1.0), &[2, 2, 2]).unwrap();
2204 let res = sin(&x).unwrap();
2205 let expected = Array::full::<f32>(&[2, 2, 2], array!(1.0f32.sin())).unwrap();
2206 assert!(all_close(&res, &expected, None, None, None)
2207 .unwrap()
2208 .item::<bool>());
2209
2210 let data = Array::from_slice(&[0.0, 1.0, 2.0, 3.0], &[2, 2]);
2211 let x = split_equal(&data, 2, 1).unwrap();
2212 let expected = Array::from_slice(&[0.0f32.sin(), 2.0f32.sin()], &[2, 1]);
2213 assert!(all_close(sin(&x[0]).unwrap(), &expected, None, None, None)
2214 .unwrap()
2215 .item::<bool>());
2216 }
2217
2218 #[test]
2219 fn test_unary_cos() {
2220 let x = array![0.0];
2221 assert_float_eq! {
2222 cos(&x).unwrap().item::<f32>(),
2223 0.0f32.cos(),
2224 abs <= 1e-5
2225 };
2226
2227 let x = array![std::f32::consts::PI / 2.0];
2228 assert_float_eq! {
2229 cos(&x).unwrap().item::<f32>(),
2230 (std::f32::consts::PI / 2.0f32).cos(),
2231 abs <= 1e-5
2232 };
2233
2234 assert_eq!(cos(array!()).unwrap(), array!());
2235
2236 let x = array![0];
2238 assert_eq!(x.dtype(), Dtype::Int32);
2239 assert_float_eq! {
2240 cos(&x).unwrap().item::<f32>(),
2241 0.0f32.cos(),
2242 abs <= 1e-5
2243 };
2244
2245 let x = broadcast_to(&array!(1.0), &[2, 2, 2]).unwrap();
2247 let res = cos(&x).unwrap();
2248 let expected = Array::full::<f32>(&[2, 2, 2], array!(1.0f32.cos())).unwrap();
2249 assert!(all_close(&res, &expected, None, None, None)
2250 .unwrap()
2251 .item::<bool>());
2252
2253 let data = Array::from_slice(&[0.0, 1.0, 2.0, 3.0], &[2, 2]);
2254 let x = split_equal(&data, 2, 1).unwrap();
2255 let expected = Array::from_slice(&[0.0f32.cos(), 2.0f32.cos()], &[2, 1]);
2256 assert!(all_close(cos(&x[0]).unwrap(), &expected, None, None, None)
2257 .unwrap()
2258 .item::<bool>());
2259 }
2260
2261 #[test]
2262 fn test_unary_degrees() {
2263 let x = array![0.0];
2264 assert_eq!(degrees(&x).unwrap().item::<f32>(), 0.0);
2265
2266 let x = array![std::f32::consts::PI / 2.0];
2267 assert_eq!(degrees(&x).unwrap().item::<f32>(), 90.0);
2268
2269 assert_eq!(degrees(array!()).unwrap(), array!());
2270
2271 let x = array![0];
2273 assert_eq!(x.dtype(), Dtype::Int32);
2274 assert_eq!(degrees(&x).unwrap().item::<f32>(), 0.0);
2275
2276 let x = broadcast_to(&array!(std::f32::consts::PI / 2.0), &[2, 2, 2]).unwrap();
2278 let res = degrees(&x).unwrap();
2279 let expected = Array::full::<f32>(&[2, 2, 2], array!(90.0)).unwrap();
2280 assert!(all_close(&res, &expected, None, None, None)
2281 .unwrap()
2282 .item::<bool>());
2283
2284 let angles = Array::from_slice(&[0.0, PI / 2.0, PI, 1.5 * PI], &[2, 2]);
2285 let x = split_equal(&angles, 2, 1).unwrap();
2286 let expected = Array::from_slice(&[0.0, 180.0], &[2, 1]);
2287 assert!(
2288 all_close(degrees(&x[0]).unwrap(), &expected, None, None, None)
2289 .unwrap()
2290 .item::<bool>()
2291 );
2292 }
2293
2294 #[test]
2295 fn test_unary_radians() {
2296 let x = array![0.0];
2297 assert_eq!(radians(&x).unwrap().item::<f32>(), 0.0);
2298
2299 let x = array![90.0];
2300 assert_eq!(
2301 radians(&x).unwrap().item::<f32>(),
2302 std::f32::consts::PI / 2.0
2303 );
2304
2305 assert_eq!(radians(array!()).unwrap(), array!());
2306
2307 let x = array![90];
2309 assert_eq!(x.dtype(), Dtype::Int32);
2310 assert_eq!(
2311 radians(&x).unwrap().item::<f32>(),
2312 std::f32::consts::PI / 2.0
2313 );
2314
2315 let x = broadcast_to(&array!(90.0), &[2, 2, 2]).unwrap();
2317 let res = radians(&x).unwrap();
2318 let expected = Array::full::<f32>(&[2, 2, 2], array!(std::f32::consts::PI / 2.0)).unwrap();
2319 assert!(all_close(&res, &expected, None, None, None)
2320 .unwrap()
2321 .item::<bool>());
2322
2323 let angles = Array::from_slice(&[0.0, 90.0, 180.0, 270.0], &[2, 2]);
2324 let x = split_equal(&angles, 2, 1).unwrap();
2325 let expected = Array::from_slice(&[0.0, PI], &[2, 1]);
2326 assert!(
2327 all_close(radians(&x[0]).unwrap(), &expected, None, None, None)
2328 .unwrap()
2329 .item::<bool>()
2330 );
2331 }
2332
2333 #[test]
2334 fn test_unary_log() {
2335 let x = array![0.0];
2336 assert_eq!(log(&x).unwrap().item::<f32>(), NEG_INF);
2337
2338 let x = array![1.0];
2339 assert_eq!(log(&x).unwrap().item::<f32>(), 0.0);
2340
2341 let x = array![1];
2343 assert_eq!(log(&x).unwrap().dtype(), Dtype::Float32);
2344 assert_eq!(log(&x).unwrap().item::<f32>(), 0.0);
2345
2346 let x = broadcast_to(&array!(1.0), &[2, 2, 2]).unwrap();
2348 let res = log(&x).unwrap();
2349 let expected = Array::full::<f32>(&[2, 2, 2], array!(0.0)).unwrap();
2350 assert!(all_close(&res, &expected, None, None, None)
2351 .unwrap()
2352 .item::<bool>());
2353
2354 let data = Array::from_slice(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
2355 let x = split_equal(&data, 2, 1).unwrap();
2356 let expected = Array::from_slice(&[1.0f32.ln(), 3.0f32.ln()], &[2, 1]);
2357 assert!(all_close(log(&x[0]).unwrap(), &expected, None, None, None)
2358 .unwrap()
2359 .item::<bool>());
2360 }
2361
2362 #[test]
2363 fn test_unary_log2() {
2364 let x = array![0.0];
2365 assert_eq!(log2(&x).unwrap().item::<f32>(), NEG_INF);
2366
2367 let x = array![1.0];
2368 assert_eq!(log2(&x).unwrap().item::<f32>(), 0.0);
2369
2370 let x = array![1024.0];
2371 assert_eq!(log2(&x).unwrap().item::<f32>(), 10.0);
2372 }
2373
2374 #[test]
2375 fn test_unary_log10() {
2376 let x = array![0.0];
2377 assert_eq!(log10(&x).unwrap().item::<f32>(), NEG_INF);
2378
2379 let x = array![1.0];
2380 assert_eq!(log10(&x).unwrap().item::<f32>(), 0.0);
2381
2382 let x = array![1000.0];
2383 assert_eq!(log10(&x).unwrap().item::<f32>(), 3.0);
2384 }
2385
2386 #[test]
2387 fn test_unary_log1p() {
2388 let x = array![-1.0];
2389 assert_float_eq! {
2390 log1p(&x).unwrap().item::<f32>(),
2391 (-1.0f32).ln_1p(),
2392 abs <= 1e-5
2393 };
2394
2395 let x = array![1.0];
2396 assert_float_eq! {
2397 log1p(&x).unwrap().item::<f32>(),
2398 1.0f32.ln_1p(),
2399 abs <= 1e-5
2400 };
2401
2402 let x = array![1];
2404 assert_eq!(log1p(&x).unwrap().dtype(), Dtype::Float32);
2405 assert_float_eq! {
2406 log1p(&x).unwrap().item::<f32>(),
2407 1.0f32.ln_1p(),
2408 abs <= 1e-5
2409 };
2410
2411 let x = broadcast_to(&array!(1.0), &[2, 2, 2]).unwrap();
2413 let res = log1p(&x).unwrap();
2414 let expected = Array::full::<f32>(&[2, 2, 2], array!(1.0f32.ln_1p())).unwrap();
2415 assert!(all_close(&res, &expected, None, None, None)
2416 .unwrap()
2417 .item::<bool>());
2418
2419 let data = Array::from_slice(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
2420 let x = split_equal(&data, 2, 1).unwrap();
2421 let expected = Array::from_slice(&[1.0f32.ln_1p(), 3.0f32.ln_1p()], &[2, 1]);
2422 assert!(
2423 all_close(log1p(&x[0]).unwrap(), &expected, None, None, None)
2424 .unwrap()
2425 .item::<bool>()
2426 );
2427 }
2428
2429 #[test]
2430 fn test_unary_sigmoid() {
2431 let x = array![0.0];
2432 assert_float_eq! {
2433 sigmoid(&x).unwrap().item::<f32>(),
2434 0.5,
2435 abs <= 1e-5
2436 };
2437
2438 let x = array![0];
2440 assert_eq!(sigmoid(&x).unwrap().dtype(), Dtype::Float32);
2441 assert_float_eq! {
2442 sigmoid(&x).unwrap().item::<f32>(),
2443 0.5,
2444 abs <= 1e-5
2445 };
2446
2447 let inf = f32::INFINITY;
2448 let x = array![inf];
2449 assert_eq!(sigmoid(&x).unwrap().item::<f32>(), 1.0);
2450
2451 let x = array![-inf];
2452 assert_eq!(sigmoid(&x).unwrap().item::<f32>(), 0.0);
2453 }
2454
2455 #[test]
2456 fn test_unary_square() {
2457 let x = array![3.0];
2458 assert_eq!(square(&x).unwrap().item::<f32>(), 9.0);
2459
2460 let x = array![2];
2461 assert_eq!(square(&x).unwrap().item::<i32>(), 4);
2462
2463 let x = Array::full::<f32>(&[3, 3], array!(2.0)).unwrap();
2464 assert!(all_close(
2465 square(&x).unwrap(),
2466 Array::full::<f32>(&[3, 3], array!(4.0)).unwrap(),
2467 None,
2468 None,
2469 None
2470 )
2471 .unwrap()
2472 .item::<bool>());
2473 }
2474
2475 #[test]
2476 fn test_unary_sqrt_rsqrt() {
2477 let x = array![4.0];
2478 assert_eq!(sqrt(&x).unwrap().item::<f32>(), 2.0);
2479 assert_eq!(rsqrt(&x).unwrap().item::<f32>(), 0.5);
2480
2481 let x = Array::full::<f32>(&[3, 3], array!(9.0)).unwrap();
2482 assert!(all_close(
2483 sqrt(&x).unwrap(),
2484 Array::full::<f32>(&[3, 3], array!(3.0)).unwrap(),
2485 None,
2486 None,
2487 None
2488 )
2489 .unwrap()
2490 .item::<bool>());
2491
2492 let x = array![4i32];
2493 assert_eq!(sqrt(&x).unwrap().item::<f32>(), 2.0);
2494 assert_eq!(rsqrt(&x).unwrap().item::<f32>(), 0.5);
2495 }
2496
2497 #[test]
2498 fn test_unary_reciprocal() {
2499 let x = array![8.0];
2500 assert_eq!(reciprocal(&x).unwrap().item::<f32>(), 0.125);
2501
2502 let x = array![2];
2503 let out = reciprocal(&x).unwrap();
2504 assert_eq!(out.dtype(), Dtype::Float32);
2505 assert_eq!(out.item::<f32>(), 0.5);
2506
2507 let x = Array::full::<f32>(&[3, 3], array!(2.0)).unwrap();
2508 assert!(all_close(
2509 reciprocal(&x).unwrap(),
2510 Array::full::<f32>(&[3, 3], array!(0.5)).unwrap(),
2511 None,
2512 None,
2513 None
2514 )
2515 .unwrap()
2516 .item::<bool>());
2517 }
2518
2519 #[test]
2520 fn test_binary_add() {
2521 let x = array![1.0];
2522 let y = array![1.0];
2523 let z = add(&x, &y).unwrap();
2524 assert_eq!(z.item::<f32>(), 2.0);
2525
2526 let z = &x + y;
2527 assert_eq!(z.item::<f32>(), 2.0);
2528
2529 let z = add(z, &x).unwrap();
2530 assert_eq!(z.item::<f32>(), 3.0);
2531
2532 let mut out = x.deep_clone();
2534 for _ in 0..10 {
2535 out = add(&out, &x).unwrap();
2536 }
2537 assert_eq!(out.item::<f32>(), 11.0);
2538
2539 let x = array!([1.0, 2.0, 3.0]);
2541 let y = array!([1.0, 2.0, 3.0]);
2542 let z = add(&x, &y).unwrap();
2543 assert_eq!(z.shape(), &[3]);
2544 assert_eq!(z, array!([2.0, 4.0, 6.0]));
2545
2546 let x = array!([1.0, 2.0, 3.0]);
2548 let y = &x + 2.0;
2549 assert_eq!(y.dtype(), Dtype::Float32);
2550 assert_eq!(y, array!([3.0, 4.0, 5.0]));
2551 let y = &x + 2.0;
2552 assert_eq!(y.dtype(), Dtype::Float32);
2553 assert_eq!(y, array!([3.0, 4.0, 5.0]));
2554
2555 let y = x + 2;
2557 assert_eq!(y.dtype(), Dtype::Float32);
2558
2559 let y = array!([1, 2, 3]) + 2.0;
2560 assert_eq!(y.dtype(), Dtype::Float32);
2561 assert_eq!(y, array!([3.0, 4.0, 5.0]));
2563
2564 let x = broadcast_to(&array!(1.0), &[10]).unwrap();
2566 let y = broadcast_to(&array!(2.0), &[10]).unwrap();
2567 let z = add(&x, &y).unwrap();
2568 assert_eq!(z, full::<f32>(&[10], array!(3.0)).unwrap());
2569
2570 let x = Array::from_slice(&[1.0, 2.0], &[1, 2]);
2571 let y = Array::from_slice(&[1.0, 2.0], &[2, 1]);
2572 let z = add(&x, &y).unwrap();
2573 assert_eq!(z.shape(), &[2, 2]);
2574 assert_eq!(z, Array::from_slice(&[2.0, 3.0, 3.0, 4.0], &[2, 2]));
2575
2576 let x = ones::<f32>(&[3, 2, 1]).unwrap();
2577 let z = x + 2.0;
2578 assert_eq!(z.shape(), &[3, 2, 1]);
2579 let expected = Array::from_slice(&[3.0, 3.0, 3.0, 3.0, 3.0, 3.0], &[3, 2, 1]);
2580 assert_eq!(z, expected);
2581
2582 let x = array!();
2584 let y = array!();
2585 let z = x + y;
2586 z.eval().unwrap();
2587 assert_eq!(z.size(), 0);
2588 assert_eq!(z.shape(), &[0]);
2589 }
2590
2591 #[test]
2592 fn test_binary_sub() {
2593 let x = array!([3.0, 2.0, 1.0]);
2594 let y = array!([1.0, 1.0, 1.0]);
2595 assert_eq!(x - y, array!([2.0, 1.0, 0.0]));
2596 }
2597
2598 #[test]
2599 fn test_binary_mul() {
2600 let x = array!([1.0, 2.0, 3.0]);
2601 let y = array!([2.0, 2.0, 2.0]);
2602 assert_eq!(x * y, array!([2.0, 4.0, 6.0]));
2603 }
2604
2605 #[test]
2606 fn test_binary_div() {
2607 let x = array![1.0];
2608 let y = array![1.0];
2609 assert_eq!(divide(&x, &y).unwrap().item::<f32>(), 1.0);
2610
2611 let x = array![1.0];
2612 let y = array![0.5];
2613 assert_eq!(divide(&x, &y).unwrap().item::<f32>(), 2.0);
2614
2615 let x = array![1.0];
2616 let y = array![4.0];
2617 assert_eq!(divide(&x, &y).unwrap().item::<f32>(), 0.25);
2618
2619 let x = array![true];
2620 let y = array![true];
2621 assert_eq!(divide(&x, &y).unwrap().item::<f32>(), 1.0);
2622
2623 let x = array![false];
2624 let y = array![true];
2625 assert_eq!(divide(&x, &y).unwrap().item::<f32>(), 0.0);
2626
2627 let x = array![true];
2628 let y = array![false];
2629 assert!(divide(&x, &y).unwrap().item::<f32>().is_infinite());
2630
2631 let x = array![false];
2632 let y = array![false];
2633 assert!(divide(&x, &y).unwrap().item::<f32>().is_nan());
2634 }
2635
2636 #[test]
2637 fn test_binary_maximum_minimum() {
2638 let x = array![1.0];
2639 let y = array![0.0];
2640 assert_eq!(maximum(&x, &y).unwrap().item::<f32>(), 1.0);
2641 assert_eq!(minimum(&x, &y).unwrap().item::<f32>(), 0.0);
2642
2643 let y = array![2.0];
2644 assert_eq!(maximum(&x, &y).unwrap().item::<f32>(), 2.0);
2645 assert_eq!(minimum(&x, &y).unwrap().item::<f32>(), 1.0);
2646 }
2647
2648 #[test]
2649 fn test_binary_logaddexp() {
2650 let x = array![0.0];
2651 let y = array![0.0];
2652 assert_float_eq! {
2653 log_add_exp(&x, &y).unwrap().item::<f32>(),
2654 2.0f32.ln(),
2655 abs <= 1e-5
2656 };
2657
2658 let x = array!([0u32]);
2659 let y = array!([10000u32]);
2660 assert_eq!(log_add_exp(&x, &y).unwrap().item::<f32>(), 10000.0);
2661
2662 let x = array![f32::INFINITY];
2663 let y = array![3.0];
2664 assert_eq!(log_add_exp(&x, &y).unwrap().item::<f32>(), f32::INFINITY);
2665
2666 let x = array![f32::NEG_INFINITY];
2667 let y = array![3.0];
2668 assert_eq!(log_add_exp(&x, &y).unwrap().item::<f32>(), 3.0);
2669
2670 let x = array![f32::NEG_INFINITY];
2671 let y = array![f32::NEG_INFINITY];
2672 assert_eq!(
2673 log_add_exp(&x, &y).unwrap().item::<f32>(),
2674 f32::NEG_INFINITY
2675 );
2676
2677 let x = array![f32::INFINITY];
2678 let y = array![f32::INFINITY];
2679 assert_eq!(log_add_exp(&x, &y).unwrap().item::<f32>(), f32::INFINITY);
2680
2681 let x = array![f32::NEG_INFINITY];
2682 let y = array![f32::INFINITY];
2683 assert_eq!(log_add_exp(&x, &y).unwrap().item::<f32>(), f32::INFINITY);
2684 }
2685
2686 #[test]
2687 fn test_basic_clip() {
2688 let a = array!([1.0, 4.0, 3.0, 8.0, 5.0]);
2689 let expected = array!([2.0, 4.0, 3.0, 6.0, 5.0]);
2690 let clipped = clip(&a, (array!(2.0), array!(6.0))).unwrap();
2691 assert_eq!(clipped, expected);
2692
2693 let clipped = clip(&a, (2.0, 6.0)).unwrap();
2695 assert_eq!(clipped, expected);
2696 }
2697
2698 #[test]
2699 fn test_clip_with_only_min() {
2700 let a = array!([-1.0, 1.0, 0.0, 5.0]);
2701 let expected = array!([0.0, 1.0, 0.0, 5.0]);
2702 let clipped = clip(&a, (array!(0.0), ())).unwrap();
2703 assert_eq!(clipped, expected);
2704
2705 let clipped = clip(&a, (0.0, ())).unwrap();
2707 assert_eq!(clipped, expected);
2708 }
2709
2710 #[test]
2711 fn test_clip_with_only_max() {
2712 let a = array!([2.0, 3.0, 4.0, 5.0]);
2713 let expected = array!([2.0, 3.0, 4.0, 4.0]);
2714 let clipped = clip(&a, ((), array!(4.0))).unwrap();
2715 assert_eq!(clipped, expected);
2716
2717 let clipped = clip(&a, ((), 4.0)).unwrap();
2719 assert_eq!(clipped, expected);
2720 }
2721
2722 #[test]
2723 fn test_tensordot() {
2724 let x = reshape(arange::<_, f32>(None, 60.0, None).unwrap(), &[3, 4, 5]).unwrap();
2725 let y = reshape(arange::<_, f32>(None, 24.0, None).unwrap(), &[4, 3, 2]).unwrap();
2726 let z = tensordot(&x, &y, (&[1i32, 0], &[0i32, 1])).unwrap();
2727 let expected = Array::from_slice(
2728 &[4400, 4730, 4532, 4874, 4664, 5018, 4796, 5162, 4928, 5306],
2729 &[5, 2],
2730 );
2731 assert_eq!(z, expected);
2732
2733 let x = reshape(arange::<_, f32>(None, 360.0, None).unwrap(), &[3, 4, 5, 6]).unwrap();
2734 let y = reshape(arange::<_, f32>(None, 360.0, None).unwrap(), &[6, 4, 5, 3]).unwrap();
2735 assert!(tensordot(&x, &y, (&[2, 1, 3], &[1, 2, 0])).is_err());
2736
2737 let x = reshape(arange::<_, f32>(None, 60.0, None).unwrap(), &[3, 4, 5]).unwrap();
2738 let y = reshape(arange::<_, f32>(None, 120.0, None).unwrap(), &[4, 5, 6]).unwrap();
2739
2740 let z = tensordot(&x, &y, 2).unwrap();
2741 let expected = Array::from_slice(
2742 &[
2743 14820.0, 15010.0, 15200.0, 15390.0, 15580.0, 15770.0, 37620.0, 38210.0, 38800.0,
2744 39390.0, 39980.0, 40570.0, 60420.0, 61410.0, 62400.0, 63390.0, 64380.0, 65370.0,
2745 ],
2746 &[3, 6],
2747 );
2748 assert_eq!(z, expected);
2749 }
2750
2751 #[test]
2752 fn test_outer() {
2753 let x = arange::<_, f32>(1.0, 5.0, None).unwrap();
2754 let y = arange::<_, f32>(1.0, 4.0, None).unwrap();
2755 let z = outer(&x, &y).unwrap();
2756 let expected = Array::from_slice(
2757 &[1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 3.0, 6.0, 9.0, 4.0, 8.0, 12.0],
2758 &[4, 3],
2759 );
2760 assert_eq!(z, expected);
2761
2762 let x = ones::<f32>(&[5]).unwrap();
2763 let y = linspace::<_, f32>(-2.0, 2.0, 5).unwrap();
2764 let z = outer(&x, &y).unwrap();
2765 let expected = Array::from_slice(
2766 &[
2767 -2.0, -1.0, 0.0, 1.0, 2.0, -2.0, -1.0, 0.0, 1.0, 2.0, -2.0, -1.0, 0.0, 1.0, 2.0,
2768 -2.0, -1.0, 0.0, 1.0, 2.0, -2.0, -1.0, 0.0, 1.0, 2.0,
2769 ],
2770 &[5, 5],
2771 );
2772 assert_eq!(z, expected);
2773 }
2774
2775 #[test]
2776 fn test_inner() {
2777 let x = reshape(arange::<_, f32>(None, 5.0, None).unwrap(), &[1, 5]).unwrap();
2778 let y = reshape(arange::<_, f32>(None, 6.0, None).unwrap(), &[2, 3]).unwrap();
2779 assert!(inner(&x, &y).is_err());
2780
2781 let x = array!([1.0, 2.0, 3.0]);
2782 let y = array!([0.0, 1.0, 0.0]);
2783 let z = inner(&x, &y).unwrap();
2784 assert_eq!(z.item::<f32>(), 2.0);
2785
2786 let x = reshape(arange::<_, f32>(None, 24.0, None).unwrap(), &[2, 3, 4]).unwrap();
2787 let y = arange::<_, f32>(None, 4.0, None).unwrap();
2788 let z = inner(&x, &y).unwrap();
2789 let expected = Array::from_slice(&[14.0, 38.0, 62.0, 86.0, 110.0, 134.0], &[2, 3]);
2790 assert_eq!(z, expected);
2791
2792 let x = reshape(arange::<_, f32>(None, 2.0, None).unwrap(), &[1, 1, 2]).unwrap();
2793 let y = reshape(arange::<_, f32>(None, 6.0, None).unwrap(), &[3, 2]).unwrap();
2794 let z = inner(&x, &y).unwrap();
2795 let expected = Array::from_slice(&[1.0, 3.0, 5.0], &[1, 1, 3]);
2796 assert_eq!(z, expected);
2797
2798 let x = eye::<f32>(2, None, None).unwrap();
2799 let y = Array::from_f32(7.0);
2800 let z = inner(&x, &y).unwrap();
2801 let expected = Array::from_slice(&[7.0, 0.0, 0.0, 7.0], &[2, 2]);
2802 assert_eq!(z, expected);
2803 }
2804
2805 #[test]
2806 fn test_divmod() {
2807 let x = array!([1.0, 2.0, 3.0]);
2808 let y = array!([1.0, 1.0, 1.0]);
2809 let out = divmod(&x, &y).unwrap();
2810 assert_eq!(out.0, array!([1.0, 2.0, 3.0]));
2811 assert_eq!(out.1, array!([0.0, 0.0, 0.0]));
2812
2813 let x = array!([5.0, 6.0, 7.0]);
2814 let y = array!([2.0, 2.0, 2.0]);
2815 let out = divmod(&x, &y).unwrap();
2816 assert_eq!(out.0, array!([2.0, 3.0, 3.0]));
2817 assert_eq!(out.1, array!([1.0, 0.0, 1.0]));
2818
2819 let x = array!([5.0, 6.0, 7.0]);
2820 let y = array!([2.0, 2.0, 2.0]);
2821 let out = divmod(&x, &y).unwrap();
2822 assert_eq!(out.0, array!([2.0, 3.0, 3.0]));
2823 assert_eq!(out.1, array!([1.0, 0.0, 1.0]));
2824
2825 let x = array![complex64::new(1.0, 0.0)];
2826 let y = array![complex64::new(2.0, 0.0)];
2827 assert!(divmod(&x, &y).is_err());
2828
2829 let x = array![1.0];
2831 let y = array![2.0];
2832 let (quo, rem) = divmod(&x, &y).unwrap();
2833 eval([&quo, &rem]).unwrap();
2834 assert_eq!(quo.item::<f32>(), 0.0);
2835 assert_eq!(rem.item::<f32>(), 1.0);
2836
2837 let x = array![1.0];
2839 let y = array![2.0];
2840 let (quo, rem) = divmod(&x, &y).unwrap();
2841 let z = quo + rem;
2842 assert_eq!(z.item::<f32>(), 1.0);
2843
2844 let mut out_holder = {
2846 let (quo, _) = divmod(&x, &y).unwrap();
2847 vec![quo]
2848 };
2849 eval(out_holder.iter()).unwrap();
2850 assert_eq!(out_holder[0].item::<f32>(), 0.0);
2851
2852 out_holder.clear();
2854 let out_holder = {
2855 let (_, rem) = divmod(&x, &y).unwrap();
2856 vec![rem]
2857 };
2858 eval(out_holder.iter()).unwrap();
2859 assert_eq!(out_holder[0].item::<f32>(), 1.0);
2860 }
2861}