1use crate::array::Array;
2use crate::error::Result;
3use crate::sealed::Sealed;
4
5use crate::utils::guard::Guarded;
6use crate::utils::{IntoOption, ScalarOrArray, VectorArray};
7use crate::Stream;
8use mlx_internal_macros::{default_device, generate_macro};
9use smallvec::SmallVec;
10
11impl Array {
12 #[default_device]
25 pub fn abs_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
26 Array::try_from_op(|res| unsafe {
27 mlx_sys::mlx_abs(res, self.as_ptr(), stream.as_ref().as_ptr())
28 })
29 }
30
31 #[default_device]
51 pub fn add_device(
52 &self,
53 other: impl AsRef<Array>,
54 stream: impl AsRef<Stream>,
55 ) -> Result<Array> {
56 Array::try_from_op(|res| unsafe {
57 mlx_sys::mlx_add(
58 res,
59 self.as_ptr(),
60 other.as_ref().as_ptr(),
61 stream.as_ref().as_ptr(),
62 )
63 })
64 }
65
66 #[default_device]
86 pub fn subtract_device(
87 &self,
88 other: impl AsRef<Array>,
89 stream: impl AsRef<Stream>,
90 ) -> Result<Array> {
91 Array::try_from_op(|res| unsafe {
92 mlx_sys::mlx_subtract(
93 res,
94 self.as_ptr(),
95 other.as_ref().as_ptr(),
96 stream.as_ref().as_ptr(),
97 )
98 })
99 }
100
101 #[default_device]
116 pub fn negative_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
117 Array::try_from_op(|res| unsafe {
118 mlx_sys::mlx_negative(res, self.as_ptr(), stream.as_ref().as_ptr())
119 })
120 }
121
122 #[default_device]
138 pub fn multiply_device(
139 &self,
140 other: impl AsRef<Array>,
141 stream: impl AsRef<Stream>,
142 ) -> Result<Array> {
143 Array::try_from_op(|res| unsafe {
144 mlx_sys::mlx_multiply(
145 res,
146 self.as_ptr(),
147 other.as_ref().as_ptr(),
148 stream.as_ref().as_ptr(),
149 )
150 })
151 }
152
153 #[default_device]
163 pub fn nan_to_num_device(
164 &self,
165 nan: impl IntoOption<f32>,
166 pos_inf: impl IntoOption<f32>,
167 neg_inf: impl IntoOption<f32>,
168 stream: impl AsRef<Stream>,
169 ) -> Result<Array> {
170 let pos_inf = pos_inf.into_option();
171 let neg_inf = neg_inf.into_option();
172
173 let pos_inf = mlx_sys::mlx_optional_float {
174 value: pos_inf.unwrap_or(0.0),
175 has_value: pos_inf.is_some(),
176 };
177 let neg_inf = mlx_sys::mlx_optional_float {
178 value: neg_inf.unwrap_or(0.0),
179 has_value: neg_inf.is_some(),
180 };
181
182 Array::try_from_op(|res| unsafe {
183 mlx_sys::mlx_nan_to_num(
184 res,
185 self.as_ptr(),
186 nan.into_option().unwrap_or(0.),
187 pos_inf,
188 neg_inf,
189 stream.as_ref().as_ptr(),
190 )
191 })
192 }
193
194 #[default_device]
214 pub fn divide_device(
215 &self,
216 other: impl AsRef<Array>,
217 stream: impl AsRef<Stream>,
218 ) -> Result<Array> {
219 Array::try_from_op(|res| unsafe {
220 mlx_sys::mlx_divide(
221 res,
222 self.as_ptr(),
223 other.as_ref().as_ptr(),
224 stream.as_ref().as_ptr(),
225 )
226 })
227 }
228
229 #[default_device]
249 pub fn power_device(
250 &self,
251 other: impl AsRef<Array>,
252 stream: impl AsRef<Stream>,
253 ) -> Result<Array> {
254 Array::try_from_op(|res| unsafe {
255 mlx_sys::mlx_power(
256 res,
257 self.as_ptr(),
258 other.as_ref().as_ptr(),
259 stream.as_ref().as_ptr(),
260 )
261 })
262 }
263
264 #[default_device]
284 pub fn remainder_device(
285 &self,
286 other: impl AsRef<Array>,
287 stream: impl AsRef<Stream>,
288 ) -> Result<Array> {
289 Array::try_from_op(|res| unsafe {
290 mlx_sys::mlx_remainder(
291 res,
292 self.as_ptr(),
293 other.as_ref().as_ptr(),
294 stream.as_ref().as_ptr(),
295 )
296 })
297 }
298
299 #[default_device]
312 pub fn sqrt_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
313 Array::try_from_op(|res| unsafe {
314 mlx_sys::mlx_sqrt(res, self.as_ptr(), stream.as_ref().as_ptr())
315 })
316 }
317
318 #[default_device]
331 pub fn cos_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
332 Array::try_from_op(|res| unsafe {
333 mlx_sys::mlx_cos(res, self.as_ptr(), stream.as_ref().as_ptr())
334 })
335 }
336
337 #[default_device]
352 pub fn exp_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
353 Array::try_from_op(|res| unsafe {
354 mlx_sys::mlx_exp(res, self.as_ptr(), stream.as_ref().as_ptr())
355 })
356 }
357
358 #[default_device]
371 pub fn floor_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
372 Array::try_from_op(|res| unsafe {
373 mlx_sys::mlx_floor(res, self.as_ptr(), stream.as_ref().as_ptr())
374 })
375 }
376
377 #[default_device]
401 pub fn floor_divide_device(
402 &self,
403 other: impl AsRef<Array>,
404 stream: impl AsRef<Stream>,
405 ) -> Result<Array> {
406 Array::try_from_op(|res| unsafe {
407 mlx_sys::mlx_floor_divide(
408 res,
409 self.as_ptr(),
410 other.as_ref().as_ptr(),
411 stream.as_ref().as_ptr(),
412 )
413 })
414 }
415
416 #[default_device]
421 pub fn is_nan_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
422 Array::try_from_op(|res| unsafe {
423 mlx_sys::mlx_isnan(res, self.as_ptr(), stream.as_ref().as_ptr())
424 })
425 }
426
427 #[default_device]
432 pub fn is_inf_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
433 Array::try_from_op(|res| unsafe {
434 mlx_sys::mlx_isinf(res, self.as_ptr(), stream.as_ref().as_ptr())
435 })
436 }
437
438 #[default_device]
443 pub fn is_finite_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
444 Array::try_from_op(|res| unsafe {
445 mlx_sys::mlx_isfinite(res, self.as_ptr(), stream.as_ref().as_ptr())
446 })
447 }
448
449 #[default_device]
454 pub fn is_neg_inf_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
455 Array::try_from_op(|res| unsafe {
456 mlx_sys::mlx_isneginf(res, self.as_ptr(), stream.as_ref().as_ptr())
457 })
458 }
459
460 #[default_device]
465 pub fn is_pos_inf_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
466 Array::try_from_op(|res| unsafe {
467 mlx_sys::mlx_isposinf(res, self.as_ptr(), stream.as_ref().as_ptr())
468 })
469 }
470
471 #[default_device]
484 pub fn log_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
485 Array::try_from_op(|res| unsafe {
486 mlx_sys::mlx_log(res, self.as_ptr(), stream.as_ref().as_ptr())
487 })
488 }
489
490 #[default_device]
503 pub fn log2_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
504 Array::try_from_op(|res| unsafe {
505 mlx_sys::mlx_log2(res, self.as_ptr(), stream.as_ref().as_ptr())
506 })
507 }
508
509 #[default_device]
522 pub fn log10_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
523 Array::try_from_op(|res| unsafe {
524 mlx_sys::mlx_log10(res, self.as_ptr(), stream.as_ref().as_ptr())
525 })
526 }
527
528 #[default_device]
541 pub fn log1p_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
542 Array::try_from_op(|res| unsafe {
543 mlx_sys::mlx_log1p(res, self.as_ptr(), stream.as_ref().as_ptr())
544 })
545 }
546
547 #[default_device]
577 pub fn matmul_device(
578 &self,
579 other: impl AsRef<Array>,
580 stream: impl AsRef<Stream>,
581 ) -> Result<Array> {
582 Array::try_from_op(|res| unsafe {
583 mlx_sys::mlx_matmul(
584 res,
585 self.as_ptr(),
586 other.as_ref().as_ptr(),
587 stream.as_ref().as_ptr(),
588 )
589 })
590 }
591
592 #[default_device]
605 pub fn reciprocal_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
606 Array::try_from_op(|res| unsafe {
607 mlx_sys::mlx_reciprocal(res, self.as_ptr(), stream.as_ref().as_ptr())
608 })
609 }
610
611 #[default_device]
617 pub fn round_device(
618 &self,
619 decimals: impl Into<Option<i32>>,
620 stream: impl AsRef<Stream>,
621 ) -> Result<Array> {
622 Array::try_from_op(|res| unsafe {
623 mlx_sys::mlx_round(
624 res,
625 self.as_ptr(),
626 decimals.into().unwrap_or(0),
627 stream.as_ref().as_ptr(),
628 )
629 })
630 }
631
632 #[default_device]
634 pub fn rsqrt_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
635 Array::try_from_op(|res| unsafe {
636 mlx_sys::mlx_rsqrt(res, self.as_ptr(), stream.as_ref().as_ptr())
637 })
638 }
639
640 #[default_device]
642 pub fn sin_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
643 Array::try_from_op(|res| unsafe {
644 mlx_sys::mlx_sin(res, self.as_ptr(), stream.as_ref().as_ptr())
645 })
646 }
647
648 #[default_device]
650 pub fn square_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
651 Array::try_from_op(|res| unsafe {
652 mlx_sys::mlx_square(res, self.as_ptr(), stream.as_ref().as_ptr())
653 })
654 }
655
656 #[default_device]
658 pub fn real_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
659 Array::try_from_op(|res| unsafe {
660 mlx_sys::mlx_real(res, self.as_ptr(), stream.as_ref().as_ptr())
661 })
662 }
663
664 #[default_device]
666 pub fn imag_device(&self, stream: impl AsRef<Stream>) -> Result<Array> {
667 Array::try_from_op(|res| unsafe {
668 mlx_sys::mlx_imag(res, self.as_ptr(), stream.as_ref().as_ptr())
669 })
670 }
671}
672
673#[generate_macro]
684#[default_device]
685pub fn abs_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
686 a.as_ref().abs_device(stream)
687}
688
689#[generate_macro]
691#[default_device]
692pub fn acos_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
693 Array::try_from_op(|res| unsafe {
694 mlx_sys::mlx_arccos(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
695 })
696}
697
698#[generate_macro]
700#[default_device]
701pub fn acosh_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
702 Array::try_from_op(|res| unsafe {
703 mlx_sys::mlx_arccosh(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
704 })
705}
706
707#[generate_macro]
709#[default_device]
710pub fn add_device(
711 lhs: impl AsRef<Array>,
712 rhs: impl AsRef<Array>,
713 #[optional] stream: impl AsRef<Stream>,
714) -> Result<Array> {
715 lhs.as_ref().add_device(rhs, stream)
716}
717
718#[generate_macro]
720#[default_device]
721pub fn asin_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
722 Array::try_from_op(|res| unsafe {
723 mlx_sys::mlx_arcsin(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
724 })
725}
726
727#[generate_macro]
729#[default_device]
730pub fn asinh_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
731 Array::try_from_op(|res| unsafe {
732 mlx_sys::mlx_arcsinh(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
733 })
734}
735
736#[generate_macro]
738#[default_device]
739pub fn atan_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
740 Array::try_from_op(|res| unsafe {
741 mlx_sys::mlx_arctan(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
742 })
743}
744
745#[generate_macro]
747#[default_device]
748pub fn atan2_device(
749 a: impl AsRef<Array>,
750 b: impl AsRef<Array>,
751 #[optional] stream: impl AsRef<Stream>,
752) -> Result<Array> {
753 let a = a.as_ref();
754 let b = b.as_ref();
755
756 Array::try_from_op(|res| unsafe {
757 mlx_sys::mlx_arctan2(res, a.as_ptr(), b.as_ptr(), stream.as_ref().as_ptr())
758 })
759}
760
761#[generate_macro]
763#[default_device]
764pub fn atanh_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
765 Array::try_from_op(|res| unsafe {
766 mlx_sys::mlx_arctanh(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
767 })
768}
769
770#[generate_macro]
772#[default_device]
773pub fn ceil_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
774 Array::try_from_op(|res| unsafe {
775 mlx_sys::mlx_ceil(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
776 })
777}
778
779pub trait ClipBound<'min, 'max>: Sealed {
784 fn into_min_max(
786 self,
787 ) -> (
788 Option<impl ScalarOrArray<'min>>,
789 Option<impl ScalarOrArray<'max>>,
790 );
791}
792
793impl<'min, Min> ClipBound<'min, 'min> for (Min, ())
794where
795 Min: ScalarOrArray<'min> + Sealed,
796{
797 fn into_min_max(
798 self,
799 ) -> (
800 Option<impl ScalarOrArray<'min>>,
801 Option<impl ScalarOrArray<'min>>,
802 ) {
803 (Some(self.0), Option::<Min>::None)
804 }
805}
806
807impl<'max, Max> ClipBound<'max, 'max> for ((), Max)
808where
809 Max: ScalarOrArray<'max> + Sealed,
810{
811 fn into_min_max(
812 self,
813 ) -> (
814 Option<impl ScalarOrArray<'max>>,
815 Option<impl ScalarOrArray<'max>>,
816 ) {
817 (Option::<Max>::None, Some(self.1))
818 }
819}
820
821impl<'min, 'max, Min, Max> ClipBound<'min, 'max> for (Min, Max)
822where
823 Min: ScalarOrArray<'min> + Sealed,
824 Max: ScalarOrArray<'max> + Sealed,
825{
826 fn into_min_max(
827 self,
828 ) -> (
829 Option<impl ScalarOrArray<'min>>,
830 Option<impl ScalarOrArray<'max>>,
831 ) {
832 (Some(self.0), Some(self.1))
833 }
834}
835
836#[generate_macro]
858#[default_device]
859pub fn clip_device<'min, 'max>(
860 a: impl AsRef<Array>,
861 bound: impl ClipBound<'min, 'max>,
862 #[optional] stream: impl AsRef<Stream>,
863) -> Result<Array> {
864 let (a_min, a_max) = bound.into_min_max();
865
866 let a_min = a_min.map(|min| min.into_owned_or_ref_array());
868 let a_max = a_max.map(|max| max.into_owned_or_ref_array());
869
870 unsafe {
871 let min_ptr = match &a_min {
872 Some(a_min) => a_min.as_ref().as_ptr(),
873 None => mlx_sys::mlx_array_new(),
874 };
875 let max_ptr = match &a_max {
876 Some(a_max) => a_max.as_ref().as_ptr(),
877 None => mlx_sys::mlx_array_new(),
878 };
879
880 Array::try_from_op(|res| {
881 mlx_sys::mlx_clip(
882 res,
883 a.as_ref().as_ptr(),
884 min_ptr,
885 max_ptr,
886 stream.as_ref().as_ptr(),
887 )
888 })
889 }
890}
891
892#[generate_macro]
894#[default_device]
895pub fn cos_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
896 a.as_ref().cos_device(stream)
897}
898
899#[generate_macro]
901#[default_device]
902pub fn cosh_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
903 Array::try_from_op(|res| unsafe {
904 mlx_sys::mlx_cosh(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
905 })
906}
907
908#[generate_macro]
910#[default_device]
911pub fn degrees_device(
912 a: impl AsRef<Array>,
913 #[optional] stream: impl AsRef<Stream>,
914) -> Result<Array> {
915 Array::try_from_op(|res| unsafe {
916 mlx_sys::mlx_degrees(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
917 })
918}
919
920#[generate_macro]
922#[default_device]
923pub fn divide_device(
924 a: impl AsRef<Array>,
925 b: impl AsRef<Array>,
926 #[optional] stream: impl AsRef<Stream>,
927) -> Result<Array> {
928 a.as_ref().divide_device(b, stream)
929}
930
931#[generate_macro]
938#[default_device]
939pub fn divmod_device(
940 a: impl AsRef<Array>,
941 b: impl AsRef<Array>,
942 #[optional] stream: impl AsRef<Stream>,
943) -> Result<(Array, Array)> {
944 let a_ptr = a.as_ref().as_ptr();
945 let b_ptr = b.as_ref().as_ptr();
946
947 let vec = VectorArray::try_from_op(|res| unsafe {
948 mlx_sys::mlx_divmod(res, a_ptr, b_ptr, stream.as_ref().as_ptr())
949 })?;
950
951 let vals: SmallVec<[_; 2]> = vec.try_into_values()?;
952 let mut iter = vals.into_iter();
953 let quotient = iter.next().unwrap();
954 let remainder = iter.next().unwrap();
955
956 Ok((quotient, remainder))
957}
958
959#[generate_macro]
961#[default_device]
962pub fn erf_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
963 Array::try_from_op(|res| unsafe {
964 mlx_sys::mlx_erf(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
965 })
966}
967
968#[generate_macro]
970#[default_device]
971pub fn erfinv_device(
972 a: impl AsRef<Array>,
973 #[optional] stream: impl AsRef<Stream>,
974) -> Result<Array> {
975 Array::try_from_op(|res| unsafe {
976 mlx_sys::mlx_erfinv(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
977 })
978}
979
980#[generate_macro]
982#[default_device]
983pub fn exp_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
984 a.as_ref().exp_device(stream)
985}
986
987#[generate_macro]
989#[default_device]
990pub fn expm1_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
991 Array::try_from_op(|res| unsafe {
992 mlx_sys::mlx_expm1(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
993 })
994}
995
996#[generate_macro]
998#[default_device]
999pub fn floor_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
1000 a.as_ref().floor_device(stream)
1001}
1002
1003#[generate_macro]
1005#[default_device]
1006pub fn floor_divide_device(
1007 a: impl AsRef<Array>,
1008 other: impl AsRef<Array>,
1009 #[optional] stream: impl AsRef<Stream>,
1010) -> Result<Array> {
1011 a.as_ref().floor_divide_device(other, stream)
1012}
1013
1014#[generate_macro]
1016#[default_device]
1017pub fn log_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
1018 a.as_ref().log_device(stream)
1019}
1020
1021#[generate_macro]
1023#[default_device]
1024pub fn log10_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
1025 a.as_ref().log10_device(stream)
1026}
1027
1028#[generate_macro]
1030#[default_device]
1031pub fn log1p_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
1032 a.as_ref().log1p_device(stream)
1033}
1034
1035#[generate_macro]
1037#[default_device]
1038pub fn log2_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
1039 a.as_ref().log2_device(stream)
1040}
1041
1042#[generate_macro]
1049#[default_device]
1050pub fn logaddexp_device(
1051 a: impl AsRef<Array>,
1052 b: impl AsRef<Array>,
1053 #[optional] stream: impl AsRef<Stream>,
1054) -> Result<Array> {
1055 let a_ptr = a.as_ref().as_ptr();
1056 let b_ptr = b.as_ref().as_ptr();
1057
1058 Array::try_from_op(|res| unsafe {
1059 mlx_sys::mlx_logaddexp(res, a_ptr, b_ptr, stream.as_ref().as_ptr())
1060 })
1061}
1062
1063#[generate_macro]
1065#[default_device]
1066pub fn matmul_device(
1067 a: impl AsRef<Array>,
1068 b: impl AsRef<Array>,
1069 #[optional] stream: impl AsRef<Stream>,
1070) -> Result<Array> {
1071 a.as_ref().matmul_device(b, stream)
1072}
1073
1074#[generate_macro]
1079#[default_device]
1080pub fn maximum_device(
1081 a: impl AsRef<Array>,
1082 b: impl AsRef<Array>,
1083 #[optional] stream: impl AsRef<Stream>,
1084) -> Result<Array> {
1085 let a_ptr = a.as_ref().as_ptr();
1086 let b_ptr = b.as_ref().as_ptr();
1087
1088 Array::try_from_op(|res| unsafe {
1089 mlx_sys::mlx_maximum(res, a_ptr, b_ptr, stream.as_ref().as_ptr())
1090 })
1091}
1092
1093#[generate_macro]
1098#[default_device]
1099pub fn minimum_device(
1100 a: impl AsRef<Array>,
1101 b: impl AsRef<Array>,
1102 #[optional] stream: impl AsRef<Stream>,
1103) -> Result<Array> {
1104 let a_ptr = a.as_ref().as_ptr();
1105 let b_ptr = b.as_ref().as_ptr();
1106
1107 Array::try_from_op(|res| unsafe {
1108 mlx_sys::mlx_minimum(res, a_ptr, b_ptr, stream.as_ref().as_ptr())
1109 })
1110}
1111
1112#[generate_macro]
1114#[default_device]
1115pub fn multiply_device(
1116 a: impl AsRef<Array>,
1117 b: impl AsRef<Array>,
1118 #[optional] stream: impl AsRef<Stream>,
1119) -> Result<Array> {
1120 a.as_ref().multiply_device(b, stream)
1121}
1122
1123#[generate_macro]
1125#[default_device]
1126pub fn negative_device(
1127 a: impl AsRef<Array>,
1128 #[optional] stream: impl AsRef<Stream>,
1129) -> Result<Array> {
1130 a.as_ref().negative_device(stream)
1131}
1132
1133#[generate_macro]
1135#[default_device]
1136pub fn power_device(
1137 a: impl AsRef<Array>,
1138 b: impl AsRef<Array>,
1139 #[optional] stream: impl AsRef<Stream>,
1140) -> Result<Array> {
1141 a.as_ref().power_device(b, stream)
1142}
1143
1144#[generate_macro]
1146#[default_device]
1147pub fn radians_device(
1148 a: impl AsRef<Array>,
1149 #[optional] stream: impl AsRef<Stream>,
1150) -> Result<Array> {
1151 Array::try_from_op(|res| unsafe {
1152 mlx_sys::mlx_radians(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
1153 })
1154}
1155
1156#[generate_macro]
1158#[default_device]
1159pub fn reciprocal_device(
1160 a: impl AsRef<Array>,
1161 #[optional] stream: impl AsRef<Stream>,
1162) -> Result<Array> {
1163 a.as_ref().reciprocal_device(stream)
1164}
1165
1166#[generate_macro]
1168#[default_device]
1169pub fn remainder_device(
1170 a: impl AsRef<Array>,
1171 b: impl AsRef<Array>,
1172 #[optional] stream: impl AsRef<Stream>,
1173) -> Result<Array> {
1174 a.as_ref().remainder_device(b, stream)
1175}
1176
1177#[generate_macro]
1179#[default_device]
1180pub fn round_device(
1181 a: impl AsRef<Array>,
1182 decimals: impl Into<Option<i32>>,
1183 #[optional] stream: impl AsRef<Stream>,
1184) -> Result<Array> {
1185 a.as_ref().round_device(decimals, stream)
1186}
1187
1188#[generate_macro]
1190#[default_device]
1191pub fn rsqrt_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
1192 a.as_ref().rsqrt_device(stream)
1193}
1194
1195#[generate_macro]
1201#[default_device]
1202pub fn sigmoid_device(
1203 a: impl AsRef<Array>,
1204 #[optional] stream: impl AsRef<Stream>,
1205) -> Result<Array> {
1206 Array::try_from_op(|res| unsafe {
1207 mlx_sys::mlx_sigmoid(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
1208 })
1209}
1210
1211#[generate_macro]
1213#[default_device]
1214pub fn sign_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
1215 Array::try_from_op(|res| unsafe {
1216 mlx_sys::mlx_sign(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
1217 })
1218}
1219
1220#[generate_macro]
1222#[default_device]
1223pub fn sin_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
1224 a.as_ref().sin_device(stream)
1225}
1226
1227#[generate_macro]
1229#[default_device]
1230pub fn sinh_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
1231 Array::try_from_op(|res| unsafe {
1232 mlx_sys::mlx_sinh(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
1233 })
1234}
1235
1236#[generate_macro]
1242#[default_device]
1243pub fn softmax_axes_device(
1244 a: impl AsRef<Array>,
1245 axes: &[i32],
1246 precise: impl Into<Option<bool>>,
1247 #[optional] stream: impl AsRef<Stream>,
1248) -> Result<Array> {
1249 let precise = precise.into().unwrap_or(false);
1250 let s = stream.as_ref().as_ptr();
1251
1252 Array::try_from_op(|res| unsafe {
1253 mlx_sys::mlx_softmax_axes(
1254 res,
1255 a.as_ref().as_ptr(),
1256 axes.as_ptr(),
1257 axes.len(),
1258 precise,
1259 s,
1260 )
1261 })
1262}
1263
1264#[generate_macro]
1266#[default_device]
1267pub fn softmax_axis_device(
1268 a: impl AsRef<Array>,
1269 axis: i32,
1270 precise: impl Into<Option<bool>>,
1271 #[optional] stream: impl AsRef<Stream>,
1272) -> Result<Array> {
1273 let precise = precise.into().unwrap_or(false);
1274 let s = stream.as_ref().as_ptr();
1275
1276 Array::try_from_op(|res| unsafe {
1277 mlx_sys::mlx_softmax_axis(res, a.as_ref().as_ptr(), axis, precise, s)
1278 })
1279}
1280
1281#[generate_macro]
1283#[default_device]
1284pub fn softmax_device(
1285 a: impl AsRef<Array>,
1286 precise: impl Into<Option<bool>>,
1287 #[optional] stream: impl AsRef<Stream>,
1288) -> Result<Array> {
1289 let precise = precise.into().unwrap_or(false);
1290 let s = stream.as_ref().as_ptr();
1291
1292 Array::try_from_op(|res| unsafe { mlx_sys::mlx_softmax(res, a.as_ref().as_ptr(), precise, s) })
1293}
1294
1295#[generate_macro]
1297#[default_device]
1298pub fn sqrt_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
1299 a.as_ref().sqrt_device(stream)
1300}
1301
1302#[generate_macro]
1304#[default_device]
1305pub fn square_device(
1306 a: impl AsRef<Array>,
1307 #[optional] stream: impl AsRef<Stream>,
1308) -> Result<Array> {
1309 a.as_ref().square_device(stream)
1310}
1311
1312#[generate_macro]
1314#[default_device]
1315pub fn subtract_device(
1316 a: impl AsRef<Array>,
1317 b: impl AsRef<Array>,
1318 #[optional] stream: impl AsRef<Stream>,
1319) -> Result<Array> {
1320 a.as_ref().subtract_device(b, stream)
1321}
1322
1323#[generate_macro]
1325#[default_device]
1326pub fn tan_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
1327 Array::try_from_op(|res| unsafe {
1328 mlx_sys::mlx_tan(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
1329 })
1330}
1331
1332#[generate_macro]
1334#[default_device]
1335pub fn tanh_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
1336 Array::try_from_op(|res| unsafe {
1337 mlx_sys::mlx_tanh(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
1338 })
1339}
1340
1341#[generate_macro]
1343#[default_device]
1344pub fn real_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
1345 Array::try_from_op(|res| unsafe {
1346 mlx_sys::mlx_real(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
1347 })
1348}
1349
1350#[generate_macro]
1352#[default_device]
1353pub fn imag_device(a: impl AsRef<Array>, #[optional] stream: impl AsRef<Stream>) -> Result<Array> {
1354 Array::try_from_op(|res| unsafe {
1355 mlx_sys::mlx_imag(res, a.as_ref().as_ptr(), stream.as_ref().as_ptr())
1356 })
1357}
1358
1359#[generate_macro]
1365#[default_device]
1366pub fn block_masked_mm_device<'mo, 'lhs, 'rhs>(
1367 a: impl AsRef<Array>,
1368 b: impl AsRef<Array>,
1369 #[optional] block_size: impl Into<Option<i32>>,
1370 #[optional] mask_out: impl Into<Option<&'mo Array>>,
1371 #[optional] mask_lhs: impl Into<Option<&'lhs Array>>,
1372 #[optional] mask_rhs: impl Into<Option<&'rhs Array>>,
1373 #[optional] stream: impl AsRef<Stream>,
1374) -> Result<Array> {
1375 let a_ptr = a.as_ref().as_ptr();
1376 let b_ptr = b.as_ref().as_ptr();
1377 unsafe {
1378 let mask_out_ptr = mask_out
1379 .into()
1380 .map(|m| m.as_ptr())
1381 .unwrap_or(mlx_sys::mlx_array_new());
1382 let mask_lhs_ptr = mask_lhs
1383 .into()
1384 .map(|m| m.as_ptr())
1385 .unwrap_or(mlx_sys::mlx_array_new());
1386 let mask_rhs_ptr = mask_rhs
1387 .into()
1388 .map(|m| m.as_ptr())
1389 .unwrap_or(mlx_sys::mlx_array_new());
1390
1391 Array::try_from_op(|res| {
1392 mlx_sys::mlx_block_masked_mm(
1393 res,
1394 a_ptr,
1395 b_ptr,
1396 block_size.into().unwrap_or(32),
1397 mask_out_ptr,
1398 mask_lhs_ptr,
1399 mask_rhs_ptr,
1400 stream.as_ref().as_ptr(),
1401 )
1402 })
1403 }
1404}
1405
1406#[generate_macro]
1419#[default_device]
1420pub fn addmm_device(
1421 c: impl AsRef<Array>,
1422 a: impl AsRef<Array>,
1423 b: impl AsRef<Array>,
1424 #[optional] alpha: impl Into<Option<f32>>,
1425 #[optional] beta: impl Into<Option<f32>>,
1426 #[optional] stream: impl AsRef<Stream>,
1427) -> Result<Array> {
1428 let c_ptr = c.as_ref().as_ptr();
1429 let a_ptr = a.as_ref().as_ptr();
1430 let b_ptr = b.as_ref().as_ptr();
1431 let alpha = alpha.into().unwrap_or(1.0);
1432 let beta = beta.into().unwrap_or(1.0);
1433
1434 Array::try_from_op(|res| unsafe {
1435 mlx_sys::mlx_addmm(
1436 res,
1437 c_ptr,
1438 a_ptr,
1439 b_ptr,
1440 alpha,
1441 beta,
1442 stream.as_ref().as_ptr(),
1443 )
1444 })
1445}
1446
1447#[generate_macro]
1450#[default_device]
1451pub fn inner_device(
1452 a: impl AsRef<Array>,
1453 b: impl AsRef<Array>,
1454 #[optional] stream: impl AsRef<Stream>,
1455) -> Result<Array> {
1456 let a = a.as_ref();
1457 let b = b.as_ref();
1458 Array::try_from_op(|res| unsafe {
1459 mlx_sys::mlx_inner(res, a.as_ptr(), b.as_ptr(), stream.as_ref().as_ptr())
1460 })
1461}
1462
1463#[generate_macro]
1466#[default_device]
1467pub fn outer_device(
1468 a: impl AsRef<Array>,
1469 b: impl AsRef<Array>,
1470 #[optional] stream: impl AsRef<Stream>,
1471) -> Result<Array> {
1472 let a = a.as_ref();
1473 let b = b.as_ref();
1474 Array::try_from_op(|res| unsafe {
1475 mlx_sys::mlx_outer(res, a.as_ptr(), b.as_ptr(), stream.as_ref().as_ptr())
1476 })
1477}
1478
1479#[generate_macro]
1481#[default_device]
1482pub fn tensordot_axes_device(
1483 a: impl AsRef<Array>,
1484 b: impl AsRef<Array>,
1485 axes_a: &[i32],
1486 axes_b: &[i32],
1487 #[optional] stream: impl AsRef<Stream>,
1488) -> Result<Array> {
1489 let a = a.as_ref();
1490 let b = b.as_ref();
1491 Array::try_from_op(|res| unsafe {
1492 mlx_sys::mlx_tensordot(
1493 res,
1494 a.as_ptr(),
1495 b.as_ptr(),
1496 axes_a.as_ptr(),
1497 axes_a.len(),
1498 axes_b.as_ptr(),
1499 axes_b.len(),
1500 stream.as_ref().as_ptr(),
1501 )
1502 })
1503}
1504
1505#[generate_macro]
1507#[default_device]
1508pub fn tensordot_axis_device(
1509 a: impl AsRef<Array>,
1510 b: impl AsRef<Array>,
1511 axis: i32,
1512 #[optional] stream: impl AsRef<Stream>,
1513) -> Result<Array> {
1514 let a = a.as_ref();
1515 let b = b.as_ref();
1516 Array::try_from_op(|res| unsafe {
1517 mlx_sys::mlx_tensordot_axis(res, a.as_ptr(), b.as_ptr(), axis, stream.as_ref().as_ptr())
1518 })
1519}
1520
1521#[cfg(test)]
1522mod tests {
1523 use std::f32::consts::PI;
1524
1525 use super::*;
1526 use crate::{
1527 array, complex64,
1528 ops::{all_close, arange, broadcast_to, eye, full, linspace, ones, reshape, split},
1529 transforms::eval,
1530 Dtype, StreamOrDevice,
1531 };
1532 use float_eq::assert_float_eq;
1533 use pretty_assertions::assert_eq;
1534
1535 #[test]
1536 fn test_abs() {
1537 let data = [1i32, 2, -3, -4, -5];
1538 let array = Array::from_slice(&data, &[5]);
1539 let result = array.abs().unwrap();
1540
1541 let data: &[i32] = result.as_slice();
1542 assert_eq!(data, [1, 2, 3, 4, 5]);
1543
1544 let data: &[i32] = array.as_slice();
1546 assert_eq!(data, [1, 2, -3, -4, -5]);
1547 }
1548
1549 #[test]
1550 fn test_add() {
1551 let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1552 let b = Array::from_slice(&[4.0, 5.0, 6.0], &[3]);
1553
1554 let c = &a + &b;
1555
1556 let c_data: &[f32] = c.as_slice();
1557 assert_eq!(c_data, &[5.0, 7.0, 9.0]);
1558
1559 let a_data: &[f32] = a.as_slice();
1561 assert_eq!(a_data, &[1.0, 2.0, 3.0]);
1562
1563 let b_data: &[f32] = b.as_slice();
1564 assert_eq!(b_data, &[4.0, 5.0, 6.0]);
1565 }
1566
1567 #[test]
1568 fn test_add_invalid_broadcast() {
1569 let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1570 let b = Array::from_slice(&[4.0, 5.0], &[2]);
1571
1572 let c = a.add(&b);
1573 assert!(c.is_err());
1574 }
1575
1576 #[test]
1577 fn test_sub() {
1578 let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1579 let b = Array::from_slice(&[4.0, 5.0, 6.0], &[3]);
1580
1581 let c = &a - &b;
1582
1583 let c_data: &[f32] = c.as_slice();
1584 assert_eq!(c_data, &[-3.0, -3.0, -3.0]);
1585
1586 let a_data: &[f32] = a.as_slice();
1588 assert_eq!(a_data, &[1.0, 2.0, 3.0]);
1589
1590 let b_data: &[f32] = b.as_slice();
1591 assert_eq!(b_data, &[4.0, 5.0, 6.0]);
1592 }
1593
1594 #[test]
1595 fn test_sub_invalid_broadcast() {
1596 let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1597 let b = Array::from_slice(&[4.0, 5.0], &[2]);
1598 let c = a.subtract(&b);
1599 assert!(c.is_err());
1600 }
1601
1602 #[test]
1603 fn test_neg() {
1604 let a = Array::from_slice::<f32>(&[1.0, 2.0, 3.0], &[3]);
1605 let b = a.negative().unwrap();
1606
1607 let b_data: &[f32] = b.as_slice();
1608 assert_eq!(b_data, &[-1.0, -2.0, -3.0]);
1609
1610 let a_data: &[f32] = a.as_slice();
1612 assert_eq!(a_data, &[1.0, 2.0, 3.0]);
1613 }
1614
1615 #[test]
1616 fn test_neg_bool() {
1617 let a = Array::from_slice(&[true, false, true], &[3]);
1618 let b = a.negative();
1619 assert!(b.is_err());
1620 }
1621
1622 #[test]
1623 fn test_logical_not() {
1624 let a: Array = false.into();
1625 let b = a.logical_not().unwrap();
1626
1627 let b_data: &[bool] = b.as_slice();
1628 assert_eq!(b_data, [true]);
1629 }
1630
1631 #[test]
1632 fn test_mul() {
1633 let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1634 let b = Array::from_slice(&[4.0, 5.0, 6.0], &[3]);
1635
1636 let c = &a * &b;
1637
1638 let c_data: &[f32] = c.as_slice();
1639 assert_eq!(c_data, &[4.0, 10.0, 18.0]);
1640
1641 let a_data: &[f32] = a.as_slice();
1643 assert_eq!(a_data, &[1.0, 2.0, 3.0]);
1644
1645 let b_data: &[f32] = b.as_slice();
1646 assert_eq!(b_data, &[4.0, 5.0, 6.0]);
1647 }
1648
1649 #[test]
1650 fn test_mul_invalid_broadcast() {
1651 let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1652 let b = Array::from_slice(&[4.0, 5.0], &[2]);
1653 let c = a.multiply(&b);
1654 assert!(c.is_err());
1655 }
1656
1657 #[test]
1658 fn test_nan_to_num() {
1659 let a = array!([1.0, 2.0, f32::NAN, 4.0, 5.0]);
1660 let b = a.nan_to_num(0.0, 1.0, 0.0).unwrap();
1661
1662 let b_data: &[f32] = b.as_slice();
1663 assert_eq!(b_data, &[1.0, 2.0, 0.0, 4.0, 5.0]);
1664 }
1665
1666 #[test]
1667 fn test_div() {
1668 let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1669 let b = Array::from_slice(&[4.0, 5.0, 6.0], &[3]);
1670
1671 let c = &a / &b;
1672
1673 let c_data: &[f32] = c.as_slice();
1674 assert_eq!(c_data, &[0.25, 0.4, 0.5]);
1675
1676 let a_data: &[f32] = a.as_slice();
1678 assert_eq!(a_data, &[1.0, 2.0, 3.0]);
1679
1680 let b_data: &[f32] = b.as_slice();
1681 assert_eq!(b_data, &[4.0, 5.0, 6.0]);
1682 }
1683
1684 #[test]
1685 fn test_div_invalid_broadcast() {
1686 let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1687 let b = Array::from_slice(&[4.0, 5.0], &[2]);
1688 let c = a.divide(&b);
1689 assert!(c.is_err());
1690 }
1691
1692 #[test]
1693 fn test_pow() {
1694 let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1695 let b = Array::from_slice(&[2.0, 3.0, 4.0], &[3]);
1696
1697 let c = a.power(&b).unwrap();
1698
1699 let c_data: &[f32] = c.as_slice();
1700 assert_eq!(c_data, &[1.0, 8.0, 81.0]);
1701
1702 let a_data: &[f32] = a.as_slice();
1704 assert_eq!(a_data, &[1.0, 2.0, 3.0]);
1705
1706 let b_data: &[f32] = b.as_slice();
1707 assert_eq!(b_data, &[2.0, 3.0, 4.0]);
1708 }
1709
1710 #[test]
1711 fn test_pow_invalid_broadcast() {
1712 let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1713 let b = Array::from_slice(&[2.0, 3.0], &[2]);
1714 let c = a.power(&b);
1715 assert!(c.is_err());
1716 }
1717
1718 #[test]
1719 fn test_rem() {
1720 let a = Array::from_slice(&[10.0, 11.0, 12.0], &[3]);
1721 let b = Array::from_slice(&[3.0, 4.0, 5.0], &[3]);
1722
1723 let c = &a % &b;
1724
1725 let c_data: &[f32] = c.as_slice();
1726 assert_eq!(c_data, &[1.0, 3.0, 2.0]);
1727
1728 let a_data: &[f32] = a.as_slice();
1730 assert_eq!(a_data, &[10.0, 11.0, 12.0]);
1731
1732 let b_data: &[f32] = b.as_slice();
1733 assert_eq!(b_data, &[3.0, 4.0, 5.0]);
1734 }
1735
1736 #[test]
1737 fn test_rem_invalid_broadcast() {
1738 let a = Array::from_slice(&[10.0, 11.0, 12.0], &[3]);
1739 let b = Array::from_slice(&[3.0, 4.0], &[2]);
1740 let c = a.remainder(&b);
1741 assert!(c.is_err());
1742 }
1743
1744 #[test]
1745 fn test_sqrt() {
1746 let a = Array::from_slice(&[1.0, 4.0, 9.0], &[3]);
1747 let b = a.sqrt().unwrap();
1748
1749 let b_data: &[f32] = b.as_slice();
1750 assert_eq!(b_data, &[1.0, 2.0, 3.0]);
1751
1752 let a_data: &[f32] = a.as_slice();
1754 assert_eq!(a_data, &[1.0, 4.0, 9.0]);
1755 }
1756
1757 #[test]
1758 fn test_cos() {
1759 let a = Array::from_slice(&[0.0, 1.0, 2.0], &[3]);
1760 let b = a.cos().unwrap();
1761
1762 let b_expected = array!([1.0, 0.54030234, -0.41614687]);
1763 assert_array_all_close!(b, b_expected);
1764
1765 let a_expected = array!([0.0, 1.0, 2.0]);
1767 assert_array_all_close!(a, a_expected);
1768 }
1769
1770 #[test]
1771 fn test_exp() {
1772 let a = Array::from_slice(&[0.0, 1.0, 2.0], &[3]);
1773 let b = a.exp().unwrap();
1774
1775 let b_expected = array!([1.0, 2.7182817, 7.389056]);
1776 assert_array_all_close!(b, b_expected);
1777
1778 let a_expected = array!([0.0, 1.0, 2.0]);
1780 assert_array_all_close!(a, a_expected);
1781 }
1782
1783 #[test]
1784 fn test_floor() {
1785 let a = Array::from_slice(&[0.1, 1.9, 2.5], &[3]);
1786 let b = a.floor().unwrap();
1787
1788 let b_data: &[f32] = b.as_slice();
1789 assert_eq!(b_data, &[0.0, 1.0, 2.0]);
1790
1791 let a_data: &[f32] = a.as_slice();
1793 assert_eq!(a_data, &[0.1, 1.9, 2.5]);
1794 }
1795
1796 #[test]
1797 fn test_floor_complex64() {
1798 let val = complex64::new(1.0, 2.0);
1799 let a = Array::from_complex(val);
1800 let b = a.floor_device(StreamOrDevice::default());
1801 assert!(b.is_err());
1802 }
1803
1804 #[test]
1805 fn test_floor_divide() {
1806 let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1807 let b = Array::from_slice(&[4.0, 5.0, 6.0], &[3]);
1808
1809 let c = a.floor_divide(&b).unwrap();
1810
1811 let c_data: &[f32] = c.as_slice();
1812 assert_eq!(c_data, &[0.0, 0.0, 0.0]);
1813
1814 let a_data: &[f32] = a.as_slice();
1816 assert_eq!(a_data, &[1.0, 2.0, 3.0]);
1817
1818 let b_data: &[f32] = b.as_slice();
1819 assert_eq!(b_data, &[4.0, 5.0, 6.0]);
1820 }
1821
1822 #[test]
1823 fn test_floor_divide_complex64() {
1824 let val = complex64::new(1.0, 2.0);
1825 let a = Array::from_complex(val);
1826 let b = Array::from_slice(&[4.0, 5.0, 6.0], &[3]);
1827 let c = a.floor_divide_device(&b, StreamOrDevice::default());
1828 assert!(c.is_err());
1829 }
1830
1831 #[test]
1832 fn test_floor_divide_invalid_broadcast() {
1833 let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1834 let b = Array::from_slice(&[4.0, 5.0], &[2]);
1835 let c = a.floor_divide_device(&b, StreamOrDevice::default());
1836 assert!(c.is_err());
1837 }
1838
1839 #[test]
1840 fn test_is_nan() {
1841 let a = Array::from_slice(&[1.0, f32::NAN, 3.0], &[3]);
1842 let b = a.is_nan().unwrap();
1843
1844 let b_data: &[bool] = b.as_slice();
1845 assert_eq!(b_data, &[false, true, false]);
1846 }
1847
1848 #[test]
1849 fn test_is_inf() {
1850 let a = Array::from_slice(&[1.0, f32::INFINITY, 3.0], &[3]);
1851 let b = a.is_inf().unwrap();
1852
1853 let b_data: &[bool] = b.as_slice();
1854 assert_eq!(b_data, &[false, true, false]);
1855 }
1856
1857 #[test]
1858 fn test_is_finite() {
1859 let a = Array::from_slice(&[1.0, f32::INFINITY, 3.0], &[3]);
1860 let b = a.is_finite().unwrap();
1861
1862 let b_data: &[bool] = b.as_slice();
1863 assert_eq!(b_data, &[true, false, true]);
1864 }
1865
1866 #[test]
1867 fn test_is_neg_inf() {
1868 let a = Array::from_slice(&[1.0, f32::NEG_INFINITY, 3.0], &[3]);
1869 let b = a.is_neg_inf().unwrap();
1870
1871 let b_data: &[bool] = b.as_slice();
1872 assert_eq!(b_data, &[false, true, false]);
1873 }
1874
1875 #[test]
1876 fn test_is_pos_inf() {
1877 let a = Array::from_slice(&[1.0, f32::INFINITY, 3.0], &[3]);
1878 let b = a.is_pos_inf().unwrap();
1879
1880 let b_data: &[bool] = b.as_slice();
1881 assert_eq!(b_data, &[false, true, false]);
1882 }
1883
1884 #[test]
1885 fn test_log() {
1886 let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1887 let b = a.log().unwrap();
1888
1889 let b_data: &[f32] = b.as_slice();
1890 assert_eq!(b_data, &[0.0, 0.6931472, 1.0986123]);
1891
1892 let a_data: &[f32] = a.as_slice();
1894 assert_eq!(a_data, &[1.0, 2.0, 3.0]);
1895 }
1896
1897 #[test]
1898 fn test_log2() {
1899 let a = Array::from_slice(&[1.0, 2.0, 4.0, 8.0], &[4]);
1900 let b = a.log2().unwrap();
1901
1902 let b_data: &[f32] = b.as_slice();
1903 assert_eq!(b_data, &[0.0, 1.0, 2.0, 3.0]);
1904
1905 let a_data: &[f32] = a.as_slice();
1907 assert_eq!(a_data, &[1.0, 2.0, 4.0, 8.0]);
1908 }
1909
1910 #[test]
1911 fn test_log10() {
1912 let a = Array::from_slice(&[1.0, 10.0, 100.0], &[3]);
1913 let b = a.log10().unwrap();
1914
1915 let b_data: &[f32] = b.as_slice();
1916 assert_eq!(b_data, &[0.0, 1.0, 2.0]);
1917
1918 let a_data: &[f32] = a.as_slice();
1920 assert_eq!(a_data, &[1.0, 10.0, 100.0]);
1921 }
1922
1923 #[test]
1924 fn test_log1p() {
1925 let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
1926 let b = a.log1p().unwrap();
1927
1928 let b_data: &[f32] = b.as_slice();
1929 assert_eq!(b_data, &[0.6931472, 1.0986123, 1.3862944]);
1930
1931 let a_data: &[f32] = a.as_slice();
1933 assert_eq!(a_data, &[1.0, 2.0, 3.0]);
1934 }
1935
1936 #[test]
1937 fn test_matmul() {
1938 let a = Array::from_slice(&[1, 2, 3, 4], &[2, 2]);
1939 let b = Array::from_slice(&[-5.0, 37.5, 4., 7., 1., 0.], &[2, 3]);
1940
1941 let c = a.matmul(&b).unwrap();
1942
1943 assert_eq!(c.shape(), &[2, 3]);
1944 let c_data: &[f32] = c.as_slice();
1945 assert_eq!(c_data, &[9.0, 39.5, 4.0, 13.0, 116.5, 12.0]);
1946
1947 let a_data: &[i32] = a.as_slice();
1949 assert_eq!(a_data, &[1, 2, 3, 4]);
1950
1951 let b_data: &[f32] = b.as_slice();
1952 assert_eq!(b_data, &[-5.0, 37.5, 4., 7., 1., 0.]);
1953 }
1954
1955 #[test]
1956 fn test_matmul_ndim_zero() {
1957 let a: Array = 1.0.into();
1958 let b = Array::from_slice::<i32>(&[1], &[1]);
1959 let c = a.matmul(&b);
1960 assert!(c.is_err());
1961 }
1962
1963 #[test]
1964 fn test_matmul_ndim_one() {
1965 let a = Array::from_slice(&[1.0, 2.0, 3.0, 4.0], &[4]);
1966 let b = Array::from_slice(&[1.0, 2.0, 3.0, 4.0], &[4]);
1967 let c = a.matmul(&b);
1968 assert!(c.is_ok());
1969 }
1970
1971 #[test]
1972 fn test_matmul_dim_mismatch() {
1973 let a = Array::from_slice(&[1, 2, 3, 4, 5, 6], &[2, 3]);
1974 let b = Array::from_slice(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10], &[2, 5]);
1975 let c = a.matmul(&b);
1976 assert!(c.is_err());
1977 }
1978
1979 #[test]
1980 fn test_matmul_non_float_output_type() {
1981 let a = Array::from_slice(&[1, 2, 3, 4], &[2, 2]);
1982 let b = Array::from_slice(&[5, 37, 4, 7, 1, 0], &[2, 3]);
1983
1984 let c = a.matmul(&b);
1985 assert!(c.is_err());
1986 }
1987
1988 #[test]
1989 fn test_reciprocal() {
1990 let a = Array::from_slice(&[1.0, 2.0, 4.0], &[3]);
1991 let b = a.reciprocal().unwrap();
1992
1993 let b_data: &[f32] = b.as_slice();
1994 assert_eq!(b_data, &[1.0, 0.5, 0.25]);
1995
1996 let a_data: &[f32] = a.as_slice();
1998 assert_eq!(a_data, &[1.0, 2.0, 4.0]);
1999 }
2000
2001 #[test]
2002 fn test_round() {
2003 let a = Array::from_slice(&[1.1, 2.9, 3.5], &[3]);
2004 let b = a.round(None).unwrap();
2005
2006 let b_data: &[f32] = b.as_slice();
2007 assert_eq!(b_data, &[1.0, 3.0, 4.0]);
2008
2009 let a_data: &[f32] = a.as_slice();
2011 assert_eq!(a_data, &[1.1, 2.9, 3.5]);
2012 }
2013
2014 #[test]
2015 fn test_rsqrt() {
2016 let a = Array::from_slice(&[1.0, 2.0, 4.0], &[3]);
2017 let b = a.rsqrt().unwrap();
2018
2019 let b_data: &[f32] = b.as_slice();
2020 assert_eq!(b_data, &[1.0, 0.70710677, 0.5]);
2021
2022 let a_data: &[f32] = a.as_slice();
2024 assert_eq!(a_data, &[1.0, 2.0, 4.0]);
2025 }
2026
2027 #[test]
2028 fn test_sin() {
2029 let a = Array::from_slice(&[0.0, 1.0, 2.0], &[3]);
2030 let b = a.sin().unwrap();
2031
2032 let b_data: &[f32] = b.as_slice();
2033 assert_eq!(b_data, &[0.0, 0.841471, 0.9092974]);
2034
2035 let a_data: &[f32] = a.as_slice();
2037 assert_eq!(a_data, &[0.0, 1.0, 2.0]);
2038 }
2039
2040 #[test]
2041 fn test_square() {
2042 let a = Array::from_slice(&[1.0, 2.0, 3.0], &[3]);
2043 let b = a.square().unwrap();
2044
2045 let b_data: &[f32] = b.as_slice();
2046 assert_eq!(b_data, &[1.0, 4.0, 9.0]);
2047
2048 let a_data: &[f32] = a.as_slice();
2050 assert_eq!(a_data, &[1.0, 2.0, 3.0]);
2051 }
2052
2053 #[test]
2056 fn test_unary_neg() {
2057 let x = array!(1.0);
2058 assert_eq!(negative(&x).unwrap().item::<f32>(), -1.0);
2059 assert_eq!((-x).item::<f32>(), -1.0);
2060
2061 assert_eq!(-array!(), array!());
2063
2064 let x = array!(true);
2066 assert!(negative(&x).is_err());
2067 }
2068
2069 #[test]
2070 fn test_unary_abs() {
2071 let x = array!([-1.0, 0.0, 1.0]);
2072 assert_eq!(abs(&x).unwrap(), array!([1.0, 0.0, 1.0]));
2073
2074 assert_eq!(abs(array!()).unwrap(), array!());
2076
2077 let x = array!([-1, 0, 1]);
2079 assert_eq!(abs(&x).unwrap(), array!([1, 0, 1]));
2080
2081 let x = array!([1u32, 0, 1]);
2083 assert_eq!(abs(&x).unwrap(), array!([1u32, 0, 1]));
2084
2085 let x = array!([false, true]);
2087 assert_eq!(abs(&x).unwrap(), array!([false, true]));
2088 }
2089
2090 #[test]
2091 fn test_unary_sign() {
2092 let x = array!([-1.0, 0.0, 1.0]);
2093 assert_eq!(sign(&x).unwrap(), x);
2094
2095 assert_eq!(sign(array!()).unwrap(), array!());
2097
2098 let x = array!([-1, 0, 1]);
2100 assert_eq!(sign(&x).unwrap(), x);
2101
2102 let x = array!([1u32, 0, 1]);
2104 assert_eq!(sign(&x).unwrap(), x);
2105
2106 let x = array!([false, true]);
2108 assert_eq!(sign(&x).unwrap(), x);
2109 }
2110
2111 const NEG_INF: f32 = f32::NEG_INFINITY;
2112
2113 #[test]
2114 fn test_unary_floor_ceil() {
2115 let x = array![1.0];
2116 assert_eq!(floor(&x).unwrap().item::<f32>(), 1.0);
2117 assert_eq!(ceil(&x).unwrap().item::<f32>(), 1.0);
2118
2119 let x = array![1.5];
2120 assert_eq!(floor(&x).unwrap().item::<f32>(), 1.0);
2121 assert_eq!(ceil(&x).unwrap().item::<f32>(), 2.0);
2122
2123 let x = array![-1.5];
2124 assert_eq!(floor(&x).unwrap().item::<f32>(), -2.0);
2125 assert_eq!(ceil(&x).unwrap().item::<f32>(), -1.0);
2126
2127 let x = array![NEG_INF];
2128 assert_eq!(floor(&x).unwrap().item::<f32>(), NEG_INF);
2129 assert_eq!(ceil(&x).unwrap().item::<f32>(), NEG_INF);
2130
2131 let x = array!([1.0, 1.0]).as_type::<complex64>().unwrap();
2132 assert!(floor(&x).is_err());
2133 assert!(ceil(&x).is_err());
2134 }
2135
2136 #[test]
2137 fn test_unary_round() {
2138 let x = array!([0.5, -0.5, 1.5, -1.5, 2.3, 2.6]);
2139 assert_eq!(round(&x, None).unwrap(), array!([0, 0, 2, -2, 2, 3]));
2140
2141 let x = array!([11, 222, 32]);
2142 assert_eq!(round(&x, -1).unwrap(), array!([10, 220, 30]));
2143 }
2144
2145 #[test]
2146 fn test_unary_exp() {
2147 let x = array![0.0];
2148 assert_eq!(exp(&x).unwrap().item::<f32>(), 1.0);
2149
2150 let x = array![2.0];
2151 assert_float_eq! {
2152 exp(&x).unwrap().item::<f32>(),
2153 2.0f32.exp(),
2154 abs <= 1e-5
2155 };
2156
2157 assert_eq!(exp(array!()).unwrap(), array!());
2158
2159 let x = array![NEG_INF];
2160 assert_eq!(exp(&x).unwrap().item::<f32>(), 0.0);
2161
2162 let x = array![2];
2164 assert_eq!(x.dtype(), Dtype::Int32);
2165 assert_float_eq! {
2166 exp(&x).unwrap().item::<f32>(),
2167 2.0f32.exp(),
2168 abs <= 1e-5
2169 };
2170
2171 let x = broadcast_to(&array!(1.0), &[2, 2, 2]).unwrap();
2173 let res = exp(&x).unwrap();
2174 let expected = Array::full::<f32>(&[2, 2, 2], array!(1.0f32.exp())).unwrap();
2175 assert!(all_close(&res, &expected, None, None, None)
2176 .unwrap()
2177 .item::<bool>());
2178
2179 let data = Array::from_slice(&[0.0, 1.0, 2.0, 3.0], &[2, 2]);
2180 let x = split(&data, 2, 1).unwrap();
2181 let expected = Array::from_slice(&[0.0f32.exp(), 2.0f32.exp()], &[2, 1]);
2182 assert!(all_close(exp(&x[0]).unwrap(), &expected, None, None, None)
2183 .unwrap()
2184 .item::<bool>());
2185 }
2186
2187 #[test]
2188 fn test_unary_expm1() {
2189 let x = array![-1.0];
2190 assert_float_eq! {
2191 expm1(&x).unwrap().item::<f32>(),
2192 (-1.0f32).exp_m1(),
2193 abs <= 1e-5
2194 };
2195
2196 let x = array![1.0];
2197 assert_float_eq! {
2198 expm1(&x).unwrap().item::<f32>(),
2199 1.0f32.exp_m1(),
2200 abs <= 1e-5
2201 };
2202
2203 let x = array![1];
2205 assert_eq!(expm1(&x).unwrap().dtype(), Dtype::Float32);
2206 assert_float_eq! {
2207 expm1(&x).unwrap().item::<f32>(),
2208 1.0f32.exp_m1(),
2209 abs <= 1e-5
2210 };
2211 }
2212
2213 #[test]
2214 fn test_unary_sin() {
2215 let x = array![0.0];
2216 assert_eq!(sin(&x).unwrap().item::<f32>(), 0.0);
2217
2218 let x = array![std::f32::consts::PI / 2.0];
2219 assert_float_eq! {
2220 sin(&x).unwrap().item::<f32>(),
2221 (std::f32::consts::PI / 2.0f32).sin(),
2222 abs <= 1e-5
2223 };
2224
2225 assert_eq!(sin(array!()).unwrap(), array!());
2226
2227 let x = array![0];
2229 assert_eq!(x.dtype(), Dtype::Int32);
2230 assert_float_eq! {
2231 sin(&x).unwrap().item::<f32>(),
2232 0.0f32.sin(),
2233 abs <= 1e-5
2234 };
2235
2236 let x = broadcast_to(&array!(1.0), &[2, 2, 2]).unwrap();
2238 let res = sin(&x).unwrap();
2239 let expected = Array::full::<f32>(&[2, 2, 2], array!(1.0f32.sin())).unwrap();
2240 assert!(all_close(&res, &expected, None, None, None)
2241 .unwrap()
2242 .item::<bool>());
2243
2244 let data = Array::from_slice(&[0.0, 1.0, 2.0, 3.0], &[2, 2]);
2245 let x = split(&data, 2, 1).unwrap();
2246 let expected = Array::from_slice(&[0.0f32.sin(), 2.0f32.sin()], &[2, 1]);
2247 assert!(all_close(sin(&x[0]).unwrap(), &expected, None, None, None)
2248 .unwrap()
2249 .item::<bool>());
2250 }
2251
2252 #[test]
2253 fn test_unary_cos() {
2254 let x = array![0.0];
2255 assert_float_eq! {
2256 cos(&x).unwrap().item::<f32>(),
2257 0.0f32.cos(),
2258 abs <= 1e-5
2259 };
2260
2261 let x = array![std::f32::consts::PI / 2.0];
2262 assert_float_eq! {
2263 cos(&x).unwrap().item::<f32>(),
2264 (std::f32::consts::PI / 2.0f32).cos(),
2265 abs <= 1e-5
2266 };
2267
2268 assert_eq!(cos(array!()).unwrap(), array!());
2269
2270 let x = array![0];
2272 assert_eq!(x.dtype(), Dtype::Int32);
2273 assert_float_eq! {
2274 cos(&x).unwrap().item::<f32>(),
2275 0.0f32.cos(),
2276 abs <= 1e-5
2277 };
2278
2279 let x = broadcast_to(&array!(1.0), &[2, 2, 2]).unwrap();
2281 let res = cos(&x).unwrap();
2282 let expected = Array::full::<f32>(&[2, 2, 2], array!(1.0f32.cos())).unwrap();
2283 assert!(all_close(&res, &expected, None, None, None)
2284 .unwrap()
2285 .item::<bool>());
2286
2287 let data = Array::from_slice(&[0.0, 1.0, 2.0, 3.0], &[2, 2]);
2288 let x = split(&data, 2, 1).unwrap();
2289 let expected = Array::from_slice(&[0.0f32.cos(), 2.0f32.cos()], &[2, 1]);
2290 assert!(all_close(cos(&x[0]).unwrap(), &expected, None, None, None)
2291 .unwrap()
2292 .item::<bool>());
2293 }
2294
2295 #[test]
2296 fn test_unary_degrees() {
2297 let x = array![0.0];
2298 assert_eq!(degrees(&x).unwrap().item::<f32>(), 0.0);
2299
2300 let x = array![std::f32::consts::PI / 2.0];
2301 assert_eq!(degrees(&x).unwrap().item::<f32>(), 90.0);
2302
2303 assert_eq!(degrees(array!()).unwrap(), array!());
2304
2305 let x = array![0];
2307 assert_eq!(x.dtype(), Dtype::Int32);
2308 assert_eq!(degrees(&x).unwrap().item::<f32>(), 0.0);
2309
2310 let x = broadcast_to(&array!(std::f32::consts::PI / 2.0), &[2, 2, 2]).unwrap();
2312 let res = degrees(&x).unwrap();
2313 let expected = Array::full::<f32>(&[2, 2, 2], array!(90.0)).unwrap();
2314 assert!(all_close(&res, &expected, None, None, None)
2315 .unwrap()
2316 .item::<bool>());
2317
2318 let angles = Array::from_slice(&[0.0, PI / 2.0, PI, 1.5 * PI], &[2, 2]);
2319 let x = split(&angles, 2, 1).unwrap();
2320 let expected = Array::from_slice(&[0.0, 180.0], &[2, 1]);
2321 assert!(
2322 all_close(degrees(&x[0]).unwrap(), &expected, None, None, None)
2323 .unwrap()
2324 .item::<bool>()
2325 );
2326 }
2327
2328 #[test]
2329 fn test_unary_radians() {
2330 let x = array![0.0];
2331 assert_eq!(radians(&x).unwrap().item::<f32>(), 0.0);
2332
2333 let x = array![90.0];
2334 assert_eq!(
2335 radians(&x).unwrap().item::<f32>(),
2336 std::f32::consts::PI / 2.0
2337 );
2338
2339 assert_eq!(radians(array!()).unwrap(), array!());
2340
2341 let x = array![90];
2343 assert_eq!(x.dtype(), Dtype::Int32);
2344 assert_eq!(
2345 radians(&x).unwrap().item::<f32>(),
2346 std::f32::consts::PI / 2.0
2347 );
2348
2349 let x = broadcast_to(&array!(90.0), &[2, 2, 2]).unwrap();
2351 let res = radians(&x).unwrap();
2352 let expected = Array::full::<f32>(&[2, 2, 2], array!(std::f32::consts::PI / 2.0)).unwrap();
2353 assert!(all_close(&res, &expected, None, None, None)
2354 .unwrap()
2355 .item::<bool>());
2356
2357 let angles = Array::from_slice(&[0.0, 90.0, 180.0, 270.0], &[2, 2]);
2358 let x = split(&angles, 2, 1).unwrap();
2359 let expected = Array::from_slice(&[0.0, PI], &[2, 1]);
2360 assert!(
2361 all_close(radians(&x[0]).unwrap(), &expected, None, None, None)
2362 .unwrap()
2363 .item::<bool>()
2364 );
2365 }
2366
2367 #[test]
2368 fn test_unary_log() {
2369 let x = array![0.0];
2370 assert_eq!(log(&x).unwrap().item::<f32>(), NEG_INF);
2371
2372 let x = array![1.0];
2373 assert_eq!(log(&x).unwrap().item::<f32>(), 0.0);
2374
2375 let x = array![1];
2377 assert_eq!(log(&x).unwrap().dtype(), Dtype::Float32);
2378 assert_eq!(log(&x).unwrap().item::<f32>(), 0.0);
2379
2380 let x = broadcast_to(&array!(1.0), &[2, 2, 2]).unwrap();
2382 let res = log(&x).unwrap();
2383 let expected = Array::full::<f32>(&[2, 2, 2], array!(0.0)).unwrap();
2384 assert!(all_close(&res, &expected, None, None, None)
2385 .unwrap()
2386 .item::<bool>());
2387
2388 let data = Array::from_slice(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
2389 let x = split(&data, 2, 1).unwrap();
2390 let expected = Array::from_slice(&[1.0f32.ln(), 3.0f32.ln()], &[2, 1]);
2391 assert!(all_close(log(&x[0]).unwrap(), &expected, None, None, None)
2392 .unwrap()
2393 .item::<bool>());
2394 }
2395
2396 #[test]
2397 fn test_unary_log2() {
2398 let x = array![0.0];
2399 assert_eq!(log2(&x).unwrap().item::<f32>(), NEG_INF);
2400
2401 let x = array![1.0];
2402 assert_eq!(log2(&x).unwrap().item::<f32>(), 0.0);
2403
2404 let x = array![1024.0];
2405 assert_eq!(log2(&x).unwrap().item::<f32>(), 10.0);
2406 }
2407
2408 #[test]
2409 fn test_unary_log10() {
2410 let x = array![0.0];
2411 assert_eq!(log10(&x).unwrap().item::<f32>(), NEG_INF);
2412
2413 let x = array![1.0];
2414 assert_eq!(log10(&x).unwrap().item::<f32>(), 0.0);
2415
2416 let x = array![1000.0];
2417 assert_eq!(log10(&x).unwrap().item::<f32>(), 3.0);
2418 }
2419
2420 #[test]
2421 fn test_unary_log1p() {
2422 let x = array![-1.0];
2423 assert_float_eq! {
2424 log1p(&x).unwrap().item::<f32>(),
2425 (-1.0f32).ln_1p(),
2426 abs <= 1e-5
2427 };
2428
2429 let x = array![1.0];
2430 assert_float_eq! {
2431 log1p(&x).unwrap().item::<f32>(),
2432 1.0f32.ln_1p(),
2433 abs <= 1e-5
2434 };
2435
2436 let x = array![1];
2438 assert_eq!(log1p(&x).unwrap().dtype(), Dtype::Float32);
2439 assert_float_eq! {
2440 log1p(&x).unwrap().item::<f32>(),
2441 1.0f32.ln_1p(),
2442 abs <= 1e-5
2443 };
2444
2445 let x = broadcast_to(&array!(1.0), &[2, 2, 2]).unwrap();
2447 let res = log1p(&x).unwrap();
2448 let expected = Array::full::<f32>(&[2, 2, 2], array!(1.0f32.ln_1p())).unwrap();
2449 assert!(all_close(&res, &expected, None, None, None)
2450 .unwrap()
2451 .item::<bool>());
2452
2453 let data = Array::from_slice(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
2454 let x = split(&data, 2, 1).unwrap();
2455 let expected = Array::from_slice(&[1.0f32.ln_1p(), 3.0f32.ln_1p()], &[2, 1]);
2456 assert!(
2457 all_close(log1p(&x[0]).unwrap(), &expected, None, None, None)
2458 .unwrap()
2459 .item::<bool>()
2460 );
2461 }
2462
2463 #[test]
2464 fn test_unary_sigmoid() {
2465 let x = array![0.0];
2466 assert_float_eq! {
2467 sigmoid(&x).unwrap().item::<f32>(),
2468 0.5,
2469 abs <= 1e-5
2470 };
2471
2472 let x = array![0];
2474 assert_eq!(sigmoid(&x).unwrap().dtype(), Dtype::Float32);
2475 assert_float_eq! {
2476 sigmoid(&x).unwrap().item::<f32>(),
2477 0.5,
2478 abs <= 1e-5
2479 };
2480
2481 let inf = f32::INFINITY;
2482 let x = array![inf];
2483 assert_eq!(sigmoid(&x).unwrap().item::<f32>(), 1.0);
2484
2485 let x = array![-inf];
2486 assert_eq!(sigmoid(&x).unwrap().item::<f32>(), 0.0);
2487 }
2488
2489 #[test]
2490 fn test_unary_square() {
2491 let x = array![3.0];
2492 assert_eq!(square(&x).unwrap().item::<f32>(), 9.0);
2493
2494 let x = array![2];
2495 assert_eq!(square(&x).unwrap().item::<i32>(), 4);
2496
2497 let x = Array::full::<f32>(&[3, 3], array!(2.0)).unwrap();
2498 assert!(all_close(
2499 square(&x).unwrap(),
2500 Array::full::<f32>(&[3, 3], array!(4.0)).unwrap(),
2501 None,
2502 None,
2503 None
2504 )
2505 .unwrap()
2506 .item::<bool>());
2507 }
2508
2509 #[test]
2510 fn test_unary_sqrt_rsqrt() {
2511 let x = array![4.0];
2512 assert_eq!(sqrt(&x).unwrap().item::<f32>(), 2.0);
2513 assert_eq!(rsqrt(&x).unwrap().item::<f32>(), 0.5);
2514
2515 let x = Array::full::<f32>(&[3, 3], array!(9.0)).unwrap();
2516 assert!(all_close(
2517 sqrt(&x).unwrap(),
2518 Array::full::<f32>(&[3, 3], array!(3.0)).unwrap(),
2519 None,
2520 None,
2521 None
2522 )
2523 .unwrap()
2524 .item::<bool>());
2525
2526 let x = array![4i32];
2527 assert_eq!(sqrt(&x).unwrap().item::<f32>(), 2.0);
2528 assert_eq!(rsqrt(&x).unwrap().item::<f32>(), 0.5);
2529 }
2530
2531 #[test]
2532 fn test_unary_reciprocal() {
2533 let x = array![8.0];
2534 assert_eq!(reciprocal(&x).unwrap().item::<f32>(), 0.125);
2535
2536 let x = array![2];
2537 let out = reciprocal(&x).unwrap();
2538 assert_eq!(out.dtype(), Dtype::Float32);
2539 assert_eq!(out.item::<f32>(), 0.5);
2540
2541 let x = Array::full::<f32>(&[3, 3], array!(2.0)).unwrap();
2542 assert!(all_close(
2543 reciprocal(&x).unwrap(),
2544 Array::full::<f32>(&[3, 3], array!(0.5)).unwrap(),
2545 None,
2546 None,
2547 None
2548 )
2549 .unwrap()
2550 .item::<bool>());
2551 }
2552
2553 #[test]
2554 fn test_unary_real_imag() {
2555 let x = Array::from_complex(complex64::new(0.0, 1.0));
2556 assert_eq!(real(&x).unwrap(), Array::from_f32(0.0));
2557 assert_eq!(imag(&x).unwrap(), Array::from_f32(1.0));
2558 }
2559
2560 #[test]
2561 fn test_binary_add() {
2562 let x = array![1.0];
2563 let y = array![1.0];
2564 let z = add(&x, &y).unwrap();
2565 assert_eq!(z.item::<f32>(), 2.0);
2566
2567 let z = &x + y;
2568 assert_eq!(z.item::<f32>(), 2.0);
2569
2570 let z = add(z, &x).unwrap();
2571 assert_eq!(z.item::<f32>(), 3.0);
2572
2573 let mut out = x.deep_clone();
2575 for _ in 0..10 {
2576 out = add(&out, &x).unwrap();
2577 }
2578 assert_eq!(out.item::<f32>(), 11.0);
2579
2580 let x = array!([1.0, 2.0, 3.0]);
2582 let y = array!([1.0, 2.0, 3.0]);
2583 let z = add(&x, &y).unwrap();
2584 assert_eq!(z.shape(), &[3]);
2585 assert_eq!(z, array!([2.0, 4.0, 6.0]));
2586
2587 let x = array!([1.0, 2.0, 3.0]);
2589 let y = &x + 2.0;
2590 assert_eq!(y.dtype(), Dtype::Float32);
2591 assert_eq!(y, array!([3.0, 4.0, 5.0]));
2592 let y = &x + 2.0;
2593 assert_eq!(y.dtype(), Dtype::Float32);
2594 assert_eq!(y, array!([3.0, 4.0, 5.0]));
2595
2596 let y = x + 2;
2598 assert_eq!(y.dtype(), Dtype::Float32);
2599
2600 let y = array!([1, 2, 3]) + 2.0;
2601 assert_eq!(y.dtype(), Dtype::Float32);
2602 assert_eq!(y, array!([3.0, 4.0, 5.0]));
2604
2605 let x = broadcast_to(&array!(1.0), &[10]).unwrap();
2607 let y = broadcast_to(&array!(2.0), &[10]).unwrap();
2608 let z = add(&x, &y).unwrap();
2609 assert_eq!(z, full::<f32>(&[10], array!(3.0)).unwrap());
2610
2611 let x = Array::from_slice(&[1.0, 2.0], &[1, 2]);
2612 let y = Array::from_slice(&[1.0, 2.0], &[2, 1]);
2613 let z = add(&x, &y).unwrap();
2614 assert_eq!(z.shape(), &[2, 2]);
2615 assert_eq!(z, Array::from_slice(&[2.0, 3.0, 3.0, 4.0], &[2, 2]));
2616
2617 let x = ones::<f32>(&[3, 2, 1]).unwrap();
2618 let z = x + 2.0;
2619 assert_eq!(z.shape(), &[3, 2, 1]);
2620 let expected = Array::from_slice(&[3.0, 3.0, 3.0, 3.0, 3.0, 3.0], &[3, 2, 1]);
2621 assert_eq!(z, expected);
2622
2623 let x = array!();
2625 let y = array!();
2626 let z = x + y;
2627 z.eval().unwrap();
2628 assert_eq!(z.size(), 0);
2629 assert_eq!(z.shape(), &[0]);
2630 }
2631
2632 #[test]
2633 fn test_binary_sub() {
2634 let x = array!([3.0, 2.0, 1.0]);
2635 let y = array!([1.0, 1.0, 1.0]);
2636 assert_eq!(x - y, array!([2.0, 1.0, 0.0]));
2637 }
2638
2639 #[test]
2640 fn test_binary_mul() {
2641 let x = array!([1.0, 2.0, 3.0]);
2642 let y = array!([2.0, 2.0, 2.0]);
2643 assert_eq!(x * y, array!([2.0, 4.0, 6.0]));
2644 }
2645
2646 #[test]
2647 fn test_binary_div() {
2648 let x = array![1.0];
2649 let y = array![1.0];
2650 assert_eq!(divide(&x, &y).unwrap().item::<f32>(), 1.0);
2651
2652 let x = array![1.0];
2653 let y = array![0.5];
2654 assert_eq!(divide(&x, &y).unwrap().item::<f32>(), 2.0);
2655
2656 let x = array![1.0];
2657 let y = array![4.0];
2658 assert_eq!(divide(&x, &y).unwrap().item::<f32>(), 0.25);
2659
2660 let x = array![true];
2661 let y = array![true];
2662 assert_eq!(divide(&x, &y).unwrap().item::<f32>(), 1.0);
2663
2664 let x = array![false];
2665 let y = array![true];
2666 assert_eq!(divide(&x, &y).unwrap().item::<f32>(), 0.0);
2667
2668 let x = array![true];
2669 let y = array![false];
2670 assert!(divide(&x, &y).unwrap().item::<f32>().is_infinite());
2671
2672 let x = array![false];
2673 let y = array![false];
2674 assert!(divide(&x, &y).unwrap().item::<f32>().is_nan());
2675 }
2676
2677 #[test]
2678 fn test_binary_maximum_minimum() {
2679 let x = array![1.0];
2680 let y = array![0.0];
2681 assert_eq!(maximum(&x, &y).unwrap().item::<f32>(), 1.0);
2682 assert_eq!(minimum(&x, &y).unwrap().item::<f32>(), 0.0);
2683
2684 let y = array![2.0];
2685 assert_eq!(maximum(&x, &y).unwrap().item::<f32>(), 2.0);
2686 assert_eq!(minimum(&x, &y).unwrap().item::<f32>(), 1.0);
2687 }
2688
2689 #[test]
2690 fn test_binary_logaddexp() {
2691 let x = array![0.0];
2692 let y = array![0.0];
2693 assert_float_eq! {
2694 logaddexp(&x, &y).unwrap().item::<f32>(),
2695 2.0f32.ln(),
2696 abs <= 1e-5
2697 };
2698
2699 let x = array!([0u32]);
2700 let y = array!([10000u32]);
2701 assert_eq!(logaddexp(&x, &y).unwrap().item::<f32>(), 10000.0);
2702
2703 let x = array![f32::INFINITY];
2704 let y = array![3.0];
2705 assert_eq!(logaddexp(&x, &y).unwrap().item::<f32>(), f32::INFINITY);
2706
2707 let x = array![f32::NEG_INFINITY];
2708 let y = array![3.0];
2709 assert_eq!(logaddexp(&x, &y).unwrap().item::<f32>(), 3.0);
2710
2711 let x = array![f32::NEG_INFINITY];
2712 let y = array![f32::NEG_INFINITY];
2713 assert_eq!(logaddexp(&x, &y).unwrap().item::<f32>(), f32::NEG_INFINITY);
2714
2715 let x = array![f32::INFINITY];
2716 let y = array![f32::INFINITY];
2717 assert_eq!(logaddexp(&x, &y).unwrap().item::<f32>(), f32::INFINITY);
2718
2719 let x = array![f32::NEG_INFINITY];
2720 let y = array![f32::INFINITY];
2721 assert_eq!(logaddexp(&x, &y).unwrap().item::<f32>(), f32::INFINITY);
2722 }
2723
2724 #[test]
2725 fn test_basic_clip() {
2726 let a = array!([1.0, 4.0, 3.0, 8.0, 5.0]);
2727 let expected = array!([2.0, 4.0, 3.0, 6.0, 5.0]);
2728 let clipped = clip(&a, (array!(2.0), array!(6.0))).unwrap();
2729 assert_eq!(clipped, expected);
2730
2731 let clipped = clip(&a, (2.0, 6.0)).unwrap();
2733 assert_eq!(clipped, expected);
2734 }
2735
2736 #[test]
2737 fn test_clip_with_only_min() {
2738 let a = array!([-1.0, 1.0, 0.0, 5.0]);
2739 let expected = array!([0.0, 1.0, 0.0, 5.0]);
2740 let clipped = clip(&a, (array!(0.0), ())).unwrap();
2741 assert_eq!(clipped, expected);
2742
2743 let clipped = clip(&a, (0.0, ())).unwrap();
2745 assert_eq!(clipped, expected);
2746 }
2747
2748 #[test]
2749 fn test_clip_with_only_max() {
2750 let a = array!([2.0, 3.0, 4.0, 5.0]);
2751 let expected = array!([2.0, 3.0, 4.0, 4.0]);
2752 let clipped = clip(&a, ((), array!(4.0))).unwrap();
2753 assert_eq!(clipped, expected);
2754
2755 let clipped = clip(&a, ((), 4.0)).unwrap();
2757 assert_eq!(clipped, expected);
2758 }
2759
2760 #[test]
2761 fn test_tensordot() {
2762 let x = reshape(arange::<_, f32>(None, 60.0, None).unwrap(), &[3, 4, 5]).unwrap();
2763 let y = reshape(arange::<_, f32>(None, 24.0, None).unwrap(), &[4, 3, 2]).unwrap();
2764 let z = tensordot_axes(&x, &y, &[1i32, 0], &[0i32, 1]).unwrap();
2765 let expected = Array::from_slice(
2766 &[4400, 4730, 4532, 4874, 4664, 5018, 4796, 5162, 4928, 5306],
2767 &[5, 2],
2768 );
2769 assert_eq!(z, expected);
2770
2771 let x = reshape(arange::<_, f32>(None, 360.0, None).unwrap(), &[3, 4, 5, 6]).unwrap();
2772 let y = reshape(arange::<_, f32>(None, 360.0, None).unwrap(), &[6, 4, 5, 3]).unwrap();
2773 assert!(tensordot_axes(&x, &y, &[2, 1, 3], &[1, 2, 0]).is_err());
2774
2775 let x = reshape(arange::<_, f32>(None, 60.0, None).unwrap(), &[3, 4, 5]).unwrap();
2776 let y = reshape(arange::<_, f32>(None, 120.0, None).unwrap(), &[4, 5, 6]).unwrap();
2777
2778 let z = tensordot_axis(&x, &y, 2).unwrap();
2779 let expected = Array::from_slice(
2780 &[
2781 14820.0, 15010.0, 15200.0, 15390.0, 15580.0, 15770.0, 37620.0, 38210.0, 38800.0,
2782 39390.0, 39980.0, 40570.0, 60420.0, 61410.0, 62400.0, 63390.0, 64380.0, 65370.0,
2783 ],
2784 &[3, 6],
2785 );
2786 assert_eq!(z, expected);
2787 }
2788
2789 #[test]
2790 fn test_outer() {
2791 let x = arange::<_, f32>(1.0, 5.0, None).unwrap();
2792 let y = arange::<_, f32>(1.0, 4.0, None).unwrap();
2793 let z = outer(&x, &y).unwrap();
2794 let expected = Array::from_slice(
2795 &[1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 3.0, 6.0, 9.0, 4.0, 8.0, 12.0],
2796 &[4, 3],
2797 );
2798 assert_eq!(z, expected);
2799
2800 let x = ones::<f32>(&[5]).unwrap();
2801 let y = linspace::<_, f32>(-2.0, 2.0, 5).unwrap();
2802 let z = outer(&x, &y).unwrap();
2803 let expected = Array::from_slice(
2804 &[
2805 -2.0, -1.0, 0.0, 1.0, 2.0, -2.0, -1.0, 0.0, 1.0, 2.0, -2.0, -1.0, 0.0, 1.0, 2.0,
2806 -2.0, -1.0, 0.0, 1.0, 2.0, -2.0, -1.0, 0.0, 1.0, 2.0,
2807 ],
2808 &[5, 5],
2809 );
2810 assert_eq!(z, expected);
2811 }
2812
2813 #[test]
2814 fn test_inner() {
2815 let x = reshape(arange::<_, f32>(None, 5.0, None).unwrap(), &[1, 5]).unwrap();
2816 let y = reshape(arange::<_, f32>(None, 6.0, None).unwrap(), &[2, 3]).unwrap();
2817 assert!(inner(&x, &y).is_err());
2818
2819 let x = array!([1.0, 2.0, 3.0]);
2820 let y = array!([0.0, 1.0, 0.0]);
2821 let z = inner(&x, &y).unwrap();
2822 assert_eq!(z.item::<f32>(), 2.0);
2823
2824 let x = reshape(arange::<_, f32>(None, 24.0, None).unwrap(), &[2, 3, 4]).unwrap();
2825 let y = arange::<_, f32>(None, 4.0, None).unwrap();
2826 let z = inner(&x, &y).unwrap();
2827 let expected = Array::from_slice(&[14.0, 38.0, 62.0, 86.0, 110.0, 134.0], &[2, 3]);
2828 assert_eq!(z, expected);
2829
2830 let x = reshape(arange::<_, f32>(None, 2.0, None).unwrap(), &[1, 1, 2]).unwrap();
2831 let y = reshape(arange::<_, f32>(None, 6.0, None).unwrap(), &[3, 2]).unwrap();
2832 let z = inner(&x, &y).unwrap();
2833 let expected = Array::from_slice(&[1.0, 3.0, 5.0], &[1, 1, 3]);
2834 assert_eq!(z, expected);
2835
2836 let x = eye::<f32>(2, None, None).unwrap();
2837 let y = Array::from_f32(7.0);
2838 let z = inner(&x, &y).unwrap();
2839 let expected = Array::from_slice(&[7.0, 0.0, 0.0, 7.0], &[2, 2]);
2840 assert_eq!(z, expected);
2841 }
2842
2843 #[test]
2844 fn test_divmod() {
2845 let x = array!([1.0, 2.0, 3.0]);
2846 let y = array!([1.0, 1.0, 1.0]);
2847 let out = divmod(&x, &y).unwrap();
2848 assert_eq!(out.0, array!([1.0, 2.0, 3.0]));
2849 assert_eq!(out.1, array!([0.0, 0.0, 0.0]));
2850
2851 let x = array!([5.0, 6.0, 7.0]);
2852 let y = array!([2.0, 2.0, 2.0]);
2853 let out = divmod(&x, &y).unwrap();
2854 assert_eq!(out.0, array!([2.0, 3.0, 3.0]));
2855 assert_eq!(out.1, array!([1.0, 0.0, 1.0]));
2856
2857 let x = array!([5.0, 6.0, 7.0]);
2858 let y = array!([2.0, 2.0, 2.0]);
2859 let out = divmod(&x, &y).unwrap();
2860 assert_eq!(out.0, array!([2.0, 3.0, 3.0]));
2861 assert_eq!(out.1, array!([1.0, 0.0, 1.0]));
2862
2863 let x = array![complex64::new(1.0, 0.0)];
2864 let y = array![complex64::new(2.0, 0.0)];
2865 assert!(divmod(&x, &y).is_err());
2866
2867 let x = array![1.0];
2869 let y = array![2.0];
2870 let (quo, rem) = divmod(&x, &y).unwrap();
2871 eval([&quo, &rem]).unwrap();
2872 assert_eq!(quo.item::<f32>(), 0.0);
2873 assert_eq!(rem.item::<f32>(), 1.0);
2874
2875 let x = array![1.0];
2877 let y = array![2.0];
2878 let (quo, rem) = divmod(&x, &y).unwrap();
2879 let z = quo + rem;
2880 assert_eq!(z.item::<f32>(), 1.0);
2881
2882 let mut out_holder = {
2884 let (quo, _) = divmod(&x, &y).unwrap();
2885 vec![quo]
2886 };
2887 eval(out_holder.iter()).unwrap();
2888 assert_eq!(out_holder[0].item::<f32>(), 0.0);
2889
2890 out_holder.clear();
2892 let out_holder = {
2893 let (_, rem) = divmod(&x, &y).unwrap();
2894 vec![rem]
2895 };
2896 eval(out_holder.iter()).unwrap();
2897 assert_eq!(out_holder[0].item::<f32>(), 1.0);
2898 }
2899}