1use std::f32::consts::PI;
2
3use crate::module::{Module, Param};
4use crate::{
5 array,
6 error::{Exception, Result},
7 ops::{abs, exp, log_sum_exp, maximum, minimum, multiply, which},
8 transforms::compile::compile,
9 Array,
10};
11use mlx_internal_macros::{generate_builder, Buildable, Builder};
12use mlx_macros::ModuleParameters;
13
14pub fn sigmoid(x: impl AsRef<Array>) -> Result<Array> {
25 crate::ops::sigmoid(x.as_ref())
26}
27
28pub fn relu(x: impl AsRef<Array>) -> Result<Array> {
36 crate::ops::maximum(x.as_ref(), &array!(0))
37}
38
39pub fn leaky_relu(x: impl AsRef<Array>, neg_slope: impl Into<Option<f32>>) -> Result<Array> {
49 let neg_slope = array!(neg_slope.into().unwrap_or(0.01));
50 compiled_leaky_relu(x.as_ref(), &neg_slope)
53}
54
55pub fn log_softmax(x: impl AsRef<Array>, axis: impl Into<Option<i32>>) -> Result<Array> {
63 let x = x.as_ref();
64 let axis = axis.into().unwrap_or(-1);
65 x.subtract(log_sum_exp(x, &[axis], true)?)
66}
67
68pub fn elu(x: impl AsRef<Array>, alpha: impl Into<Option<f32>>) -> Result<Array> {
81 let alpha = array!(alpha.into().unwrap_or(1.0));
82 compiled_elu(x.as_ref(), &alpha)
85}
86
87pub fn relu6(x: impl AsRef<Array>) -> Result<Array> {
95 compiled_relu6(x.as_ref())
96}
97
98pub fn softplus(x: impl AsRef<Array>) -> Result<Array> {
106 crate::ops::log_add_exp(x.as_ref(), &array!(0))
107}
108
109pub fn softsign(x: impl AsRef<Array>) -> Result<Array> {
117 compiled_softsign(x.as_ref())
118}
119
120pub fn celu(x: impl AsRef<Array>, alpha: impl Into<Option<f32>>) -> Result<Array> {
128 let alpha = array!(alpha.into().unwrap_or(1.0));
129 compiled_celu(x.as_ref(), &alpha)
132}
133
134pub fn silu(x: impl AsRef<Array>) -> Result<Array> {
142 compiled_silu(x.as_ref())
143}
144
145pub fn log_sigmoid(x: impl AsRef<Array>) -> Result<Array> {
153 compiled_log_sigmoid(x.as_ref())
154}
155
156pub fn gelu(x: impl AsRef<Array>) -> Result<Array> {
164 compiled_gelu(x.as_ref())
165}
166
167pub fn gelu_approximate(x: impl AsRef<Array>) -> Result<Array> {
175 compiled_gelu_approximate(x.as_ref())
176}
177
178pub fn gelu_fast_approximate(x: impl AsRef<Array>) -> Result<Array> {
186 compiled_gelu_fast_approximate(x.as_ref())
187}
188
189pub fn glu(x: impl AsRef<Array>, axis: impl Into<Option<i32>>) -> Result<Array> {
194 let split = x.as_ref().split_equal(2, axis)?;
195 let (a, b) = (&split[0], &split[1]);
196 Ok(a * sigmoid(b)?)
197}
198
199pub fn step(x: impl AsRef<Array>, threshold: impl Into<Option<f32>>) -> Result<Array> {
210 let threshold = array!(threshold.into().unwrap_or(0.0));
211 crate::ops::r#where(&x.as_ref().gt(threshold)?, &array!(1), &array!(0))
212}
213
214pub fn selu(x: impl AsRef<Array>) -> Result<Array> {
222 compiled_selu(x.as_ref())
223}
224
225pub fn prelu(x: impl AsRef<Array>, alpha: impl AsRef<Array>) -> Result<Array> {
233 compiled_prelu(x.as_ref(), alpha.as_ref())
234}
235
236pub fn mish(x: impl AsRef<Array>) -> Result<Array> {
248 compiled_mish(x.as_ref())
249}
250
251pub fn hard_swish(x: impl AsRef<Array>) -> Result<Array> {
259 compiled_hard_swish(x.as_ref())
260}
261
262generate_builder! {
263 #[derive(Debug, Clone, ModuleParameters, Buildable)]
268 #[module(root = crate)]
269 #[buildable(root = crate)]
270 #[builder(root = crate)]
271 pub struct Glu {
272 #[builder(optional, default = Glu::DEFAULT_AXIS)]
274 pub axis: i32,
275 }
276}
277
278impl Glu {
279 pub const DEFAULT_AXIS: i32 = -1;
281}
282
283impl Module<&Array> for Glu {
284 type Error = Exception;
285 type Output = Array;
286
287 fn forward(&mut self, x: &Array) -> Result<Array> {
288 glu(x, self.axis)
289 }
290
291 fn training_mode(&mut self, _: bool) {}
292}
293
294#[derive(Debug, Clone, ModuleParameters)]
305#[module(root = crate)]
306pub struct Sigmoid;
307
308impl Module<&Array> for Sigmoid {
309 type Error = Exception;
310 type Output = Array;
311
312 fn forward(&mut self, x: &Array) -> Result<Array> {
313 sigmoid(x)
314 }
315
316 fn training_mode(&mut self, _: bool) {}
317}
318
319#[derive(Debug, Clone, ModuleParameters)]
331#[module(root = crate)]
332pub struct Mish;
333
334impl Module<&Array> for Mish {
335 type Error = Exception;
336 type Output = Array;
337
338 fn forward(&mut self, x: &Array) -> Result<Array> {
339 mish(x)
340 }
341
342 fn training_mode(&mut self, _: bool) {}
343}
344
345#[derive(Debug, Clone, ModuleParameters)]
353#[module(root = crate)]
354pub struct Relu;
355
356impl Module<&Array> for Relu {
357 type Error = Exception;
358 type Output = Array;
359
360 fn forward(&mut self, x: &Array) -> Result<Array> {
361 relu(x)
362 }
363
364 fn training_mode(&mut self, _: bool) {}
365}
366
367generate_builder! {
368 #[derive(Debug, Clone, ModuleParameters, Buildable)]
376 #[module(root = crate)]
377 #[buildable(root = crate)]
378 #[builder(root = crate)]
379 pub struct LeakyRelu {
380 #[builder(optional, default = LeakyRelu::DEFAULT_NEG_SLOPE)]
382 pub neg_slope: f32,
383 }
384}
385
386impl LeakyRelu {
387 pub const DEFAULT_NEG_SLOPE: f32 = 0.01;
389}
390
391impl Module<&Array> for LeakyRelu {
392 type Error = Exception;
393 type Output = Array;
394
395 fn forward(&mut self, x: &Array) -> Result<Array> {
396 leaky_relu(x, self.neg_slope)
397 }
398
399 fn training_mode(&mut self, _: bool) {}
400}
401
402#[derive(Debug, Clone, ModuleParameters)]
410#[module(root = crate)]
411pub struct Relu6;
412
413impl Module<&Array> for Relu6 {
414 type Error = Exception;
415 type Output = Array;
416
417 fn forward(&mut self, x: &Array) -> Result<Array> {
418 relu6(x)
419 }
420
421 fn training_mode(&mut self, _: bool) {}
422}
423
424generate_builder! {
425 #[derive(Debug, Clone, ModuleParameters, Buildable)]
433 #[module(root = crate)]
434 #[buildable(root = crate)]
435 #[builder(root = crate)]
436 pub struct Softmax {
437 #[builder(optional, default = Softmax::DEFAULT_AXIS)]
439 pub axis: i32,
440 }
441}
442
443impl Softmax {
444 pub const DEFAULT_AXIS: i32 = -1;
446}
447
448impl Module<&Array> for Softmax {
449 type Error = Exception;
450 type Output = Array;
451
452 fn forward(&mut self, x: &Array) -> Result<Array> {
453 crate::ops::softmax(x, &[self.axis], None)
454 }
455
456 fn training_mode(&mut self, _: bool) {}
457}
458
459#[derive(Debug, Clone, ModuleParameters)]
467#[module(root = crate)]
468pub struct Softplus;
469
470impl Module<&Array> for Softplus {
471 type Error = Exception;
472 type Output = Array;
473
474 fn forward(&mut self, x: &Array) -> Result<Array> {
475 softplus(x)
476 }
477
478 fn training_mode(&mut self, _: bool) {}
479}
480
481#[derive(Debug, Clone, ModuleParameters)]
489#[module(root = crate)]
490pub struct Softsign;
491
492impl Module<&Array> for Softsign {
493 type Error = Exception;
494 type Output = Array;
495
496 fn forward(&mut self, x: &Array) -> Result<Array> {
497 softsign(x)
498 }
499
500 fn training_mode(&mut self, _: bool) {}
501}
502
503generate_builder! {
504 #[derive(Debug, Clone, ModuleParameters, Buildable)]
513 #[module(root = crate)]
514 #[buildable(root = crate)]
515 #[builder(root = crate)]
516 pub struct Celu {
517 #[builder(optional, default = Celu::DEFAULT_ALPHA)]
519 pub alpha: f32,
520 }
521}
522
523impl Celu {
524 pub const DEFAULT_ALPHA: f32 = 1.0;
526}
527
528impl Module<&Array> for Celu {
529 type Error = Exception;
530 type Output = Array;
531
532 fn forward(&mut self, x: &Array) -> Result<Array> {
533 celu(x, self.alpha)
534 }
535
536 fn training_mode(&mut self, _: bool) {}
537}
538
539#[derive(Debug, Clone, ModuleParameters)]
547#[module(root = crate)]
548pub struct Silu;
549
550impl Module<&Array> for Silu {
551 type Error = Exception;
552 type Output = Array;
553
554 fn forward(&mut self, x: &Array) -> Result<Array> {
555 silu(x)
556 }
557
558 fn training_mode(&mut self, _: bool) {}
559}
560
561generate_builder! {
562 #[derive(Debug, Clone, ModuleParameters, Buildable)]
570 #[module(root = crate)]
571 #[buildable(root = crate)]
572 #[builder(root = crate)]
573 pub struct LogSoftmax {
574 #[builder(optional, default = LogSoftmax::DEFAULT_AXIS)]
576 pub axis: i32,
577 }
578}
579
580impl LogSoftmax {
581 pub const DEFAULT_AXIS: i32 = -1;
583}
584
585impl Module<&Array> for LogSoftmax {
586 type Error = Exception;
587 type Output = Array;
588
589 fn forward(&mut self, x: &Array) -> Result<Array> {
590 log_softmax(x, self.axis)
591 }
592
593 fn training_mode(&mut self, _: bool) {}
594}
595
596#[derive(Debug, Clone, ModuleParameters)]
604#[module(root = crate)]
605pub struct LogSigmoid;
606
607impl Module<&Array> for LogSigmoid {
608 type Error = Exception;
609 type Output = Array;
610
611 fn forward(&mut self, x: &Array) -> Result<Array> {
612 log_sigmoid(x)
613 }
614
615 fn training_mode(&mut self, _: bool) {}
616}
617
618#[derive(Debug, Clone, ModuleParameters, Buildable)]
626#[module(root = crate)]
627#[buildable(root = crate)]
628pub struct Prelu {
629 #[param]
631 #[builder(ignore)]
632 pub weight: Param<Array>, }
634
635#[derive(Debug, Clone, Builder)]
637#[builder(
638 root = crate,
639 build_with = build_prelu,
640 default_infallible,
641 err = Exception,
642)]
643pub struct PreluBuilder {
644 #[builder(optional, default = Prelu::DEFAULT_COUNT)]
646 pub count: i32,
647
648 #[builder(optional, default = Prelu::DEFAULT_VALUE)]
650 pub value: f32,
651}
652
653fn build_prelu(builder: PreluBuilder) -> Result<Prelu> {
655 let count = builder.count;
656 let value = builder.value;
657 let weight = Param::new(crate::ops::full::<f32>(&[count], &array!(value))?);
658 Ok(Prelu { weight })
659}
660
661impl Prelu {
662 pub const DEFAULT_COUNT: i32 = 1;
664
665 pub const DEFAULT_VALUE: f32 = 0.25;
667}
668
669impl Module<&Array> for Prelu {
670 type Error = Exception;
671 type Output = Array;
672
673 fn forward(&mut self, x: &Array) -> Result<Array> {
674 prelu(x, &self.weight)
675 }
676
677 fn training_mode(&mut self, _: bool) {}
678}
679
680#[derive(Debug, Clone, Copy, Default)]
682pub enum GeluApprox {
683 #[default]
685 None,
686
687 Precise,
689
690 Fast,
692}
693
694generate_builder! {
695 #[derive(Debug, Clone, ModuleParameters, Buildable)]
703 #[module(root = crate)]
704 #[buildable(root = crate)]
705 #[builder(root = crate)]
706 pub struct Gelu {
707 #[builder(optional, default = GeluApprox::None)]
709 pub approximate: GeluApprox,
710 }
711}
712
713impl Module<&Array> for Gelu {
714 type Error = Exception;
715 type Output = Array;
716
717 fn forward(&mut self, x: &Array) -> Result<Array> {
718 match self.approximate {
719 GeluApprox::None => gelu(x),
720 GeluApprox::Precise => gelu_approximate(x),
721 GeluApprox::Fast => gelu_fast_approximate(x),
722 }
723 }
724
725 fn training_mode(&mut self, _: bool) {}
726}
727
728#[derive(Debug, Clone, ModuleParameters)]
730#[module(root = crate)]
731pub struct Tanh;
732
733impl Module<&Array> for Tanh {
734 type Error = Exception;
735 type Output = Array;
736
737 fn forward(&mut self, x: &Array) -> Result<Array> {
738 crate::ops::tanh(x)
739 }
740
741 fn training_mode(&mut self, _: bool) {}
742}
743
744#[derive(Debug, Clone, ModuleParameters)]
752#[module(root = crate)]
753pub struct HardSwish;
754
755impl Module<&Array> for HardSwish {
756 type Error = Exception;
757 type Output = Array;
758
759 fn forward(&mut self, x: &Array) -> Result<Array> {
760 hard_swish(x)
761 }
762
763 fn training_mode(&mut self, _: bool) {}
764}
765
766generate_builder! {
767 #[derive(Debug, Clone, ModuleParameters, Buildable)]
778 #[module(root = crate)]
779 #[buildable(root = crate)]
780 #[builder(root = crate)]
781 pub struct Step {
782 #[builder(optional, default = Step::DEFAULT_THRESHOLD)]
784 pub threshold: f32,
785 }
786}
787
788impl Step {
789 pub const DEFAULT_THRESHOLD: f32 = 0.0;
791}
792
793impl Module<&Array> for Step {
794 type Error = Exception;
795 type Output = Array;
796
797 fn forward(&mut self, x: &Array) -> Result<Array> {
798 step(x, self.threshold)
799 }
800
801 fn training_mode(&mut self, _: bool) {}
802}
803
804#[derive(Debug, Clone, ModuleParameters)]
812#[module(root = crate)]
813pub struct Selu;
814
815impl Module<&Array> for Selu {
816 type Error = Exception;
817 type Output = Array;
818
819 fn forward(&mut self, x: &Array) -> Result<Array> {
820 selu(x)
821 }
822
823 fn training_mode(&mut self, _: bool) {}
824}
825
826#[inline]
831fn compiled_leaky_relu(x: &Array, neg_slope: &Array) -> Result<Array> {
832 let f = |(x_, neg_slope_): (&Array, &Array)| {
833 let a = multiply(neg_slope_, x_)?;
835 maximum(&a, x_)
836 };
837 let mut compiled = compile(f, true);
838 compiled((x, neg_slope))
839}
840
841#[inline]
842fn compiled_elu(x: &Array, alpha: &Array) -> Result<Array> {
843 let f = |(x_, alpha_): (&Array, &Array)| {
844 which(&x_.gt(&array!(0.0))?, x_, alpha_ * (exp(x_)? - array!(1.0)))
845 };
846 let mut compiled = compile(f, true);
847 compiled((x, alpha))
848}
849
850#[inline]
851fn compiled_relu6(x: &Array) -> Result<Array> {
852 let f = |x_: &Array| minimum(maximum(x_, &array!(0.0))?, &array!(6.0));
853 let mut compiled = compile(f, true);
854 compiled(x)
855}
856
857#[inline]
858fn compiled_softsign(x: &Array) -> Result<Array> {
859 let f = |x_: &Array| x_.divide(array!(1.0) + abs(x_)?);
860 let mut compiled = compile(f, true);
861 compiled(x)
862}
863
864#[inline]
865fn compiled_celu(x: &Array, alpha: &Array) -> Result<Array> {
866 let f = |(x_, alpha_): (&Array, &Array)| {
867 maximum(x_, &array!(0.0))?
868 .add(alpha_.multiply(exp(&(minimum(x_, &array!(0.0))? / alpha_))? - array!(1.0))?)
869 };
870 let mut compiled = compile(f, true);
871 compiled((x, alpha))
872}
873
874#[inline]
875fn compiled_silu(x: &Array) -> Result<Array> {
876 let f = |x_: &Array| x_.multiply(sigmoid(x_)?);
877 let mut compiled = compile(f, true);
878 compiled(x)
879}
880
881#[inline]
882fn compiled_log_sigmoid(x: &Array) -> Result<Array> {
883 let f = |x_: &Array| Ok(-softplus(&(-x_))?);
884 let mut compiled = compile(f, true);
885 compiled(x)
886}
887
888#[inline]
889fn compiled_gelu(x: &Array) -> Result<Array> {
890 use crate::ops::erf;
891 let f = |x_: &Array| {
892 x_.multiply(array!(1) + erf(&(x_ / array!(2f32.sqrt())))?)?
893 .divide(array!(2.0))
894 };
895 let mut compiled = compile(f, true);
896 compiled(x)
897}
898
899#[inline]
900fn compiled_gelu_approximate(x: &Array) -> Result<Array> {
901 use crate::ops::{sqrt, tanh};
902
903 let f = move |x_: &Array| {
904 array!(0.5).multiply(x_)?.multiply(
906 array!(1.0).add(tanh(
907 &(sqrt(&array!(2.0 / PI))?
908 .multiply(x_ + array!(0.044715).multiply(x_.power(&array!(3))?)?)?),
909 )?)?,
910 )
911 };
912 let mut compiled = compile(f, true);
913 compiled(x)
914}
915
916#[inline]
917fn compiled_gelu_fast_approximate(x: &Array) -> Result<Array> {
918 let f = |x_: &Array| x_.multiply(sigmoid(&(array!(1.773) * x_))?);
919 let mut compiled = compile(f, true);
920 compiled(x)
921}
922
923#[inline]
924fn compiled_selu(x: &Array) -> Result<Array> {
925 let f = |x_: &Array| elu(x_, 1.67326)?.multiply(array!(1.0507));
926 let mut compiled = compile(f, true);
927 compiled(x)
928}
929
930#[inline]
931fn compiled_prelu(x: &Array, alpha: &Array) -> Result<Array> {
932 let f = |(x_, alpha_): (&Array, &Array)| {
933 maximum(&array!(0.0), x_)?.add(alpha_ * minimum(&array!(0.0), x_)?)
934 };
935 let mut compiled = compile(f, true);
936 compiled((x, alpha))
937}
938
939#[inline]
940fn compiled_mish(x: &Array) -> Result<Array> {
941 use crate::ops::tanh;
942
943 let f = |x_: &Array| x_.multiply(tanh(&softplus(x_)?)?);
944 let mut compiled = compile(f, true);
945 compiled(x)
946}
947
948#[inline]
949fn compiled_hard_swish(x: &Array) -> Result<Array> {
950 let f = |x_: &Array| {
951 let max_x_plus_3 = maximum(&(x_ + array!(3.0)), &array!(0.0))?;
952 x_.multiply(minimum(&max_x_plus_3, &array!(6.0))?)?
953 .divide(&array!(6.0))
954 };
955 let mut compiled = compile(f, true);
956 compiled(x)
957}
958
959#[cfg(test)]
962mod tests {
963 use crate::{builder::Builder, random::uniform, Dtype};
964 use float_eq::assert_float_eq;
965
966 use super::*;
967
968 #[test]
969 fn test_glu() {
970 crate::random::seed(850).unwrap();
971 let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
972 assert_eq!(a.shape(), &[2, 8, 16]);
973 assert_eq!(a.dtype(), Dtype::Float32);
974 assert_float_eq!(
975 a.mean(None, None).unwrap().item::<f32>(),
976 0.547_252_66,
977 abs <= 0.010_945_053
978 );
979 assert_float_eq!(
980 a.sum(None, None).unwrap().item::<f32>(),
981 140.096_68,
982 abs <= 2.801_933_5
983 );
984 let result = Glu::new().forward(&a).unwrap();
985 assert_eq!(result.shape(), &[2, 8, 8]);
986 assert_eq!(result.dtype(), Dtype::Float32);
987 assert_float_eq!(
988 result.mean(None, None).unwrap().item::<f32>(),
989 0.333_276_75,
990 abs <= 0.006_665_535
991 );
992 assert_float_eq!(
993 result.sum(None, None).unwrap().item::<f32>(),
994 42.659_424,
995 abs <= 0.853_188_46
996 );
997 }
998
999 #[test]
1000 fn test_sigmoid() {
1001 crate::random::seed(589).unwrap();
1002 let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
1003 assert_eq!(a.shape(), &[2, 8, 16]);
1004 assert_eq!(a.dtype(), Dtype::Float32);
1005 assert_float_eq!(
1006 a.mean(None, None).unwrap().item::<f32>(),
1007 0.529_697_9,
1008 abs <= 0.010_593_958
1009 );
1010 assert_float_eq!(
1011 a.sum(None, None).unwrap().item::<f32>(),
1012 135.602_66,
1013 abs <= 2.712_053_3
1014 );
1015 let result = Sigmoid.forward(&a).unwrap();
1016 assert_eq!(result.shape(), &[2, 8, 16]);
1017 assert_eq!(result.dtype(), Dtype::Float32);
1018 assert_float_eq!(
1019 result.mean(None, None).unwrap().item::<f32>(),
1020 0.627_014,
1021 abs <= 0.012_540_28
1022 );
1023 assert_float_eq!(
1024 result.sum(None, None).unwrap().item::<f32>(),
1025 160.515_58,
1026 abs <= 3.210_311_7
1027 );
1028 }
1029
1030 #[test]
1031 fn test_mish() {
1032 crate::random::seed(122).unwrap();
1033 let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
1034 assert_eq!(a.shape(), &[2, 8, 16]);
1035 assert_eq!(a.dtype(), Dtype::Float32);
1036 assert_float_eq!(
1037 a.mean(None, None).unwrap().item::<f32>(),
1038 0.501_719_8,
1039 abs <= 0.010_034_395
1040 );
1041 assert_float_eq!(
1042 a.sum(None, None).unwrap().item::<f32>(),
1043 128.440_26,
1044 abs <= 2.568_805_2
1045 );
1046 let result = Mish.forward(&a).unwrap();
1047 assert_eq!(result.shape(), &[2, 8, 16]);
1048 assert_eq!(result.dtype(), Dtype::Float32);
1049 assert_float_eq!(
1050 result.mean(None, None).unwrap().item::<f32>(),
1051 0.395_375_73,
1052 abs <= 0.007_907_514
1053 );
1054 assert_float_eq!(
1055 result.sum(None, None).unwrap().item::<f32>(),
1056 101.216_19,
1057 abs <= 2.024_323_7
1058 );
1059 }
1060
1061 #[test]
1062 fn test_relu() {
1063 crate::random::seed(400).unwrap();
1064 let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
1065 assert_eq!(a.shape(), &[2, 8, 16]);
1066 assert_eq!(a.dtype(), Dtype::Float32);
1067 assert_float_eq!(
1068 a.mean(None, None).unwrap().item::<f32>(),
1069 0.478_322_74,
1070 abs <= 0.009_566_455
1071 );
1072 assert_float_eq!(
1073 a.sum(None, None).unwrap().item::<f32>(),
1074 122.450_62,
1075 abs <= 2.449_012_5
1076 );
1077 let result = Relu.forward(&a).unwrap();
1078 assert_eq!(result.shape(), &[2, 8, 16]);
1079 assert_eq!(result.dtype(), Dtype::Float32);
1080 assert_float_eq!(
1081 result.mean(None, None).unwrap().item::<f32>(),
1082 0.478_322_74,
1083 abs <= 0.009_566_455
1084 );
1085 assert_float_eq!(
1086 result.sum(None, None).unwrap().item::<f32>(),
1087 122.450_62,
1088 abs <= 2.449_012_5
1089 );
1090 }
1091
1092 #[test]
1093 fn test_leaky_relu() {
1094 crate::random::seed(93).unwrap();
1095 let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
1096 assert_eq!(a.shape(), &[2, 8, 16]);
1097 assert_eq!(a.dtype(), Dtype::Float32);
1098 assert_float_eq!(
1099 a.mean(None, None).unwrap().item::<f32>(),
1100 0.499_930_68,
1101 abs <= 0.009_998_614
1102 );
1103 assert_float_eq!(
1104 a.sum(None, None).unwrap().item::<f32>(),
1105 127.982_254,
1106 abs <= 2.559_645_2
1107 );
1108 let result = LeakyRelu::new().forward(&a).unwrap();
1109 assert_eq!(result.shape(), &[2, 8, 16]);
1110 assert_eq!(result.dtype(), Dtype::Float32);
1111 assert_float_eq!(
1112 result.mean(None, None).unwrap().item::<f32>(),
1113 0.499_930_68,
1114 abs <= 0.009_998_614
1115 );
1116 assert_float_eq!(
1117 result.sum(None, None).unwrap().item::<f32>(),
1118 127.982_254,
1119 abs <= 2.559_645_2
1120 );
1121 }
1122
1123 #[test]
1124 fn test_relu6() {
1125 crate::random::seed(379).unwrap();
1126 let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
1127 assert_eq!(a.shape(), &[2, 8, 16]);
1128 assert_eq!(a.dtype(), Dtype::Float32);
1129 assert_float_eq!(
1130 a.mean(None, None).unwrap().item::<f32>(),
1131 0.493_258_66,
1132 abs <= 0.009_865_173
1133 );
1134 assert_float_eq!(
1135 a.sum(None, None).unwrap().item::<f32>(),
1136 126.274_216,
1137 abs <= 2.525_484_3
1138 );
1139 let result = Relu6.forward(&a).unwrap();
1140 assert_eq!(result.shape(), &[2, 8, 16]);
1141 assert_eq!(result.dtype(), Dtype::Float32);
1142 assert_float_eq!(
1143 result.mean(None, None).unwrap().item::<f32>(),
1144 0.493_258_66,
1145 abs <= 0.009_865_173
1146 );
1147 assert_float_eq!(
1148 result.sum(None, None).unwrap().item::<f32>(),
1149 126.274_216,
1150 abs <= 2.525_484_3
1151 );
1152 }
1153
1154 #[test]
1155 fn test_softmax() {
1156 crate::random::seed(853).unwrap();
1157 let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
1158 assert_eq!(a.shape(), &[2, 8, 16]);
1159 assert_eq!(a.dtype(), Dtype::Float32);
1160 assert_float_eq!(
1161 a.mean(None, None).unwrap().item::<f32>(),
1162 0.514_396_3,
1163 abs <= 0.010_287_926_5
1164 );
1165 assert_float_eq!(
1166 a.sum(None, None).unwrap().item::<f32>(),
1167 131.685_46,
1168 abs <= 2.633_709_2
1169 );
1170 let result = Softmax::new().forward(&a).unwrap();
1171 assert_eq!(result.shape(), &[2, 8, 16]);
1172 assert_eq!(result.dtype(), Dtype::Float32);
1173 assert_float_eq!(
1174 result.mean(None, None).unwrap().item::<f32>(),
1175 0.062_499_996,
1176 abs <= 0.001_25
1177 );
1178 assert_float_eq!(
1179 result.sum(None, None).unwrap().item::<f32>(),
1180 15.999_999,
1181 abs <= 0.32
1182 );
1183 }
1184
1185 #[test]
1186 fn test_softplus() {
1187 crate::random::seed(118).unwrap();
1188 let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
1189 assert_eq!(a.shape(), &[2, 8, 16]);
1190 assert_eq!(a.dtype(), Dtype::Float32);
1191 assert_float_eq!(
1192 a.mean(None, None).unwrap().item::<f32>(),
1193 0.498_981_42,
1194 abs <= 0.009_979_628
1195 );
1196 assert_float_eq!(
1197 a.sum(None, None).unwrap().item::<f32>(),
1198 127.739_24,
1199 abs <= 2.554_784_8
1200 );
1201 let result = Softplus.forward(&a).unwrap();
1202 assert_eq!(result.shape(), &[2, 8, 16]);
1203 assert_eq!(result.dtype(), Dtype::Float32);
1204 assert_float_eq!(
1205 result.mean(None, None).unwrap().item::<f32>(),
1206 0.982_857_76,
1207 abs <= 0.019_657_155
1208 );
1209 assert_float_eq!(
1210 result.sum(None, None).unwrap().item::<f32>(),
1211 251.611_59,
1212 abs <= 5.032_232
1213 );
1214 }
1215
1216 #[test]
1217 fn test_softsign() {
1218 crate::random::seed(37).unwrap();
1219 let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
1220 assert_eq!(a.shape(), &[2, 8, 16]);
1221 assert_eq!(a.dtype(), Dtype::Float32);
1222 assert_float_eq!(
1223 a.mean(None, None).unwrap().item::<f32>(),
1224 0.506_551_27,
1225 abs <= 0.010_131_026
1226 );
1227 assert_float_eq!(
1228 a.sum(None, None).unwrap().item::<f32>(),
1229 129.677_12,
1230 abs <= 2.593_542_6
1231 );
1232 let result = Softsign.forward(&a).unwrap();
1233 assert_eq!(result.shape(), &[2, 8, 16]);
1234 assert_eq!(result.dtype(), Dtype::Float32);
1235 assert_float_eq!(
1236 result.mean(None, None).unwrap().item::<f32>(),
1237 0.314_089_83,
1238 abs <= 0.006_281_797
1239 );
1240 assert_float_eq!(
1241 result.sum(None, None).unwrap().item::<f32>(),
1242 80.407,
1243 abs <= 1.608_14
1244 );
1245 }
1246
1247 #[test]
1250 fn test_celu() {
1251 let x = array!([1.0, -1.0, 0.0]);
1252 let y = Celu::new().forward(&x).unwrap();
1253 let epsilon = array!(1e-4);
1254 let expected_y = array!([1.0, -0.6321, 0.0]);
1255 assert!(y
1256 .subtract(&expected_y)
1257 .unwrap()
1258 .abs()
1259 .unwrap()
1260 .lt(&epsilon)
1261 .unwrap()
1262 .all(None, None)
1263 .unwrap()
1264 .item::<bool>());
1265 assert_eq!(y.shape(), &[3]);
1266 assert_eq!(y.dtype(), Dtype::Float32);
1267
1268 let y = CeluBuilder::new()
1269 .alpha(1.1)
1270 .build()
1271 .unwrap()
1272 .forward(&x)
1273 .unwrap();
1274 let expected_y = array!([1.0, -0.6568, 0.0]);
1275 assert!(y
1276 .subtract(&expected_y)
1277 .unwrap()
1278 .abs()
1279 .unwrap()
1280 .lt(&epsilon)
1281 .unwrap()
1282 .all(None, None)
1283 .unwrap()
1284 .item::<bool>());
1285 assert_eq!(y.shape(), &[3]);
1286 assert_eq!(y.dtype(), Dtype::Float32);
1287 }
1288
1289 #[test]
1290 fn test_silu() {
1291 crate::random::seed(22).unwrap();
1292 let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
1293 assert_eq!(a.shape(), &[2, 8, 16]);
1294 assert_eq!(a.dtype(), Dtype::Float32);
1295 assert_float_eq!(
1296 a.mean(None, None).unwrap().item::<f32>(),
1297 0.502_970_6,
1298 abs <= 0.010_059_412
1299 );
1300 assert_float_eq!(
1301 a.sum(None, None).unwrap().item::<f32>(),
1302 128.760_47,
1303 abs <= 2.575_209_4
1304 );
1305 let result = Silu.forward(&a).unwrap();
1306 assert_eq!(result.shape(), &[2, 8, 16]);
1307 assert_eq!(result.dtype(), Dtype::Float32);
1308 assert_float_eq!(
1309 result.mean(None, None).unwrap().item::<f32>(),
1310 0.331_970_93,
1311 abs <= 0.006_639_418_7
1312 );
1313 assert_float_eq!(
1314 result.sum(None, None).unwrap().item::<f32>(),
1315 84.984_56,
1316 abs <= 1.699_691_2
1317 );
1318 }
1319
1320 #[test]
1321 fn test_log_softmax() {
1322 crate::random::seed(199).unwrap();
1323 let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
1324 assert_eq!(a.shape(), &[2, 8, 16]);
1325 assert_eq!(a.dtype(), Dtype::Float32);
1326 assert_float_eq!(
1327 a.mean(None, None).unwrap().item::<f32>(),
1328 0.527_843_7,
1329 abs <= 0.010_556_874
1330 );
1331 assert_float_eq!(
1332 a.sum(None, None).unwrap().item::<f32>(),
1333 135.127_99,
1334 abs <= 2.702_559_7
1335 );
1336 let result = LogSoftmax::new().forward(&a).unwrap();
1337 assert_eq!(result.shape(), &[2, 8, 16]);
1338 assert_eq!(result.dtype(), Dtype::Float32);
1339 assert_float_eq!(
1340 result.mean(None, None).unwrap().item::<f32>(),
1341 -2.810_954_6,
1342 abs <= 0.056_219_09
1343 );
1344 assert_float_eq!(
1345 result.sum(None, None).unwrap().item::<f32>(),
1346 -719.604_4,
1347 abs <= 14.392_087
1348 );
1349 }
1350
1351 #[test]
1352 fn test_log_sigmoid() {
1353 crate::random::seed(984).unwrap();
1354 let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
1355 assert_eq!(a.shape(), &[2, 8, 16]);
1356 assert_eq!(a.dtype(), Dtype::Float32);
1357 assert_float_eq!(
1358 a.mean(None, None).unwrap().item::<f32>(),
1359 0.510_977_7,
1360 abs <= 0.010_219_553_5
1361 );
1362 assert_float_eq!(
1363 a.sum(None, None).unwrap().item::<f32>(),
1364 130.810_29,
1365 abs <= 2.616_205_7
1366 );
1367 let result = LogSigmoid.forward(&a).unwrap();
1368 assert_eq!(result.shape(), &[2, 8, 16]);
1369 assert_eq!(result.dtype(), Dtype::Float32);
1370 assert_float_eq!(
1371 result.mean(None, None).unwrap().item::<f32>(),
1372 -0.479_598_55,
1373 abs <= 0.009_591_971
1374 );
1375 assert_float_eq!(
1376 result.sum(None, None).unwrap().item::<f32>(),
1377 -122.777_23,
1378 abs <= 2.455_544_5
1379 );
1380 }
1381
1382 #[test]
1383 fn test_prelu() {
1384 crate::random::seed(993).unwrap();
1385 let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
1386 assert_eq!(a.shape(), &[2, 8, 16]);
1387 assert_eq!(a.dtype(), Dtype::Float32);
1388 assert_float_eq!(
1389 a.mean(None, None).unwrap().item::<f32>(),
1390 0.496_651_44,
1391 abs <= 0.009_933_028
1392 );
1393 assert_float_eq!(
1394 a.sum(None, None).unwrap().item::<f32>(),
1395 127.142_77,
1396 abs <= 2.542_855_3
1397 );
1398 let result = Prelu::new().forward(&a).unwrap();
1399 assert_eq!(result.shape(), &[2, 8, 16]);
1400 assert_eq!(result.dtype(), Dtype::Float32);
1401 assert_float_eq!(
1402 result.mean(None, None).unwrap().item::<f32>(),
1403 0.496_651_44,
1404 abs <= 0.009_933_028
1405 );
1406 assert_float_eq!(
1407 result.sum(None, None).unwrap().item::<f32>(),
1408 127.142_77,
1409 abs <= 2.542_855_3
1410 );
1411 }
1412
1413 #[test]
1414 fn test_gelu() {
1415 crate::random::seed(189).unwrap();
1416 let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
1417 assert_eq!(a.shape(), &[2, 8, 16]);
1418 assert_eq!(a.dtype(), Dtype::Float32);
1419 assert_float_eq!(
1420 a.mean(None, None).unwrap().item::<f32>(),
1421 0.492_950_32,
1422 abs <= 0.009_859_007
1423 );
1424 assert_float_eq!(
1425 a.sum(None, None).unwrap().item::<f32>(),
1426 126.195_28,
1427 abs <= 2.523_905_8
1428 );
1429 let result = Gelu::new().forward(&a).unwrap();
1430 assert_eq!(result.shape(), &[2, 8, 16]);
1431 assert_eq!(result.dtype(), Dtype::Float32);
1432 assert_float_eq!(
1433 result.mean(None, None).unwrap().item::<f32>(),
1434 0.365_638_38,
1435 abs <= 0.007_312_767_7
1436 );
1437 assert_float_eq!(
1438 result.sum(None, None).unwrap().item::<f32>(),
1439 93.603_424,
1440 abs <= 1.872_068_5
1441 );
1442 }
1443
1444 #[test]
1445 fn test_tanh() {
1446 crate::random::seed(735).unwrap();
1447 let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
1448 assert_eq!(a.shape(), &[2, 8, 16]);
1449 assert_eq!(a.dtype(), Dtype::Float32);
1450 assert_float_eq!(
1451 a.mean(None, None).unwrap().item::<f32>(),
1452 0.474_122_7,
1453 abs <= 0.009_482_454_5
1454 );
1455 assert_float_eq!(
1456 a.sum(None, None).unwrap().item::<f32>(),
1457 121.375_41,
1458 abs <= 2.427_508_4
1459 );
1460 let result = Tanh.forward(&a).unwrap();
1461 assert_eq!(result.shape(), &[2, 8, 16]);
1462 assert_eq!(result.dtype(), Dtype::Float32);
1463 assert_float_eq!(
1464 result.mean(None, None).unwrap().item::<f32>(),
1465 0.413_079_68,
1466 abs <= 0.008_261_594
1467 );
1468 assert_float_eq!(
1469 result.sum(None, None).unwrap().item::<f32>(),
1470 105.748_4,
1471 abs <= 2.114_968
1472 );
1473 }
1474
1475 #[test]
1476 fn test_hardswish() {
1477 crate::random::seed(126).unwrap();
1478 let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
1479 assert_eq!(a.shape(), &[2, 8, 16]);
1480 assert_eq!(a.dtype(), Dtype::Float32);
1481 assert_float_eq!(
1482 a.mean(None, None).unwrap().item::<f32>(),
1483 0.491_892_46,
1484 abs <= 0.009_837_849
1485 );
1486 assert_float_eq!(
1487 a.sum(None, None).unwrap().item::<f32>(),
1488 125.924_47,
1489 abs <= 2.518_489_4
1490 );
1491 let result = HardSwish.forward(&a).unwrap();
1492 assert_eq!(result.shape(), &[2, 8, 16]);
1493 assert_eq!(result.dtype(), Dtype::Float32);
1494 assert_float_eq!(
1495 result.mean(None, None).unwrap().item::<f32>(),
1496 0.299_602_24,
1497 abs <= 0.005_992_044_7
1498 );
1499 assert_float_eq!(
1500 result.sum(None, None).unwrap().item::<f32>(),
1501 76.698_17,
1502 abs <= 1.533_963_4
1503 );
1504 }
1505
1506 #[test]
1507 fn test_step() {
1508 crate::random::seed(490).unwrap();
1509 let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
1510 assert_eq!(a.shape(), &[2, 8, 16]);
1511 assert_eq!(a.dtype(), Dtype::Float32);
1512 assert_float_eq!(
1513 a.mean(None, None).unwrap().item::<f32>(),
1514 0.479_360_64,
1515 abs <= 0.009_587_212_5
1516 );
1517 assert_float_eq!(
1518 a.sum(None, None).unwrap().item::<f32>(),
1519 122.716_324,
1520 abs <= 2.454_326_4
1521 );
1522 let result = Step::new().forward(&a).unwrap();
1523 assert_eq!(result.shape(), &[2, 8, 16]);
1524 assert_eq!(result.dtype(), Dtype::Int32);
1525 assert_float_eq!(
1526 result.mean(None, None).unwrap().item::<f32>(),
1527 1.0,
1528 abs <= 0.02
1529 );
1530 assert_float_eq!(
1531 result.sum(None, None).unwrap().item::<f32>(),
1532 256.0,
1533 abs <= 5.12
1534 );
1535 }
1536
1537 #[test]
1538 fn test_selu() {
1539 crate::random::seed(215).unwrap();
1540 let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
1541 assert_eq!(a.shape(), &[2, 8, 16]);
1542 assert_eq!(a.dtype(), Dtype::Float32);
1543 assert_float_eq!(
1544 a.mean(None, None).unwrap().item::<f32>(),
1545 0.493_026_8,
1546 abs <= 0.009_860_536
1547 );
1548 assert_float_eq!(
1549 a.sum(None, None).unwrap().item::<f32>(),
1550 126.214_86,
1551 abs <= 2.524_297_2
1552 );
1553 let result = Selu.forward(&a).unwrap();
1554 assert_eq!(result.shape(), &[2, 8, 16]);
1555 assert_eq!(result.dtype(), Dtype::Float32);
1556 assert_float_eq!(
1557 result.mean(None, None).unwrap().item::<f32>(),
1558 0.518_023_2,
1559 abs <= 0.010_360_463_5
1560 );
1561 assert_float_eq!(
1562 result.sum(None, None).unwrap().item::<f32>(),
1563 132.613_94,
1564 abs <= 2.652_278_7
1565 );
1566 }
1567}