1use std::f32::consts::PI;
2
3use crate::module::{Module, Param};
4use crate::ops::logsumexp_axis;
5use crate::{
6 array,
7 error::{Exception, Result},
8 ops::{abs, exp, maximum, minimum, multiply, which},
9 transforms::compile::compile,
10 Array,
11};
12use mlx_internal_macros::{generate_builder, Buildable, Builder};
13use mlx_macros::ModuleParameters;
14
15pub fn sigmoid(x: impl AsRef<Array>) -> Result<Array> {
26 crate::ops::sigmoid(x.as_ref())
27}
28
29pub fn relu(x: impl AsRef<Array>) -> Result<Array> {
37 crate::ops::maximum(x.as_ref(), &array!(0))
38}
39
40pub fn leaky_relu(x: impl AsRef<Array>, neg_slope: impl Into<Option<f32>>) -> Result<Array> {
50 let neg_slope = array!(neg_slope.into().unwrap_or(0.01));
51 compiled_leaky_relu(x.as_ref(), &neg_slope)
54}
55
56pub fn log_softmax(x: impl AsRef<Array>, axis: impl Into<Option<i32>>) -> Result<Array> {
64 let x = x.as_ref();
65 let axis = axis.into().unwrap_or(-1);
66 x.subtract(logsumexp_axis(x, axis, true)?)
67}
68
69pub fn elu(x: impl AsRef<Array>, alpha: impl Into<Option<f32>>) -> Result<Array> {
82 let alpha = array!(alpha.into().unwrap_or(1.0));
83 compiled_elu(x.as_ref(), &alpha)
86}
87
88pub fn relu6(x: impl AsRef<Array>) -> Result<Array> {
96 compiled_relu6(x.as_ref())
97}
98
99pub fn softplus(x: impl AsRef<Array>) -> Result<Array> {
107 crate::ops::logaddexp(x.as_ref(), &array!(0))
108}
109
110pub fn softsign(x: impl AsRef<Array>) -> Result<Array> {
118 compiled_softsign(x.as_ref())
119}
120
121pub fn celu(x: impl AsRef<Array>, alpha: impl Into<Option<f32>>) -> Result<Array> {
129 let alpha = array!(alpha.into().unwrap_or(1.0));
130 compiled_celu(x.as_ref(), &alpha)
133}
134
135pub fn silu(x: impl AsRef<Array>) -> Result<Array> {
143 compiled_silu(x.as_ref())
144}
145
146pub fn log_sigmoid(x: impl AsRef<Array>) -> Result<Array> {
154 compiled_log_sigmoid(x.as_ref())
155}
156
157pub fn gelu(x: impl AsRef<Array>) -> Result<Array> {
165 compiled_gelu(x.as_ref())
166}
167
168pub fn gelu_approximate(x: impl AsRef<Array>) -> Result<Array> {
176 compiled_gelu_approximate(x.as_ref())
177}
178
179pub fn gelu_fast_approximate(x: impl AsRef<Array>) -> Result<Array> {
187 compiled_gelu_fast_approximate(x.as_ref())
188}
189
190pub fn glu(x: impl AsRef<Array>, axis: impl Into<Option<i32>>) -> Result<Array> {
195 let split = x.as_ref().split(2, axis)?;
196 let (a, b) = (&split[0], &split[1]);
197 Ok(a * sigmoid(b)?)
198}
199
200pub fn step(x: impl AsRef<Array>, threshold: impl Into<Option<f32>>) -> Result<Array> {
211 let threshold = array!(threshold.into().unwrap_or(0.0));
212 crate::ops::r#where(&x.as_ref().gt(threshold)?, &array!(1), &array!(0))
213}
214
215pub fn selu(x: impl AsRef<Array>) -> Result<Array> {
223 compiled_selu(x.as_ref())
224}
225
226pub fn prelu(x: impl AsRef<Array>, alpha: impl AsRef<Array>) -> Result<Array> {
234 compiled_prelu(x.as_ref(), alpha.as_ref())
235}
236
237pub fn mish(x: impl AsRef<Array>) -> Result<Array> {
249 compiled_mish(x.as_ref())
250}
251
252pub fn hard_swish(x: impl AsRef<Array>) -> Result<Array> {
260 compiled_hard_swish(x.as_ref())
261}
262
263generate_builder! {
264 #[derive(Debug, Clone, ModuleParameters, Buildable)]
269 #[module(root = crate)]
270 #[buildable(root = crate)]
271 #[builder(root = crate)]
272 pub struct Glu {
273 #[builder(optional, default = Glu::DEFAULT_AXIS)]
275 pub axis: i32,
276 }
277}
278
279impl Glu {
280 pub const DEFAULT_AXIS: i32 = -1;
282}
283
284impl Module<&Array> for Glu {
285 type Error = Exception;
286 type Output = Array;
287
288 fn forward(&mut self, x: &Array) -> Result<Array> {
289 glu(x, self.axis)
290 }
291
292 fn training_mode(&mut self, _: bool) {}
293}
294
295#[derive(Debug, Clone, ModuleParameters)]
306#[module(root = crate)]
307pub struct Sigmoid;
308
309impl Module<&Array> for Sigmoid {
310 type Error = Exception;
311 type Output = Array;
312
313 fn forward(&mut self, x: &Array) -> Result<Array> {
314 sigmoid(x)
315 }
316
317 fn training_mode(&mut self, _: bool) {}
318}
319
320#[derive(Debug, Clone, ModuleParameters)]
332#[module(root = crate)]
333pub struct Mish;
334
335impl Module<&Array> for Mish {
336 type Error = Exception;
337 type Output = Array;
338
339 fn forward(&mut self, x: &Array) -> Result<Array> {
340 mish(x)
341 }
342
343 fn training_mode(&mut self, _: bool) {}
344}
345
346#[derive(Debug, Clone, ModuleParameters)]
354#[module(root = crate)]
355pub struct Relu;
356
357impl Module<&Array> for Relu {
358 type Error = Exception;
359 type Output = Array;
360
361 fn forward(&mut self, x: &Array) -> Result<Array> {
362 relu(x)
363 }
364
365 fn training_mode(&mut self, _: bool) {}
366}
367
368generate_builder! {
369 #[derive(Debug, Clone, ModuleParameters, Buildable)]
377 #[module(root = crate)]
378 #[buildable(root = crate)]
379 #[builder(root = crate)]
380 pub struct LeakyRelu {
381 #[builder(optional, default = LeakyRelu::DEFAULT_NEG_SLOPE)]
383 pub neg_slope: f32,
384 }
385}
386
387impl LeakyRelu {
388 pub const DEFAULT_NEG_SLOPE: f32 = 0.01;
390}
391
392impl Module<&Array> for LeakyRelu {
393 type Error = Exception;
394 type Output = Array;
395
396 fn forward(&mut self, x: &Array) -> Result<Array> {
397 leaky_relu(x, self.neg_slope)
398 }
399
400 fn training_mode(&mut self, _: bool) {}
401}
402
403#[derive(Debug, Clone, ModuleParameters)]
411#[module(root = crate)]
412pub struct Relu6;
413
414impl Module<&Array> for Relu6 {
415 type Error = Exception;
416 type Output = Array;
417
418 fn forward(&mut self, x: &Array) -> Result<Array> {
419 relu6(x)
420 }
421
422 fn training_mode(&mut self, _: bool) {}
423}
424
425generate_builder! {
426 #[derive(Debug, Clone, ModuleParameters, Buildable)]
434 #[module(root = crate)]
435 #[buildable(root = crate)]
436 #[builder(root = crate)]
437 pub struct Softmax {
438 #[builder(optional, default = Softmax::DEFAULT_AXIS)]
440 pub axis: i32,
441 }
442}
443
444impl Softmax {
445 pub const DEFAULT_AXIS: i32 = -1;
447}
448
449impl Module<&Array> for Softmax {
450 type Error = Exception;
451 type Output = Array;
452
453 fn forward(&mut self, x: &Array) -> Result<Array> {
454 crate::ops::softmax_axis(x, self.axis, None)
455 }
456
457 fn training_mode(&mut self, _: bool) {}
458}
459
460#[derive(Debug, Clone, ModuleParameters)]
468#[module(root = crate)]
469pub struct Softplus;
470
471impl Module<&Array> for Softplus {
472 type Error = Exception;
473 type Output = Array;
474
475 fn forward(&mut self, x: &Array) -> Result<Array> {
476 softplus(x)
477 }
478
479 fn training_mode(&mut self, _: bool) {}
480}
481
482#[derive(Debug, Clone, ModuleParameters)]
490#[module(root = crate)]
491pub struct Softsign;
492
493impl Module<&Array> for Softsign {
494 type Error = Exception;
495 type Output = Array;
496
497 fn forward(&mut self, x: &Array) -> Result<Array> {
498 softsign(x)
499 }
500
501 fn training_mode(&mut self, _: bool) {}
502}
503
504generate_builder! {
505 #[derive(Debug, Clone, ModuleParameters, Buildable)]
514 #[module(root = crate)]
515 #[buildable(root = crate)]
516 #[builder(root = crate)]
517 pub struct Celu {
518 #[builder(optional, default = Celu::DEFAULT_ALPHA)]
520 pub alpha: f32,
521 }
522}
523
524impl Celu {
525 pub const DEFAULT_ALPHA: f32 = 1.0;
527}
528
529impl Module<&Array> for Celu {
530 type Error = Exception;
531 type Output = Array;
532
533 fn forward(&mut self, x: &Array) -> Result<Array> {
534 celu(x, self.alpha)
535 }
536
537 fn training_mode(&mut self, _: bool) {}
538}
539
540#[derive(Debug, Clone, ModuleParameters)]
548#[module(root = crate)]
549pub struct Silu;
550
551impl Module<&Array> for Silu {
552 type Error = Exception;
553 type Output = Array;
554
555 fn forward(&mut self, x: &Array) -> Result<Array> {
556 silu(x)
557 }
558
559 fn training_mode(&mut self, _: bool) {}
560}
561
562generate_builder! {
563 #[derive(Debug, Clone, ModuleParameters, Buildable)]
571 #[module(root = crate)]
572 #[buildable(root = crate)]
573 #[builder(root = crate)]
574 pub struct LogSoftmax {
575 #[builder(optional, default = LogSoftmax::DEFAULT_AXIS)]
577 pub axis: i32,
578 }
579}
580
581impl LogSoftmax {
582 pub const DEFAULT_AXIS: i32 = -1;
584}
585
586impl Module<&Array> for LogSoftmax {
587 type Error = Exception;
588 type Output = Array;
589
590 fn forward(&mut self, x: &Array) -> Result<Array> {
591 log_softmax(x, self.axis)
592 }
593
594 fn training_mode(&mut self, _: bool) {}
595}
596
597#[derive(Debug, Clone, ModuleParameters)]
605#[module(root = crate)]
606pub struct LogSigmoid;
607
608impl Module<&Array> for LogSigmoid {
609 type Error = Exception;
610 type Output = Array;
611
612 fn forward(&mut self, x: &Array) -> Result<Array> {
613 log_sigmoid(x)
614 }
615
616 fn training_mode(&mut self, _: bool) {}
617}
618
619#[derive(Debug, Clone, ModuleParameters, Buildable)]
627#[module(root = crate)]
628#[buildable(root = crate)]
629pub struct Prelu {
630 #[param]
632 #[builder(ignore)]
633 pub weight: Param<Array>, }
635
636#[derive(Debug, Clone, Builder)]
638#[builder(
639 root = crate,
640 build_with = build_prelu,
641 default_infallible,
642 err = Exception,
643)]
644pub struct PreluBuilder {
645 #[builder(optional, default = Prelu::DEFAULT_COUNT)]
647 pub count: i32,
648
649 #[builder(optional, default = Prelu::DEFAULT_VALUE)]
651 pub value: f32,
652}
653
654fn build_prelu(builder: PreluBuilder) -> Result<Prelu> {
656 let count = builder.count;
657 let value = builder.value;
658 let weight = Param::new(crate::ops::full::<f32>(&[count], &array!(value))?);
659 Ok(Prelu { weight })
660}
661
662impl Prelu {
663 pub const DEFAULT_COUNT: i32 = 1;
665
666 pub const DEFAULT_VALUE: f32 = 0.25;
668}
669
670impl Module<&Array> for Prelu {
671 type Error = Exception;
672 type Output = Array;
673
674 fn forward(&mut self, x: &Array) -> Result<Array> {
675 prelu(x, &self.weight)
676 }
677
678 fn training_mode(&mut self, _: bool) {}
679}
680
681#[derive(Debug, Clone, Copy, Default)]
683pub enum GeluApprox {
684 #[default]
686 None,
687
688 Precise,
690
691 Fast,
693}
694
695generate_builder! {
696 #[derive(Debug, Clone, ModuleParameters, Buildable)]
704 #[module(root = crate)]
705 #[buildable(root = crate)]
706 #[builder(root = crate)]
707 pub struct Gelu {
708 #[builder(optional, default = GeluApprox::None)]
710 pub approximate: GeluApprox,
711 }
712}
713
714impl Module<&Array> for Gelu {
715 type Error = Exception;
716 type Output = Array;
717
718 fn forward(&mut self, x: &Array) -> Result<Array> {
719 match self.approximate {
720 GeluApprox::None => gelu(x),
721 GeluApprox::Precise => gelu_approximate(x),
722 GeluApprox::Fast => gelu_fast_approximate(x),
723 }
724 }
725
726 fn training_mode(&mut self, _: bool) {}
727}
728
729#[derive(Debug, Clone, ModuleParameters)]
731#[module(root = crate)]
732pub struct Tanh;
733
734impl Module<&Array> for Tanh {
735 type Error = Exception;
736 type Output = Array;
737
738 fn forward(&mut self, x: &Array) -> Result<Array> {
739 crate::ops::tanh(x)
740 }
741
742 fn training_mode(&mut self, _: bool) {}
743}
744
745#[derive(Debug, Clone, ModuleParameters)]
753#[module(root = crate)]
754pub struct HardSwish;
755
756impl Module<&Array> for HardSwish {
757 type Error = Exception;
758 type Output = Array;
759
760 fn forward(&mut self, x: &Array) -> Result<Array> {
761 hard_swish(x)
762 }
763
764 fn training_mode(&mut self, _: bool) {}
765}
766
767generate_builder! {
768 #[derive(Debug, Clone, ModuleParameters, Buildable)]
779 #[module(root = crate)]
780 #[buildable(root = crate)]
781 #[builder(root = crate)]
782 pub struct Step {
783 #[builder(optional, default = Step::DEFAULT_THRESHOLD)]
785 pub threshold: f32,
786 }
787}
788
789impl Step {
790 pub const DEFAULT_THRESHOLD: f32 = 0.0;
792}
793
794impl Module<&Array> for Step {
795 type Error = Exception;
796 type Output = Array;
797
798 fn forward(&mut self, x: &Array) -> Result<Array> {
799 step(x, self.threshold)
800 }
801
802 fn training_mode(&mut self, _: bool) {}
803}
804
805#[derive(Debug, Clone, ModuleParameters)]
813#[module(root = crate)]
814pub struct Selu;
815
816impl Module<&Array> for Selu {
817 type Error = Exception;
818 type Output = Array;
819
820 fn forward(&mut self, x: &Array) -> Result<Array> {
821 selu(x)
822 }
823
824 fn training_mode(&mut self, _: bool) {}
825}
826
827#[inline]
832fn compiled_leaky_relu(x: &Array, neg_slope: &Array) -> Result<Array> {
833 let f = |(x_, neg_slope_): (&Array, &Array)| {
834 let a = multiply(neg_slope_, x_)?;
836 maximum(&a, x_)
837 };
838 let mut compiled = compile(f, true);
839 compiled((x, neg_slope))
840}
841
842#[inline]
843fn compiled_elu(x: &Array, alpha: &Array) -> Result<Array> {
844 let f = |(x_, alpha_): (&Array, &Array)| {
845 which(&x_.gt(&array!(0.0))?, x_, alpha_ * (exp(x_)? - array!(1.0)))
846 };
847 let mut compiled = compile(f, true);
848 compiled((x, alpha))
849}
850
851#[inline]
852fn compiled_relu6(x: &Array) -> Result<Array> {
853 let f = |x_: &Array| minimum(maximum(x_, &array!(0.0))?, &array!(6.0));
854 let mut compiled = compile(f, true);
855 compiled(x)
856}
857
858#[inline]
859fn compiled_softsign(x: &Array) -> Result<Array> {
860 let f = |x_: &Array| x_.divide(array!(1.0) + abs(x_)?);
861 let mut compiled = compile(f, true);
862 compiled(x)
863}
864
865#[inline]
866fn compiled_celu(x: &Array, alpha: &Array) -> Result<Array> {
867 let f = |(x_, alpha_): (&Array, &Array)| {
868 maximum(x_, &array!(0.0))?
869 .add(alpha_.multiply(exp(&(minimum(x_, &array!(0.0))? / alpha_))? - array!(1.0))?)
870 };
871 let mut compiled = compile(f, true);
872 compiled((x, alpha))
873}
874
875#[inline]
876fn compiled_silu(x: &Array) -> Result<Array> {
877 let f = |x_: &Array| x_.multiply(sigmoid(x_)?);
878 let mut compiled = compile(f, true);
879 compiled(x)
880}
881
882#[inline]
883fn compiled_log_sigmoid(x: &Array) -> Result<Array> {
884 let f = |x_: &Array| Ok(-softplus(&(-x_))?);
885 let mut compiled = compile(f, true);
886 compiled(x)
887}
888
889#[inline]
890fn compiled_gelu(x: &Array) -> Result<Array> {
891 use crate::ops::erf;
892 let f = |x_: &Array| {
893 x_.multiply(array!(1) + erf(&(x_ / array!(2f32.sqrt())))?)?
894 .divide(array!(2.0))
895 };
896 let mut compiled = compile(f, true);
897 compiled(x)
898}
899
900#[inline]
901fn compiled_gelu_approximate(x: &Array) -> Result<Array> {
902 use crate::ops::{sqrt, tanh};
903
904 let f = move |x_: &Array| {
905 array!(0.5).multiply(x_)?.multiply(
907 array!(1.0).add(tanh(
908 &(sqrt(&array!(2.0 / PI))?
909 .multiply(x_ + array!(0.044715).multiply(x_.power(&array!(3))?)?)?),
910 )?)?,
911 )
912 };
913 let mut compiled = compile(f, true);
914 compiled(x)
915}
916
917#[inline]
918fn compiled_gelu_fast_approximate(x: &Array) -> Result<Array> {
919 let f = |x_: &Array| x_.multiply(sigmoid(&(array!(1.773) * x_))?);
920 let mut compiled = compile(f, true);
921 compiled(x)
922}
923
924#[inline]
925fn compiled_selu(x: &Array) -> Result<Array> {
926 let f = |x_: &Array| elu(x_, 1.67326)?.multiply(array!(1.0507));
927 let mut compiled = compile(f, true);
928 compiled(x)
929}
930
931#[inline]
932fn compiled_prelu(x: &Array, alpha: &Array) -> Result<Array> {
933 let f = |(x_, alpha_): (&Array, &Array)| {
934 maximum(&array!(0.0), x_)?.add(alpha_ * minimum(&array!(0.0), x_)?)
935 };
936 let mut compiled = compile(f, true);
937 compiled((x, alpha))
938}
939
940#[inline]
941fn compiled_mish(x: &Array) -> Result<Array> {
942 use crate::ops::tanh;
943
944 let f = |x_: &Array| x_.multiply(tanh(&softplus(x_)?)?);
945 let mut compiled = compile(f, true);
946 compiled(x)
947}
948
949#[inline]
950fn compiled_hard_swish(x: &Array) -> Result<Array> {
951 let f = |x_: &Array| {
952 let max_x_plus_3 = maximum(&(x_ + array!(3.0)), &array!(0.0))?;
953 x_.multiply(minimum(&max_x_plus_3, &array!(6.0))?)?
954 .divide(&array!(6.0))
955 };
956 let mut compiled = compile(f, true);
957 compiled(x)
958}
959
960#[cfg(test)]
963mod tests {
964 use crate::{builder::Builder, random::uniform, Dtype};
965 use float_eq::assert_float_eq;
966
967 use super::*;
968
969 #[test]
970 fn test_glu() {
971 crate::random::seed(850).unwrap();
972 let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
973 assert_eq!(a.shape(), &[2, 8, 16]);
974 assert_eq!(a.dtype(), Dtype::Float32);
975 assert_float_eq!(
976 a.mean(None).unwrap().item::<f32>(),
977 0.547_252_66,
978 abs <= 0.010_945_053
979 );
980 assert_float_eq!(
981 a.sum(None).unwrap().item::<f32>(),
982 140.096_68,
983 abs <= 2.801_933_5
984 );
985 let result = Glu::new().forward(&a).unwrap();
986 assert_eq!(result.shape(), &[2, 8, 8]);
987 assert_eq!(result.dtype(), Dtype::Float32);
988 assert_float_eq!(
989 result.mean(None).unwrap().item::<f32>(),
990 0.333_276_75,
991 abs <= 0.006_665_535
992 );
993 assert_float_eq!(
994 result.sum(None).unwrap().item::<f32>(),
995 42.659_424,
996 abs <= 0.853_188_46
997 );
998 }
999
1000 #[test]
1001 fn test_sigmoid() {
1002 crate::random::seed(589).unwrap();
1003 let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
1004 assert_eq!(a.shape(), &[2, 8, 16]);
1005 assert_eq!(a.dtype(), Dtype::Float32);
1006 assert_float_eq!(
1007 a.mean(None).unwrap().item::<f32>(),
1008 0.529_697_9,
1009 abs <= 0.010_593_958
1010 );
1011 assert_float_eq!(
1012 a.sum(None).unwrap().item::<f32>(),
1013 135.602_66,
1014 abs <= 2.712_053_3
1015 );
1016 let result = Sigmoid.forward(&a).unwrap();
1017 assert_eq!(result.shape(), &[2, 8, 16]);
1018 assert_eq!(result.dtype(), Dtype::Float32);
1019 assert_float_eq!(
1020 result.mean(None).unwrap().item::<f32>(),
1021 0.627_014,
1022 abs <= 0.012_540_28
1023 );
1024 assert_float_eq!(
1025 result.sum(None).unwrap().item::<f32>(),
1026 160.515_58,
1027 abs <= 3.210_311_7
1028 );
1029 }
1030
1031 #[test]
1032 fn test_mish() {
1033 crate::random::seed(122).unwrap();
1034 let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
1035 assert_eq!(a.shape(), &[2, 8, 16]);
1036 assert_eq!(a.dtype(), Dtype::Float32);
1037 assert_float_eq!(
1038 a.mean(None).unwrap().item::<f32>(),
1039 0.501_719_8,
1040 abs <= 0.010_034_395
1041 );
1042 assert_float_eq!(
1043 a.sum(None).unwrap().item::<f32>(),
1044 128.440_26,
1045 abs <= 2.568_805_2
1046 );
1047 let result = Mish.forward(&a).unwrap();
1048 assert_eq!(result.shape(), &[2, 8, 16]);
1049 assert_eq!(result.dtype(), Dtype::Float32);
1050 assert_float_eq!(
1051 result.mean(None).unwrap().item::<f32>(),
1052 0.395_375_73,
1053 abs <= 0.007_907_514
1054 );
1055 assert_float_eq!(
1056 result.sum(None).unwrap().item::<f32>(),
1057 101.216_19,
1058 abs <= 2.024_323_7
1059 );
1060 }
1061
1062 #[test]
1063 fn test_relu() {
1064 crate::random::seed(400).unwrap();
1065 let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
1066 assert_eq!(a.shape(), &[2, 8, 16]);
1067 assert_eq!(a.dtype(), Dtype::Float32);
1068 assert_float_eq!(
1069 a.mean(None).unwrap().item::<f32>(),
1070 0.478_322_74,
1071 abs <= 0.009_566_455
1072 );
1073 assert_float_eq!(
1074 a.sum(None).unwrap().item::<f32>(),
1075 122.450_62,
1076 abs <= 2.449_012_5
1077 );
1078 let result = Relu.forward(&a).unwrap();
1079 assert_eq!(result.shape(), &[2, 8, 16]);
1080 assert_eq!(result.dtype(), Dtype::Float32);
1081 assert_float_eq!(
1082 result.mean(None).unwrap().item::<f32>(),
1083 0.478_322_74,
1084 abs <= 0.009_566_455
1085 );
1086 assert_float_eq!(
1087 result.sum(None).unwrap().item::<f32>(),
1088 122.450_62,
1089 abs <= 2.449_012_5
1090 );
1091 }
1092
1093 #[test]
1094 fn test_leaky_relu() {
1095 crate::random::seed(93).unwrap();
1096 let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
1097 assert_eq!(a.shape(), &[2, 8, 16]);
1098 assert_eq!(a.dtype(), Dtype::Float32);
1099 assert_float_eq!(
1100 a.mean(None).unwrap().item::<f32>(),
1101 0.499_930_68,
1102 abs <= 0.009_998_614
1103 );
1104 assert_float_eq!(
1105 a.sum(None).unwrap().item::<f32>(),
1106 127.982_254,
1107 abs <= 2.559_645_2
1108 );
1109 let result = LeakyRelu::new().forward(&a).unwrap();
1110 assert_eq!(result.shape(), &[2, 8, 16]);
1111 assert_eq!(result.dtype(), Dtype::Float32);
1112 assert_float_eq!(
1113 result.mean(None).unwrap().item::<f32>(),
1114 0.499_930_68,
1115 abs <= 0.009_998_614
1116 );
1117 assert_float_eq!(
1118 result.sum(None).unwrap().item::<f32>(),
1119 127.982_254,
1120 abs <= 2.559_645_2
1121 );
1122 }
1123
1124 #[test]
1125 fn test_relu6() {
1126 crate::random::seed(379).unwrap();
1127 let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
1128 assert_eq!(a.shape(), &[2, 8, 16]);
1129 assert_eq!(a.dtype(), Dtype::Float32);
1130 assert_float_eq!(
1131 a.mean(None).unwrap().item::<f32>(),
1132 0.493_258_66,
1133 abs <= 0.009_865_173
1134 );
1135 assert_float_eq!(
1136 a.sum(None).unwrap().item::<f32>(),
1137 126.274_216,
1138 abs <= 2.525_484_3
1139 );
1140 let result = Relu6.forward(&a).unwrap();
1141 assert_eq!(result.shape(), &[2, 8, 16]);
1142 assert_eq!(result.dtype(), Dtype::Float32);
1143 assert_float_eq!(
1144 result.mean(None).unwrap().item::<f32>(),
1145 0.493_258_66,
1146 abs <= 0.009_865_173
1147 );
1148 assert_float_eq!(
1149 result.sum(None).unwrap().item::<f32>(),
1150 126.274_216,
1151 abs <= 2.525_484_3
1152 );
1153 }
1154
1155 #[test]
1156 fn test_softmax() {
1157 crate::random::seed(853).unwrap();
1158 let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
1159 assert_eq!(a.shape(), &[2, 8, 16]);
1160 assert_eq!(a.dtype(), Dtype::Float32);
1161 assert_float_eq!(
1162 a.mean(None).unwrap().item::<f32>(),
1163 0.514_396_3,
1164 abs <= 0.010_287_926_5
1165 );
1166 assert_float_eq!(
1167 a.sum(None).unwrap().item::<f32>(),
1168 131.685_46,
1169 abs <= 2.633_709_2
1170 );
1171 let result = Softmax::new().forward(&a).unwrap();
1172 assert_eq!(result.shape(), &[2, 8, 16]);
1173 assert_eq!(result.dtype(), Dtype::Float32);
1174 assert_float_eq!(
1175 result.mean(None).unwrap().item::<f32>(),
1176 0.062_499_996,
1177 abs <= 0.001_25
1178 );
1179 assert_float_eq!(
1180 result.sum(None).unwrap().item::<f32>(),
1181 15.999_999,
1182 abs <= 0.32
1183 );
1184 }
1185
1186 #[test]
1187 fn test_softplus() {
1188 crate::random::seed(118).unwrap();
1189 let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
1190 assert_eq!(a.shape(), &[2, 8, 16]);
1191 assert_eq!(a.dtype(), Dtype::Float32);
1192 assert_float_eq!(
1193 a.mean(None).unwrap().item::<f32>(),
1194 0.498_981_42,
1195 abs <= 0.009_979_628
1196 );
1197 assert_float_eq!(
1198 a.sum(None).unwrap().item::<f32>(),
1199 127.739_24,
1200 abs <= 2.554_784_8
1201 );
1202 let result = Softplus.forward(&a).unwrap();
1203 assert_eq!(result.shape(), &[2, 8, 16]);
1204 assert_eq!(result.dtype(), Dtype::Float32);
1205 assert_float_eq!(
1206 result.mean(None).unwrap().item::<f32>(),
1207 0.982_857_76,
1208 abs <= 0.019_657_155
1209 );
1210 assert_float_eq!(
1211 result.sum(None).unwrap().item::<f32>(),
1212 251.611_59,
1213 abs <= 5.032_232
1214 );
1215 }
1216
1217 #[test]
1218 fn test_softsign() {
1219 crate::random::seed(37).unwrap();
1220 let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
1221 assert_eq!(a.shape(), &[2, 8, 16]);
1222 assert_eq!(a.dtype(), Dtype::Float32);
1223 assert_float_eq!(
1224 a.mean(None).unwrap().item::<f32>(),
1225 0.506_551_27,
1226 abs <= 0.010_131_026
1227 );
1228 assert_float_eq!(
1229 a.sum(None).unwrap().item::<f32>(),
1230 129.677_12,
1231 abs <= 2.593_542_6
1232 );
1233 let result = Softsign.forward(&a).unwrap();
1234 assert_eq!(result.shape(), &[2, 8, 16]);
1235 assert_eq!(result.dtype(), Dtype::Float32);
1236 assert_float_eq!(
1237 result.mean(None).unwrap().item::<f32>(),
1238 0.314_089_83,
1239 abs <= 0.006_281_797
1240 );
1241 assert_float_eq!(
1242 result.sum(None).unwrap().item::<f32>(),
1243 80.407,
1244 abs <= 1.608_14
1245 );
1246 }
1247
1248 #[test]
1251 fn test_celu() {
1252 let x = array!([1.0, -1.0, 0.0]);
1253 let y = Celu::new().forward(&x).unwrap();
1254 let epsilon = array!(1e-4);
1255 let expected_y = array!([1.0, -0.6321, 0.0]);
1256 assert!(y
1257 .subtract(&expected_y)
1258 .unwrap()
1259 .abs()
1260 .unwrap()
1261 .lt(&epsilon)
1262 .unwrap()
1263 .all(None)
1264 .unwrap()
1265 .item::<bool>());
1266 assert_eq!(y.shape(), &[3]);
1267 assert_eq!(y.dtype(), Dtype::Float32);
1268
1269 let y = CeluBuilder::new()
1270 .alpha(1.1)
1271 .build()
1272 .unwrap()
1273 .forward(&x)
1274 .unwrap();
1275 let expected_y = array!([1.0, -0.6568, 0.0]);
1276 assert!(y
1277 .subtract(&expected_y)
1278 .unwrap()
1279 .abs()
1280 .unwrap()
1281 .lt(&epsilon)
1282 .unwrap()
1283 .all(None)
1284 .unwrap()
1285 .item::<bool>());
1286 assert_eq!(y.shape(), &[3]);
1287 assert_eq!(y.dtype(), Dtype::Float32);
1288 }
1289
1290 #[test]
1291 fn test_silu() {
1292 crate::random::seed(22).unwrap();
1293 let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
1294 assert_eq!(a.shape(), &[2, 8, 16]);
1295 assert_eq!(a.dtype(), Dtype::Float32);
1296 assert_float_eq!(
1297 a.mean(None).unwrap().item::<f32>(),
1298 0.502_970_6,
1299 abs <= 0.010_059_412
1300 );
1301 assert_float_eq!(
1302 a.sum(None).unwrap().item::<f32>(),
1303 128.760_47,
1304 abs <= 2.575_209_4
1305 );
1306 let result = Silu.forward(&a).unwrap();
1307 assert_eq!(result.shape(), &[2, 8, 16]);
1308 assert_eq!(result.dtype(), Dtype::Float32);
1309 assert_float_eq!(
1310 result.mean(None).unwrap().item::<f32>(),
1311 0.331_970_93,
1312 abs <= 0.006_639_418_7
1313 );
1314 assert_float_eq!(
1315 result.sum(None).unwrap().item::<f32>(),
1316 84.984_56,
1317 abs <= 1.699_691_2
1318 );
1319 }
1320
1321 #[test]
1322 fn test_log_softmax() {
1323 crate::random::seed(199).unwrap();
1324 let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
1325 assert_eq!(a.shape(), &[2, 8, 16]);
1326 assert_eq!(a.dtype(), Dtype::Float32);
1327 assert_float_eq!(
1328 a.mean(None).unwrap().item::<f32>(),
1329 0.527_843_7,
1330 abs <= 0.010_556_874
1331 );
1332 assert_float_eq!(
1333 a.sum(None).unwrap().item::<f32>(),
1334 135.127_99,
1335 abs <= 2.702_559_7
1336 );
1337 let result = LogSoftmax::new().forward(&a).unwrap();
1338 assert_eq!(result.shape(), &[2, 8, 16]);
1339 assert_eq!(result.dtype(), Dtype::Float32);
1340 assert_float_eq!(
1341 result.mean(None).unwrap().item::<f32>(),
1342 -2.810_954_6,
1343 abs <= 0.056_219_09
1344 );
1345 assert_float_eq!(
1346 result.sum(None).unwrap().item::<f32>(),
1347 -719.604_4,
1348 abs <= 14.392_087
1349 );
1350 }
1351
1352 #[test]
1353 fn test_log_sigmoid() {
1354 crate::random::seed(984).unwrap();
1355 let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
1356 assert_eq!(a.shape(), &[2, 8, 16]);
1357 assert_eq!(a.dtype(), Dtype::Float32);
1358 assert_float_eq!(
1359 a.mean(None).unwrap().item::<f32>(),
1360 0.510_977_7,
1361 abs <= 0.010_219_553_5
1362 );
1363 assert_float_eq!(
1364 a.sum(None).unwrap().item::<f32>(),
1365 130.810_29,
1366 abs <= 2.616_205_7
1367 );
1368 let result = LogSigmoid.forward(&a).unwrap();
1369 assert_eq!(result.shape(), &[2, 8, 16]);
1370 assert_eq!(result.dtype(), Dtype::Float32);
1371 assert_float_eq!(
1372 result.mean(None).unwrap().item::<f32>(),
1373 -0.479_598_55,
1374 abs <= 0.009_591_971
1375 );
1376 assert_float_eq!(
1377 result.sum(None).unwrap().item::<f32>(),
1378 -122.777_23,
1379 abs <= 2.455_544_5
1380 );
1381 }
1382
1383 #[test]
1384 fn test_prelu() {
1385 crate::random::seed(993).unwrap();
1386 let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
1387 assert_eq!(a.shape(), &[2, 8, 16]);
1388 assert_eq!(a.dtype(), Dtype::Float32);
1389 assert_float_eq!(
1390 a.mean(None).unwrap().item::<f32>(),
1391 0.496_651_44,
1392 abs <= 0.009_933_028
1393 );
1394 assert_float_eq!(
1395 a.sum(None).unwrap().item::<f32>(),
1396 127.142_77,
1397 abs <= 2.542_855_3
1398 );
1399 let result = Prelu::new().forward(&a).unwrap();
1400 assert_eq!(result.shape(), &[2, 8, 16]);
1401 assert_eq!(result.dtype(), Dtype::Float32);
1402 assert_float_eq!(
1403 result.mean(None).unwrap().item::<f32>(),
1404 0.496_651_44,
1405 abs <= 0.009_933_028
1406 );
1407 assert_float_eq!(
1408 result.sum(None).unwrap().item::<f32>(),
1409 127.142_77,
1410 abs <= 2.542_855_3
1411 );
1412 }
1413
1414 #[test]
1415 fn test_gelu() {
1416 crate::random::seed(189).unwrap();
1417 let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
1418 assert_eq!(a.shape(), &[2, 8, 16]);
1419 assert_eq!(a.dtype(), Dtype::Float32);
1420 assert_float_eq!(
1421 a.mean(None).unwrap().item::<f32>(),
1422 0.492_950_32,
1423 abs <= 0.009_859_007
1424 );
1425 assert_float_eq!(
1426 a.sum(None).unwrap().item::<f32>(),
1427 126.195_28,
1428 abs <= 2.523_905_8
1429 );
1430 let result = Gelu::new().forward(&a).unwrap();
1431 assert_eq!(result.shape(), &[2, 8, 16]);
1432 assert_eq!(result.dtype(), Dtype::Float32);
1433 assert_float_eq!(
1434 result.mean(None).unwrap().item::<f32>(),
1435 0.365_638_38,
1436 abs <= 0.007_312_767_7
1437 );
1438 assert_float_eq!(
1439 result.sum(None).unwrap().item::<f32>(),
1440 93.603_424,
1441 abs <= 1.872_068_5
1442 );
1443 }
1444
1445 #[test]
1446 fn test_tanh() {
1447 crate::random::seed(735).unwrap();
1448 let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
1449 assert_eq!(a.shape(), &[2, 8, 16]);
1450 assert_eq!(a.dtype(), Dtype::Float32);
1451 assert_float_eq!(
1452 a.mean(None).unwrap().item::<f32>(),
1453 0.474_122_7,
1454 abs <= 0.009_482_454_5
1455 );
1456 assert_float_eq!(
1457 a.sum(None).unwrap().item::<f32>(),
1458 121.375_41,
1459 abs <= 2.427_508_4
1460 );
1461 let result = Tanh.forward(&a).unwrap();
1462 assert_eq!(result.shape(), &[2, 8, 16]);
1463 assert_eq!(result.dtype(), Dtype::Float32);
1464 assert_float_eq!(
1465 result.mean(None).unwrap().item::<f32>(),
1466 0.413_079_68,
1467 abs <= 0.008_261_594
1468 );
1469 assert_float_eq!(
1470 result.sum(None).unwrap().item::<f32>(),
1471 105.748_4,
1472 abs <= 2.114_968
1473 );
1474 }
1475
1476 #[test]
1477 fn test_hardswish() {
1478 crate::random::seed(126).unwrap();
1479 let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
1480 assert_eq!(a.shape(), &[2, 8, 16]);
1481 assert_eq!(a.dtype(), Dtype::Float32);
1482 assert_float_eq!(
1483 a.mean(None).unwrap().item::<f32>(),
1484 0.491_892_46,
1485 abs <= 0.009_837_849
1486 );
1487 assert_float_eq!(
1488 a.sum(None).unwrap().item::<f32>(),
1489 125.924_47,
1490 abs <= 2.518_489_4
1491 );
1492 let result = HardSwish.forward(&a).unwrap();
1493 assert_eq!(result.shape(), &[2, 8, 16]);
1494 assert_eq!(result.dtype(), Dtype::Float32);
1495 assert_float_eq!(
1496 result.mean(None).unwrap().item::<f32>(),
1497 0.299_602_24,
1498 abs <= 0.005_992_044_7
1499 );
1500 assert_float_eq!(
1501 result.sum(None).unwrap().item::<f32>(),
1502 76.698_17,
1503 abs <= 1.533_963_4
1504 );
1505 }
1506
1507 #[test]
1508 fn test_step() {
1509 crate::random::seed(490).unwrap();
1510 let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
1511 assert_eq!(a.shape(), &[2, 8, 16]);
1512 assert_eq!(a.dtype(), Dtype::Float32);
1513 assert_float_eq!(
1514 a.mean(None).unwrap().item::<f32>(),
1515 0.479_360_64,
1516 abs <= 0.009_587_212_5
1517 );
1518 assert_float_eq!(
1519 a.sum(None).unwrap().item::<f32>(),
1520 122.716_324,
1521 abs <= 2.454_326_4
1522 );
1523 let result = Step::new().forward(&a).unwrap();
1524 assert_eq!(result.shape(), &[2, 8, 16]);
1525 assert_eq!(result.dtype(), Dtype::Int32);
1526 assert_float_eq!(result.mean(None).unwrap().item::<f32>(), 1.0, abs <= 0.02);
1527 assert_float_eq!(result.sum(None).unwrap().item::<f32>(), 256.0, abs <= 5.12);
1528 }
1529
1530 #[test]
1531 fn test_selu() {
1532 crate::random::seed(215).unwrap();
1533 let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
1534 assert_eq!(a.shape(), &[2, 8, 16]);
1535 assert_eq!(a.dtype(), Dtype::Float32);
1536 assert_float_eq!(
1537 a.mean(None).unwrap().item::<f32>(),
1538 0.493_026_8,
1539 abs <= 0.009_860_536
1540 );
1541 assert_float_eq!(
1542 a.sum(None).unwrap().item::<f32>(),
1543 126.214_86,
1544 abs <= 2.524_297_2
1545 );
1546 let result = Selu.forward(&a).unwrap();
1547 assert_eq!(result.shape(), &[2, 8, 16]);
1548 assert_eq!(result.dtype(), Dtype::Float32);
1549 assert_float_eq!(
1550 result.mean(None).unwrap().item::<f32>(),
1551 0.518_023_2,
1552 abs <= 0.010_360_463_5
1553 );
1554 assert_float_eq!(
1555 result.sum(None).unwrap().item::<f32>(),
1556 132.613_94,
1557 abs <= 2.652_278_7
1558 );
1559 }
1560}