1use crate::array::Array;
2use crate::error::Result;
3use crate::sealed::Sealed;
4
5use crate::utils::guard::Guarded;
6use crate::utils::{IntoOption, ScalarOrArray, VectorArray};
7use crate::Stream;
8use mlx_internal_macros::{default_device, generate_macro};
9use smallvec::SmallVec;
10
11impl Array {
12 #[default_device]
25 pub fn abs_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
26 Array::try_from_op(|res| unsafe {
27 mlx_sys::mlx_abs(res, self.as_ptr(), stream.as_ref().as_ptr())
28 })
29 }
30
31 #[default_device]
51 pub fn add_device(
52 &self,
53 other: impl AsRef<Array>,
54 stream: impl AsRef<Stream>,
55 ) -> Result<Array> {
56 Array::try_from_op(|res| unsafe {
57 mlx_sys::mlx_add(
58 res,
59 self.as_ptr(),
60 other.as_ref().as_ptr(),
61 stream.as_ref().as_ptr(),
62 )
63 })
64 }
65
66 #[default_device]
86 pub fn subtract_device(
87 &self,
88 other: impl AsRef<Array>,
89 stream: impl AsRef<Stream>,
90 ) -> Result<Array> {
91 Array::try_from_op(|res| unsafe {
92 mlx_sys::mlx_subtract(
93 res,
94 self.as_ptr(),
95 other.as_ref().as_ptr(),
96 stream.as_ref().as_ptr(),
97 )
98 })
99 }
100
101 #[default_device]
116 pub fn negative_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
117 Array::try_from_op(|res| unsafe {
118 mlx_sys::mlx_negative(res, self.as_ptr(), stream.as_ref().as_ptr())
119 })
120 }
121
122 #[default_device]
138 pub fn multiply_device(
139 &self,
140 other: impl AsRef<Array>,
141 stream: impl AsRef<Stream>,
142 ) -> Result<Array> {
143 Array::try_from_op(|res| unsafe {
144 mlx_sys::mlx_multiply(
145 res,
146 self.as_ptr(),
147 other.as_ref().as_ptr(),
148 stream.as_ref().as_ptr(),
149 )
150 })
151 }
152
153 #[default_device]
163 pub fn nan_to_num_device(
164 &self,
165 nan: impl IntoOption<f32>,
166 pos_inf: impl IntoOption<f32>,
167 neg_inf: impl IntoOption<f32>,
168 stream: impl AsRef<Stream>,
169 ) -> Result<Array> {
170 let pos_inf = pos_inf.into_option();
171 let neg_inf = neg_inf.into_option();
172
173 let pos_inf = mlx_sys::mlx_optional_float {
174 value: pos_inf.unwrap_or(0.0),
175 has_value: pos_inf.is_some(),
176 };
177 let neg_inf = mlx_sys::mlx_optional_float {
178 value: neg_inf.unwrap_or(0.0),
179 has_value: neg_inf.is_some(),
180 };
181
182 Array::try_from_op(|res| unsafe {
183 mlx_sys::mlx_nan_to_num(
184 res,
185 self.as_ptr(),
186 nan.into_option().unwrap_or(0.),
187 pos_inf,
188 neg_inf,
189 stream.as_ref().as_ptr(),
190 )
191 })
192 }
193
194 #[default_device]
214 pub fn divide_device(
215 &self,
216 other: impl AsRef<Array>,
217 stream: impl AsRef<Stream>,
218 ) -> Result<Array> {
219 Array::try_from_op(|res| unsafe {
220 mlx_sys::mlx_divide(
221 res,
222 self.as_ptr(),
223 other.as_ref().as_ptr(),
224 stream.as_ref().as_ptr(),
225 )
226 })
227 }
228
229 #[default_device]
249 pub fn power_device(
250 &self,
251 other: impl AsRef<Array>,
252 stream: impl AsRef<Stream>,
253 ) -> Result<Array> {
254 Array::try_from_op(|res| unsafe {
255 mlx_sys::mlx_power(
256 res,
257 self.as_ptr(),
258 other.as_ref().as_ptr(),
259 stream.as_ref().as_ptr(),
260 )
261 })
262 }
263
264 #[default_device]
284 pub fn remainder_device(
285 &self,
286 other: impl AsRef<Array>,
287 stream: impl AsRef<Stream>,
288 ) -> Result<Array> {
289 Array::try_from_op(|res| unsafe {
290 mlx_sys::mlx_remainder(
291 res,
292 self.as_ptr(),
293 other.as_ref().as_ptr(),
294 stream.as_ref().as_ptr(),
295 )
296 })
297 }
298
299 #[default_device]
312 pub fn sqrt_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
313 Array::try_from_op(|res| unsafe {
314 mlx_sys::mlx_sqrt(res, self.as_ptr(), stream.as_ref().as_ptr())
315 })
316 }
317
318 #[default_device]
331 pub fn cos_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
332 Array::try_from_op(|res| unsafe {
333 mlx_sys::mlx_cos(res, self.as_ptr(), stream.as_ref().as_ptr())
334 })
335 }
336
337 #[default_device]
352 pub fn exp_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
353 Array::try_from_op(|res| unsafe {
354 mlx_sys::mlx_exp(res, self.as_ptr(), stream.as_ref().as_ptr())
355 })
356 }
357
358 #[default_device]
371 pub fn floor_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
372 Array::try_from_op(|res| unsafe {
373 mlx_sys::mlx_floor(res, self.as_ptr(), stream.as_ref().as_ptr())
374 })
375 }
376
377 #[default_device]
401 pub fn floor_divide_device(
402 &self,
403 other: impl AsRef<Array>,
404 stream: impl AsRef<Stream>,
405 ) -> Result<Array> {
406 Array::try_from_op(|res| unsafe {
407 mlx_sys::mlx_floor_divide(
408 res,
409 self.as_ptr(),
410 other.as_ref().as_ptr(),
411 stream.as_ref().as_ptr(),
412 )
413 })
414 }
415
416 #[default_device]
421 pub fn is_nan_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
422 Array::try_from_op(|res| unsafe {
423 mlx_sys::mlx_isnan(res, self.as_ptr(), stream.as_ref().as_ptr())
424 })
425 }
426
427 #[default_device]
432 pub fn is_inf_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
433 Array::try_from_op(|res| unsafe {
434 mlx_sys::mlx_isinf(res, self.as_ptr(), stream.as_ref().as_ptr())
435 })
436 }
437
438 #[default_device]
443 pub fn is_finite_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
444 Array::try_from_op(|res| unsafe {
445 mlx_sys::mlx_isfinite(res, self.as_ptr(), stream.as_ref().as_ptr())
446 })
447 }
448
449 #[default_device]
454 pub fn is_neg_inf_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
455 Array::try_from_op(|res| unsafe {
456 mlx_sys::mlx_isneginf(res, self.as_ptr(), stream.as_ref().as_ptr())
457 })
458 }
459
460 #[default_device]
465 pub fn is_pos_inf_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
466 Array::try_from_op(|res| unsafe {
467 mlx_sys::mlx_isposinf(res, self.as_ptr(), stream.as_ref().as_ptr())
468 })
469 }
470
471 #[default_device]
484 pub fn log_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
485 Array::try_from_op(|res| unsafe {
486 mlx_sys::mlx_log(res, self.as_ptr(), stream.as_ref().as_ptr())
487 })
488 }
489
490 #[default_device]
503 pub fn log2_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
504 Array::try_from_op(|res| unsafe {
505 mlx_sys::mlx_log2(res, self.as_ptr(), stream.as_ref().as_ptr())
506 })
507 }
508
509 #[default_device]
522 pub fn log10_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
523 Array::try_from_op(|res| unsafe {
524 mlx_sys::mlx_log10(res, self.as_ptr(), stream.as_ref().as_ptr())
525 })
526 }
527
528 #[default_device]
541 pub fn log1p_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
542 Array::try_from_op(|res| unsafe {
543 mlx_sys::mlx_log1p(res, self.as_ptr(), stream.as_ref().as_ptr())
544 })
545 }
546
547 #[default_device]
577 pub fn matmul_device(
578 &self,
579 other: impl AsRef<Array>,
580 stream: impl AsRef<Stream>,
581 ) -> Result<Array> {
582 Array::try_from_op(|res| unsafe {
583 mlx_sys::mlx_matmul(
584 res,
585 self.as_ptr(),
586 other.as_ref().as_ptr(),
587 stream.as_ref().as_ptr(),
588 )
589 })
590 }
591
592 #[default_device]
605 pub fn reciprocal_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
606 Array::try_from_op(|res| unsafe {
607 mlx_sys::mlx_reciprocal(res, self.as_ptr(), stream.as_ref().as_ptr())
608 })
609 }
610
611 #[default_device]
617 pub fn round_device(
618 &self,
619 decimals: impl Into<Option<i32>>,
620 stream: impl AsRef<Stream>,
621 ) -> Result<Array> {
622 Array::try_from_op(|res| unsafe {
623 mlx_sys::mlx_round(
624 res,
625 self.as_ptr(),
626 decimals.into().unwrap_or(0),
627 stream.as_ref().as_ptr(),
628 )
629 })
630 }
631
632 #[default_device]
634 pub fn rsqrt_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
635 Array::try_from_op(|res| unsafe {
636 mlx_sys::mlx_rsqrt(res, self.as_ptr(), stream.as_ref().as_ptr())
637 })
638 }
639
640 #[default_device]
642 pub fn sin_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
643 Array::try_from_op(|res| unsafe {
644 mlx_sys::mlx_sin(res, self.as_ptr(), stream.as_ref().as_ptr())
645 })
646 }
647
648 #[default_device]
650 pub fn square_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
651 Array::try_from_op(|res| unsafe {
652 mlx_sys::mlx_square(res, self.as_ptr(), stream.as_ref().as_ptr())
653 })
654 }
655
656 #[default_device]
658 pub fn real_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
659 Array::try_from_op(|res| unsafe {
660 mlx_sys::mlx_real(res, self.as_ptr(), stream.as_ref().as_ptr())
661 })
662 }
663
664 #[default_device]
666 pub fn imag_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
667 Array::try_from_op(|res| unsafe {
668 mlx_sys::mlx_imag(res, self.as_ptr(), stream.as_ref().as_ptr())
669 })
670 }
671}
672
673#[generate_macro]
684#[default_device]
685pub fn abs_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
686 a.as_ref().abs_device(stream)
687}
688
689#[generate_macro]
691#[default_device]
692pub fn acos_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
693 Array::try_from_op(|res| unsafe {
694 mlx_sys::mlx_arccos(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
695 })
696}
697
698#[generate_macro]
700#[default_device]
701pub fn acosh_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
702 Array::try_from_op(|res| unsafe {
703 mlx_sys::mlx_arccosh(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
704 })
705}
706
707#[generate_macro]
709#[default_device]
710pub fn add_device(
711 lhs: impl AsRef<Array>,
712 rhs: impl AsRef<Array>,
713 #[optional] stream: impl AsRef<Stream>,
714) -> Result<Array> {
715 lhs.as_ref().add_device(rhs, stream)
716}
717
718#[generate_macro]
720#[default_device]
721pub fn asin_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
722 Array::try_from_op(|res| unsafe {
723 mlx_sys::mlx_arcsin(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
724 })
725}
726
727#[generate_macro]
729#[default_device]
730pub fn asinh_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
731 Array::try_from_op(|res| unsafe {
732 mlx_sys::mlx_arcsinh(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
733 })
734}
735
736#[generate_macro]
738#[default_device]
739pub fn atan_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
740 Array::try_from_op(|res| unsafe {
741 mlx_sys::mlx_arctan(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
742 })
743}
744
745#[generate_macro]
747#[default_device]
748pub fn atanh_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
749 Array::try_from_op(|res| unsafe {
750 mlx_sys::mlx_arctanh(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
751 })
752}
753
754#[generate_macro]
756#[default_device]
757pub fn ceil_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
758 Array::try_from_op(|res| unsafe {
759 mlx_sys::mlx_ceil(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
760 })
761}
762
763pub trait ClipBound<'min, 'max>: Sealed {
768 fn into_min_max(
770 self,
771 ) -> (
772 Option<impl ScalarOrArray<'min>>,
773 Option<impl ScalarOrArray<'max>>,
774 );
775}
776
777impl<'min, Min> ClipBound<'min, 'min> for (Min, ())
778where
779 Min: ScalarOrArray<'min> + Sealed,
780{
781 fn into_min_max(
782 self,
783 ) -> (
784 Option<impl ScalarOrArray<'min>>,
785 Option<impl ScalarOrArray<'min>>,
786 ) {
787 (Some(self.0), Option::<Min>::None)
788 }
789}
790
791impl<'max, Max> ClipBound<'max, 'max> for ((), Max)
792where
793 Max: ScalarOrArray<'max> + Sealed,
794{
795 fn into_min_max(
796 self,
797 ) -> (
798 Option<impl ScalarOrArray<'max>>,
799 Option<impl ScalarOrArray<'max>>,
800 ) {
801 (Option::<Max>::None, Some(self.1))
802 }
803}
804
805impl<'min, 'max, Min, Max> ClipBound<'min, 'max> for (Min, Max)
806where
807 Min: ScalarOrArray<'min> + Sealed,
808 Max: ScalarOrArray<'max> + Sealed,
809{
810 fn into_min_max(
811 self,
812 ) -> (
813 Option<impl ScalarOrArray<'min>>,
814 Option<impl ScalarOrArray<'max>>,
815 ) {
816 (Some(self.0), Some(self.1))
817 }
818}
819
820#[generate_macro]
842#[default_device]
843pub fn clip_device<'min, 'max>(
844 a: impl AsRef<Array>,
845 bound: impl ClipBound<'min, 'max>,
846 #[optional] stream: impl AsRef<Stream>,
847) -> Result<Array> {
848 let (a_min, a_max) = bound.into_min_max();
849
850 let a_min = a_min.map(|min| min.into_owned_or_ref_array());
852 let a_max = a_max.map(|max| max.into_owned_or_ref_array());
853
854 unsafe {
855 let min_ptr = match &a_min {
856 Some(a_min) => a_min.as_ref().as_ptr(),
857 None => mlx_sys::mlx_array_new(),
858 };
859 let max_ptr = match &a_max {
860 Some(a_max) => a_max.as_ref().as_ptr(),
861 None => mlx_sys::mlx_array_new(),
862 };
863
864 Array::try_from_op(|res| {
865 mlx_sys::mlx_clip(
866 res,
867 a.as_ref().as_ptr(),
868 min_ptr,
869 max_ptr,
870 stream.as_ref().as_ptr(),
871 )
872 })
873 }
874}
875
876#[generate_macro]
878#[default_device]
879pub fn cos_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
880 a.as_ref().cos_device(stream)
881}
882
883#[generate_macro]
885#[default_device]
886pub fn cosh_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
887 Array::try_from_op(|res| unsafe {
888 mlx_sys::mlx_cosh(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
889 })
890}
891
892#[generate_macro]
894#[default_device]
895pub fn degrees_device(
896 a: impl AsRef<Array>,
897 #[optional] stream: impl AsRef<Stream>,
898) -> Result<Array> {
899 Array::try_from_op(|res| unsafe {
900 mlx_sys::mlx_degrees(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
901 })
902}
903
904#[generate_macro]
906#[default_device]
907pub fn divide_device(
908 a: impl AsRef<Array>,
909 b: impl AsRef<Array>,
910 #[optional] stream: impl AsRef<Stream>,
911) -> Result<Array> {
912 a.as_ref().divide_device(b, stream)
913}
914
915#[generate_macro]
922#[default_device]
923pub fn divmod_device(
924 a: impl AsRef<Array>,
925 b: impl AsRef<Array>,
926 #[optional] stream: impl AsRef<Stream>,
927) -> Result<(Array, Array)> {
928 let a_ptr = a.as_ref().as_ptr();
929 let b_ptr = b.as_ref().as_ptr();
930
931 let vec = VectorArray::try_from_op(|res| unsafe {
932 mlx_sys::mlx_divmod(res, a_ptr, b_ptr, stream.as_ref().as_ptr())
933 })?;
934
935 let vals: SmallVec<[_; 2]> = vec.try_into_values()?;
936 let mut iter = vals.into_iter();
937 let quotient = iter.next().unwrap();
938 let remainder = iter.next().unwrap();
939
940 Ok((quotient, remainder))
941}
942
943#[generate_macro]
945#[default_device]
946pub fn erf_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
947 Array::try_from_op(|res| unsafe {
948 mlx_sys::mlx_erf(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
949 })
950}
951
952#[generate_macro]
954#[default_device]
955pub fn erfinv_device(
956 a: impl AsRef<Array>,
957 #[optional] stream: impl AsRef<Stream>,
958) -> Result<Array> {
959 Array::try_from_op(|res| unsafe {
960 mlx_sys::mlx_erfinv(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
961 })
962}
963
964#[generate_macro]
966#[default_device]
967pub fn exp_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
968 a.as_ref().exp_device(stream)
969}
970
971#[generate_macro]
973#[default_device]
974pub fn expm1_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
975 Array::try_from_op(|res| unsafe {
976 mlx_sys::mlx_expm1(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
977 })
978}
979
980#[generate_macro]
982#[default_device]
983pub fn floor_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
984 a.as_ref().floor_device(stream)
985}
986
987#[generate_macro]
989#[default_device]
990pub fn floor_divide_device(
991 a: impl AsRef<Array>,
992 other: impl AsRef<Array>,
993 #[optional] stream: impl AsRef<Stream>,
994) -> Result<Array> {
995 a.as_ref().floor_divide_device(other, stream)
996}
997
998#[generate_macro]
1000#[default_device]
1001pub fn log_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
1002 a.as_ref().log_device(stream)
1003}
1004
1005#[generate_macro]
1007#[default_device]
1008pub fn log10_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
1009 a.as_ref().log10_device(stream)
1010}
1011
1012#[generate_macro]
1014#[default_device]
1015pub fn log1p_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
1016 a.as_ref().log1p_device(stream)
1017}
1018
1019#[generate_macro]
1021#[default_device]
1022pub fn log2_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
1023 a.as_ref().log2_device(stream)
1024}
1025
1026#[generate_macro]
1033#[default_device]
1034pub fn logaddexp_device(
1035 a: impl AsRef<Array>,
1036 b: impl AsRef<Array>,
1037 #[optional] stream: impl AsRef<Stream>,
1038) -> Result<Array> {
1039 let a_ptr = a.as_ref().as_ptr();
1040 let b_ptr = b.as_ref().as_ptr();
1041
1042 Array::try_from_op(|res| unsafe {
1043 mlx_sys::mlx_logaddexp(res, a_ptr, b_ptr, stream.as_ref().as_ptr())
1044 })
1045}
1046
1047#[generate_macro]
1049#[default_device]
1050pub fn matmul_device(
1051 a: impl AsRef<Array>,
1052 b: impl AsRef<Array>,
1053 #[optional] stream: impl AsRef<Stream>,
1054) -> Result<Array> {
1055 a.as_ref().matmul_device(b, stream)
1056}
1057
1058#[generate_macro]
1063#[default_device]
1064pub fn maximum_device(
1065 a: impl AsRef<Array>,
1066 b: impl AsRef<Array>,
1067 #[optional] stream: impl AsRef<Stream>,
1068) -> Result<Array> {
1069 let a_ptr = a.as_ref().as_ptr();
1070 let b_ptr = b.as_ref().as_ptr();
1071
1072 Array::try_from_op(|res| unsafe {
1073 mlx_sys::mlx_maximum(res, a_ptr, b_ptr, stream.as_ref().as_ptr())
1074 })
1075}
1076
1077#[generate_macro]
1082#[default_device]
1083pub fn minimum_device(
1084 a: impl AsRef<Array>,
1085 b: impl AsRef<Array>,
1086 #[optional] stream: impl AsRef<Stream>,
1087) -> Result<Array> {
1088 let a_ptr = a.as_ref().as_ptr();
1089 let b_ptr = b.as_ref().as_ptr();
1090
1091 Array::try_from_op(|res| unsafe {
1092 mlx_sys::mlx_minimum(res, a_ptr, b_ptr, stream.as_ref().as_ptr())
1093 })
1094}
1095
1096#[generate_macro]
1098#[default_device]
1099pub fn multiply_device(
1100 a: impl AsRef<Array>,
1101 b: impl AsRef<Array>,
1102 #[optional] stream: impl AsRef<Stream>,
1103) -> Result<Array> {
1104 a.as_ref().multiply_device(b, stream)
1105}
1106
1107#[generate_macro]
1109#[default_device]
1110pub fn negative_device(
1111 a: impl AsRef<Array>,
1112 #[optional] stream: impl AsRef<Stream>,
1113) -> Result<Array> {
1114 a.as_ref().negative_device(stream)
1115}
1116
1117#[generate_macro]
1119#[default_device]
1120pub fn power_device(
1121 a: impl AsRef<Array>,
1122 b: impl AsRef<Array>,
1123 #[optional] stream: impl AsRef<Stream>,
1124) -> Result<Array> {
1125 a.as_ref().power_device(b, stream)
1126}
1127
1128#[generate_macro]
1130#[default_device]
1131pub fn radians_device(
1132 a: impl AsRef<Array>,
1133 #[optional] stream: impl AsRef<Stream>,
1134) -> Result<Array> {
1135 Array::try_from_op(|res| unsafe {
1136 mlx_sys::mlx_radians(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
1137 })
1138}
1139
1140#[generate_macro]
1142#[default_device]
1143pub fn reciprocal_device(
1144 a: impl AsRef<Array>,
1145 #[optional] stream: impl AsRef<Stream>,
1146) -> Result<Array> {
1147 a.as_ref().reciprocal_device(stream)
1148}
1149
1150#[generate_macro]
1152#[default_device]
1153pub fn remainder_device(
1154 a: impl AsRef<Array>,
1155 b: impl AsRef<Array>,
1156 #[optional] stream: impl AsRef<Stream>,
1157) -> Result<Array> {
1158 a.as_ref().remainder_device(b, stream)
1159}
1160
1161#[generate_macro]
1163#[default_device]
1164pub fn round_device(
1165 a: impl AsRef<Array>,
1166 decimals: impl Into<Option<i32>>,
1167 #[optional] stream: impl AsRef<Stream>,
1168) -> Result<Array> {
1169 a.as_ref().round_device(decimals, stream)
1170}
1171
1172#[generate_macro]
1174#[default_device]
1175pub fn rsqrt_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
1176 a.as_ref().rsqrt_device(stream)
1177}
1178
1179#[generate_macro]
1185#[default_device]
1186pub fn sigmoid_device(
1187 a: impl AsRef<Array>,
1188 #[optional] stream: impl AsRef<Stream>,
1189) -> Result<Array> {
1190 Array::try_from_op(|res| unsafe {
1191 mlx_sys::mlx_sigmoid(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
1192 })
1193}
1194
1195#[generate_macro]
1197#[default_device]
1198pub fn sign_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
1199 Array::try_from_op(|res| unsafe {
1200 mlx_sys::mlx_sign(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
1201 })
1202}
1203
1204#[generate_macro]
1206#[default_device]
1207pub fn sin_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
1208 a.as_ref().sin_device(stream)
1209}
1210
1211#[generate_macro]
1213#[default_device]
1214pub fn sinh_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
1215 Array::try_from_op(|res| unsafe {
1216 mlx_sys::mlx_sinh(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
1217 })
1218}
1219
1220#[generate_macro]
1226#[default_device]
1227pub fn softmax_axes_device(
1228 a: impl AsRef<Array>,
1229 axes: &[i32],
1230 precise: impl Into<Option<bool>>,
1231 #[optional] stream: impl AsRef<Stream>,
1232) -> Result<Array> {
1233 let precise = precise.into().unwrap_or(false);
1234 let s = stream.as_ref().as_ptr();
1235
1236 Array::try_from_op(|res| unsafe {
1237 mlx_sys::mlx_softmax_axes(
1238 res,
1239 a.as_ref().as_ptr(),
1240 axes.as_ptr(),
1241 axes.len(),
1242 precise,
1243 s,
1244 )
1245 })
1246}
1247
1248#[generate_macro]
1250#[default_device]
1251pub fn softmax_axis_device(
1252 a: impl AsRef<Array>,
1253 axis: i32,
1254 precise: impl Into<Option<bool>>,
1255 #[optional] stream: impl AsRef<Stream>,
1256) -> Result<Array> {
1257 let precise = precise.into().unwrap_or(false);
1258 let s = stream.as_ref().as_ptr();
1259
1260 Array::try_from_op(|res| unsafe {
1261 mlx_sys::mlx_softmax_axis(res, a.as_ref().as_ptr(), axis, precise, s)
1262 })
1263}
1264
1265#[generate_macro]
1267#[default_device]
1268pub fn softmax_device(
1269 a: impl AsRef<Array>,
1270 precise: impl Into<Option<bool>>,
1271 #[optional] stream: impl AsRef<Stream>,
1272) -> Result<Array> {
1273 let precise = precise.into().unwrap_or(false);
1274 let s = stream.as_ref().as_ptr();
1275
1276 Array::try_from_op(|res| unsafe { mlx_sys::mlx_softmax(res, a.as_ref().as_ptr(), precise, s) })
1277}
1278
1279#[generate_macro]
1281#[default_device]
1282pub fn sqrt_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
1283 a.as_ref().sqrt_device(stream)
1284}
1285
1286#[generate_macro]
1288#[default_device]
1289pub fn square_device(
1290 a: impl AsRef<Array>,
1291 #[optional] stream: impl AsRef<Stream>,
1292) -> Result<Array> {
1293 a.as_ref().square_device(stream)
1294}
1295
1296#[generate_macro]
1298#[default_device]
1299pub fn subtract_device(
1300 a: impl AsRef<Array>,
1301 b: impl AsRef<Array>,
1302 #[optional] stream: impl AsRef<Stream>,
1303) -> Result<Array> {
1304 a.as_ref().subtract_device(b, stream)
1305}
1306
1307#[generate_macro]
1309#[default_device]
1310pub fn tan_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
1311 Array::try_from_op(|res| unsafe {
1312 mlx_sys::mlx_tan(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
1313 })
1314}
1315
1316#[generate_macro]
1318#[default_device]
1319pub fn tanh_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
1320 Array::try_from_op(|res| unsafe {
1321 mlx_sys::mlx_tanh(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
1322 })
1323}
1324
1325#[generate_macro]
1327#[default_device]
1328pub fn real_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
1329 Array::try_from_op(|res| unsafe {
1330 mlx_sys::mlx_real(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
1331 })
1332}
1333
1334#[generate_macro]
1336#[default_device]
1337pub fn imag_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
1338 Array::try_from_op(|res| unsafe {
1339 mlx_sys::mlx_imag(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
1340 })
1341}
1342
1343#[generate_macro]
1349#[default_device]
1350pub fn block_masked_mm_device<'mo, 'lhs, 'rhs>(
1351 a: impl AsRef<Array>,
1352 b: impl AsRef<Array>,
1353 #[optional] block_size: impl Into<Option<i32>>,
1354 #[optional] mask_out: impl Into<Option<&'mo Array>>,
1355 #[optional] mask_lhs: impl Into<Option<&'lhs Array>>,
1356 #[optional] mask_rhs: impl Into<Option<&'rhs Array>>,
1357 #[optional] stream: impl AsRef<Stream>,
1358) -> Result<Array> {
1359 let a_ptr = a.as_ref().as_ptr();
1360 let b_ptr = b.as_ref().as_ptr();
1361 unsafe {
1362 let mask_out_ptr = mask_out
1363 .into()
1364 .map(|m| m.as_ptr())
1365 .unwrap_or(mlx_sys::mlx_array_new());
1366 let mask_lhs_ptr = mask_lhs
1367 .into()
1368 .map(|m| m.as_ptr())
1369 .unwrap_or(mlx_sys::mlx_array_new());
1370 let mask_rhs_ptr = mask_rhs
1371 .into()
1372 .map(|m| m.as_ptr())
1373 .unwrap_or(mlx_sys::mlx_array_new());
1374
1375 Array::try_from_op(|res| {
1376 mlx_sys::mlx_block_masked_mm(
1377 res,
1378 a_ptr,
1379 b_ptr,
1380 block_size.into().unwrap_or(32),
1381 mask_out_ptr,
1382 mask_lhs_ptr,
1383 mask_rhs_ptr,
1384 stream.as_ref().as_ptr(),
1385 )
1386 })
1387 }
1388}
1389
1390#[generate_macro]
1403#[default_device]
1404pub fn addmm_device(
1405 c: impl AsRef<Array>,
1406 a: impl AsRef<Array>,
1407 b: impl AsRef<Array>,
1408 #[optional] alpha: impl Into<Option<f32>>,
1409 #[optional] beta: impl Into<Option<f32>>,
1410 #[optional] stream: impl AsRef<Stream>,
1411) -> Result<Array> {
1412 let c_ptr = c.as_ref().as_ptr();
1413 let a_ptr = a.as_ref().as_ptr();
1414 let b_ptr = b.as_ref().as_ptr();
1415 let alpha = alpha.into().unwrap_or(1.0);
1416 let beta = beta.into().unwrap_or(1.0);
1417
1418 Array::try_from_op(|res| unsafe {
1419 mlx_sys::mlx_addmm(
1420 res,
1421 c_ptr,
1422 a_ptr,
1423 b_ptr,
1424 alpha,
1425 beta,
1426 stream.as_ref().as_ptr(),
1427 )
1428 })
1429}
1430
1431#[generate_macro]
1434#[default_device]
1435pub fn inner_device(
1436 a: impl AsRef<Array>,
1437 b: impl AsRef<Array>,
1438 #[optional] stream: impl AsRef<Stream>,
1439) -> Result<Array> {
1440 let a = a.as_ref();
1441 let b = b.as_ref();
1442 Array::try_from_op(|res| unsafe {
1443 mlx_sys::mlx_inner(res, a.as_ptr(), b.as_ptr(), stream.as_ref().as_ptr())
1444 })
1445}
1446
1447#[generate_macro]
1450#[default_device]
1451pub fn outer_device(
1452 a: impl AsRef<Array>,
1453 b: impl AsRef<Array>,
1454 #[optional] stream: impl AsRef<Stream>,
1455) -> Result<Array> {
1456 let a = a.as_ref();
1457 let b = b.as_ref();
1458 Array::try_from_op(|res| unsafe {
1459 mlx_sys::mlx_outer(res, a.as_ptr(), b.as_ptr(), stream.as_ref().as_ptr())
1460 })
1461}
1462
1463#[generate_macro]
1465#[default_device]
1466pub fn tensordot_axes_device(
1467 a: impl AsRef<Array>,
1468 b: impl AsRef<Array>,
1469 axes_a: &[i32],
1470 axes_b: &[i32],
1471 #[optional] stream: impl AsRef<Stream>,
1472) -> Result<Array> {
1473 let a = a.as_ref();
1474 let b = b.as_ref();
1475 Array::try_from_op(|res| unsafe {
1476 mlx_sys::mlx_tensordot(
1477 res,
1478 a.as_ptr(),
1479 b.as_ptr(),
1480 axes_a.as_ptr(),
1481 axes_a.len(),
1482 axes_b.as_ptr(),
1483 axes_b.len(),
1484 stream.as_ref().as_ptr(),
1485 )
1486 })
1487}
1488
1489#[generate_macro]
1491#[default_device]
1492pub fn tensordot_axis_device(
1493 a: impl AsRef<Array>,
1494 b: impl AsRef<Array>,
1495 axis: i32,
1496 #[optional] stream: impl AsRef<Stream>,
1497) -> Result<Array> {
1498 let a = a.as_ref();
1499 let b = b.as_ref();
1500 Array::try_from_op(|res| unsafe {
1501 mlx_sys::mlx_tensordot_axis(res, a.as_ptr(), b.as_ptr(), axis, stream.as_ref().as_ptr())
1502 })
1503}
1504
1505#[cfg(test)]
1506mod tests {
1507 use std::f32::consts::PI;
1508
1509 use super::*;
1510 use crate::{
1511 array, complex64,
1512 ops::{all_close, arange, broadcast_to, eye, full, linspace, ones, reshape, split},
1513 transforms::eval,
1514 Dtype, StreamOrDevice,
1515 };
1516 use float_eq::assert_float_eq;
1517 use pretty_assertions::assert_eq;
1518
1519 #[test]
1520 fn test_abs() {
1521 let data = [1i32, 2, -3, -4, -5];
1522 let array = Array::from_slice(&data, &[5]);
1523 let result = array.abs().unwrap();
1524
1525 let data: &[i32] = result.as_slice();
1526 assert_eq!(data, [1, 2, 3, 4, 5]);
1527
1528 let data: &[i32] = array.as_slice();
1530 assert_eq!(data, [1, 2, -3, -4, -5]);
1531 }
1532
1533 #[test]
1534 fn test_add() {
1535 let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1536 let b = Array::from_slice(&[4.0, 5.0, 6.0], &[3]);
1537
1538 let c = &a + &b;
1539
1540 let c_data: &[f32] = c.as_slice();
1541 assert_eq!(c_data, &[5.0, 7.0, 9.0]);
1542
1543 let a_data: &[f32] = a.as_slice();
1545 assert_eq!(a_data, &[1.0, 2.0, 3.0]);
1546
1547 let b_data: &[f32] = b.as_slice();
1548 assert_eq!(b_data, &[4.0, 5.0, 6.0]);
1549 }
1550
1551 #[test]
1552 fn test_add_invalid_broadcast() {
1553 let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1554 let b = Array::from_slice(&[4.0, 5.0], &[2]);
1555
1556 let c = a.add(&b);
1557 assert!(c.is_err());
1558 }
1559
1560 #[test]
1561 fn test_sub() {
1562 let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1563 let b = Array::from_slice(&[4.0, 5.0, 6.0], &[3]);
1564
1565 let c = &a - &b;
1566
1567 let c_data: &[f32] = c.as_slice();
1568 assert_eq!(c_data, &[-3.0, -3.0, -3.0]);
1569
1570 let a_data: &[f32] = a.as_slice();
1572 assert_eq!(a_data, &[1.0, 2.0, 3.0]);
1573
1574 let b_data: &[f32] = b.as_slice();
1575 assert_eq!(b_data, &[4.0, 5.0, 6.0]);
1576 }
1577
1578 #[test]
1579 fn test_sub_invalid_broadcast() {
1580 let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1581 let b = Array::from_slice(&[4.0, 5.0], &[2]);
1582 let c = a.subtract(&b);
1583 assert!(c.is_err());
1584 }
1585
1586 #[test]
1587 fn test_neg() {
1588 let a = Array::from_slice::<f32>(&[1.0, 2.0, 3.0], &[3]);
1589 let b = a.negative().unwrap();
1590
1591 let b_data: &[f32] = b.as_slice();
1592 assert_eq!(b_data, &[-1.0, -2.0, -3.0]);
1593
1594 let a_data: &[f32] = a.as_slice();
1596 assert_eq!(a_data, &[1.0, 2.0, 3.0]);
1597 }
1598
1599 #[test]
1600 fn test_neg_bool() {
1601 let a = Array::from_slice(&[true, false, true], &[3]);
1602 let b = a.negative();
1603 assert!(b.is_err());
1604 }
1605
1606 #[test]
1607 fn test_logical_not() {
1608 let a: Array = false.into();
1609 let b = a.logical_not().unwrap();
1610
1611 let b_data: &[bool] = b.as_slice();
1612 assert_eq!(b_data, [true]);
1613 }
1614
1615 #[test]
1616 fn test_mul() {
1617 let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1618 let b = Array::from_slice(&[4.0, 5.0, 6.0], &[3]);
1619
1620 let c = &a * &b;
1621
1622 let c_data: &[f32] = c.as_slice();
1623 assert_eq!(c_data, &[4.0, 10.0, 18.0]);
1624
1625 let a_data: &[f32] = a.as_slice();
1627 assert_eq!(a_data, &[1.0, 2.0, 3.0]);
1628
1629 let b_data: &[f32] = b.as_slice();
1630 assert_eq!(b_data, &[4.0, 5.0, 6.0]);
1631 }
1632
1633 #[test]
1634 fn test_mul_invalid_broadcast() {
1635 let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1636 let b = Array::from_slice(&[4.0, 5.0], &[2]);
1637 let c = a.multiply(&b);
1638 assert!(c.is_err());
1639 }
1640
1641 #[test]
1642 fn test_nan_to_num() {
1643 let a = array!([1.0, 2.0, f32::NAN, 4.0, 5.0]);
1644 let b = a.nan_to_num(0.0, 1.0, 0.0).unwrap();
1645
1646 let b_data: &[f32] = b.as_slice();
1647 assert_eq!(b_data, &[1.0, 2.0, 0.0, 4.0, 5.0]);
1648 }
1649
1650 #[test]
1651 fn test_div() {
1652 let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1653 let b = Array::from_slice(&[4.0, 5.0, 6.0], &[3]);
1654
1655 let c = &a / &b;
1656
1657 let c_data: &[f32] = c.as_slice();
1658 assert_eq!(c_data, &[0.25, 0.4, 0.5]);
1659
1660 let a_data: &[f32] = a.as_slice();
1662 assert_eq!(a_data, &[1.0, 2.0, 3.0]);
1663
1664 let b_data: &[f32] = b.as_slice();
1665 assert_eq!(b_data, &[4.0, 5.0, 6.0]);
1666 }
1667
1668 #[test]
1669 fn test_div_invalid_broadcast() {
1670 let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1671 let b = Array::from_slice(&[4.0, 5.0], &[2]);
1672 let c = a.divide(&b);
1673 assert!(c.is_err());
1674 }
1675
1676 #[test]
1677 fn test_pow() {
1678 let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1679 let b = Array::from_slice(&[2.0, 3.0, 4.0], &[3]);
1680
1681 let c = a.power(&b).unwrap();
1682
1683 let c_data: &[f32] = c.as_slice();
1684 assert_eq!(c_data, &[1.0, 8.0, 81.0]);
1685
1686 let a_data: &[f32] = a.as_slice();
1688 assert_eq!(a_data, &[1.0, 2.0, 3.0]);
1689
1690 let b_data: &[f32] = b.as_slice();
1691 assert_eq!(b_data, &[2.0, 3.0, 4.0]);
1692 }
1693
1694 #[test]
1695 fn test_pow_invalid_broadcast() {
1696 let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1697 let b = Array::from_slice(&[2.0, 3.0], &[2]);
1698 let c = a.power(&b);
1699 assert!(c.is_err());
1700 }
1701
1702 #[test]
1703 fn test_rem() {
1704 let a = Array::from_slice(&[10.0, 11.0, 12.0], &[3]);
1705 let b = Array::from_slice(&[3.0, 4.0, 5.0], &[3]);
1706
1707 let c = &a % &b;
1708
1709 let c_data: &[f32] = c.as_slice();
1710 assert_eq!(c_data, &[1.0, 3.0, 2.0]);
1711
1712 let a_data: &[f32] = a.as_slice();
1714 assert_eq!(a_data, &[10.0, 11.0, 12.0]);
1715
1716 let b_data: &[f32] = b.as_slice();
1717 assert_eq!(b_data, &[3.0, 4.0, 5.0]);
1718 }
1719
1720 #[test]
1721 fn test_rem_invalid_broadcast() {
1722 let a = Array::from_slice(&[10.0, 11.0, 12.0], &[3]);
1723 let b = Array::from_slice(&[3.0, 4.0], &[2]);
1724 let c = a.remainder(&b);
1725 assert!(c.is_err());
1726 }
1727
1728 #[test]
1729 fn test_sqrt() {
1730 let a = Array::from_slice(&[1.0, 4.0, 9.0], &[3]);
1731 let b = a.sqrt().unwrap();
1732
1733 let b_data: &[f32] = b.as_slice();
1734 assert_eq!(b_data, &[1.0, 2.0, 3.0]);
1735
1736 let a_data: &[f32] = a.as_slice();
1738 assert_eq!(a_data, &[1.0, 4.0, 9.0]);
1739 }
1740
1741 #[test]
1742 fn test_cos() {
1743 let a = Array::from_slice(&[0.0, 1.0, 2.0], &[3]);
1744 let b = a.cos().unwrap();
1745
1746 let b_expected = array!([1.0, 0.54030234, -0.41614687]);
1747 assert_array_all_close!(b, b_expected);
1748
1749 let a_expected = array!([0.0, 1.0, 2.0]);
1751 assert_array_all_close!(a, a_expected);
1752 }
1753
1754 #[test]
1755 fn test_exp() {
1756 let a = Array::from_slice(&[0.0, 1.0, 2.0], &[3]);
1757 let b = a.exp().unwrap();
1758
1759 let b_expected = array!([1.0, 2.7182817, 7.389056]);
1760 assert_array_all_close!(b, b_expected);
1761
1762 let a_expected = array!([0.0, 1.0, 2.0]);
1764 assert_array_all_close!(a, a_expected);
1765 }
1766
1767 #[test]
1768 fn test_floor() {
1769 let a = Array::from_slice(&[0.1, 1.9, 2.5], &[3]);
1770 let b = a.floor().unwrap();
1771
1772 let b_data: &[f32] = b.as_slice();
1773 assert_eq!(b_data, &[0.0, 1.0, 2.0]);
1774
1775 let a_data: &[f32] = a.as_slice();
1777 assert_eq!(a_data, &[0.1, 1.9, 2.5]);
1778 }
1779
1780 #[test]
1781 fn test_floor_complex64() {
1782 let val = complex64::new(1.0, 2.0);
1783 let a = Array::from_complex(val);
1784 let b = a.floor_device(StreamOrDevice::default());
1785 assert!(b.is_err());
1786 }
1787
1788 #[test]
1789 fn test_floor_divide() {
1790 let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1791 let b = Array::from_slice(&[4.0, 5.0, 6.0], &[3]);
1792
1793 let c = a.floor_divide(&b).unwrap();
1794
1795 let c_data: &[f32] = c.as_slice();
1796 assert_eq!(c_data, &[0.0, 0.0, 0.0]);
1797
1798 let a_data: &[f32] = a.as_slice();
1800 assert_eq!(a_data, &[1.0, 2.0, 3.0]);
1801
1802 let b_data: &[f32] = b.as_slice();
1803 assert_eq!(b_data, &[4.0, 5.0, 6.0]);
1804 }
1805
1806 #[test]
1807 fn test_floor_divide_complex64() {
1808 let val = complex64::new(1.0, 2.0);
1809 let a = Array::from_complex(val);
1810 let b = Array::from_slice(&[4.0, 5.0, 6.0], &[3]);
1811 let c = a.floor_divide_device(&b, StreamOrDevice::default());
1812 assert!(c.is_err());
1813 }
1814
1815 #[test]
1816 fn test_floor_divide_invalid_broadcast() {
1817 let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1818 let b = Array::from_slice(&[4.0, 5.0], &[2]);
1819 let c = a.floor_divide_device(&b, StreamOrDevice::default());
1820 assert!(c.is_err());
1821 }
1822
1823 #[test]
1824 fn test_is_nan() {
1825 let a = Array::from_slice(&[1.0, f32::NAN, 3.0], &[3]);
1826 let b = a.is_nan().unwrap();
1827
1828 let b_data: &[bool] = b.as_slice();
1829 assert_eq!(b_data, &[false, true, false]);
1830 }
1831
1832 #[test]
1833 fn test_is_inf() {
1834 let a = Array::from_slice(&[1.0, f32::INFINITY, 3.0], &[3]);
1835 let b = a.is_inf().unwrap();
1836
1837 let b_data: &[bool] = b.as_slice();
1838 assert_eq!(b_data, &[false, true, false]);
1839 }
1840
1841 #[test]
1842 fn test_is_finite() {
1843 let a = Array::from_slice(&[1.0, f32::INFINITY, 3.0], &[3]);
1844 let b = a.is_finite().unwrap();
1845
1846 let b_data: &[bool] = b.as_slice();
1847 assert_eq!(b_data, &[true, false, true]);
1848 }
1849
1850 #[test]
1851 fn test_is_neg_inf() {
1852 let a = Array::from_slice(&[1.0, f32::NEG_INFINITY, 3.0], &[3]);
1853 let b = a.is_neg_inf().unwrap();
1854
1855 let b_data: &[bool] = b.as_slice();
1856 assert_eq!(b_data, &[false, true, false]);
1857 }
1858
1859 #[test]
1860 fn test_is_pos_inf() {
1861 let a = Array::from_slice(&[1.0, f32::INFINITY, 3.0], &[3]);
1862 let b = a.is_pos_inf().unwrap();
1863
1864 let b_data: &[bool] = b.as_slice();
1865 assert_eq!(b_data, &[false, true, false]);
1866 }
1867
1868 #[test]
1869 fn test_log() {
1870 let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1871 let b = a.log().unwrap();
1872
1873 let b_data: &[f32] = b.as_slice();
1874 assert_eq!(b_data, &[0.0, 0.6931472, 1.0986123]);
1875
1876 let a_data: &[f32] = a.as_slice();
1878 assert_eq!(a_data, &[1.0, 2.0, 3.0]);
1879 }
1880
1881 #[test]
1882 fn test_log2() {
1883 let a = Array::from_slice(&[1.0, 2.0, 4.0, 8.0], &[4]);
1884 let b = a.log2().unwrap();
1885
1886 let b_data: &[f32] = b.as_slice();
1887 assert_eq!(b_data, &[0.0, 1.0, 2.0, 3.0]);
1888
1889 let a_data: &[f32] = a.as_slice();
1891 assert_eq!(a_data, &[1.0, 2.0, 4.0, 8.0]);
1892 }
1893
1894 #[test]
1895 fn test_log10() {
1896 let a = Array::from_slice(&[1.0, 10.0, 100.0], &[3]);
1897 let b = a.log10().unwrap();
1898
1899 let b_data: &[f32] = b.as_slice();
1900 assert_eq!(b_data, &[0.0, 1.0, 2.0]);
1901
1902 let a_data: &[f32] = a.as_slice();
1904 assert_eq!(a_data, &[1.0, 10.0, 100.0]);
1905 }
1906
1907 #[test]
1908 fn test_log1p() {
1909 let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1910 let b = a.log1p().unwrap();
1911
1912 let b_data: &[f32] = b.as_slice();
1913 assert_eq!(b_data, &[0.6931472, 1.0986123, 1.3862944]);
1914
1915 let a_data: &[f32] = a.as_slice();
1917 assert_eq!(a_data, &[1.0, 2.0, 3.0]);
1918 }
1919
1920 #[test]
1921 fn test_matmul() {
1922 let a = Array::from_slice(&[1, 2, 3, 4], &[2, 2]);
1923 let b = Array::from_slice(&[-5.0, 37.5, 4., 7., 1., 0.], &[2, 3]);
1924
1925 let c = a.matmul(&b).unwrap();
1926
1927 assert_eq!(c.shape(), &[2, 3]);
1928 let c_data: &[f32] = c.as_slice();
1929 assert_eq!(c_data, &[9.0, 39.5, 4.0, 13.0, 116.5, 12.0]);
1930
1931 let a_data: &[i32] = a.as_slice();
1933 assert_eq!(a_data, &[1, 2, 3, 4]);
1934
1935 let b_data: &[f32] = b.as_slice();
1936 assert_eq!(b_data, &[-5.0, 37.5, 4., 7., 1., 0.]);
1937 }
1938
1939 #[test]
1940 fn test_matmul_ndim_zero() {
1941 let a: Array = 1.0.into();
1942 let b = Array::from_slice::<i32>(&[1], &[1]);
1943 let c = a.matmul(&b);
1944 assert!(c.is_err());
1945 }
1946
1947 #[test]
1948 fn test_matmul_ndim_one() {
1949 let a = Array::from_slice(&[1.0, 2.0, 3.0, 4.0], &[4]);
1950 let b = Array::from_slice(&[1.0, 2.0, 3.0, 4.0], &[4]);
1951 let c = a.matmul(&b);
1952 assert!(c.is_ok());
1953 }
1954
1955 #[test]
1956 fn test_matmul_dim_mismatch() {
1957 let a = Array::from_slice(&[1, 2, 3, 4, 5, 6], &[2, 3]);
1958 let b = Array::from_slice(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10], &[2, 5]);
1959 let c = a.matmul(&b);
1960 assert!(c.is_err());
1961 }
1962
1963 #[test]
1964 fn test_matmul_non_float_output_type() {
1965 let a = Array::from_slice(&[1, 2, 3, 4], &[2, 2]);
1966 let b = Array::from_slice(&[5, 37, 4, 7, 1, 0], &[2, 3]);
1967
1968 let c = a.matmul(&b);
1969 assert!(c.is_err());
1970 }
1971
1972 #[test]
1973 fn test_reciprocal() {
1974 let a = Array::from_slice(&[1.0, 2.0, 4.0], &[3]);
1975 let b = a.reciprocal().unwrap();
1976
1977 let b_data: &[f32] = b.as_slice();
1978 assert_eq!(b_data, &[1.0, 0.5, 0.25]);
1979
1980 let a_data: &[f32] = a.as_slice();
1982 assert_eq!(a_data, &[1.0, 2.0, 4.0]);
1983 }
1984
1985 #[test]
1986 fn test_round() {
1987 let a = Array::from_slice(&[1.1, 2.9, 3.5], &[3]);
1988 let b = a.round(None).unwrap();
1989
1990 let b_data: &[f32] = b.as_slice();
1991 assert_eq!(b_data, &[1.0, 3.0, 4.0]);
1992
1993 let a_data: &[f32] = a.as_slice();
1995 assert_eq!(a_data, &[1.1, 2.9, 3.5]);
1996 }
1997
1998 #[test]
1999 fn test_rsqrt() {
2000 let a = Array::from_slice(&[1.0, 2.0, 4.0], &[3]);
2001 let b = a.rsqrt().unwrap();
2002
2003 let b_data: &[f32] = b.as_slice();
2004 assert_eq!(b_data, &[1.0, 0.70710677, 0.5]);
2005
2006 let a_data: &[f32] = a.as_slice();
2008 assert_eq!(a_data, &[1.0, 2.0, 4.0]);
2009 }
2010
2011 #[test]
2012 fn test_sin() {
2013 let a = Array::from_slice(&[0.0, 1.0, 2.0], &[3]);
2014 let b = a.sin().unwrap();
2015
2016 let b_data: &[f32] = b.as_slice();
2017 assert_eq!(b_data, &[0.0, 0.841471, 0.9092974]);
2018
2019 let a_data: &[f32] = a.as_slice();
2021 assert_eq!(a_data, &[0.0, 1.0, 2.0]);
2022 }
2023
2024 #[test]
2025 fn test_square() {
2026 let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
2027 let b = a.square().unwrap();
2028
2029 let b_data: &[f32] = b.as_slice();
2030 assert_eq!(b_data, &[1.0, 4.0, 9.0]);
2031
2032 let a_data: &[f32] = a.as_slice();
2034 assert_eq!(a_data, &[1.0, 2.0, 3.0]);
2035 }
2036
2037 #[test]
2040 fn test_unary_neg() {
2041 let x = array!(1.0);
2042 assert_eq!(negative(&x).unwrap().item::<f32>(), -1.0);
2043 assert_eq!((-x).item::<f32>(), -1.0);
2044
2045 assert_eq!(-array!(), array!());
2047
2048 let x = array!(true);
2050 assert!(negative(&x).is_err());
2051 }
2052
2053 #[test]
2054 fn test_unary_abs() {
2055 let x = array!([-1.0, 0.0, 1.0]);
2056 assert_eq!(abs(&x).unwrap(), array!([1.0, 0.0, 1.0]));
2057
2058 assert_eq!(abs(array!()).unwrap(), array!());
2060
2061 let x = array!([-1, 0, 1]);
2063 assert_eq!(abs(&x).unwrap(), array!([1, 0, 1]));
2064
2065 let x = array!([1u32, 0, 1]);
2067 assert_eq!(abs(&x).unwrap(), array!([1u32, 0, 1]));
2068
2069 let x = array!([false, true]);
2071 assert_eq!(abs(&x).unwrap(), array!([false, true]));
2072 }
2073
2074 #[test]
2075 fn test_unary_sign() {
2076 let x = array!([-1.0, 0.0, 1.0]);
2077 assert_eq!(sign(&x).unwrap(), x);
2078
2079 assert_eq!(sign(array!()).unwrap(), array!());
2081
2082 let x = array!([-1, 0, 1]);
2084 assert_eq!(sign(&x).unwrap(), x);
2085
2086 let x = array!([1u32, 0, 1]);
2088 assert_eq!(sign(&x).unwrap(), x);
2089
2090 let x = array!([false, true]);
2092 assert_eq!(sign(&x).unwrap(), x);
2093 }
2094
2095 const NEG_INF: f32 = f32::NEG_INFINITY;
2096
2097 #[test]
2098 fn test_unary_floor_ceil() {
2099 let x = array![1.0];
2100 assert_eq!(floor(&x).unwrap().item::<f32>(), 1.0);
2101 assert_eq!(ceil(&x).unwrap().item::<f32>(), 1.0);
2102
2103 let x = array![1.5];
2104 assert_eq!(floor(&x).unwrap().item::<f32>(), 1.0);
2105 assert_eq!(ceil(&x).unwrap().item::<f32>(), 2.0);
2106
2107 let x = array![-1.5];
2108 assert_eq!(floor(&x).unwrap().item::<f32>(), -2.0);
2109 assert_eq!(ceil(&x).unwrap().item::<f32>(), -1.0);
2110
2111 let x = array![NEG_INF];
2112 assert_eq!(floor(&x).unwrap().item::<f32>(), NEG_INF);
2113 assert_eq!(ceil(&x).unwrap().item::<f32>(), NEG_INF);
2114
2115 let x = array!([1.0, 1.0]).as_type::<complex64>().unwrap();
2116 assert!(floor(&x).is_err());
2117 assert!(ceil(&x).is_err());
2118 }
2119
2120 #[test]
2121 fn test_unary_round() {
2122 let x = array!([0.5, -0.5, 1.5, -1.5, 2.3, 2.6]);
2123 assert_eq!(round(&x, None).unwrap(), array!([0, 0, 2, -2, 2, 3]));
2124
2125 let x = array!([11, 222, 32]);
2126 assert_eq!(round(&x, -1).unwrap(), array!([10, 220, 30]));
2127 }
2128
2129 #[test]
2130 fn test_unary_exp() {
2131 let x = array![0.0];
2132 assert_eq!(exp(&x).unwrap().item::<f32>(), 1.0);
2133
2134 let x = array![2.0];
2135 assert_float_eq! {
2136 exp(&x).unwrap().item::<f32>(),
2137 2.0f32.exp(),
2138 abs <= 1e-5
2139 };
2140
2141 assert_eq!(exp(array!()).unwrap(), array!());
2142
2143 let x = array![NEG_INF];
2144 assert_eq!(exp(&x).unwrap().item::<f32>(), 0.0);
2145
2146 let x = array![2];
2148 assert_eq!(x.dtype(), Dtype::Int32);
2149 assert_float_eq! {
2150 exp(&x).unwrap().item::<f32>(),
2151 2.0f32.exp(),
2152 abs <= 1e-5
2153 };
2154
2155 let x = broadcast_to(&array!(1.0), &[2, 2, 2]).unwrap();
2157 let res = exp(&x).unwrap();
2158 let expected = Array::full::<f32>(&[2, 2, 2], array!(1.0f32.exp())).unwrap();
2159 assert!(all_close(&res, &expected, None, None, None)
2160 .unwrap()
2161 .item::<bool>());
2162
2163 let data = Array::from_slice(&[0.0, 1.0, 2.0, 3.0], &[2, 2]);
2164 let x = split(&data, 2, 1).unwrap();
2165 let expected = Array::from_slice(&[0.0f32.exp(), 2.0f32.exp()], &[2, 1]);
2166 assert!(all_close(exp(&x[0]).unwrap(), &expected, None, None, None)
2167 .unwrap()
2168 .item::<bool>());
2169 }
2170
2171 #[test]
2172 fn test_unary_expm1() {
2173 let x = array![-1.0];
2174 assert_float_eq! {
2175 expm1(&x).unwrap().item::<f32>(),
2176 (-1.0f32).exp_m1(),
2177 abs <= 1e-5
2178 };
2179
2180 let x = array![1.0];
2181 assert_float_eq! {
2182 expm1(&x).unwrap().item::<f32>(),
2183 1.0f32.exp_m1(),
2184 abs <= 1e-5
2185 };
2186
2187 let x = array![1];
2189 assert_eq!(expm1(&x).unwrap().dtype(), Dtype::Float32);
2190 assert_float_eq! {
2191 expm1(&x).unwrap().item::<f32>(),
2192 1.0f32.exp_m1(),
2193 abs <= 1e-5
2194 };
2195 }
2196
2197 #[test]
2198 fn test_unary_sin() {
2199 let x = array![0.0];
2200 assert_eq!(sin(&x).unwrap().item::<f32>(), 0.0);
2201
2202 let x = array![std::f32::consts::PI / 2.0];
2203 assert_float_eq! {
2204 sin(&x).unwrap().item::<f32>(),
2205 (std::f32::consts::PI / 2.0f32).sin(),
2206 abs <= 1e-5
2207 };
2208
2209 assert_eq!(sin(array!()).unwrap(), array!());
2210
2211 let x = array![0];
2213 assert_eq!(x.dtype(), Dtype::Int32);
2214 assert_float_eq! {
2215 sin(&x).unwrap().item::<f32>(),
2216 0.0f32.sin(),
2217 abs <= 1e-5
2218 };
2219
2220 let x = broadcast_to(&array!(1.0), &[2, 2, 2]).unwrap();
2222 let res = sin(&x).unwrap();
2223 let expected = Array::full::<f32>(&[2, 2, 2], array!(1.0f32.sin())).unwrap();
2224 assert!(all_close(&res, &expected, None, None, None)
2225 .unwrap()
2226 .item::<bool>());
2227
2228 let data = Array::from_slice(&[0.0, 1.0, 2.0, 3.0], &[2, 2]);
2229 let x = split(&data, 2, 1).unwrap();
2230 let expected = Array::from_slice(&[0.0f32.sin(), 2.0f32.sin()], &[2, 1]);
2231 assert!(all_close(sin(&x[0]).unwrap(), &expected, None, None, None)
2232 .unwrap()
2233 .item::<bool>());
2234 }
2235
2236 #[test]
2237 fn test_unary_cos() {
2238 let x = array![0.0];
2239 assert_float_eq! {
2240 cos(&x).unwrap().item::<f32>(),
2241 0.0f32.cos(),
2242 abs <= 1e-5
2243 };
2244
2245 let x = array![std::f32::consts::PI / 2.0];
2246 assert_float_eq! {
2247 cos(&x).unwrap().item::<f32>(),
2248 (std::f32::consts::PI / 2.0f32).cos(),
2249 abs <= 1e-5
2250 };
2251
2252 assert_eq!(cos(array!()).unwrap(), array!());
2253
2254 let x = array![0];
2256 assert_eq!(x.dtype(), Dtype::Int32);
2257 assert_float_eq! {
2258 cos(&x).unwrap().item::<f32>(),
2259 0.0f32.cos(),
2260 abs <= 1e-5
2261 };
2262
2263 let x = broadcast_to(&array!(1.0), &[2, 2, 2]).unwrap();
2265 let res = cos(&x).unwrap();
2266 let expected = Array::full::<f32>(&[2, 2, 2], array!(1.0f32.cos())).unwrap();
2267 assert!(all_close(&res, &expected, None, None, None)
2268 .unwrap()
2269 .item::<bool>());
2270
2271 let data = Array::from_slice(&[0.0, 1.0, 2.0, 3.0], &[2, 2]);
2272 let x = split(&data, 2, 1).unwrap();
2273 let expected = Array::from_slice(&[0.0f32.cos(), 2.0f32.cos()], &[2, 1]);
2274 assert!(all_close(cos(&x[0]).unwrap(), &expected, None, None, None)
2275 .unwrap()
2276 .item::<bool>());
2277 }
2278
2279 #[test]
2280 fn test_unary_degrees() {
2281 let x = array![0.0];
2282 assert_eq!(degrees(&x).unwrap().item::<f32>(), 0.0);
2283
2284 let x = array![std::f32::consts::PI / 2.0];
2285 assert_eq!(degrees(&x).unwrap().item::<f32>(), 90.0);
2286
2287 assert_eq!(degrees(array!()).unwrap(), array!());
2288
2289 let x = array![0];
2291 assert_eq!(x.dtype(), Dtype::Int32);
2292 assert_eq!(degrees(&x).unwrap().item::<f32>(), 0.0);
2293
2294 let x = broadcast_to(&array!(std::f32::consts::PI / 2.0), &[2, 2, 2]).unwrap();
2296 let res = degrees(&x).unwrap();
2297 let expected = Array::full::<f32>(&[2, 2, 2], array!(90.0)).unwrap();
2298 assert!(all_close(&res, &expected, None, None, None)
2299 .unwrap()
2300 .item::<bool>());
2301
2302 let angles = Array::from_slice(&[0.0, PI / 2.0, PI, 1.5 * PI], &[2, 2]);
2303 let x = split(&angles, 2, 1).unwrap();
2304 let expected = Array::from_slice(&[0.0, 180.0], &[2, 1]);
2305 assert!(
2306 all_close(degrees(&x[0]).unwrap(), &expected, None, None, None)
2307 .unwrap()
2308 .item::<bool>()
2309 );
2310 }
2311
2312 #[test]
2313 fn test_unary_radians() {
2314 let x = array![0.0];
2315 assert_eq!(radians(&x).unwrap().item::<f32>(), 0.0);
2316
2317 let x = array![90.0];
2318 assert_eq!(
2319 radians(&x).unwrap().item::<f32>(),
2320 std::f32::consts::PI / 2.0
2321 );
2322
2323 assert_eq!(radians(array!()).unwrap(), array!());
2324
2325 let x = array![90];
2327 assert_eq!(x.dtype(), Dtype::Int32);
2328 assert_eq!(
2329 radians(&x).unwrap().item::<f32>(),
2330 std::f32::consts::PI / 2.0
2331 );
2332
2333 let x = broadcast_to(&array!(90.0), &[2, 2, 2]).unwrap();
2335 let res = radians(&x).unwrap();
2336 let expected = Array::full::<f32>(&[2, 2, 2], array!(std::f32::consts::PI / 2.0)).unwrap();
2337 assert!(all_close(&res, &expected, None, None, None)
2338 .unwrap()
2339 .item::<bool>());
2340
2341 let angles = Array::from_slice(&[0.0, 90.0, 180.0, 270.0], &[2, 2]);
2342 let x = split(&angles, 2, 1).unwrap();
2343 let expected = Array::from_slice(&[0.0, PI], &[2, 1]);
2344 assert!(
2345 all_close(radians(&x[0]).unwrap(), &expected, None, None, None)
2346 .unwrap()
2347 .item::<bool>()
2348 );
2349 }
2350
2351 #[test]
2352 fn test_unary_log() {
2353 let x = array![0.0];
2354 assert_eq!(log(&x).unwrap().item::<f32>(), NEG_INF);
2355
2356 let x = array![1.0];
2357 assert_eq!(log(&x).unwrap().item::<f32>(), 0.0);
2358
2359 let x = array![1];
2361 assert_eq!(log(&x).unwrap().dtype(), Dtype::Float32);
2362 assert_eq!(log(&x).unwrap().item::<f32>(), 0.0);
2363
2364 let x = broadcast_to(&array!(1.0), &[2, 2, 2]).unwrap();
2366 let res = log(&x).unwrap();
2367 let expected = Array::full::<f32>(&[2, 2, 2], array!(0.0)).unwrap();
2368 assert!(all_close(&res, &expected, None, None, None)
2369 .unwrap()
2370 .item::<bool>());
2371
2372 let data = Array::from_slice(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
2373 let x = split(&data, 2, 1).unwrap();
2374 let expected = Array::from_slice(&[1.0f32.ln(), 3.0f32.ln()], &[2, 1]);
2375 assert!(all_close(log(&x[0]).unwrap(), &expected, None, None, None)
2376 .unwrap()
2377 .item::<bool>());
2378 }
2379
2380 #[test]
2381 fn test_unary_log2() {
2382 let x = array![0.0];
2383 assert_eq!(log2(&x).unwrap().item::<f32>(), NEG_INF);
2384
2385 let x = array![1.0];
2386 assert_eq!(log2(&x).unwrap().item::<f32>(), 0.0);
2387
2388 let x = array![1024.0];
2389 assert_eq!(log2(&x).unwrap().item::<f32>(), 10.0);
2390 }
2391
2392 #[test]
2393 fn test_unary_log10() {
2394 let x = array![0.0];
2395 assert_eq!(log10(&x).unwrap().item::<f32>(), NEG_INF);
2396
2397 let x = array![1.0];
2398 assert_eq!(log10(&x).unwrap().item::<f32>(), 0.0);
2399
2400 let x = array![1000.0];
2401 assert_eq!(log10(&x).unwrap().item::<f32>(), 3.0);
2402 }
2403
2404 #[test]
2405 fn test_unary_log1p() {
2406 let x = array![-1.0];
2407 assert_float_eq! {
2408 log1p(&x).unwrap().item::<f32>(),
2409 (-1.0f32).ln_1p(),
2410 abs <= 1e-5
2411 };
2412
2413 let x = array![1.0];
2414 assert_float_eq! {
2415 log1p(&x).unwrap().item::<f32>(),
2416 1.0f32.ln_1p(),
2417 abs <= 1e-5
2418 };
2419
2420 let x = array![1];
2422 assert_eq!(log1p(&x).unwrap().dtype(), Dtype::Float32);
2423 assert_float_eq! {
2424 log1p(&x).unwrap().item::<f32>(),
2425 1.0f32.ln_1p(),
2426 abs <= 1e-5
2427 };
2428
2429 let x = broadcast_to(&array!(1.0), &[2, 2, 2]).unwrap();
2431 let res = log1p(&x).unwrap();
2432 let expected = Array::full::<f32>(&[2, 2, 2], array!(1.0f32.ln_1p())).unwrap();
2433 assert!(all_close(&res, &expected, None, None, None)
2434 .unwrap()
2435 .item::<bool>());
2436
2437 let data = Array::from_slice(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
2438 let x = split(&data, 2, 1).unwrap();
2439 let expected = Array::from_slice(&[1.0f32.ln_1p(), 3.0f32.ln_1p()], &[2, 1]);
2440 assert!(
2441 all_close(log1p(&x[0]).unwrap(), &expected, None, None, None)
2442 .unwrap()
2443 .item::<bool>()
2444 );
2445 }
2446
2447 #[test]
2448 fn test_unary_sigmoid() {
2449 let x = array![0.0];
2450 assert_float_eq! {
2451 sigmoid(&x).unwrap().item::<f32>(),
2452 0.5,
2453 abs <= 1e-5
2454 };
2455
2456 let x = array![0];
2458 assert_eq!(sigmoid(&x).unwrap().dtype(), Dtype::Float32);
2459 assert_float_eq! {
2460 sigmoid(&x).unwrap().item::<f32>(),
2461 0.5,
2462 abs <= 1e-5
2463 };
2464
2465 let inf = f32::INFINITY;
2466 let x = array![inf];
2467 assert_eq!(sigmoid(&x).unwrap().item::<f32>(), 1.0);
2468
2469 let x = array![-inf];
2470 assert_eq!(sigmoid(&x).unwrap().item::<f32>(), 0.0);
2471 }
2472
2473 #[test]
2474 fn test_unary_square() {
2475 let x = array![3.0];
2476 assert_eq!(square(&x).unwrap().item::<f32>(), 9.0);
2477
2478 let x = array![2];
2479 assert_eq!(square(&x).unwrap().item::<i32>(), 4);
2480
2481 let x = Array::full::<f32>(&[3, 3], array!(2.0)).unwrap();
2482 assert!(all_close(
2483 square(&x).unwrap(),
2484 Array::full::<f32>(&[3, 3], array!(4.0)).unwrap(),
2485 None,
2486 None,
2487 None
2488 )
2489 .unwrap()
2490 .item::<bool>());
2491 }
2492
2493 #[test]
2494 fn test_unary_sqrt_rsqrt() {
2495 let x = array![4.0];
2496 assert_eq!(sqrt(&x).unwrap().item::<f32>(), 2.0);
2497 assert_eq!(rsqrt(&x).unwrap().item::<f32>(), 0.5);
2498
2499 let x = Array::full::<f32>(&[3, 3], array!(9.0)).unwrap();
2500 assert!(all_close(
2501 sqrt(&x).unwrap(),
2502 Array::full::<f32>(&[3, 3], array!(3.0)).unwrap(),
2503 None,
2504 None,
2505 None
2506 )
2507 .unwrap()
2508 .item::<bool>());
2509
2510 let x = array![4i32];
2511 assert_eq!(sqrt(&x).unwrap().item::<f32>(), 2.0);
2512 assert_eq!(rsqrt(&x).unwrap().item::<f32>(), 0.5);
2513 }
2514
2515 #[test]
2516 fn test_unary_reciprocal() {
2517 let x = array![8.0];
2518 assert_eq!(reciprocal(&x).unwrap().item::<f32>(), 0.125);
2519
2520 let x = array![2];
2521 let out = reciprocal(&x).unwrap();
2522 assert_eq!(out.dtype(), Dtype::Float32);
2523 assert_eq!(out.item::<f32>(), 0.5);
2524
2525 let x = Array::full::<f32>(&[3, 3], array!(2.0)).unwrap();
2526 assert!(all_close(
2527 reciprocal(&x).unwrap(),
2528 Array::full::<f32>(&[3, 3], array!(0.5)).unwrap(),
2529 None,
2530 None,
2531 None
2532 )
2533 .unwrap()
2534 .item::<bool>());
2535 }
2536
2537 #[test]
2538 fn test_unary_real_imag() {
2539 let x = Array::from_complex(complex64::new(0.0, 1.0));
2540 assert_eq!(real(&x).unwrap(), Array::from_f32(0.0));
2541 assert_eq!(imag(&x).unwrap(), Array::from_f32(1.0));
2542 }
2543
2544 #[test]
2545 fn test_binary_add() {
2546 let x = array![1.0];
2547 let y = array![1.0];
2548 let z = add(&x, &y).unwrap();
2549 assert_eq!(z.item::<f32>(), 2.0);
2550
2551 let z = &x + y;
2552 assert_eq!(z.item::<f32>(), 2.0);
2553
2554 let z = add(z, &x).unwrap();
2555 assert_eq!(z.item::<f32>(), 3.0);
2556
2557 let mut out = x.deep_clone();
2559 for _ in 0..10 {
2560 out = add(&out, &x).unwrap();
2561 }
2562 assert_eq!(out.item::<f32>(), 11.0);
2563
2564 let x = array!([1.0, 2.0, 3.0]);
2566 let y = array!([1.0, 2.0, 3.0]);
2567 let z = add(&x, &y).unwrap();
2568 assert_eq!(z.shape(), &[3]);
2569 assert_eq!(z, array!([2.0, 4.0, 6.0]));
2570
2571 let x = array!([1.0, 2.0, 3.0]);
2573 let y = &x + 2.0;
2574 assert_eq!(y.dtype(), Dtype::Float32);
2575 assert_eq!(y, array!([3.0, 4.0, 5.0]));
2576 let y = &x + 2.0;
2577 assert_eq!(y.dtype(), Dtype::Float32);
2578 assert_eq!(y, array!([3.0, 4.0, 5.0]));
2579
2580 let y = x + 2;
2582 assert_eq!(y.dtype(), Dtype::Float32);
2583
2584 let y = array!([1, 2, 3]) + 2.0;
2585 assert_eq!(y.dtype(), Dtype::Float32);
2586 assert_eq!(y, array!([3.0, 4.0, 5.0]));
2588
2589 let x = broadcast_to(&array!(1.0), &[10]).unwrap();
2591 let y = broadcast_to(&array!(2.0), &[10]).unwrap();
2592 let z = add(&x, &y).unwrap();
2593 assert_eq!(z, full::<f32>(&[10], array!(3.0)).unwrap());
2594
2595 let x = Array::from_slice(&[1.0, 2.0], &[1, 2]);
2596 let y = Array::from_slice(&[1.0, 2.0], &[2, 1]);
2597 let z = add(&x, &y).unwrap();
2598 assert_eq!(z.shape(), &[2, 2]);
2599 assert_eq!(z, Array::from_slice(&[2.0, 3.0, 3.0, 4.0], &[2, 2]));
2600
2601 let x = ones::<f32>(&[3, 2, 1]).unwrap();
2602 let z = x + 2.0;
2603 assert_eq!(z.shape(), &[3, 2, 1]);
2604 let expected = Array::from_slice(&[3.0, 3.0, 3.0, 3.0, 3.0, 3.0], &[3, 2, 1]);
2605 assert_eq!(z, expected);
2606
2607 let x = array!();
2609 let y = array!();
2610 let z = x + y;
2611 z.eval().unwrap();
2612 assert_eq!(z.size(), 0);
2613 assert_eq!(z.shape(), &[0]);
2614 }
2615
2616 #[test]
2617 fn test_binary_sub() {
2618 let x = array!([3.0, 2.0, 1.0]);
2619 let y = array!([1.0, 1.0, 1.0]);
2620 assert_eq!(x - y, array!([2.0, 1.0, 0.0]));
2621 }
2622
2623 #[test]
2624 fn test_binary_mul() {
2625 let x = array!([1.0, 2.0, 3.0]);
2626 let y = array!([2.0, 2.0, 2.0]);
2627 assert_eq!(x * y, array!([2.0, 4.0, 6.0]));
2628 }
2629
2630 #[test]
2631 fn test_binary_div() {
2632 let x = array![1.0];
2633 let y = array![1.0];
2634 assert_eq!(divide(&x, &y).unwrap().item::<f32>(), 1.0);
2635
2636 let x = array![1.0];
2637 let y = array![0.5];
2638 assert_eq!(divide(&x, &y).unwrap().item::<f32>(), 2.0);
2639
2640 let x = array![1.0];
2641 let y = array![4.0];
2642 assert_eq!(divide(&x, &y).unwrap().item::<f32>(), 0.25);
2643
2644 let x = array![true];
2645 let y = array![true];
2646 assert_eq!(divide(&x, &y).unwrap().item::<f32>(), 1.0);
2647
2648 let x = array![false];
2649 let y = array![true];
2650 assert_eq!(divide(&x, &y).unwrap().item::<f32>(), 0.0);
2651
2652 let x = array![true];
2653 let y = array![false];
2654 assert!(divide(&x, &y).unwrap().item::<f32>().is_infinite());
2655
2656 let x = array![false];
2657 let y = array![false];
2658 assert!(divide(&x, &y).unwrap().item::<f32>().is_nan());
2659 }
2660
2661 #[test]
2662 fn test_binary_maximum_minimum() {
2663 let x = array![1.0];
2664 let y = array![0.0];
2665 assert_eq!(maximum(&x, &y).unwrap().item::<f32>(), 1.0);
2666 assert_eq!(minimum(&x, &y).unwrap().item::<f32>(), 0.0);
2667
2668 let y = array![2.0];
2669 assert_eq!(maximum(&x, &y).unwrap().item::<f32>(), 2.0);
2670 assert_eq!(minimum(&x, &y).unwrap().item::<f32>(), 1.0);
2671 }
2672
2673 #[test]
2674 fn test_binary_logaddexp() {
2675 let x = array![0.0];
2676 let y = array![0.0];
2677 assert_float_eq! {
2678 logaddexp(&x, &y).unwrap().item::<f32>(),
2679 2.0f32.ln(),
2680 abs <= 1e-5
2681 };
2682
2683 let x = array!([0u32]);
2684 let y = array!([10000u32]);
2685 assert_eq!(logaddexp(&x, &y).unwrap().item::<f32>(), 10000.0);
2686
2687 let x = array![f32::INFINITY];
2688 let y = array![3.0];
2689 assert_eq!(logaddexp(&x, &y).unwrap().item::<f32>(), f32::INFINITY);
2690
2691 let x = array![f32::NEG_INFINITY];
2692 let y = array![3.0];
2693 assert_eq!(logaddexp(&x, &y).unwrap().item::<f32>(), 3.0);
2694
2695 let x = array![f32::NEG_INFINITY];
2696 let y = array![f32::NEG_INFINITY];
2697 assert_eq!(logaddexp(&x, &y).unwrap().item::<f32>(), f32::NEG_INFINITY);
2698
2699 let x = array![f32::INFINITY];
2700 let y = array![f32::INFINITY];
2701 assert_eq!(logaddexp(&x, &y).unwrap().item::<f32>(), f32::INFINITY);
2702
2703 let x = array![f32::NEG_INFINITY];
2704 let y = array![f32::INFINITY];
2705 assert_eq!(logaddexp(&x, &y).unwrap().item::<f32>(), f32::INFINITY);
2706 }
2707
2708 #[test]
2709 fn test_basic_clip() {
2710 let a = array!([1.0, 4.0, 3.0, 8.0, 5.0]);
2711 let expected = array!([2.0, 4.0, 3.0, 6.0, 5.0]);
2712 let clipped = clip(&a, (array!(2.0), array!(6.0))).unwrap();
2713 assert_eq!(clipped, expected);
2714
2715 let clipped = clip(&a, (2.0, 6.0)).unwrap();
2717 assert_eq!(clipped, expected);
2718 }
2719
2720 #[test]
2721 fn test_clip_with_only_min() {
2722 let a = array!([-1.0, 1.0, 0.0, 5.0]);
2723 let expected = array!([0.0, 1.0, 0.0, 5.0]);
2724 let clipped = clip(&a, (array!(0.0), ())).unwrap();
2725 assert_eq!(clipped, expected);
2726
2727 let clipped = clip(&a, (0.0, ())).unwrap();
2729 assert_eq!(clipped, expected);
2730 }
2731
2732 #[test]
2733 fn test_clip_with_only_max() {
2734 let a = array!([2.0, 3.0, 4.0, 5.0]);
2735 let expected = array!([2.0, 3.0, 4.0, 4.0]);
2736 let clipped = clip(&a, ((), array!(4.0))).unwrap();
2737 assert_eq!(clipped, expected);
2738
2739 let clipped = clip(&a, ((), 4.0)).unwrap();
2741 assert_eq!(clipped, expected);
2742 }
2743
2744 #[test]
2745 fn test_tensordot() {
2746 let x = reshape(arange::<_, f32>(None, 60.0, None).unwrap(), &[3, 4, 5]).unwrap();
2747 let y = reshape(arange::<_, f32>(None, 24.0, None).unwrap(), &[4, 3, 2]).unwrap();
2748 let z = tensordot_axes(&x, &y, &[1i32, 0], &[0i32, 1]).unwrap();
2749 let expected = Array::from_slice(
2750 &[4400, 4730, 4532, 4874, 4664, 5018, 4796, 5162, 4928, 5306],
2751 &[5, 2],
2752 );
2753 assert_eq!(z, expected);
2754
2755 let x = reshape(arange::<_, f32>(None, 360.0, None).unwrap(), &[3, 4, 5, 6]).unwrap();
2756 let y = reshape(arange::<_, f32>(None, 360.0, None).unwrap(), &[6, 4, 5, 3]).unwrap();
2757 assert!(tensordot_axes(&x, &y, &[2, 1, 3], &[1, 2, 0]).is_err());
2758
2759 let x = reshape(arange::<_, f32>(None, 60.0, None).unwrap(), &[3, 4, 5]).unwrap();
2760 let y = reshape(arange::<_, f32>(None, 120.0, None).unwrap(), &[4, 5, 6]).unwrap();
2761
2762 let z = tensordot_axis(&x, &y, 2).unwrap();
2763 let expected = Array::from_slice(
2764 &[
2765 14820.0, 15010.0, 15200.0, 15390.0, 15580.0, 15770.0, 37620.0, 38210.0, 38800.0,
2766 39390.0, 39980.0, 40570.0, 60420.0, 61410.0, 62400.0, 63390.0, 64380.0, 65370.0,
2767 ],
2768 &[3, 6],
2769 );
2770 assert_eq!(z, expected);
2771 }
2772
2773 #[test]
2774 fn test_outer() {
2775 let x = arange::<_, f32>(1.0, 5.0, None).unwrap();
2776 let y = arange::<_, f32>(1.0, 4.0, None).unwrap();
2777 let z = outer(&x, &y).unwrap();
2778 let expected = Array::from_slice(
2779 &[1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 3.0, 6.0, 9.0, 4.0, 8.0, 12.0],
2780 &[4, 3],
2781 );
2782 assert_eq!(z, expected);
2783
2784 let x = ones::<f32>(&[5]).unwrap();
2785 let y = linspace::<_, f32>(-2.0, 2.0, 5).unwrap();
2786 let z = outer(&x, &y).unwrap();
2787 let expected = Array::from_slice(
2788 &[
2789 -2.0, -1.0, 0.0, 1.0, 2.0, -2.0, -1.0, 0.0, 1.0, 2.0, -2.0, -1.0, 0.0, 1.0, 2.0,
2790 -2.0, -1.0, 0.0, 1.0, 2.0, -2.0, -1.0, 0.0, 1.0, 2.0,
2791 ],
2792 &[5, 5],
2793 );
2794 assert_eq!(z, expected);
2795 }
2796
2797 #[test]
2798 fn test_inner() {
2799 let x = reshape(arange::<_, f32>(None, 5.0, None).unwrap(), &[1, 5]).unwrap();
2800 let y = reshape(arange::<_, f32>(None, 6.0, None).unwrap(), &[2, 3]).unwrap();
2801 assert!(inner(&x, &y).is_err());
2802
2803 let x = array!([1.0, 2.0, 3.0]);
2804 let y = array!([0.0, 1.0, 0.0]);
2805 let z = inner(&x, &y).unwrap();
2806 assert_eq!(z.item::<f32>(), 2.0);
2807
2808 let x = reshape(arange::<_, f32>(None, 24.0, None).unwrap(), &[2, 3, 4]).unwrap();
2809 let y = arange::<_, f32>(None, 4.0, None).unwrap();
2810 let z = inner(&x, &y).unwrap();
2811 let expected = Array::from_slice(&[14.0, 38.0, 62.0, 86.0, 110.0, 134.0], &[2, 3]);
2812 assert_eq!(z, expected);
2813
2814 let x = reshape(arange::<_, f32>(None, 2.0, None).unwrap(), &[1, 1, 2]).unwrap();
2815 let y = reshape(arange::<_, f32>(None, 6.0, None).unwrap(), &[3, 2]).unwrap();
2816 let z = inner(&x, &y).unwrap();
2817 let expected = Array::from_slice(&[1.0, 3.0, 5.0], &[1, 1, 3]);
2818 assert_eq!(z, expected);
2819
2820 let x = eye::<f32>(2, None, None).unwrap();
2821 let y = Array::from_f32(7.0);
2822 let z = inner(&x, &y).unwrap();
2823 let expected = Array::from_slice(&[7.0, 0.0, 0.0, 7.0], &[2, 2]);
2824 assert_eq!(z, expected);
2825 }
2826
2827 #[test]
2828 fn test_divmod() {
2829 let x = array!([1.0, 2.0, 3.0]);
2830 let y = array!([1.0, 1.0, 1.0]);
2831 let out = divmod(&x, &y).unwrap();
2832 assert_eq!(out.0, array!([1.0, 2.0, 3.0]));
2833 assert_eq!(out.1, array!([0.0, 0.0, 0.0]));
2834
2835 let x = array!([5.0, 6.0, 7.0]);
2836 let y = array!([2.0, 2.0, 2.0]);
2837 let out = divmod(&x, &y).unwrap();
2838 assert_eq!(out.0, array!([2.0, 3.0, 3.0]));
2839 assert_eq!(out.1, array!([1.0, 0.0, 1.0]));
2840
2841 let x = array!([5.0, 6.0, 7.0]);
2842 let y = array!([2.0, 2.0, 2.0]);
2843 let out = divmod(&x, &y).unwrap();
2844 assert_eq!(out.0, array!([2.0, 3.0, 3.0]));
2845 assert_eq!(out.1, array!([1.0, 0.0, 1.0]));
2846
2847 let x = array![complex64::new(1.0, 0.0)];
2848 let y = array![complex64::new(2.0, 0.0)];
2849 assert!(divmod(&x, &y).is_err());
2850
2851 let x = array![1.0];
2853 let y = array![2.0];
2854 let (quo, rem) = divmod(&x, &y).unwrap();
2855 eval([&quo, &rem]).unwrap();
2856 assert_eq!(quo.item::<f32>(), 0.0);
2857 assert_eq!(rem.item::<f32>(), 1.0);
2858
2859 let x = array![1.0];
2861 let y = array![2.0];
2862 let (quo, rem) = divmod(&x, &y).unwrap();
2863 let z = quo + rem;
2864 assert_eq!(z.item::<f32>(), 1.0);
2865
2866 let mut out_holder = {
2868 let (quo, _) = divmod(&x, &y).unwrap();
2869 vec![quo]
2870 };
2871 eval(out_holder.iter()).unwrap();
2872 assert_eq!(out_holder[0].item::<f32>(), 0.0);
2873
2874 out_holder.clear();
2876 let out_holder = {
2877 let (_, rem) = divmod(&x, &y).unwrap();
2878 vec![rem]
2879 };
2880 eval(out_holder.iter()).unwrap();
2881 assert_eq!(out_holder[0].item::<f32>(), 1.0);
2882 }
2883}