1use crate::array::Array;
2use crate::error::Result;
3use crate::sealed::Sealed;
4
5use crate::utils::guard::Guarded;
6use crate::utils::{IntoOption, ScalarOrArray, VectorArray};
7use crate::Stream;
8use mlx_internal_macros::{default_device, generate_macro};
9use smallvec::SmallVec;
10
11impl Array {
12 #[default_device]
25 pub fn abs_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
26 Array::try_from_op(|res| unsafe {
27 mlx_sys::mlx_abs(res, self.as_ptr(), stream.as_ref().as_ptr())
28 })
29 }
30
31 #[default_device]
51 pub fn add_device(
52 &self,
53 other: impl AsRef<Array>,
54 stream: impl AsRef<Stream>,
55 ) -> Result<Array> {
56 Array::try_from_op(|res| unsafe {
57 mlx_sys::mlx_add(
58 res,
59 self.as_ptr(),
60 other.as_ref().as_ptr(),
61 stream.as_ref().as_ptr(),
62 )
63 })
64 }
65
66 #[default_device]
86 pub fn subtract_device(
87 &self,
88 other: impl AsRef<Array>,
89 stream: impl AsRef<Stream>,
90 ) -> Result<Array> {
91 Array::try_from_op(|res| unsafe {
92 mlx_sys::mlx_subtract(
93 res,
94 self.as_ptr(),
95 other.as_ref().as_ptr(),
96 stream.as_ref().as_ptr(),
97 )
98 })
99 }
100
101 #[default_device]
116 pub fn negative_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
117 Array::try_from_op(|res| unsafe {
118 mlx_sys::mlx_negative(res, self.as_ptr(), stream.as_ref().as_ptr())
119 })
120 }
121
122 #[default_device]
138 pub fn multiply_device(
139 &self,
140 other: impl AsRef<Array>,
141 stream: impl AsRef<Stream>,
142 ) -> Result<Array> {
143 Array::try_from_op(|res| unsafe {
144 mlx_sys::mlx_multiply(
145 res,
146 self.as_ptr(),
147 other.as_ref().as_ptr(),
148 stream.as_ref().as_ptr(),
149 )
150 })
151 }
152
153 #[default_device]
163 pub fn nan_to_num_device(
164 &self,
165 nan: impl IntoOption<f32>,
166 pos_inf: impl IntoOption<f32>,
167 neg_inf: impl IntoOption<f32>,
168 stream: impl AsRef<Stream>,
169 ) -> Result<Array> {
170 let pos_inf = pos_inf.into_option();
171 let neg_inf = neg_inf.into_option();
172
173 let pos_inf = mlx_sys::mlx_optional_float {
174 value: pos_inf.unwrap_or(0.0),
175 has_value: pos_inf.is_some(),
176 };
177 let neg_inf = mlx_sys::mlx_optional_float {
178 value: neg_inf.unwrap_or(0.0),
179 has_value: neg_inf.is_some(),
180 };
181
182 Array::try_from_op(|res| unsafe {
183 mlx_sys::mlx_nan_to_num(
184 res,
185 self.as_ptr(),
186 nan.into_option().unwrap_or(0.),
187 pos_inf,
188 neg_inf,
189 stream.as_ref().as_ptr(),
190 )
191 })
192 }
193
194 #[default_device]
214 pub fn divide_device(
215 &self,
216 other: impl AsRef<Array>,
217 stream: impl AsRef<Stream>,
218 ) -> Result<Array> {
219 Array::try_from_op(|res| unsafe {
220 mlx_sys::mlx_divide(
221 res,
222 self.as_ptr(),
223 other.as_ref().as_ptr(),
224 stream.as_ref().as_ptr(),
225 )
226 })
227 }
228
229 #[default_device]
249 pub fn power_device(
250 &self,
251 other: impl AsRef<Array>,
252 stream: impl AsRef<Stream>,
253 ) -> Result<Array> {
254 Array::try_from_op(|res| unsafe {
255 mlx_sys::mlx_power(
256 res,
257 self.as_ptr(),
258 other.as_ref().as_ptr(),
259 stream.as_ref().as_ptr(),
260 )
261 })
262 }
263
264 #[default_device]
284 pub fn remainder_device(
285 &self,
286 other: impl AsRef<Array>,
287 stream: impl AsRef<Stream>,
288 ) -> Result<Array> {
289 Array::try_from_op(|res| unsafe {
290 mlx_sys::mlx_remainder(
291 res,
292 self.as_ptr(),
293 other.as_ref().as_ptr(),
294 stream.as_ref().as_ptr(),
295 )
296 })
297 }
298
299 #[default_device]
312 pub fn sqrt_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
313 Array::try_from_op(|res| unsafe {
314 mlx_sys::mlx_sqrt(res, self.as_ptr(), stream.as_ref().as_ptr())
315 })
316 }
317
318 #[default_device]
331 pub fn cos_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
332 Array::try_from_op(|res| unsafe {
333 mlx_sys::mlx_cos(res, self.as_ptr(), stream.as_ref().as_ptr())
334 })
335 }
336
337 #[default_device]
352 pub fn exp_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
353 Array::try_from_op(|res| unsafe {
354 mlx_sys::mlx_exp(res, self.as_ptr(), stream.as_ref().as_ptr())
355 })
356 }
357
358 #[default_device]
371 pub fn floor_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
372 Array::try_from_op(|res| unsafe {
373 mlx_sys::mlx_floor(res, self.as_ptr(), stream.as_ref().as_ptr())
374 })
375 }
376
377 #[default_device]
401 pub fn floor_divide_device(
402 &self,
403 other: impl AsRef<Array>,
404 stream: impl AsRef<Stream>,
405 ) -> Result<Array> {
406 Array::try_from_op(|res| unsafe {
407 mlx_sys::mlx_floor_divide(
408 res,
409 self.as_ptr(),
410 other.as_ref().as_ptr(),
411 stream.as_ref().as_ptr(),
412 )
413 })
414 }
415
416 #[default_device]
421 pub fn is_nan_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
422 Array::try_from_op(|res| unsafe {
423 mlx_sys::mlx_isnan(res, self.as_ptr(), stream.as_ref().as_ptr())
424 })
425 }
426
427 #[default_device]
432 pub fn is_inf_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
433 Array::try_from_op(|res| unsafe {
434 mlx_sys::mlx_isinf(res, self.as_ptr(), stream.as_ref().as_ptr())
435 })
436 }
437
438 #[default_device]
443 pub fn is_finite_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
444 Array::try_from_op(|res| unsafe {
445 mlx_sys::mlx_isfinite(res, self.as_ptr(), stream.as_ref().as_ptr())
446 })
447 }
448
449 #[default_device]
454 pub fn is_neg_inf_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
455 Array::try_from_op(|res| unsafe {
456 mlx_sys::mlx_isneginf(res, self.as_ptr(), stream.as_ref().as_ptr())
457 })
458 }
459
460 #[default_device]
465 pub fn is_pos_inf_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
466 Array::try_from_op(|res| unsafe {
467 mlx_sys::mlx_isposinf(res, self.as_ptr(), stream.as_ref().as_ptr())
468 })
469 }
470
471 #[default_device]
484 pub fn log_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
485 Array::try_from_op(|res| unsafe {
486 mlx_sys::mlx_log(res, self.as_ptr(), stream.as_ref().as_ptr())
487 })
488 }
489
490 #[default_device]
503 pub fn log2_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
504 Array::try_from_op(|res| unsafe {
505 mlx_sys::mlx_log2(res, self.as_ptr(), stream.as_ref().as_ptr())
506 })
507 }
508
509 #[default_device]
522 pub fn log10_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
523 Array::try_from_op(|res| unsafe {
524 mlx_sys::mlx_log10(res, self.as_ptr(), stream.as_ref().as_ptr())
525 })
526 }
527
528 #[default_device]
541 pub fn log1p_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
542 Array::try_from_op(|res| unsafe {
543 mlx_sys::mlx_log1p(res, self.as_ptr(), stream.as_ref().as_ptr())
544 })
545 }
546
547 #[default_device]
577 pub fn matmul_device(
578 &self,
579 other: impl AsRef<Array>,
580 stream: impl AsRef<Stream>,
581 ) -> Result<Array> {
582 Array::try_from_op(|res| unsafe {
583 mlx_sys::mlx_matmul(
584 res,
585 self.as_ptr(),
586 other.as_ref().as_ptr(),
587 stream.as_ref().as_ptr(),
588 )
589 })
590 }
591
592 #[default_device]
605 pub fn reciprocal_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
606 Array::try_from_op(|res| unsafe {
607 mlx_sys::mlx_reciprocal(res, self.as_ptr(), stream.as_ref().as_ptr())
608 })
609 }
610
611 #[default_device]
617 pub fn round_device(
618 &self,
619 decimals: impl Into<Option<i32>>,
620 stream: impl AsRef<Stream>,
621 ) -> Result<Array> {
622 Array::try_from_op(|res| unsafe {
623 mlx_sys::mlx_round(
624 res,
625 self.as_ptr(),
626 decimals.into().unwrap_or(0),
627 stream.as_ref().as_ptr(),
628 )
629 })
630 }
631
632 #[default_device]
634 pub fn rsqrt_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
635 Array::try_from_op(|res| unsafe {
636 mlx_sys::mlx_rsqrt(res, self.as_ptr(), stream.as_ref().as_ptr())
637 })
638 }
639
640 #[default_device]
642 pub fn sin_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
643 Array::try_from_op(|res| unsafe {
644 mlx_sys::mlx_sin(res, self.as_ptr(), stream.as_ref().as_ptr())
645 })
646 }
647
648 #[default_device]
650 pub fn square_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
651 Array::try_from_op(|res| unsafe {
652 mlx_sys::mlx_square(res, self.as_ptr(), stream.as_ref().as_ptr())
653 })
654 }
655}
656
657#[generate_macro]
668#[default_device]
669pub fn abs_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
670 a.as_ref().abs_device(stream)
671}
672
673#[generate_macro]
675#[default_device]
676pub fn acos_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
677 Array::try_from_op(|res| unsafe {
678 mlx_sys::mlx_arccos(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
679 })
680}
681
682#[generate_macro]
684#[default_device]
685pub fn acosh_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
686 Array::try_from_op(|res| unsafe {
687 mlx_sys::mlx_arccosh(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
688 })
689}
690
691#[generate_macro]
693#[default_device]
694pub fn add_device(
695 lhs: impl AsRef<Array>,
696 rhs: impl AsRef<Array>,
697 #[optional] stream: impl AsRef<Stream>,
698) -> Result<Array> {
699 lhs.as_ref().add_device(rhs, stream)
700}
701
702#[generate_macro]
704#[default_device]
705pub fn asin_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
706 Array::try_from_op(|res| unsafe {
707 mlx_sys::mlx_arcsin(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
708 })
709}
710
711#[generate_macro]
713#[default_device]
714pub fn asinh_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
715 Array::try_from_op(|res| unsafe {
716 mlx_sys::mlx_arcsinh(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
717 })
718}
719
720#[generate_macro]
722#[default_device]
723pub fn atan_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
724 Array::try_from_op(|res| unsafe {
725 mlx_sys::mlx_arctan(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
726 })
727}
728
729#[generate_macro]
731#[default_device]
732pub fn atanh_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
733 Array::try_from_op(|res| unsafe {
734 mlx_sys::mlx_arctanh(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
735 })
736}
737
738#[generate_macro]
740#[default_device]
741pub fn ceil_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
742 Array::try_from_op(|res| unsafe {
743 mlx_sys::mlx_ceil(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
744 })
745}
746
747pub trait ClipBound<'min, 'max>: Sealed {
752 fn into_min_max(
754 self,
755 ) -> (
756 Option<impl ScalarOrArray<'min>>,
757 Option<impl ScalarOrArray<'max>>,
758 );
759}
760
761impl<'min, Min> ClipBound<'min, 'min> for (Min, ())
762where
763 Min: ScalarOrArray<'min> + Sealed,
764{
765 fn into_min_max(
766 self,
767 ) -> (
768 Option<impl ScalarOrArray<'min>>,
769 Option<impl ScalarOrArray<'min>>,
770 ) {
771 (Some(self.0), Option::<Min>::None)
772 }
773}
774
775impl<'max, Max> ClipBound<'max, 'max> for ((), Max)
776where
777 Max: ScalarOrArray<'max> + Sealed,
778{
779 fn into_min_max(
780 self,
781 ) -> (
782 Option<impl ScalarOrArray<'max>>,
783 Option<impl ScalarOrArray<'max>>,
784 ) {
785 (Option::<Max>::None, Some(self.1))
786 }
787}
788
789impl<'min, 'max, Min, Max> ClipBound<'min, 'max> for (Min, Max)
790where
791 Min: ScalarOrArray<'min> + Sealed,
792 Max: ScalarOrArray<'max> + Sealed,
793{
794 fn into_min_max(
795 self,
796 ) -> (
797 Option<impl ScalarOrArray<'min>>,
798 Option<impl ScalarOrArray<'max>>,
799 ) {
800 (Some(self.0), Some(self.1))
801 }
802}
803
804#[generate_macro]
826#[default_device]
827pub fn clip_device<'min, 'max>(
828 a: impl AsRef<Array>,
829 bound: impl ClipBound<'min, 'max>,
830 #[optional] stream: impl AsRef<Stream>,
831) -> Result<Array> {
832 let (a_min, a_max) = bound.into_min_max();
833
834 let a_min = a_min.map(|min| min.into_owned_or_ref_array());
836 let a_max = a_max.map(|max| max.into_owned_or_ref_array());
837
838 unsafe {
839 let min_ptr = match &a_min {
840 Some(a_min) => a_min.as_ref().as_ptr(),
841 None => mlx_sys::mlx_array_new(),
842 };
843 let max_ptr = match &a_max {
844 Some(a_max) => a_max.as_ref().as_ptr(),
845 None => mlx_sys::mlx_array_new(),
846 };
847
848 Array::try_from_op(|res| {
849 mlx_sys::mlx_clip(
850 res,
851 a.as_ref().as_ptr(),
852 min_ptr,
853 max_ptr,
854 stream.as_ref().as_ptr(),
855 )
856 })
857 }
858}
859
860#[generate_macro]
862#[default_device]
863pub fn cos_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
864 a.as_ref().cos_device(stream)
865}
866
867#[generate_macro]
869#[default_device]
870pub fn cosh_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
871 Array::try_from_op(|res| unsafe {
872 mlx_sys::mlx_cosh(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
873 })
874}
875
876#[generate_macro]
878#[default_device]
879pub fn degrees_device(
880 a: impl AsRef<Array>,
881 #[optional] stream: impl AsRef<Stream>,
882) -> Result<Array> {
883 Array::try_from_op(|res| unsafe {
884 mlx_sys::mlx_degrees(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
885 })
886}
887
888#[generate_macro]
890#[default_device]
891pub fn divide_device(
892 a: impl AsRef<Array>,
893 b: impl AsRef<Array>,
894 #[optional] stream: impl AsRef<Stream>,
895) -> Result<Array> {
896 a.as_ref().divide_device(b, stream)
897}
898
899#[generate_macro]
906#[default_device]
907pub fn divmod_device(
908 a: impl AsRef<Array>,
909 b: impl AsRef<Array>,
910 #[optional] stream: impl AsRef<Stream>,
911) -> Result<(Array, Array)> {
912 let a_ptr = a.as_ref().as_ptr();
913 let b_ptr = b.as_ref().as_ptr();
914
915 let vec = VectorArray::try_from_op(|res| unsafe {
916 mlx_sys::mlx_divmod(res, a_ptr, b_ptr, stream.as_ref().as_ptr())
917 })?;
918
919 let vals: SmallVec<[_; 2]> = vec.try_into_values()?;
920 let mut iter = vals.into_iter();
921 let quotient = iter.next().unwrap();
922 let remainder = iter.next().unwrap();
923
924 Ok((quotient, remainder))
925}
926
927#[generate_macro]
929#[default_device]
930pub fn erf_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
931 Array::try_from_op(|res| unsafe {
932 mlx_sys::mlx_erf(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
933 })
934}
935
936#[generate_macro]
938#[default_device]
939pub fn erfinv_device(
940 a: impl AsRef<Array>,
941 #[optional] stream: impl AsRef<Stream>,
942) -> Result<Array> {
943 Array::try_from_op(|res| unsafe {
944 mlx_sys::mlx_erfinv(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
945 })
946}
947
948#[generate_macro]
950#[default_device]
951pub fn exp_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
952 a.as_ref().exp_device(stream)
953}
954
955#[generate_macro]
957#[default_device]
958pub fn expm1_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
959 Array::try_from_op(|res| unsafe {
960 mlx_sys::mlx_expm1(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
961 })
962}
963
964#[generate_macro]
966#[default_device]
967pub fn floor_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
968 a.as_ref().floor_device(stream)
969}
970
971#[generate_macro]
973#[default_device]
974pub fn floor_divide_device(
975 a: impl AsRef<Array>,
976 other: impl AsRef<Array>,
977 #[optional] stream: impl AsRef<Stream>,
978) -> Result<Array> {
979 a.as_ref().floor_divide_device(other, stream)
980}
981
982#[generate_macro]
984#[default_device]
985pub fn log_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
986 a.as_ref().log_device(stream)
987}
988
989#[generate_macro]
991#[default_device]
992pub fn log10_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
993 a.as_ref().log10_device(stream)
994}
995
996#[generate_macro]
998#[default_device]
999pub fn log1p_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
1000 a.as_ref().log1p_device(stream)
1001}
1002
1003#[generate_macro]
1005#[default_device]
1006pub fn log2_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
1007 a.as_ref().log2_device(stream)
1008}
1009
1010#[generate_macro]
1017#[default_device]
1018pub fn logaddexp_device(
1019 a: impl AsRef<Array>,
1020 b: impl AsRef<Array>,
1021 #[optional] stream: impl AsRef<Stream>,
1022) -> Result<Array> {
1023 let a_ptr = a.as_ref().as_ptr();
1024 let b_ptr = b.as_ref().as_ptr();
1025
1026 Array::try_from_op(|res| unsafe {
1027 mlx_sys::mlx_logaddexp(res, a_ptr, b_ptr, stream.as_ref().as_ptr())
1028 })
1029}
1030
1031#[generate_macro]
1033#[default_device]
1034pub fn matmul_device(
1035 a: impl AsRef<Array>,
1036 b: impl AsRef<Array>,
1037 #[optional] stream: impl AsRef<Stream>,
1038) -> Result<Array> {
1039 a.as_ref().matmul_device(b, stream)
1040}
1041
1042#[generate_macro]
1047#[default_device]
1048pub fn maximum_device(
1049 a: impl AsRef<Array>,
1050 b: impl AsRef<Array>,
1051 #[optional] stream: impl AsRef<Stream>,
1052) -> Result<Array> {
1053 let a_ptr = a.as_ref().as_ptr();
1054 let b_ptr = b.as_ref().as_ptr();
1055
1056 Array::try_from_op(|res| unsafe {
1057 mlx_sys::mlx_maximum(res, a_ptr, b_ptr, stream.as_ref().as_ptr())
1058 })
1059}
1060
1061#[generate_macro]
1066#[default_device]
1067pub fn minimum_device(
1068 a: impl AsRef<Array>,
1069 b: impl AsRef<Array>,
1070 #[optional] stream: impl AsRef<Stream>,
1071) -> Result<Array> {
1072 let a_ptr = a.as_ref().as_ptr();
1073 let b_ptr = b.as_ref().as_ptr();
1074
1075 Array::try_from_op(|res| unsafe {
1076 mlx_sys::mlx_minimum(res, a_ptr, b_ptr, stream.as_ref().as_ptr())
1077 })
1078}
1079
1080#[generate_macro]
1082#[default_device]
1083pub fn multiply_device(
1084 a: impl AsRef<Array>,
1085 b: impl AsRef<Array>,
1086 #[optional] stream: impl AsRef<Stream>,
1087) -> Result<Array> {
1088 a.as_ref().multiply_device(b, stream)
1089}
1090
1091#[generate_macro]
1093#[default_device]
1094pub fn negative_device(
1095 a: impl AsRef<Array>,
1096 #[optional] stream: impl AsRef<Stream>,
1097) -> Result<Array> {
1098 a.as_ref().negative_device(stream)
1099}
1100
1101#[generate_macro]
1103#[default_device]
1104pub fn power_device(
1105 a: impl AsRef<Array>,
1106 b: impl AsRef<Array>,
1107 #[optional] stream: impl AsRef<Stream>,
1108) -> Result<Array> {
1109 a.as_ref().power_device(b, stream)
1110}
1111
1112#[generate_macro]
1114#[default_device]
1115pub fn radians_device(
1116 a: impl AsRef<Array>,
1117 #[optional] stream: impl AsRef<Stream>,
1118) -> Result<Array> {
1119 Array::try_from_op(|res| unsafe {
1120 mlx_sys::mlx_radians(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
1121 })
1122}
1123
1124#[generate_macro]
1126#[default_device]
1127pub fn reciprocal_device(
1128 a: impl AsRef<Array>,
1129 #[optional] stream: impl AsRef<Stream>,
1130) -> Result<Array> {
1131 a.as_ref().reciprocal_device(stream)
1132}
1133
1134#[generate_macro]
1136#[default_device]
1137pub fn remainder_device(
1138 a: impl AsRef<Array>,
1139 b: impl AsRef<Array>,
1140 #[optional] stream: impl AsRef<Stream>,
1141) -> Result<Array> {
1142 a.as_ref().remainder_device(b, stream)
1143}
1144
1145#[generate_macro]
1147#[default_device]
1148pub fn round_device(
1149 a: impl AsRef<Array>,
1150 decimals: impl Into<Option<i32>>,
1151 #[optional] stream: impl AsRef<Stream>,
1152) -> Result<Array> {
1153 a.as_ref().round_device(decimals, stream)
1154}
1155
1156#[generate_macro]
1158#[default_device]
1159pub fn rsqrt_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
1160 a.as_ref().rsqrt_device(stream)
1161}
1162
1163#[generate_macro]
1169#[default_device]
1170pub fn sigmoid_device(
1171 a: impl AsRef<Array>,
1172 #[optional] stream: impl AsRef<Stream>,
1173) -> Result<Array> {
1174 Array::try_from_op(|res| unsafe {
1175 mlx_sys::mlx_sigmoid(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
1176 })
1177}
1178
1179#[generate_macro]
1181#[default_device]
1182pub fn sign_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
1183 Array::try_from_op(|res| unsafe {
1184 mlx_sys::mlx_sign(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
1185 })
1186}
1187
1188#[generate_macro]
1190#[default_device]
1191pub fn sin_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
1192 a.as_ref().sin_device(stream)
1193}
1194
1195#[generate_macro]
1197#[default_device]
1198pub fn sinh_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
1199 Array::try_from_op(|res| unsafe {
1200 mlx_sys::mlx_sinh(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
1201 })
1202}
1203
1204#[generate_macro]
1210#[default_device]
1211pub fn softmax_axes_device(
1212 a: impl AsRef<Array>,
1213 axes: &[i32],
1214 precise: impl Into<Option<bool>>,
1215 #[optional] stream: impl AsRef<Stream>,
1216) -> Result<Array> {
1217 let precise = precise.into().unwrap_or(false);
1218 let s = stream.as_ref().as_ptr();
1219
1220 Array::try_from_op(|res| unsafe {
1221 mlx_sys::mlx_softmax_axes(
1222 res,
1223 a.as_ref().as_ptr(),
1224 axes.as_ptr(),
1225 axes.len(),
1226 precise,
1227 s,
1228 )
1229 })
1230}
1231
1232#[generate_macro]
1234#[default_device]
1235pub fn softmax_axis_device(
1236 a: impl AsRef<Array>,
1237 axis: i32,
1238 precise: impl Into<Option<bool>>,
1239 #[optional] stream: impl AsRef<Stream>,
1240) -> Result<Array> {
1241 let precise = precise.into().unwrap_or(false);
1242 let s = stream.as_ref().as_ptr();
1243
1244 Array::try_from_op(|res| unsafe {
1245 mlx_sys::mlx_softmax_axis(res, a.as_ref().as_ptr(), axis, precise, s)
1246 })
1247}
1248
1249#[generate_macro]
1251#[default_device]
1252pub fn softmax_device(
1253 a: impl AsRef<Array>,
1254 precise: impl Into<Option<bool>>,
1255 #[optional] stream: impl AsRef<Stream>,
1256) -> Result<Array> {
1257 let precise = precise.into().unwrap_or(false);
1258 let s = stream.as_ref().as_ptr();
1259
1260 Array::try_from_op(|res| unsafe { mlx_sys::mlx_softmax(res, a.as_ref().as_ptr(), precise, s) })
1261}
1262
1263#[generate_macro]
1265#[default_device]
1266pub fn sqrt_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
1267 a.as_ref().sqrt_device(stream)
1268}
1269
1270#[generate_macro]
1272#[default_device]
1273pub fn square_device(
1274 a: impl AsRef<Array>,
1275 #[optional] stream: impl AsRef<Stream>,
1276) -> Result<Array> {
1277 a.as_ref().square_device(stream)
1278}
1279
1280#[generate_macro]
1282#[default_device]
1283pub fn subtract_device(
1284 a: impl AsRef<Array>,
1285 b: impl AsRef<Array>,
1286 #[optional] stream: impl AsRef<Stream>,
1287) -> Result<Array> {
1288 a.as_ref().subtract_device(b, stream)
1289}
1290
1291#[generate_macro]
1293#[default_device]
1294pub fn tan_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
1295 Array::try_from_op(|res| unsafe {
1296 mlx_sys::mlx_tan(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
1297 })
1298}
1299
1300#[generate_macro]
1302#[default_device]
1303pub fn tanh_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
1304 Array::try_from_op(|res| unsafe {
1305 mlx_sys::mlx_tanh(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
1306 })
1307}
1308
1309#[generate_macro]
1315#[default_device]
1316pub fn block_masked_mm_device<'mo, 'lhs, 'rhs>(
1317 a: impl AsRef<Array>,
1318 b: impl AsRef<Array>,
1319 #[optional] block_size: impl Into<Option<i32>>,
1320 #[optional] mask_out: impl Into<Option<&'mo Array>>,
1321 #[optional] mask_lhs: impl Into<Option<&'lhs Array>>,
1322 #[optional] mask_rhs: impl Into<Option<&'rhs Array>>,
1323 #[optional] stream: impl AsRef<Stream>,
1324) -> Result<Array> {
1325 let a_ptr = a.as_ref().as_ptr();
1326 let b_ptr = b.as_ref().as_ptr();
1327 unsafe {
1328 let mask_out_ptr = mask_out
1329 .into()
1330 .map(|m| m.as_ptr())
1331 .unwrap_or(mlx_sys::mlx_array_new());
1332 let mask_lhs_ptr = mask_lhs
1333 .into()
1334 .map(|m| m.as_ptr())
1335 .unwrap_or(mlx_sys::mlx_array_new());
1336 let mask_rhs_ptr = mask_rhs
1337 .into()
1338 .map(|m| m.as_ptr())
1339 .unwrap_or(mlx_sys::mlx_array_new());
1340
1341 Array::try_from_op(|res| {
1342 mlx_sys::mlx_block_masked_mm(
1343 res,
1344 a_ptr,
1345 b_ptr,
1346 block_size.into().unwrap_or(32),
1347 mask_out_ptr,
1348 mask_lhs_ptr,
1349 mask_rhs_ptr,
1350 stream.as_ref().as_ptr(),
1351 )
1352 })
1353 }
1354}
1355
1356#[generate_macro]
1369#[default_device]
1370pub fn addmm_device(
1371 c: impl AsRef<Array>,
1372 a: impl AsRef<Array>,
1373 b: impl AsRef<Array>,
1374 #[optional] alpha: impl Into<Option<f32>>,
1375 #[optional] beta: impl Into<Option<f32>>,
1376 #[optional] stream: impl AsRef<Stream>,
1377) -> Result<Array> {
1378 let c_ptr = c.as_ref().as_ptr();
1379 let a_ptr = a.as_ref().as_ptr();
1380 let b_ptr = b.as_ref().as_ptr();
1381 let alpha = alpha.into().unwrap_or(1.0);
1382 let beta = beta.into().unwrap_or(1.0);
1383
1384 Array::try_from_op(|res| unsafe {
1385 mlx_sys::mlx_addmm(
1386 res,
1387 c_ptr,
1388 a_ptr,
1389 b_ptr,
1390 alpha,
1391 beta,
1392 stream.as_ref().as_ptr(),
1393 )
1394 })
1395}
1396
1397#[generate_macro]
1400#[default_device]
1401pub fn inner_device(
1402 a: impl AsRef<Array>,
1403 b: impl AsRef<Array>,
1404 #[optional] stream: impl AsRef<Stream>,
1405) -> Result<Array> {
1406 let a = a.as_ref();
1407 let b = b.as_ref();
1408 Array::try_from_op(|res| unsafe {
1409 mlx_sys::mlx_inner(res, a.as_ptr(), b.as_ptr(), stream.as_ref().as_ptr())
1410 })
1411}
1412
1413#[generate_macro]
1416#[default_device]
1417pub fn outer_device(
1418 a: impl AsRef<Array>,
1419 b: impl AsRef<Array>,
1420 #[optional] stream: impl AsRef<Stream>,
1421) -> Result<Array> {
1422 let a = a.as_ref();
1423 let b = b.as_ref();
1424 Array::try_from_op(|res| unsafe {
1425 mlx_sys::mlx_outer(res, a.as_ptr(), b.as_ptr(), stream.as_ref().as_ptr())
1426 })
1427}
1428
1429#[generate_macro]
1431#[default_device]
1432pub fn tensordot_axes_device(
1433 a: impl AsRef<Array>,
1434 b: impl AsRef<Array>,
1435 axes_a: &[i32],
1436 axes_b: &[i32],
1437 #[optional] stream: impl AsRef<Stream>,
1438) -> Result<Array> {
1439 let a = a.as_ref();
1440 let b = b.as_ref();
1441 Array::try_from_op(|res| unsafe {
1442 mlx_sys::mlx_tensordot(
1443 res,
1444 a.as_ptr(),
1445 b.as_ptr(),
1446 axes_a.as_ptr(),
1447 axes_a.len(),
1448 axes_b.as_ptr(),
1449 axes_b.len(),
1450 stream.as_ref().as_ptr(),
1451 )
1452 })
1453}
1454
1455#[generate_macro]
1457#[default_device]
1458pub fn tensordot_axis_device(
1459 a: impl AsRef<Array>,
1460 b: impl AsRef<Array>,
1461 axis: i32,
1462 #[optional] stream: impl AsRef<Stream>,
1463) -> Result<Array> {
1464 let a = a.as_ref();
1465 let b = b.as_ref();
1466 Array::try_from_op(|res| unsafe {
1467 mlx_sys::mlx_tensordot_axis(res, a.as_ptr(), b.as_ptr(), axis, stream.as_ref().as_ptr())
1468 })
1469}
1470
1471#[cfg(test)]
1472mod tests {
1473 use std::f32::consts::PI;
1474
1475 use super::*;
1476 use crate::{
1477 array, complex64,
1478 ops::{all_close, arange, broadcast_to, eye, full, linspace, ones, reshape, split},
1479 transforms::eval,
1480 Dtype, StreamOrDevice,
1481 };
1482 use float_eq::assert_float_eq;
1483 use pretty_assertions::assert_eq;
1484
1485 #[test]
1486 fn test_abs() {
1487 let data = [1i32, 2, -3, -4, -5];
1488 let array = Array::from_slice(&data, &[5]);
1489 let result = array.abs().unwrap();
1490
1491 let data: &[i32] = result.as_slice();
1492 assert_eq!(data, [1, 2, 3, 4, 5]);
1493
1494 let data: &[i32] = array.as_slice();
1496 assert_eq!(data, [1, 2, -3, -4, -5]);
1497 }
1498
1499 #[test]
1500 fn test_add() {
1501 let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1502 let b = Array::from_slice(&[4.0, 5.0, 6.0], &[3]);
1503
1504 let c = &a + &b;
1505
1506 let c_data: &[f32] = c.as_slice();
1507 assert_eq!(c_data, &[5.0, 7.0, 9.0]);
1508
1509 let a_data: &[f32] = a.as_slice();
1511 assert_eq!(a_data, &[1.0, 2.0, 3.0]);
1512
1513 let b_data: &[f32] = b.as_slice();
1514 assert_eq!(b_data, &[4.0, 5.0, 6.0]);
1515 }
1516
1517 #[test]
1518 fn test_add_invalid_broadcast() {
1519 let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1520 let b = Array::from_slice(&[4.0, 5.0], &[2]);
1521
1522 let c = a.add(&b);
1523 assert!(c.is_err());
1524 }
1525
1526 #[test]
1527 fn test_sub() {
1528 let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1529 let b = Array::from_slice(&[4.0, 5.0, 6.0], &[3]);
1530
1531 let c = &a - &b;
1532
1533 let c_data: &[f32] = c.as_slice();
1534 assert_eq!(c_data, &[-3.0, -3.0, -3.0]);
1535
1536 let a_data: &[f32] = a.as_slice();
1538 assert_eq!(a_data, &[1.0, 2.0, 3.0]);
1539
1540 let b_data: &[f32] = b.as_slice();
1541 assert_eq!(b_data, &[4.0, 5.0, 6.0]);
1542 }
1543
1544 #[test]
1545 fn test_sub_invalid_broadcast() {
1546 let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1547 let b = Array::from_slice(&[4.0, 5.0], &[2]);
1548 let c = a.subtract(&b);
1549 assert!(c.is_err());
1550 }
1551
1552 #[test]
1553 fn test_neg() {
1554 let a = Array::from_slice::<f32>(&[1.0, 2.0, 3.0], &[3]);
1555 let b = a.negative().unwrap();
1556
1557 let b_data: &[f32] = b.as_slice();
1558 assert_eq!(b_data, &[-1.0, -2.0, -3.0]);
1559
1560 let a_data: &[f32] = a.as_slice();
1562 assert_eq!(a_data, &[1.0, 2.0, 3.0]);
1563 }
1564
1565 #[test]
1566 fn test_neg_bool() {
1567 let a = Array::from_slice(&[true, false, true], &[3]);
1568 let b = a.negative();
1569 assert!(b.is_err());
1570 }
1571
1572 #[test]
1573 fn test_logical_not() {
1574 let a: Array = false.into();
1575 let b = a.logical_not().unwrap();
1576
1577 let b_data: &[bool] = b.as_slice();
1578 assert_eq!(b_data, [true]);
1579 }
1580
1581 #[test]
1582 fn test_mul() {
1583 let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1584 let b = Array::from_slice(&[4.0, 5.0, 6.0], &[3]);
1585
1586 let c = &a * &b;
1587
1588 let c_data: &[f32] = c.as_slice();
1589 assert_eq!(c_data, &[4.0, 10.0, 18.0]);
1590
1591 let a_data: &[f32] = a.as_slice();
1593 assert_eq!(a_data, &[1.0, 2.0, 3.0]);
1594
1595 let b_data: &[f32] = b.as_slice();
1596 assert_eq!(b_data, &[4.0, 5.0, 6.0]);
1597 }
1598
1599 #[test]
1600 fn test_mul_invalid_broadcast() {
1601 let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1602 let b = Array::from_slice(&[4.0, 5.0], &[2]);
1603 let c = a.multiply(&b);
1604 assert!(c.is_err());
1605 }
1606
1607 #[test]
1608 fn test_nan_to_num() {
1609 let a = array!([1.0, 2.0, f32::NAN, 4.0, 5.0]);
1610 let b = a.nan_to_num(0.0, 1.0, 0.0).unwrap();
1611
1612 let b_data: &[f32] = b.as_slice();
1613 assert_eq!(b_data, &[1.0, 2.0, 0.0, 4.0, 5.0]);
1614 }
1615
1616 #[test]
1617 fn test_div() {
1618 let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1619 let b = Array::from_slice(&[4.0, 5.0, 6.0], &[3]);
1620
1621 let c = &a / &b;
1622
1623 let c_data: &[f32] = c.as_slice();
1624 assert_eq!(c_data, &[0.25, 0.4, 0.5]);
1625
1626 let a_data: &[f32] = a.as_slice();
1628 assert_eq!(a_data, &[1.0, 2.0, 3.0]);
1629
1630 let b_data: &[f32] = b.as_slice();
1631 assert_eq!(b_data, &[4.0, 5.0, 6.0]);
1632 }
1633
1634 #[test]
1635 fn test_div_invalid_broadcast() {
1636 let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1637 let b = Array::from_slice(&[4.0, 5.0], &[2]);
1638 let c = a.divide(&b);
1639 assert!(c.is_err());
1640 }
1641
1642 #[test]
1643 fn test_pow() {
1644 let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1645 let b = Array::from_slice(&[2.0, 3.0, 4.0], &[3]);
1646
1647 let c = a.power(&b).unwrap();
1648
1649 let c_data: &[f32] = c.as_slice();
1650 assert_eq!(c_data, &[1.0, 8.0, 81.0]);
1651
1652 let a_data: &[f32] = a.as_slice();
1654 assert_eq!(a_data, &[1.0, 2.0, 3.0]);
1655
1656 let b_data: &[f32] = b.as_slice();
1657 assert_eq!(b_data, &[2.0, 3.0, 4.0]);
1658 }
1659
1660 #[test]
1661 fn test_pow_invalid_broadcast() {
1662 let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1663 let b = Array::from_slice(&[2.0, 3.0], &[2]);
1664 let c = a.power(&b);
1665 assert!(c.is_err());
1666 }
1667
1668 #[test]
1669 fn test_rem() {
1670 let a = Array::from_slice(&[10.0, 11.0, 12.0], &[3]);
1671 let b = Array::from_slice(&[3.0, 4.0, 5.0], &[3]);
1672
1673 let c = &a % &b;
1674
1675 let c_data: &[f32] = c.as_slice();
1676 assert_eq!(c_data, &[1.0, 3.0, 2.0]);
1677
1678 let a_data: &[f32] = a.as_slice();
1680 assert_eq!(a_data, &[10.0, 11.0, 12.0]);
1681
1682 let b_data: &[f32] = b.as_slice();
1683 assert_eq!(b_data, &[3.0, 4.0, 5.0]);
1684 }
1685
1686 #[test]
1687 fn test_rem_invalid_broadcast() {
1688 let a = Array::from_slice(&[10.0, 11.0, 12.0], &[3]);
1689 let b = Array::from_slice(&[3.0, 4.0], &[2]);
1690 let c = a.remainder(&b);
1691 assert!(c.is_err());
1692 }
1693
1694 #[test]
1695 fn test_sqrt() {
1696 let a = Array::from_slice(&[1.0, 4.0, 9.0], &[3]);
1697 let b = a.sqrt().unwrap();
1698
1699 let b_data: &[f32] = b.as_slice();
1700 assert_eq!(b_data, &[1.0, 2.0, 3.0]);
1701
1702 let a_data: &[f32] = a.as_slice();
1704 assert_eq!(a_data, &[1.0, 4.0, 9.0]);
1705 }
1706
1707 #[test]
1708 fn test_cos() {
1709 let a = Array::from_slice(&[0.0, 1.0, 2.0], &[3]);
1710 let b = a.cos().unwrap();
1711
1712 let b_expected = array!([1.0, 0.54030234, -0.41614687]);
1713 assert_array_all_close!(b, b_expected);
1714
1715 let a_expected = array!([0.0, 1.0, 2.0]);
1717 assert_array_all_close!(a, a_expected);
1718 }
1719
1720 #[test]
1721 fn test_exp() {
1722 let a = Array::from_slice(&[0.0, 1.0, 2.0], &[3]);
1723 let b = a.exp().unwrap();
1724
1725 let b_expected = array!([1.0, 2.7182817, 7.389056]);
1726 assert_array_all_close!(b, b_expected);
1727
1728 let a_expected = array!([0.0, 1.0, 2.0]);
1730 assert_array_all_close!(a, a_expected);
1731 }
1732
1733 #[test]
1734 fn test_floor() {
1735 let a = Array::from_slice(&[0.1, 1.9, 2.5], &[3]);
1736 let b = a.floor().unwrap();
1737
1738 let b_data: &[f32] = b.as_slice();
1739 assert_eq!(b_data, &[0.0, 1.0, 2.0]);
1740
1741 let a_data: &[f32] = a.as_slice();
1743 assert_eq!(a_data, &[0.1, 1.9, 2.5]);
1744 }
1745
1746 #[test]
1747 fn test_floor_complex64() {
1748 let val = complex64::new(1.0, 2.0);
1749 let a = Array::from_complex(val);
1750 let b = a.floor_device(StreamOrDevice::default());
1751 assert!(b.is_err());
1752 }
1753
1754 #[test]
1755 fn test_floor_divide() {
1756 let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1757 let b = Array::from_slice(&[4.0, 5.0, 6.0], &[3]);
1758
1759 let c = a.floor_divide(&b).unwrap();
1760
1761 let c_data: &[f32] = c.as_slice();
1762 assert_eq!(c_data, &[0.0, 0.0, 0.0]);
1763
1764 let a_data: &[f32] = a.as_slice();
1766 assert_eq!(a_data, &[1.0, 2.0, 3.0]);
1767
1768 let b_data: &[f32] = b.as_slice();
1769 assert_eq!(b_data, &[4.0, 5.0, 6.0]);
1770 }
1771
1772 #[test]
1773 fn test_floor_divide_complex64() {
1774 let val = complex64::new(1.0, 2.0);
1775 let a = Array::from_complex(val);
1776 let b = Array::from_slice(&[4.0, 5.0, 6.0], &[3]);
1777 let c = a.floor_divide_device(&b, StreamOrDevice::default());
1778 assert!(c.is_err());
1779 }
1780
1781 #[test]
1782 fn test_floor_divide_invalid_broadcast() {
1783 let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1784 let b = Array::from_slice(&[4.0, 5.0], &[2]);
1785 let c = a.floor_divide_device(&b, StreamOrDevice::default());
1786 assert!(c.is_err());
1787 }
1788
1789 #[test]
1790 fn test_is_nan() {
1791 let a = Array::from_slice(&[1.0, f32::NAN, 3.0], &[3]);
1792 let b = a.is_nan().unwrap();
1793
1794 let b_data: &[bool] = b.as_slice();
1795 assert_eq!(b_data, &[false, true, false]);
1796 }
1797
1798 #[test]
1799 fn test_is_inf() {
1800 let a = Array::from_slice(&[1.0, f32::INFINITY, 3.0], &[3]);
1801 let b = a.is_inf().unwrap();
1802
1803 let b_data: &[bool] = b.as_slice();
1804 assert_eq!(b_data, &[false, true, false]);
1805 }
1806
1807 #[test]
1808 fn test_is_finite() {
1809 let a = Array::from_slice(&[1.0, f32::INFINITY, 3.0], &[3]);
1810 let b = a.is_finite().unwrap();
1811
1812 let b_data: &[bool] = b.as_slice();
1813 assert_eq!(b_data, &[true, false, true]);
1814 }
1815
1816 #[test]
1817 fn test_is_neg_inf() {
1818 let a = Array::from_slice(&[1.0, f32::NEG_INFINITY, 3.0], &[3]);
1819 let b = a.is_neg_inf().unwrap();
1820
1821 let b_data: &[bool] = b.as_slice();
1822 assert_eq!(b_data, &[false, true, false]);
1823 }
1824
1825 #[test]
1826 fn test_is_pos_inf() {
1827 let a = Array::from_slice(&[1.0, f32::INFINITY, 3.0], &[3]);
1828 let b = a.is_pos_inf().unwrap();
1829
1830 let b_data: &[bool] = b.as_slice();
1831 assert_eq!(b_data, &[false, true, false]);
1832 }
1833
1834 #[test]
1835 fn test_log() {
1836 let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1837 let b = a.log().unwrap();
1838
1839 let b_data: &[f32] = b.as_slice();
1840 assert_eq!(b_data, &[0.0, 0.6931472, 1.0986123]);
1841
1842 let a_data: &[f32] = a.as_slice();
1844 assert_eq!(a_data, &[1.0, 2.0, 3.0]);
1845 }
1846
1847 #[test]
1848 fn test_log2() {
1849 let a = Array::from_slice(&[1.0, 2.0, 4.0, 8.0], &[4]);
1850 let b = a.log2().unwrap();
1851
1852 let b_data: &[f32] = b.as_slice();
1853 assert_eq!(b_data, &[0.0, 1.0, 2.0, 3.0]);
1854
1855 let a_data: &[f32] = a.as_slice();
1857 assert_eq!(a_data, &[1.0, 2.0, 4.0, 8.0]);
1858 }
1859
1860 #[test]
1861 fn test_log10() {
1862 let a = Array::from_slice(&[1.0, 10.0, 100.0], &[3]);
1863 let b = a.log10().unwrap();
1864
1865 let b_data: &[f32] = b.as_slice();
1866 assert_eq!(b_data, &[0.0, 1.0, 2.0]);
1867
1868 let a_data: &[f32] = a.as_slice();
1870 assert_eq!(a_data, &[1.0, 10.0, 100.0]);
1871 }
1872
1873 #[test]
1874 fn test_log1p() {
1875 let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1876 let b = a.log1p().unwrap();
1877
1878 let b_data: &[f32] = b.as_slice();
1879 assert_eq!(b_data, &[0.6931472, 1.0986123, 1.3862944]);
1880
1881 let a_data: &[f32] = a.as_slice();
1883 assert_eq!(a_data, &[1.0, 2.0, 3.0]);
1884 }
1885
1886 #[test]
1887 fn test_matmul() {
1888 let a = Array::from_slice(&[1, 2, 3, 4], &[2, 2]);
1889 let b = Array::from_slice(&[-5.0, 37.5, 4., 7., 1., 0.], &[2, 3]);
1890
1891 let c = a.matmul(&b).unwrap();
1892
1893 assert_eq!(c.shape(), &[2, 3]);
1894 let c_data: &[f32] = c.as_slice();
1895 assert_eq!(c_data, &[9.0, 39.5, 4.0, 13.0, 116.5, 12.0]);
1896
1897 let a_data: &[i32] = a.as_slice();
1899 assert_eq!(a_data, &[1, 2, 3, 4]);
1900
1901 let b_data: &[f32] = b.as_slice();
1902 assert_eq!(b_data, &[-5.0, 37.5, 4., 7., 1., 0.]);
1903 }
1904
1905 #[test]
1906 fn test_matmul_ndim_zero() {
1907 let a: Array = 1.0.into();
1908 let b = Array::from_slice::<i32>(&[1], &[1]);
1909 let c = a.matmul(&b);
1910 assert!(c.is_err());
1911 }
1912
1913 #[test]
1914 fn test_matmul_ndim_one() {
1915 let a = Array::from_slice(&[1.0, 2.0, 3.0, 4.0], &[4]);
1916 let b = Array::from_slice(&[1.0, 2.0, 3.0, 4.0], &[4]);
1917 let c = a.matmul(&b);
1918 assert!(c.is_ok());
1919 }
1920
1921 #[test]
1922 fn test_matmul_dim_mismatch() {
1923 let a = Array::from_slice(&[1, 2, 3, 4, 5, 6], &[2, 3]);
1924 let b = Array::from_slice(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10], &[2, 5]);
1925 let c = a.matmul(&b);
1926 assert!(c.is_err());
1927 }
1928
1929 #[test]
1930 fn test_matmul_non_float_output_type() {
1931 let a = Array::from_slice(&[1, 2, 3, 4], &[2, 2]);
1932 let b = Array::from_slice(&[5, 37, 4, 7, 1, 0], &[2, 3]);
1933
1934 let c = a.matmul(&b);
1935 assert!(c.is_err());
1936 }
1937
1938 #[test]
1939 fn test_reciprocal() {
1940 let a = Array::from_slice(&[1.0, 2.0, 4.0], &[3]);
1941 let b = a.reciprocal().unwrap();
1942
1943 let b_data: &[f32] = b.as_slice();
1944 assert_eq!(b_data, &[1.0, 0.5, 0.25]);
1945
1946 let a_data: &[f32] = a.as_slice();
1948 assert_eq!(a_data, &[1.0, 2.0, 4.0]);
1949 }
1950
1951 #[test]
1952 fn test_round() {
1953 let a = Array::from_slice(&[1.1, 2.9, 3.5], &[3]);
1954 let b = a.round(None).unwrap();
1955
1956 let b_data: &[f32] = b.as_slice();
1957 assert_eq!(b_data, &[1.0, 3.0, 4.0]);
1958
1959 let a_data: &[f32] = a.as_slice();
1961 assert_eq!(a_data, &[1.1, 2.9, 3.5]);
1962 }
1963
1964 #[test]
1965 fn test_rsqrt() {
1966 let a = Array::from_slice(&[1.0, 2.0, 4.0], &[3]);
1967 let b = a.rsqrt().unwrap();
1968
1969 let b_data: &[f32] = b.as_slice();
1970 assert_eq!(b_data, &[1.0, 0.70710677, 0.5]);
1971
1972 let a_data: &[f32] = a.as_slice();
1974 assert_eq!(a_data, &[1.0, 2.0, 4.0]);
1975 }
1976
1977 #[test]
1978 fn test_sin() {
1979 let a = Array::from_slice(&[0.0, 1.0, 2.0], &[3]);
1980 let b = a.sin().unwrap();
1981
1982 let b_data: &[f32] = b.as_slice();
1983 assert_eq!(b_data, &[0.0, 0.841471, 0.9092974]);
1984
1985 let a_data: &[f32] = a.as_slice();
1987 assert_eq!(a_data, &[0.0, 1.0, 2.0]);
1988 }
1989
1990 #[test]
1991 fn test_square() {
1992 let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1993 let b = a.square().unwrap();
1994
1995 let b_data: &[f32] = b.as_slice();
1996 assert_eq!(b_data, &[1.0, 4.0, 9.0]);
1997
1998 let a_data: &[f32] = a.as_slice();
2000 assert_eq!(a_data, &[1.0, 2.0, 3.0]);
2001 }
2002
2003 #[test]
2006 fn test_unary_neg() {
2007 let x = array!(1.0);
2008 assert_eq!(negative(&x).unwrap().item::<f32>(), -1.0);
2009 assert_eq!((-x).item::<f32>(), -1.0);
2010
2011 assert_eq!(-array!(), array!());
2013
2014 let x = array!(true);
2016 assert!(negative(&x).is_err());
2017 }
2018
2019 #[test]
2020 fn test_unary_abs() {
2021 let x = array!([-1.0, 0.0, 1.0]);
2022 assert_eq!(abs(&x).unwrap(), array!([1.0, 0.0, 1.0]));
2023
2024 assert_eq!(abs(array!()).unwrap(), array!());
2026
2027 let x = array!([-1, 0, 1]);
2029 assert_eq!(abs(&x).unwrap(), array!([1, 0, 1]));
2030
2031 let x = array!([1u32, 0, 1]);
2033 assert_eq!(abs(&x).unwrap(), array!([1u32, 0, 1]));
2034
2035 let x = array!([false, true]);
2037 assert_eq!(abs(&x).unwrap(), array!([false, true]));
2038 }
2039
2040 #[test]
2041 fn test_unary_sign() {
2042 let x = array!([-1.0, 0.0, 1.0]);
2043 assert_eq!(sign(&x).unwrap(), x);
2044
2045 assert_eq!(sign(array!()).unwrap(), array!());
2047
2048 let x = array!([-1, 0, 1]);
2050 assert_eq!(sign(&x).unwrap(), x);
2051
2052 let x = array!([1u32, 0, 1]);
2054 assert_eq!(sign(&x).unwrap(), x);
2055
2056 let x = array!([false, true]);
2058 assert_eq!(sign(&x).unwrap(), x);
2059 }
2060
2061 const NEG_INF: f32 = f32::NEG_INFINITY;
2062
2063 #[test]
2064 fn test_unary_floor_ceil() {
2065 let x = array![1.0];
2066 assert_eq!(floor(&x).unwrap().item::<f32>(), 1.0);
2067 assert_eq!(ceil(&x).unwrap().item::<f32>(), 1.0);
2068
2069 let x = array![1.5];
2070 assert_eq!(floor(&x).unwrap().item::<f32>(), 1.0);
2071 assert_eq!(ceil(&x).unwrap().item::<f32>(), 2.0);
2072
2073 let x = array![-1.5];
2074 assert_eq!(floor(&x).unwrap().item::<f32>(), -2.0);
2075 assert_eq!(ceil(&x).unwrap().item::<f32>(), -1.0);
2076
2077 let x = array![NEG_INF];
2078 assert_eq!(floor(&x).unwrap().item::<f32>(), NEG_INF);
2079 assert_eq!(ceil(&x).unwrap().item::<f32>(), NEG_INF);
2080
2081 let x = array!([1.0, 1.0]).as_type::<complex64>().unwrap();
2082 assert!(floor(&x).is_err());
2083 assert!(ceil(&x).is_err());
2084 }
2085
2086 #[test]
2087 fn test_unary_round() {
2088 let x = array!([0.5, -0.5, 1.5, -1.5, 2.3, 2.6]);
2089 assert_eq!(round(&x, None).unwrap(), array!([0, 0, 2, -2, 2, 3]));
2090
2091 let x = array!([11, 222, 32]);
2092 assert_eq!(round(&x, -1).unwrap(), array!([10, 220, 30]));
2093 }
2094
2095 #[test]
2096 fn test_unary_exp() {
2097 let x = array![0.0];
2098 assert_eq!(exp(&x).unwrap().item::<f32>(), 1.0);
2099
2100 let x = array![2.0];
2101 assert_float_eq! {
2102 exp(&x).unwrap().item::<f32>(),
2103 2.0f32.exp(),
2104 abs <= 1e-5
2105 };
2106
2107 assert_eq!(exp(array!()).unwrap(), array!());
2108
2109 let x = array![NEG_INF];
2110 assert_eq!(exp(&x).unwrap().item::<f32>(), 0.0);
2111
2112 let x = array![2];
2114 assert_eq!(x.dtype(), Dtype::Int32);
2115 assert_float_eq! {
2116 exp(&x).unwrap().item::<f32>(),
2117 2.0f32.exp(),
2118 abs <= 1e-5
2119 };
2120
2121 let x = broadcast_to(&array!(1.0), &[2, 2, 2]).unwrap();
2123 let res = exp(&x).unwrap();
2124 let expected = Array::full::<f32>(&[2, 2, 2], array!(1.0f32.exp())).unwrap();
2125 assert!(all_close(&res, &expected, None, None, None)
2126 .unwrap()
2127 .item::<bool>());
2128
2129 let data = Array::from_slice(&[0.0, 1.0, 2.0, 3.0], &[2, 2]);
2130 let x = split(&data, 2, 1).unwrap();
2131 let expected = Array::from_slice(&[0.0f32.exp(), 2.0f32.exp()], &[2, 1]);
2132 assert!(all_close(exp(&x[0]).unwrap(), &expected, None, None, None)
2133 .unwrap()
2134 .item::<bool>());
2135 }
2136
2137 #[test]
2138 fn test_unary_expm1() {
2139 let x = array![-1.0];
2140 assert_float_eq! {
2141 expm1(&x).unwrap().item::<f32>(),
2142 (-1.0f32).exp_m1(),
2143 abs <= 1e-5
2144 };
2145
2146 let x = array![1.0];
2147 assert_float_eq! {
2148 expm1(&x).unwrap().item::<f32>(),
2149 1.0f32.exp_m1(),
2150 abs <= 1e-5
2151 };
2152
2153 let x = array![1];
2155 assert_eq!(expm1(&x).unwrap().dtype(), Dtype::Float32);
2156 assert_float_eq! {
2157 expm1(&x).unwrap().item::<f32>(),
2158 1.0f32.exp_m1(),
2159 abs <= 1e-5
2160 };
2161 }
2162
2163 #[test]
2164 fn test_unary_sin() {
2165 let x = array![0.0];
2166 assert_eq!(sin(&x).unwrap().item::<f32>(), 0.0);
2167
2168 let x = array![std::f32::consts::PI / 2.0];
2169 assert_float_eq! {
2170 sin(&x).unwrap().item::<f32>(),
2171 (std::f32::consts::PI / 2.0f32).sin(),
2172 abs <= 1e-5
2173 };
2174
2175 assert_eq!(sin(array!()).unwrap(), array!());
2176
2177 let x = array![0];
2179 assert_eq!(x.dtype(), Dtype::Int32);
2180 assert_float_eq! {
2181 sin(&x).unwrap().item::<f32>(),
2182 0.0f32.sin(),
2183 abs <= 1e-5
2184 };
2185
2186 let x = broadcast_to(&array!(1.0), &[2, 2, 2]).unwrap();
2188 let res = sin(&x).unwrap();
2189 let expected = Array::full::<f32>(&[2, 2, 2], array!(1.0f32.sin())).unwrap();
2190 assert!(all_close(&res, &expected, None, None, None)
2191 .unwrap()
2192 .item::<bool>());
2193
2194 let data = Array::from_slice(&[0.0, 1.0, 2.0, 3.0], &[2, 2]);
2195 let x = split(&data, 2, 1).unwrap();
2196 let expected = Array::from_slice(&[0.0f32.sin(), 2.0f32.sin()], &[2, 1]);
2197 assert!(all_close(sin(&x[0]).unwrap(), &expected, None, None, None)
2198 .unwrap()
2199 .item::<bool>());
2200 }
2201
2202 #[test]
2203 fn test_unary_cos() {
2204 let x = array![0.0];
2205 assert_float_eq! {
2206 cos(&x).unwrap().item::<f32>(),
2207 0.0f32.cos(),
2208 abs <= 1e-5
2209 };
2210
2211 let x = array![std::f32::consts::PI / 2.0];
2212 assert_float_eq! {
2213 cos(&x).unwrap().item::<f32>(),
2214 (std::f32::consts::PI / 2.0f32).cos(),
2215 abs <= 1e-5
2216 };
2217
2218 assert_eq!(cos(array!()).unwrap(), array!());
2219
2220 let x = array![0];
2222 assert_eq!(x.dtype(), Dtype::Int32);
2223 assert_float_eq! {
2224 cos(&x).unwrap().item::<f32>(),
2225 0.0f32.cos(),
2226 abs <= 1e-5
2227 };
2228
2229 let x = broadcast_to(&array!(1.0), &[2, 2, 2]).unwrap();
2231 let res = cos(&x).unwrap();
2232 let expected = Array::full::<f32>(&[2, 2, 2], array!(1.0f32.cos())).unwrap();
2233 assert!(all_close(&res, &expected, None, None, None)
2234 .unwrap()
2235 .item::<bool>());
2236
2237 let data = Array::from_slice(&[0.0, 1.0, 2.0, 3.0], &[2, 2]);
2238 let x = split(&data, 2, 1).unwrap();
2239 let expected = Array::from_slice(&[0.0f32.cos(), 2.0f32.cos()], &[2, 1]);
2240 assert!(all_close(cos(&x[0]).unwrap(), &expected, None, None, None)
2241 .unwrap()
2242 .item::<bool>());
2243 }
2244
2245 #[test]
2246 fn test_unary_degrees() {
2247 let x = array![0.0];
2248 assert_eq!(degrees(&x).unwrap().item::<f32>(), 0.0);
2249
2250 let x = array![std::f32::consts::PI / 2.0];
2251 assert_eq!(degrees(&x).unwrap().item::<f32>(), 90.0);
2252
2253 assert_eq!(degrees(array!()).unwrap(), array!());
2254
2255 let x = array![0];
2257 assert_eq!(x.dtype(), Dtype::Int32);
2258 assert_eq!(degrees(&x).unwrap().item::<f32>(), 0.0);
2259
2260 let x = broadcast_to(&array!(std::f32::consts::PI / 2.0), &[2, 2, 2]).unwrap();
2262 let res = degrees(&x).unwrap();
2263 let expected = Array::full::<f32>(&[2, 2, 2], array!(90.0)).unwrap();
2264 assert!(all_close(&res, &expected, None, None, None)
2265 .unwrap()
2266 .item::<bool>());
2267
2268 let angles = Array::from_slice(&[0.0, PI / 2.0, PI, 1.5 * PI], &[2, 2]);
2269 let x = split(&angles, 2, 1).unwrap();
2270 let expected = Array::from_slice(&[0.0, 180.0], &[2, 1]);
2271 assert!(
2272 all_close(degrees(&x[0]).unwrap(), &expected, None, None, None)
2273 .unwrap()
2274 .item::<bool>()
2275 );
2276 }
2277
2278 #[test]
2279 fn test_unary_radians() {
2280 let x = array![0.0];
2281 assert_eq!(radians(&x).unwrap().item::<f32>(), 0.0);
2282
2283 let x = array![90.0];
2284 assert_eq!(
2285 radians(&x).unwrap().item::<f32>(),
2286 std::f32::consts::PI / 2.0
2287 );
2288
2289 assert_eq!(radians(array!()).unwrap(), array!());
2290
2291 let x = array![90];
2293 assert_eq!(x.dtype(), Dtype::Int32);
2294 assert_eq!(
2295 radians(&x).unwrap().item::<f32>(),
2296 std::f32::consts::PI / 2.0
2297 );
2298
2299 let x = broadcast_to(&array!(90.0), &[2, 2, 2]).unwrap();
2301 let res = radians(&x).unwrap();
2302 let expected = Array::full::<f32>(&[2, 2, 2], array!(std::f32::consts::PI / 2.0)).unwrap();
2303 assert!(all_close(&res, &expected, None, None, None)
2304 .unwrap()
2305 .item::<bool>());
2306
2307 let angles = Array::from_slice(&[0.0, 90.0, 180.0, 270.0], &[2, 2]);
2308 let x = split(&angles, 2, 1).unwrap();
2309 let expected = Array::from_slice(&[0.0, PI], &[2, 1]);
2310 assert!(
2311 all_close(radians(&x[0]).unwrap(), &expected, None, None, None)
2312 .unwrap()
2313 .item::<bool>()
2314 );
2315 }
2316
2317 #[test]
2318 fn test_unary_log() {
2319 let x = array![0.0];
2320 assert_eq!(log(&x).unwrap().item::<f32>(), NEG_INF);
2321
2322 let x = array![1.0];
2323 assert_eq!(log(&x).unwrap().item::<f32>(), 0.0);
2324
2325 let x = array![1];
2327 assert_eq!(log(&x).unwrap().dtype(), Dtype::Float32);
2328 assert_eq!(log(&x).unwrap().item::<f32>(), 0.0);
2329
2330 let x = broadcast_to(&array!(1.0), &[2, 2, 2]).unwrap();
2332 let res = log(&x).unwrap();
2333 let expected = Array::full::<f32>(&[2, 2, 2], array!(0.0)).unwrap();
2334 assert!(all_close(&res, &expected, None, None, None)
2335 .unwrap()
2336 .item::<bool>());
2337
2338 let data = Array::from_slice(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
2339 let x = split(&data, 2, 1).unwrap();
2340 let expected = Array::from_slice(&[1.0f32.ln(), 3.0f32.ln()], &[2, 1]);
2341 assert!(all_close(log(&x[0]).unwrap(), &expected, None, None, None)
2342 .unwrap()
2343 .item::<bool>());
2344 }
2345
2346 #[test]
2347 fn test_unary_log2() {
2348 let x = array![0.0];
2349 assert_eq!(log2(&x).unwrap().item::<f32>(), NEG_INF);
2350
2351 let x = array![1.0];
2352 assert_eq!(log2(&x).unwrap().item::<f32>(), 0.0);
2353
2354 let x = array![1024.0];
2355 assert_eq!(log2(&x).unwrap().item::<f32>(), 10.0);
2356 }
2357
2358 #[test]
2359 fn test_unary_log10() {
2360 let x = array![0.0];
2361 assert_eq!(log10(&x).unwrap().item::<f32>(), NEG_INF);
2362
2363 let x = array![1.0];
2364 assert_eq!(log10(&x).unwrap().item::<f32>(), 0.0);
2365
2366 let x = array![1000.0];
2367 assert_eq!(log10(&x).unwrap().item::<f32>(), 3.0);
2368 }
2369
2370 #[test]
2371 fn test_unary_log1p() {
2372 let x = array![-1.0];
2373 assert_float_eq! {
2374 log1p(&x).unwrap().item::<f32>(),
2375 (-1.0f32).ln_1p(),
2376 abs <= 1e-5
2377 };
2378
2379 let x = array![1.0];
2380 assert_float_eq! {
2381 log1p(&x).unwrap().item::<f32>(),
2382 1.0f32.ln_1p(),
2383 abs <= 1e-5
2384 };
2385
2386 let x = array![1];
2388 assert_eq!(log1p(&x).unwrap().dtype(), Dtype::Float32);
2389 assert_float_eq! {
2390 log1p(&x).unwrap().item::<f32>(),
2391 1.0f32.ln_1p(),
2392 abs <= 1e-5
2393 };
2394
2395 let x = broadcast_to(&array!(1.0), &[2, 2, 2]).unwrap();
2397 let res = log1p(&x).unwrap();
2398 let expected = Array::full::<f32>(&[2, 2, 2], array!(1.0f32.ln_1p())).unwrap();
2399 assert!(all_close(&res, &expected, None, None, None)
2400 .unwrap()
2401 .item::<bool>());
2402
2403 let data = Array::from_slice(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
2404 let x = split(&data, 2, 1).unwrap();
2405 let expected = Array::from_slice(&[1.0f32.ln_1p(), 3.0f32.ln_1p()], &[2, 1]);
2406 assert!(
2407 all_close(log1p(&x[0]).unwrap(), &expected, None, None, None)
2408 .unwrap()
2409 .item::<bool>()
2410 );
2411 }
2412
2413 #[test]
2414 fn test_unary_sigmoid() {
2415 let x = array![0.0];
2416 assert_float_eq! {
2417 sigmoid(&x).unwrap().item::<f32>(),
2418 0.5,
2419 abs <= 1e-5
2420 };
2421
2422 let x = array![0];
2424 assert_eq!(sigmoid(&x).unwrap().dtype(), Dtype::Float32);
2425 assert_float_eq! {
2426 sigmoid(&x).unwrap().item::<f32>(),
2427 0.5,
2428 abs <= 1e-5
2429 };
2430
2431 let inf = f32::INFINITY;
2432 let x = array![inf];
2433 assert_eq!(sigmoid(&x).unwrap().item::<f32>(), 1.0);
2434
2435 let x = array![-inf];
2436 assert_eq!(sigmoid(&x).unwrap().item::<f32>(), 0.0);
2437 }
2438
2439 #[test]
2440 fn test_unary_square() {
2441 let x = array![3.0];
2442 assert_eq!(square(&x).unwrap().item::<f32>(), 9.0);
2443
2444 let x = array![2];
2445 assert_eq!(square(&x).unwrap().item::<i32>(), 4);
2446
2447 let x = Array::full::<f32>(&[3, 3], array!(2.0)).unwrap();
2448 assert!(all_close(
2449 square(&x).unwrap(),
2450 Array::full::<f32>(&[3, 3], array!(4.0)).unwrap(),
2451 None,
2452 None,
2453 None
2454 )
2455 .unwrap()
2456 .item::<bool>());
2457 }
2458
2459 #[test]
2460 fn test_unary_sqrt_rsqrt() {
2461 let x = array![4.0];
2462 assert_eq!(sqrt(&x).unwrap().item::<f32>(), 2.0);
2463 assert_eq!(rsqrt(&x).unwrap().item::<f32>(), 0.5);
2464
2465 let x = Array::full::<f32>(&[3, 3], array!(9.0)).unwrap();
2466 assert!(all_close(
2467 sqrt(&x).unwrap(),
2468 Array::full::<f32>(&[3, 3], array!(3.0)).unwrap(),
2469 None,
2470 None,
2471 None
2472 )
2473 .unwrap()
2474 .item::<bool>());
2475
2476 let x = array![4i32];
2477 assert_eq!(sqrt(&x).unwrap().item::<f32>(), 2.0);
2478 assert_eq!(rsqrt(&x).unwrap().item::<f32>(), 0.5);
2479 }
2480
2481 #[test]
2482 fn test_unary_reciprocal() {
2483 let x = array![8.0];
2484 assert_eq!(reciprocal(&x).unwrap().item::<f32>(), 0.125);
2485
2486 let x = array![2];
2487 let out = reciprocal(&x).unwrap();
2488 assert_eq!(out.dtype(), Dtype::Float32);
2489 assert_eq!(out.item::<f32>(), 0.5);
2490
2491 let x = Array::full::<f32>(&[3, 3], array!(2.0)).unwrap();
2492 assert!(all_close(
2493 reciprocal(&x).unwrap(),
2494 Array::full::<f32>(&[3, 3], array!(0.5)).unwrap(),
2495 None,
2496 None,
2497 None
2498 )
2499 .unwrap()
2500 .item::<bool>());
2501 }
2502
2503 #[test]
2504 fn test_binary_add() {
2505 let x = array![1.0];
2506 let y = array![1.0];
2507 let z = add(&x, &y).unwrap();
2508 assert_eq!(z.item::<f32>(), 2.0);
2509
2510 let z = &x + y;
2511 assert_eq!(z.item::<f32>(), 2.0);
2512
2513 let z = add(z, &x).unwrap();
2514 assert_eq!(z.item::<f32>(), 3.0);
2515
2516 let mut out = x.deep_clone();
2518 for _ in 0..10 {
2519 out = add(&out, &x).unwrap();
2520 }
2521 assert_eq!(out.item::<f32>(), 11.0);
2522
2523 let x = array!([1.0, 2.0, 3.0]);
2525 let y = array!([1.0, 2.0, 3.0]);
2526 let z = add(&x, &y).unwrap();
2527 assert_eq!(z.shape(), &[3]);
2528 assert_eq!(z, array!([2.0, 4.0, 6.0]));
2529
2530 let x = array!([1.0, 2.0, 3.0]);
2532 let y = &x + 2.0;
2533 assert_eq!(y.dtype(), Dtype::Float32);
2534 assert_eq!(y, array!([3.0, 4.0, 5.0]));
2535 let y = &x + 2.0;
2536 assert_eq!(y.dtype(), Dtype::Float32);
2537 assert_eq!(y, array!([3.0, 4.0, 5.0]));
2538
2539 let y = x + 2;
2541 assert_eq!(y.dtype(), Dtype::Float32);
2542
2543 let y = array!([1, 2, 3]) + 2.0;
2544 assert_eq!(y.dtype(), Dtype::Float32);
2545 assert_eq!(y, array!([3.0, 4.0, 5.0]));
2547
2548 let x = broadcast_to(&array!(1.0), &[10]).unwrap();
2550 let y = broadcast_to(&array!(2.0), &[10]).unwrap();
2551 let z = add(&x, &y).unwrap();
2552 assert_eq!(z, full::<f32>(&[10], array!(3.0)).unwrap());
2553
2554 let x = Array::from_slice(&[1.0, 2.0], &[1, 2]);
2555 let y = Array::from_slice(&[1.0, 2.0], &[2, 1]);
2556 let z = add(&x, &y).unwrap();
2557 assert_eq!(z.shape(), &[2, 2]);
2558 assert_eq!(z, Array::from_slice(&[2.0, 3.0, 3.0, 4.0], &[2, 2]));
2559
2560 let x = ones::<f32>(&[3, 2, 1]).unwrap();
2561 let z = x + 2.0;
2562 assert_eq!(z.shape(), &[3, 2, 1]);
2563 let expected = Array::from_slice(&[3.0, 3.0, 3.0, 3.0, 3.0, 3.0], &[3, 2, 1]);
2564 assert_eq!(z, expected);
2565
2566 let x = array!();
2568 let y = array!();
2569 let z = x + y;
2570 z.eval().unwrap();
2571 assert_eq!(z.size(), 0);
2572 assert_eq!(z.shape(), &[0]);
2573 }
2574
2575 #[test]
2576 fn test_binary_sub() {
2577 let x = array!([3.0, 2.0, 1.0]);
2578 let y = array!([1.0, 1.0, 1.0]);
2579 assert_eq!(x - y, array!([2.0, 1.0, 0.0]));
2580 }
2581
2582 #[test]
2583 fn test_binary_mul() {
2584 let x = array!([1.0, 2.0, 3.0]);
2585 let y = array!([2.0, 2.0, 2.0]);
2586 assert_eq!(x * y, array!([2.0, 4.0, 6.0]));
2587 }
2588
2589 #[test]
2590 fn test_binary_div() {
2591 let x = array![1.0];
2592 let y = array![1.0];
2593 assert_eq!(divide(&x, &y).unwrap().item::<f32>(), 1.0);
2594
2595 let x = array![1.0];
2596 let y = array![0.5];
2597 assert_eq!(divide(&x, &y).unwrap().item::<f32>(), 2.0);
2598
2599 let x = array![1.0];
2600 let y = array![4.0];
2601 assert_eq!(divide(&x, &y).unwrap().item::<f32>(), 0.25);
2602
2603 let x = array![true];
2604 let y = array![true];
2605 assert_eq!(divide(&x, &y).unwrap().item::<f32>(), 1.0);
2606
2607 let x = array![false];
2608 let y = array![true];
2609 assert_eq!(divide(&x, &y).unwrap().item::<f32>(), 0.0);
2610
2611 let x = array![true];
2612 let y = array![false];
2613 assert!(divide(&x, &y).unwrap().item::<f32>().is_infinite());
2614
2615 let x = array![false];
2616 let y = array![false];
2617 assert!(divide(&x, &y).unwrap().item::<f32>().is_nan());
2618 }
2619
2620 #[test]
2621 fn test_binary_maximum_minimum() {
2622 let x = array![1.0];
2623 let y = array![0.0];
2624 assert_eq!(maximum(&x, &y).unwrap().item::<f32>(), 1.0);
2625 assert_eq!(minimum(&x, &y).unwrap().item::<f32>(), 0.0);
2626
2627 let y = array![2.0];
2628 assert_eq!(maximum(&x, &y).unwrap().item::<f32>(), 2.0);
2629 assert_eq!(minimum(&x, &y).unwrap().item::<f32>(), 1.0);
2630 }
2631
2632 #[test]
2633 fn test_binary_logaddexp() {
2634 let x = array![0.0];
2635 let y = array![0.0];
2636 assert_float_eq! {
2637 logaddexp(&x, &y).unwrap().item::<f32>(),
2638 2.0f32.ln(),
2639 abs <= 1e-5
2640 };
2641
2642 let x = array!([0u32]);
2643 let y = array!([10000u32]);
2644 assert_eq!(logaddexp(&x, &y).unwrap().item::<f32>(), 10000.0);
2645
2646 let x = array![f32::INFINITY];
2647 let y = array![3.0];
2648 assert_eq!(logaddexp(&x, &y).unwrap().item::<f32>(), f32::INFINITY);
2649
2650 let x = array![f32::NEG_INFINITY];
2651 let y = array![3.0];
2652 assert_eq!(logaddexp(&x, &y).unwrap().item::<f32>(), 3.0);
2653
2654 let x = array![f32::NEG_INFINITY];
2655 let y = array![f32::NEG_INFINITY];
2656 assert_eq!(logaddexp(&x, &y).unwrap().item::<f32>(), f32::NEG_INFINITY);
2657
2658 let x = array![f32::INFINITY];
2659 let y = array![f32::INFINITY];
2660 assert_eq!(logaddexp(&x, &y).unwrap().item::<f32>(), f32::INFINITY);
2661
2662 let x = array![f32::NEG_INFINITY];
2663 let y = array![f32::INFINITY];
2664 assert_eq!(logaddexp(&x, &y).unwrap().item::<f32>(), f32::INFINITY);
2665 }
2666
2667 #[test]
2668 fn test_basic_clip() {
2669 let a = array!([1.0, 4.0, 3.0, 8.0, 5.0]);
2670 let expected = array!([2.0, 4.0, 3.0, 6.0, 5.0]);
2671 let clipped = clip(&a, (array!(2.0), array!(6.0))).unwrap();
2672 assert_eq!(clipped, expected);
2673
2674 let clipped = clip(&a, (2.0, 6.0)).unwrap();
2676 assert_eq!(clipped, expected);
2677 }
2678
2679 #[test]
2680 fn test_clip_with_only_min() {
2681 let a = array!([-1.0, 1.0, 0.0, 5.0]);
2682 let expected = array!([0.0, 1.0, 0.0, 5.0]);
2683 let clipped = clip(&a, (array!(0.0), ())).unwrap();
2684 assert_eq!(clipped, expected);
2685
2686 let clipped = clip(&a, (0.0, ())).unwrap();
2688 assert_eq!(clipped, expected);
2689 }
2690
2691 #[test]
2692 fn test_clip_with_only_max() {
2693 let a = array!([2.0, 3.0, 4.0, 5.0]);
2694 let expected = array!([2.0, 3.0, 4.0, 4.0]);
2695 let clipped = clip(&a, ((), array!(4.0))).unwrap();
2696 assert_eq!(clipped, expected);
2697
2698 let clipped = clip(&a, ((), 4.0)).unwrap();
2700 assert_eq!(clipped, expected);
2701 }
2702
2703 #[test]
2704 fn test_tensordot() {
2705 let x = reshape(arange::<_, f32>(None, 60.0, None).unwrap(), &[3, 4, 5]).unwrap();
2706 let y = reshape(arange::<_, f32>(None, 24.0, None).unwrap(), &[4, 3, 2]).unwrap();
2707 let z = tensordot_axes(&x, &y, &[1i32, 0], &[0i32, 1]).unwrap();
2708 let expected = Array::from_slice(
2709 &[4400, 4730, 4532, 4874, 4664, 5018, 4796, 5162, 4928, 5306],
2710 &[5, 2],
2711 );
2712 assert_eq!(z, expected);
2713
2714 let x = reshape(arange::<_, f32>(None, 360.0, None).unwrap(), &[3, 4, 5, 6]).unwrap();
2715 let y = reshape(arange::<_, f32>(None, 360.0, None).unwrap(), &[6, 4, 5, 3]).unwrap();
2716 assert!(tensordot_axes(&x, &y, &[2, 1, 3], &[1, 2, 0]).is_err());
2717
2718 let x = reshape(arange::<_, f32>(None, 60.0, None).unwrap(), &[3, 4, 5]).unwrap();
2719 let y = reshape(arange::<_, f32>(None, 120.0, None).unwrap(), &[4, 5, 6]).unwrap();
2720
2721 let z = tensordot_axis(&x, &y, 2).unwrap();
2722 let expected = Array::from_slice(
2723 &[
2724 14820.0, 15010.0, 15200.0, 15390.0, 15580.0, 15770.0, 37620.0, 38210.0, 38800.0,
2725 39390.0, 39980.0, 40570.0, 60420.0, 61410.0, 62400.0, 63390.0, 64380.0, 65370.0,
2726 ],
2727 &[3, 6],
2728 );
2729 assert_eq!(z, expected);
2730 }
2731
2732 #[test]
2733 fn test_outer() {
2734 let x = arange::<_, f32>(1.0, 5.0, None).unwrap();
2735 let y = arange::<_, f32>(1.0, 4.0, None).unwrap();
2736 let z = outer(&x, &y).unwrap();
2737 let expected = Array::from_slice(
2738 &[1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 3.0, 6.0, 9.0, 4.0, 8.0, 12.0],
2739 &[4, 3],
2740 );
2741 assert_eq!(z, expected);
2742
2743 let x = ones::<f32>(&[5]).unwrap();
2744 let y = linspace::<_, f32>(-2.0, 2.0, 5).unwrap();
2745 let z = outer(&x, &y).unwrap();
2746 let expected = Array::from_slice(
2747 &[
2748 -2.0, -1.0, 0.0, 1.0, 2.0, -2.0, -1.0, 0.0, 1.0, 2.0, -2.0, -1.0, 0.0, 1.0, 2.0,
2749 -2.0, -1.0, 0.0, 1.0, 2.0, -2.0, -1.0, 0.0, 1.0, 2.0,
2750 ],
2751 &[5, 5],
2752 );
2753 assert_eq!(z, expected);
2754 }
2755
2756 #[test]
2757 fn test_inner() {
2758 let x = reshape(arange::<_, f32>(None, 5.0, None).unwrap(), &[1, 5]).unwrap();
2759 let y = reshape(arange::<_, f32>(None, 6.0, None).unwrap(), &[2, 3]).unwrap();
2760 assert!(inner(&x, &y).is_err());
2761
2762 let x = array!([1.0, 2.0, 3.0]);
2763 let y = array!([0.0, 1.0, 0.0]);
2764 let z = inner(&x, &y).unwrap();
2765 assert_eq!(z.item::<f32>(), 2.0);
2766
2767 let x = reshape(arange::<_, f32>(None, 24.0, None).unwrap(), &[2, 3, 4]).unwrap();
2768 let y = arange::<_, f32>(None, 4.0, None).unwrap();
2769 let z = inner(&x, &y).unwrap();
2770 let expected = Array::from_slice(&[14.0, 38.0, 62.0, 86.0, 110.0, 134.0], &[2, 3]);
2771 assert_eq!(z, expected);
2772
2773 let x = reshape(arange::<_, f32>(None, 2.0, None).unwrap(), &[1, 1, 2]).unwrap();
2774 let y = reshape(arange::<_, f32>(None, 6.0, None).unwrap(), &[3, 2]).unwrap();
2775 let z = inner(&x, &y).unwrap();
2776 let expected = Array::from_slice(&[1.0, 3.0, 5.0], &[1, 1, 3]);
2777 assert_eq!(z, expected);
2778
2779 let x = eye::<f32>(2, None, None).unwrap();
2780 let y = Array::from_f32(7.0);
2781 let z = inner(&x, &y).unwrap();
2782 let expected = Array::from_slice(&[7.0, 0.0, 0.0, 7.0], &[2, 2]);
2783 assert_eq!(z, expected);
2784 }
2785
2786 #[test]
2787 fn test_divmod() {
2788 let x = array!([1.0, 2.0, 3.0]);
2789 let y = array!([1.0, 1.0, 1.0]);
2790 let out = divmod(&x, &y).unwrap();
2791 assert_eq!(out.0, array!([1.0, 2.0, 3.0]));
2792 assert_eq!(out.1, array!([0.0, 0.0, 0.0]));
2793
2794 let x = array!([5.0, 6.0, 7.0]);
2795 let y = array!([2.0, 2.0, 2.0]);
2796 let out = divmod(&x, &y).unwrap();
2797 assert_eq!(out.0, array!([2.0, 3.0, 3.0]));
2798 assert_eq!(out.1, array!([1.0, 0.0, 1.0]));
2799
2800 let x = array!([5.0, 6.0, 7.0]);
2801 let y = array!([2.0, 2.0, 2.0]);
2802 let out = divmod(&x, &y).unwrap();
2803 assert_eq!(out.0, array!([2.0, 3.0, 3.0]));
2804 assert_eq!(out.1, array!([1.0, 0.0, 1.0]));
2805
2806 let x = array![complex64::new(1.0, 0.0)];
2807 let y = array![complex64::new(2.0, 0.0)];
2808 assert!(divmod(&x, &y).is_err());
2809
2810 let x = array![1.0];
2812 let y = array![2.0];
2813 let (quo, rem) = divmod(&x, &y).unwrap();
2814 eval([&quo, &rem]).unwrap();
2815 assert_eq!(quo.item::<f32>(), 0.0);
2816 assert_eq!(rem.item::<f32>(), 1.0);
2817
2818 let x = array![1.0];
2820 let y = array![2.0];
2821 let (quo, rem) = divmod(&x, &y).unwrap();
2822 let z = quo + rem;
2823 assert_eq!(z.item::<f32>(), 1.0);
2824
2825 let mut out_holder = {
2827 let (quo, _) = divmod(&x, &y).unwrap();
2828 vec![quo]
2829 };
2830 eval(out_holder.iter()).unwrap();
2831 assert_eq!(out_holder[0].item::<f32>(), 0.0);
2832
2833 out_holder.clear();
2835 let out_holder = {
2836 let (_, rem) = divmod(&x, &y).unwrap();
2837 vec![rem]
2838 };
2839 eval(out_holder.iter()).unwrap();
2840 assert_eq!(out_holder[0].item::<f32>(), 1.0);
2841 }
2842}