mlx_rs/nn/
activation.rs

1use std::f32::consts::PI;
2
3use crate::module::{Module, Param};
4use crate::{
5    array,
6    error::{Exception, Result},
7    ops::{abs, exp, log_sum_exp, maximum, minimum, multiply, which},
8    transforms::compile::compile,
9    Array,
10};
11use mlx_internal_macros::{generate_builder, Buildable, Builder};
12use mlx_macros::ModuleParameters;
13
14/// Applies the element-wise sigmoid logistic sigmoid.
15///
16/// For details, please see
17/// [this documentation](https://ml-explore.github.io/mlx/build/html/python/_autosummary/mlx.core.sigmoid.html)
18///
19/// This is:
20///
21/// ```rust, ignore
22/// sigmoid(x)
23/// ```
24pub fn sigmoid(x: impl AsRef<Array>) -> Result<Array> {
25    crate::ops::sigmoid(x.as_ref())
26}
27
28/// Applies the Rectified Linear Unit.
29///
30/// This is:
31///
32/// ```rust, ignore
33/// maximum(x, 0)
34/// ```
35pub fn relu(x: impl AsRef<Array>) -> Result<Array> {
36    crate::ops::maximum(x.as_ref(), &array!(0))
37}
38
39/// Applies the Leaky Rectified Linear Unit.
40///
41/// `neg_slope` is default to 0.01 if not provided.
42///
43/// This is:
44///
45/// ```rust, ignore
46/// maximum(neg_slope * x, x)
47/// ```
48pub fn leaky_relu(x: impl AsRef<Array>, neg_slope: impl Into<Option<f32>>) -> Result<Array> {
49    let neg_slope = array!(neg_slope.into().unwrap_or(0.01));
50    // We have to use this indirection, otherwise the compiler cannot
51    // infer the lifetime of the value returned by the closure properly
52    compiled_leaky_relu(x.as_ref(), &neg_slope)
53}
54
55/// Applies the Log Softmax function.
56///
57/// This is:
58///
59/// ```rust, ignore
60/// x - log_sum_exp(x, axis, true)
61/// ```
62pub fn log_softmax(x: impl AsRef<Array>, axis: impl Into<Option<i32>>) -> Result<Array> {
63    let x = x.as_ref();
64    let axis = axis.into().unwrap_or(-1);
65    x.subtract(log_sum_exp(x, &[axis], true)?)
66}
67
68/// Applies the Exponential Linear Unit.
69///
70/// This is:
71///
72/// ```rust, ignore
73/// which(x.gt(0), x, alpha * (exp(x) - 1))
74/// ```
75///
76/// # Params
77///
78/// - `x`: The input array
79/// - `alpha`: Default to 1.0 if not provided
80pub fn elu(x: impl AsRef<Array>, alpha: impl Into<Option<f32>>) -> Result<Array> {
81    let alpha = array!(alpha.into().unwrap_or(1.0));
82    // We have to use this indirection, otherwise the compiler cannot
83    // infer the lifetime of the value returned by the closure properly
84    compiled_elu(x.as_ref(), &alpha)
85}
86
87/// Applies the Rectified Linear Unit 6.
88///
89/// This is:
90///
91/// ```rust, ignore
92/// minimum(maximum(x, 0), 6)
93/// ```
94pub fn relu6(x: impl AsRef<Array>) -> Result<Array> {
95    compiled_relu6(x.as_ref())
96}
97
98/// Applies the Exponential Linear Unit.
99///
100/// This is:
101///
102/// ```rust, ignore
103/// log_add_exp(x, 0)
104/// ```
105pub fn softplus(x: impl AsRef<Array>) -> Result<Array> {
106    crate::ops::log_add_exp(x.as_ref(), &array!(0))
107}
108
109/// Applies the Softsign function.
110///
111/// This is:
112///
113/// ```rust, ignore
114/// x / (1 + abs(x))
115/// ```
116pub fn softsign(x: impl AsRef<Array>) -> Result<Array> {
117    compiled_softsign(x.as_ref())
118}
119
120/// Applies the Continuously Differentiable Exponential Linear Unit.
121///
122/// This is:
123///
124/// ```rust, ignore
125/// maximum(x, 0) + alpha * (exp(minimum(x, 0) / alpha) - 1)
126/// ```
127pub fn celu(x: impl AsRef<Array>, alpha: impl Into<Option<f32>>) -> Result<Array> {
128    let alpha = array!(alpha.into().unwrap_or(1.0));
129    // We have to use this indirection, otherwise the compiler cannot
130    // infer the lifetime of the value returned by the closure properly
131    compiled_celu(x.as_ref(), &alpha)
132}
133
134/// Applies the Sigmoid Linear Unit. Also known as Swish.
135///
136/// This is:
137///
138/// ```rust, ignore
139/// x * sigmoid(x)
140/// ```
141pub fn silu(x: impl AsRef<Array>) -> Result<Array> {
142    compiled_silu(x.as_ref())
143}
144
145/// Applies the Log Sigmoid function.
146///
147/// This is:
148///
149/// ```rust, ignore
150/// -softplus(-x)
151/// ```
152pub fn log_sigmoid(x: impl AsRef<Array>) -> Result<Array> {
153    compiled_log_sigmoid(x.as_ref())
154}
155
156/// Applies the Gaussian Error Linear Units function.
157///
158/// This is:
159///
160/// ```rust, ignore
161/// x * (1 + erf(x / 2.sqrt())) / 2
162/// ```
163pub fn gelu(x: impl AsRef<Array>) -> Result<Array> {
164    compiled_gelu(x.as_ref())
165}
166
167/// An approximation to Gaussian Error Linear Unit.
168///
169/// This is:
170///
171/// ```rust, ignore
172/// 0.5 * x * (1 + tanh(sqrt(2 / PI) * (x + 0.044715 * x ** 3)))
173/// ```
174pub fn gelu_approximate(x: impl AsRef<Array>) -> Result<Array> {
175    compiled_gelu_approximate(x.as_ref())
176}
177
178/// A fast approximation to Gaussian Error Linear Unit.
179///
180/// This is:
181///
182/// ```rust, ignore
183/// x * sigmoid(1.773 * x)
184/// ```
185pub fn gelu_fast_approximate(x: impl AsRef<Array>) -> Result<Array> {
186    compiled_gelu_fast_approximate(x.as_ref())
187}
188
189/// Applies the gated linear unit function.
190///
191/// This function splits the `axis` dimension of the input into two halves
192/// (`a` and `b`) and applies `a * sigmoid(b)`.
193pub fn glu(x: impl AsRef<Array>, axis: impl Into<Option<i32>>) -> Result<Array> {
194    let split = x.as_ref().split_equal(2, axis)?;
195    let (a, b) = (&split[0], &split[1]);
196    Ok(a * sigmoid(b)?)
197}
198
199/// Applies the Step Activation Function.
200///
201/// This function implements a binary step activation, where the output is set
202/// to 1 if the input is greater than a specified threshold, and 0 otherwise.
203///
204/// This is:
205///
206/// ```rust, ignore
207/// r#where(x.gt(threshold), 1, 0)
208/// ```
209pub fn step(x: impl AsRef<Array>, threshold: impl Into<Option<f32>>) -> Result<Array> {
210    let threshold = array!(threshold.into().unwrap_or(0.0));
211    crate::ops::r#where(&x.as_ref().gt(threshold)?, &array!(1), &array!(0))
212}
213
214/// Applies the Scaled Exponential Linear Unit.
215///
216/// This is:
217///
218/// ```rust, ignore
219/// elu(x, 1.67326) * 1.0507
220/// ```
221pub fn selu(x: impl AsRef<Array>) -> Result<Array> {
222    compiled_selu(x.as_ref())
223}
224
225/// Applies the element-wise parametric ReLU.
226///
227/// This is:
228///
229/// ```rust, ignore
230/// maximum(0, x) + alpha * minimum(0, x)
231/// ```
232pub fn prelu(x: impl AsRef<Array>, alpha: impl AsRef<Array>) -> Result<Array> {
233    compiled_prelu(x.as_ref(), alpha.as_ref())
234}
235
236/// Applies the Mish function, element-wise.
237///
238/// Mish: A Self Regularized Non-Monotonic Neural Activation Function.
239///
240/// Reference: [https://arxiv.org/abs/1908.08681](https://arxiv.org/abs/1908.08681)
241///
242/// This is:
243///
244/// ```rust, ignore
245/// x * tanh(softplus(x))
246/// ```
247pub fn mish(x: impl AsRef<Array>) -> Result<Array> {
248    compiled_mish(x.as_ref())
249}
250
251/// Applies the hardswish function, element-wise.
252///
253/// This is:
254///
255/// ```rust, ignore
256/// x * minimum(maximum(x + 3, 0), 6) / 6
257/// ```
258pub fn hard_swish(x: impl AsRef<Array>) -> Result<Array> {
259    compiled_hard_swish(x.as_ref())
260}
261
262generate_builder! {
263    /// Applies the gated linear unit function.
264    ///
265    /// This splits the `axis` dimension of the input into two halves
266    /// (`a` and `b`) and applies `a * sigmoid(b)`.
267    #[derive(Debug, Clone, ModuleParameters, Buildable)]
268    #[module(root = crate)]
269    #[buildable(root = crate)]
270    #[builder(root = crate)]
271    pub struct Glu {
272        /// The axis to split the input tensor. Default to [`Glu::DEFAULT_AXIS`] if not provided.
273        #[builder(optional, default = Glu::DEFAULT_AXIS)]
274        pub axis: i32,
275    }
276}
277
278impl Glu {
279    /// The default axis value.
280    pub const DEFAULT_AXIS: i32 = -1;
281}
282
283impl Module<&Array> for Glu {
284    type Error = Exception;
285    type Output = Array;
286
287    fn forward(&mut self, x: &Array) -> Result<Array> {
288        glu(x, self.axis)
289    }
290
291    fn training_mode(&mut self, _: bool) {}
292}
293
294/// Applies the element-wise logistic sigmoid.
295///
296/// For details, please see
297/// [this documentation](https://ml-explore.github.io/mlx/build/html/python/_autosummary/mlx.core.sigmoid.html)
298///
299/// This is:
300///
301/// ```rust, ignore
302/// sigmoid(x)
303/// ```
304#[derive(Debug, Clone, ModuleParameters)]
305#[module(root = crate)]
306pub struct Sigmoid;
307
308impl Module<&Array> for Sigmoid {
309    type Error = Exception;
310    type Output = Array;
311
312    fn forward(&mut self, x: &Array) -> Result<Array> {
313        sigmoid(x)
314    }
315
316    fn training_mode(&mut self, _: bool) {}
317}
318
319/// Applies the Mish function, element-wise.
320///
321/// Mish: A Self Regularized Non-Monotonic Neural Activation Function.
322///
323/// Reference: [https://arxiv.org/abs/1908.08681](https://arxiv.org/abs/1908.08681)
324///
325/// This is:
326///
327/// ```rust, ignore
328/// x * tanh(softplus(x))
329/// ```
330#[derive(Debug, Clone, ModuleParameters)]
331#[module(root = crate)]
332pub struct Mish;
333
334impl Module<&Array> for Mish {
335    type Error = Exception;
336    type Output = Array;
337
338    fn forward(&mut self, x: &Array) -> Result<Array> {
339        mish(x)
340    }
341
342    fn training_mode(&mut self, _: bool) {}
343}
344
345/// Applies the Rectified Linear Unit.
346///
347/// This is:
348///
349/// ```rust, ignore
350/// maximum(x, 0)
351/// ```
352#[derive(Debug, Clone, ModuleParameters)]
353#[module(root = crate)]
354pub struct Relu;
355
356impl Module<&Array> for Relu {
357    type Error = Exception;
358    type Output = Array;
359
360    fn forward(&mut self, x: &Array) -> Result<Array> {
361        relu(x)
362    }
363
364    fn training_mode(&mut self, _: bool) {}
365}
366
367generate_builder! {
368    /// Applies the Leaky Rectified Linear Unit.
369    ///
370    /// This is:
371    ///
372    /// ```rust, ignore
373    /// maximum(neg_slope * x, x)
374    /// ```
375    #[derive(Debug, Clone, ModuleParameters, Buildable)]
376    #[module(root = crate)]
377    #[buildable(root = crate)]
378    #[builder(root = crate)]
379    pub struct LeakyRelu {
380        /// The negative slope. Default to [`LeakyReLU::DEFAULT_NEG_SLOPE`] if not provided.
381        #[builder(optional, default = LeakyRelu::DEFAULT_NEG_SLOPE)]
382        pub neg_slope: f32,
383    }
384}
385
386impl LeakyRelu {
387    /// The default negative slope value.
388    pub const DEFAULT_NEG_SLOPE: f32 = 0.01;
389}
390
391impl Module<&Array> for LeakyRelu {
392    type Error = Exception;
393    type Output = Array;
394
395    fn forward(&mut self, x: &Array) -> Result<Array> {
396        leaky_relu(x, self.neg_slope)
397    }
398
399    fn training_mode(&mut self, _: bool) {}
400}
401
402/// Applies the Rectified Linear Unit 6.
403///
404/// This is:
405///
406/// ```rust, ignore
407/// minimum(&maximum(x, 0).unwrap(), 6).unwrap()
408/// ```
409#[derive(Debug, Clone, ModuleParameters)]
410#[module(root = crate)]
411pub struct Relu6;
412
413impl Module<&Array> for Relu6 {
414    type Error = Exception;
415    type Output = Array;
416
417    fn forward(&mut self, x: &Array) -> Result<Array> {
418        relu6(x)
419    }
420
421    fn training_mode(&mut self, _: bool) {}
422}
423
424generate_builder! {
425    /// Applies the Softmax function.
426    ///
427    /// This is:
428    ///
429    /// ```rust, ignore
430    /// softmax(&x, None, None)
431    /// ```
432    #[derive(Debug, Clone, ModuleParameters, Buildable)]
433    #[module(root = crate)]
434    #[buildable(root = crate)]
435    #[builder(root = crate)]
436    pub struct Softmax {
437        /// The axis to apply the softmax.
438        #[builder(optional, default = Softmax::DEFAULT_AXIS)]
439        pub axis: i32,
440    }
441}
442
443impl Softmax {
444    /// The default axis value.
445    pub const DEFAULT_AXIS: i32 = -1;
446}
447
448impl Module<&Array> for Softmax {
449    type Error = Exception;
450    type Output = Array;
451
452    fn forward(&mut self, x: &Array) -> Result<Array> {
453        crate::ops::softmax(x, &[self.axis], None)
454    }
455
456    fn training_mode(&mut self, _: bool) {}
457}
458
459/// Applies the Softplus function.
460///
461/// This is:
462///
463/// ```rust, ignore
464/// log_add_exp(x, 0)
465/// ```
466#[derive(Debug, Clone, ModuleParameters)]
467#[module(root = crate)]
468pub struct Softplus;
469
470impl Module<&Array> for Softplus {
471    type Error = Exception;
472    type Output = Array;
473
474    fn forward(&mut self, x: &Array) -> Result<Array> {
475        softplus(x)
476    }
477
478    fn training_mode(&mut self, _: bool) {}
479}
480
481/// Applies the Softsign function.
482///
483/// This is:
484///
485/// ```rust, ignore
486/// x / (array!(1) + abs(x)
487/// ```
488#[derive(Debug, Clone, ModuleParameters)]
489#[module(root = crate)]
490pub struct Softsign;
491
492impl Module<&Array> for Softsign {
493    type Error = Exception;
494    type Output = Array;
495
496    fn forward(&mut self, x: &Array) -> Result<Array> {
497        softsign(x)
498    }
499
500    fn training_mode(&mut self, _: bool) {}
501}
502
503generate_builder! {
504    /// Applies the Continuously Differentiable Exponential Linear Unit.
505    ///
506    /// This is:
507    ///
508    /// ```rust, ignore
509    /// maximum(x, 0.0).unwrap()
510    ///     + alpha * (exp(&(minimum(x, 0.0).unwrap() / alpha)) - 1)
511    /// ```
512    #[derive(Debug, Clone, ModuleParameters, Buildable)]
513    #[module(root = crate)]
514    #[buildable(root = crate)]
515    #[builder(root = crate)]
516    pub struct Celu {
517        /// The alpha value. Default to [`Celu::DEFAULT_ALPHA`] if not provided.
518        #[builder(optional, default = Celu::DEFAULT_ALPHA)]
519        pub alpha: f32,
520    }
521}
522
523impl Celu {
524    /// The default alpha value.
525    pub const DEFAULT_ALPHA: f32 = 1.0;
526}
527
528impl Module<&Array> for Celu {
529    type Error = Exception;
530    type Output = Array;
531
532    fn forward(&mut self, x: &Array) -> Result<Array> {
533        celu(x, self.alpha)
534    }
535
536    fn training_mode(&mut self, _: bool) {}
537}
538
539/// Applies the Sigmoid Linear Unit. Also known as Swish.
540///
541/// This is:
542///
543/// ```rust, ignore
544/// x * sigmoid(x)
545/// ```
546#[derive(Debug, Clone, ModuleParameters)]
547#[module(root = crate)]
548pub struct Silu;
549
550impl Module<&Array> for Silu {
551    type Error = Exception;
552    type Output = Array;
553
554    fn forward(&mut self, x: &Array) -> Result<Array> {
555        silu(x)
556    }
557
558    fn training_mode(&mut self, _: bool) {}
559}
560
561generate_builder! {
562    /// Applies the Log Softmax function.
563    ///
564    /// This is:
565    ///
566    /// ```rust, ignore
567    /// x - log_sum_exp(x, axis, true)
568    /// ```
569    #[derive(Debug, Clone, ModuleParameters, Buildable)]
570    #[module(root = crate)]
571    #[buildable(root = crate)]
572    #[builder(root = crate)]
573    pub struct LogSoftmax {
574        /// The axis value. Default to [`LogSoftmax::DEFAULT_AXIS`] if not provided.
575        #[builder(optional, default = LogSoftmax::DEFAULT_AXIS)]
576        pub axis: i32,
577    }
578}
579
580impl LogSoftmax {
581    /// The default axis value.
582    pub const DEFAULT_AXIS: i32 = -1;
583}
584
585impl Module<&Array> for LogSoftmax {
586    type Error = Exception;
587    type Output = Array;
588
589    fn forward(&mut self, x: &Array) -> Result<Array> {
590        log_softmax(x, self.axis)
591    }
592
593    fn training_mode(&mut self, _: bool) {}
594}
595
596/// Applies the Log Sigmoid function.
597///
598/// This is:
599///
600/// ```rust, ignore
601/// -softplus(-x)
602/// ```
603#[derive(Debug, Clone, ModuleParameters)]
604#[module(root = crate)]
605pub struct LogSigmoid;
606
607impl Module<&Array> for LogSigmoid {
608    type Error = Exception;
609    type Output = Array;
610
611    fn forward(&mut self, x: &Array) -> Result<Array> {
612        log_sigmoid(x)
613    }
614
615    fn training_mode(&mut self, _: bool) {}
616}
617
618/// Applies the element-wise parametric ReLU.
619///
620/// This is:
621///
622/// ```rust, ignore
623/// maximum(0, x) + alpha * minimum(0, x)
624/// ```
625#[derive(Debug, Clone, ModuleParameters, Buildable)]
626#[module(root = crate)]
627#[buildable(root = crate)]
628pub struct Prelu {
629    /// The alpha value. See [`prelu`] for more details.
630    #[param]
631    #[builder(ignore)]
632    pub weight: Param<Array>, // TODO: double check if this is trainable
633}
634
635/// The builder for the Prelu module.
636#[derive(Debug, Clone, Builder)]
637#[builder(
638    root = crate,
639    build_with = build_prelu,
640    default_infallible,
641    err = Exception,
642)]
643pub struct PreluBuilder {
644    /// The count. Default to [`Prelu::DEFAULT_COUNT`] if not provided.
645    #[builder(optional, default = Prelu::DEFAULT_COUNT)]
646    pub count: i32,
647
648    /// The value. Default to [`Prelu::DEFAULT_VALUE`] if not provided.
649    #[builder(optional, default = Prelu::DEFAULT_VALUE)]
650    pub value: f32,
651}
652
653/// Builds the Prelu module.
654fn build_prelu(builder: PreluBuilder) -> Result<Prelu> {
655    let count = builder.count;
656    let value = builder.value;
657    let weight = Param::new(crate::ops::full::<f32>(&[count], &array!(value))?);
658    Ok(Prelu { weight })
659}
660
661impl Prelu {
662    /// The default count value.
663    pub const DEFAULT_COUNT: i32 = 1;
664
665    /// The default value.
666    pub const DEFAULT_VALUE: f32 = 0.25;
667}
668
669impl Module<&Array> for Prelu {
670    type Error = Exception;
671    type Output = Array;
672
673    fn forward(&mut self, x: &Array) -> Result<Array> {
674        prelu(x, &self.weight)
675    }
676
677    fn training_mode(&mut self, _: bool) {}
678}
679
680/// Variants of Gaussian Error Linear Units function.
681#[derive(Debug, Clone, Copy, Default)]
682pub enum GeluApprox {
683    /// Uses [`gelu`]
684    #[default]
685    None,
686
687    /// Uses [`gelu_approximate`]
688    Precise,
689
690    /// Uses [`gelu_fast_approximate`]
691    Fast,
692}
693
694generate_builder! {
695    /// Applies the Gaussian Error Linear Units function.
696    ///
697    /// There are three variants:
698    ///
699    /// - `GeluApprox::None`: Uses [`gelu`]. This is the default.
700    /// - `GeluApprox::Precise`: Uses [`gelu_approximate`]
701    /// - `GeluApprox::Fast`: Uses [`gelu_fast_approximate`]
702    #[derive(Debug, Clone, ModuleParameters, Buildable)]
703    #[module(root = crate)]
704    #[buildable(root = crate)]
705    #[builder(root = crate)]
706    pub struct Gelu {
707        /// The approximation to use. Default to `GeluApprox::None` if not provided.
708        #[builder(optional, default = GeluApprox::None)]
709        pub approximate: GeluApprox,
710    }
711}
712
713impl Module<&Array> for Gelu {
714    type Error = Exception;
715    type Output = Array;
716
717    fn forward(&mut self, x: &Array) -> Result<Array> {
718        match self.approximate {
719            GeluApprox::None => gelu(x),
720            GeluApprox::Precise => gelu_approximate(x),
721            GeluApprox::Fast => gelu_fast_approximate(x),
722        }
723    }
724
725    fn training_mode(&mut self, _: bool) {}
726}
727
728/// Applies the hyperbolic tangent function
729#[derive(Debug, Clone, ModuleParameters)]
730#[module(root = crate)]
731pub struct Tanh;
732
733impl Module<&Array> for Tanh {
734    type Error = Exception;
735    type Output = Array;
736
737    fn forward(&mut self, x: &Array) -> Result<Array> {
738        crate::ops::tanh(x)
739    }
740
741    fn training_mode(&mut self, _: bool) {}
742}
743
744/// Applies the hardswish function, element-wise
745///
746/// This is:
747///
748/// ```rust, ignore
749/// x * minimum(maximum(x + 3, 0), 6) / 6
750/// ```
751#[derive(Debug, Clone, ModuleParameters)]
752#[module(root = crate)]
753pub struct HardSwish;
754
755impl Module<&Array> for HardSwish {
756    type Error = Exception;
757    type Output = Array;
758
759    fn forward(&mut self, x: &Array) -> Result<Array> {
760        hard_swish(x)
761    }
762
763    fn training_mode(&mut self, _: bool) {}
764}
765
766generate_builder! {
767    /// Applies the Step Activation Function.
768    ///
769    /// This function implements a binary step activation, where the output is set
770    /// to 1 if the input is greater than a specified threshold, and 0 otherwise.
771    ///
772    /// This is:
773    ///
774    /// ```rust, ignore
775    /// r#where(x.gt(threshold), 1, 0)
776    /// ```
777    #[derive(Debug, Clone, ModuleParameters, Buildable)]
778    #[module(root = crate)]
779    #[buildable(root = crate)]
780    #[builder(root = crate)]
781    pub struct Step {
782        /// The threshold value. Default to [`Step::DEFAULT_THRESHOLD`] if not provided.
783        #[builder(optional, default = Step::DEFAULT_THRESHOLD)]
784        pub threshold: f32,
785    }
786}
787
788impl Step {
789    /// The default threshold value.
790    pub const DEFAULT_THRESHOLD: f32 = 0.0;
791}
792
793impl Module<&Array> for Step {
794    type Error = Exception;
795    type Output = Array;
796
797    fn forward(&mut self, x: &Array) -> Result<Array> {
798        step(x, self.threshold)
799    }
800
801    fn training_mode(&mut self, _: bool) {}
802}
803
804/// Applies the Scaled Exponential Linear Unit.
805///
806/// This is:
807///
808/// ```rust, ignore
809/// elu(x, 1.67326) * 1.0507
810/// ```
811#[derive(Debug, Clone, ModuleParameters)]
812#[module(root = crate)]
813pub struct Selu;
814
815impl Module<&Array> for Selu {
816    type Error = Exception;
817    type Output = Array;
818
819    fn forward(&mut self, x: &Array) -> Result<Array> {
820        selu(x)
821    }
822
823    fn training_mode(&mut self, _: bool) {}
824}
825
826/* -------------------------------------------------------------------------- */
827/*                        Compiled activation functions                       */
828/* -------------------------------------------------------------------------- */
829
830#[inline]
831fn compiled_leaky_relu(x: &Array, neg_slope: &Array) -> Result<Array> {
832    let f = |(x_, neg_slope_): (&Array, &Array)| {
833        // This will not panic because a scalar can always be broadcasted to any shape
834        let a = multiply(neg_slope_, x_)?;
835        maximum(&a, x_)
836    };
837    let mut compiled = compile(f, true);
838    compiled((x, neg_slope))
839}
840
841#[inline]
842fn compiled_elu(x: &Array, alpha: &Array) -> Result<Array> {
843    let f = |(x_, alpha_): (&Array, &Array)| {
844        which(&x_.gt(&array!(0.0))?, x_, alpha_ * (exp(x_)? - array!(1.0)))
845    };
846    let mut compiled = compile(f, true);
847    compiled((x, alpha))
848}
849
850#[inline]
851fn compiled_relu6(x: &Array) -> Result<Array> {
852    let f = |x_: &Array| minimum(maximum(x_, &array!(0.0))?, &array!(6.0));
853    let mut compiled = compile(f, true);
854    compiled(x)
855}
856
857#[inline]
858fn compiled_softsign(x: &Array) -> Result<Array> {
859    let f = |x_: &Array| x_.divide(array!(1.0) + abs(x_)?);
860    let mut compiled = compile(f, true);
861    compiled(x)
862}
863
864#[inline]
865fn compiled_celu(x: &Array, alpha: &Array) -> Result<Array> {
866    let f = |(x_, alpha_): (&Array, &Array)| {
867        maximum(x_, &array!(0.0))?
868            .add(alpha_.multiply(exp(&(minimum(x_, &array!(0.0))? / alpha_))? - array!(1.0))?)
869    };
870    let mut compiled = compile(f, true);
871    compiled((x, alpha))
872}
873
874#[inline]
875fn compiled_silu(x: &Array) -> Result<Array> {
876    let f = |x_: &Array| x_.multiply(sigmoid(x_)?);
877    let mut compiled = compile(f, true);
878    compiled(x)
879}
880
881#[inline]
882fn compiled_log_sigmoid(x: &Array) -> Result<Array> {
883    let f = |x_: &Array| Ok(-softplus(&(-x_))?);
884    let mut compiled = compile(f, true);
885    compiled(x)
886}
887
888#[inline]
889fn compiled_gelu(x: &Array) -> Result<Array> {
890    use crate::ops::erf;
891    let f = |x_: &Array| {
892        x_.multiply(array!(1) + erf(&(x_ / array!(2f32.sqrt())))?)?
893            .divide(array!(2.0))
894    };
895    let mut compiled = compile(f, true);
896    compiled(x)
897}
898
899#[inline]
900fn compiled_gelu_approximate(x: &Array) -> Result<Array> {
901    use crate::ops::{sqrt, tanh};
902
903    let f = move |x_: &Array| {
904        // 0.5 * x * (1 + tanh(sqrt(2 / Float.pi) * (x + 0.044715 * x ** 3)))
905        array!(0.5).multiply(x_)?.multiply(
906            array!(1.0).add(tanh(
907                &(sqrt(&array!(2.0 / PI))?
908                    .multiply(x_ + array!(0.044715).multiply(x_.power(&array!(3))?)?)?),
909            )?)?,
910        )
911    };
912    let mut compiled = compile(f, true);
913    compiled(x)
914}
915
916#[inline]
917fn compiled_gelu_fast_approximate(x: &Array) -> Result<Array> {
918    let f = |x_: &Array| x_.multiply(sigmoid(&(array!(1.773) * x_))?);
919    let mut compiled = compile(f, true);
920    compiled(x)
921}
922
923#[inline]
924fn compiled_selu(x: &Array) -> Result<Array> {
925    let f = |x_: &Array| elu(x_, 1.67326)?.multiply(array!(1.0507));
926    let mut compiled = compile(f, true);
927    compiled(x)
928}
929
930#[inline]
931fn compiled_prelu(x: &Array, alpha: &Array) -> Result<Array> {
932    let f = |(x_, alpha_): (&Array, &Array)| {
933        maximum(&array!(0.0), x_)?.add(alpha_ * minimum(&array!(0.0), x_)?)
934    };
935    let mut compiled = compile(f, true);
936    compiled((x, alpha))
937}
938
939#[inline]
940fn compiled_mish(x: &Array) -> Result<Array> {
941    use crate::ops::tanh;
942
943    let f = |x_: &Array| x_.multiply(tanh(&softplus(x_)?)?);
944    let mut compiled = compile(f, true);
945    compiled(x)
946}
947
948#[inline]
949fn compiled_hard_swish(x: &Array) -> Result<Array> {
950    let f = |x_: &Array| {
951        let max_x_plus_3 = maximum(&(x_ + array!(3.0)), &array!(0.0))?;
952        x_.multiply(minimum(&max_x_plus_3, &array!(6.0))?)?
953            .divide(&array!(6.0))
954    };
955    let mut compiled = compile(f, true);
956    compiled(x)
957}
958
959// The following tests are ported from the swift binding:
960// mlx-swift/Tests/MLXTests/IntegrationTests.swift
961#[cfg(test)]
962mod tests {
963    use crate::{builder::Builder, random::uniform, Dtype};
964    use float_eq::assert_float_eq;
965
966    use super::*;
967
968    #[test]
969    fn test_glu() {
970        crate::random::seed(850).unwrap();
971        let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
972        assert_eq!(a.shape(), &[2, 8, 16]);
973        assert_eq!(a.dtype(), Dtype::Float32);
974        assert_float_eq!(
975            a.mean(None, None).unwrap().item::<f32>(),
976            0.547_252_66,
977            abs <= 0.010_945_053
978        );
979        assert_float_eq!(
980            a.sum(None, None).unwrap().item::<f32>(),
981            140.096_68,
982            abs <= 2.801_933_5
983        );
984        let result = Glu::new().forward(&a).unwrap();
985        assert_eq!(result.shape(), &[2, 8, 8]);
986        assert_eq!(result.dtype(), Dtype::Float32);
987        assert_float_eq!(
988            result.mean(None, None).unwrap().item::<f32>(),
989            0.333_276_75,
990            abs <= 0.006_665_535
991        );
992        assert_float_eq!(
993            result.sum(None, None).unwrap().item::<f32>(),
994            42.659_424,
995            abs <= 0.853_188_46
996        );
997    }
998
999    #[test]
1000    fn test_sigmoid() {
1001        crate::random::seed(589).unwrap();
1002        let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
1003        assert_eq!(a.shape(), &[2, 8, 16]);
1004        assert_eq!(a.dtype(), Dtype::Float32);
1005        assert_float_eq!(
1006            a.mean(None, None).unwrap().item::<f32>(),
1007            0.529_697_9,
1008            abs <= 0.010_593_958
1009        );
1010        assert_float_eq!(
1011            a.sum(None, None).unwrap().item::<f32>(),
1012            135.602_66,
1013            abs <= 2.712_053_3
1014        );
1015        let result = Sigmoid.forward(&a).unwrap();
1016        assert_eq!(result.shape(), &[2, 8, 16]);
1017        assert_eq!(result.dtype(), Dtype::Float32);
1018        assert_float_eq!(
1019            result.mean(None, None).unwrap().item::<f32>(),
1020            0.627_014,
1021            abs <= 0.012_540_28
1022        );
1023        assert_float_eq!(
1024            result.sum(None, None).unwrap().item::<f32>(),
1025            160.515_58,
1026            abs <= 3.210_311_7
1027        );
1028    }
1029
1030    #[test]
1031    fn test_mish() {
1032        crate::random::seed(122).unwrap();
1033        let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
1034        assert_eq!(a.shape(), &[2, 8, 16]);
1035        assert_eq!(a.dtype(), Dtype::Float32);
1036        assert_float_eq!(
1037            a.mean(None, None).unwrap().item::<f32>(),
1038            0.501_719_8,
1039            abs <= 0.010_034_395
1040        );
1041        assert_float_eq!(
1042            a.sum(None, None).unwrap().item::<f32>(),
1043            128.440_26,
1044            abs <= 2.568_805_2
1045        );
1046        let result = Mish.forward(&a).unwrap();
1047        assert_eq!(result.shape(), &[2, 8, 16]);
1048        assert_eq!(result.dtype(), Dtype::Float32);
1049        assert_float_eq!(
1050            result.mean(None, None).unwrap().item::<f32>(),
1051            0.395_375_73,
1052            abs <= 0.007_907_514
1053        );
1054        assert_float_eq!(
1055            result.sum(None, None).unwrap().item::<f32>(),
1056            101.216_19,
1057            abs <= 2.024_323_7
1058        );
1059    }
1060
1061    #[test]
1062    fn test_relu() {
1063        crate::random::seed(400).unwrap();
1064        let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
1065        assert_eq!(a.shape(), &[2, 8, 16]);
1066        assert_eq!(a.dtype(), Dtype::Float32);
1067        assert_float_eq!(
1068            a.mean(None, None).unwrap().item::<f32>(),
1069            0.478_322_74,
1070            abs <= 0.009_566_455
1071        );
1072        assert_float_eq!(
1073            a.sum(None, None).unwrap().item::<f32>(),
1074            122.450_62,
1075            abs <= 2.449_012_5
1076        );
1077        let result = Relu.forward(&a).unwrap();
1078        assert_eq!(result.shape(), &[2, 8, 16]);
1079        assert_eq!(result.dtype(), Dtype::Float32);
1080        assert_float_eq!(
1081            result.mean(None, None).unwrap().item::<f32>(),
1082            0.478_322_74,
1083            abs <= 0.009_566_455
1084        );
1085        assert_float_eq!(
1086            result.sum(None, None).unwrap().item::<f32>(),
1087            122.450_62,
1088            abs <= 2.449_012_5
1089        );
1090    }
1091
1092    #[test]
1093    fn test_leaky_relu() {
1094        crate::random::seed(93).unwrap();
1095        let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
1096        assert_eq!(a.shape(), &[2, 8, 16]);
1097        assert_eq!(a.dtype(), Dtype::Float32);
1098        assert_float_eq!(
1099            a.mean(None, None).unwrap().item::<f32>(),
1100            0.499_930_68,
1101            abs <= 0.009_998_614
1102        );
1103        assert_float_eq!(
1104            a.sum(None, None).unwrap().item::<f32>(),
1105            127.982_254,
1106            abs <= 2.559_645_2
1107        );
1108        let result = LeakyRelu::new().forward(&a).unwrap();
1109        assert_eq!(result.shape(), &[2, 8, 16]);
1110        assert_eq!(result.dtype(), Dtype::Float32);
1111        assert_float_eq!(
1112            result.mean(None, None).unwrap().item::<f32>(),
1113            0.499_930_68,
1114            abs <= 0.009_998_614
1115        );
1116        assert_float_eq!(
1117            result.sum(None, None).unwrap().item::<f32>(),
1118            127.982_254,
1119            abs <= 2.559_645_2
1120        );
1121    }
1122
1123    #[test]
1124    fn test_relu6() {
1125        crate::random::seed(379).unwrap();
1126        let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
1127        assert_eq!(a.shape(), &[2, 8, 16]);
1128        assert_eq!(a.dtype(), Dtype::Float32);
1129        assert_float_eq!(
1130            a.mean(None, None).unwrap().item::<f32>(),
1131            0.493_258_66,
1132            abs <= 0.009_865_173
1133        );
1134        assert_float_eq!(
1135            a.sum(None, None).unwrap().item::<f32>(),
1136            126.274_216,
1137            abs <= 2.525_484_3
1138        );
1139        let result = Relu6.forward(&a).unwrap();
1140        assert_eq!(result.shape(), &[2, 8, 16]);
1141        assert_eq!(result.dtype(), Dtype::Float32);
1142        assert_float_eq!(
1143            result.mean(None, None).unwrap().item::<f32>(),
1144            0.493_258_66,
1145            abs <= 0.009_865_173
1146        );
1147        assert_float_eq!(
1148            result.sum(None, None).unwrap().item::<f32>(),
1149            126.274_216,
1150            abs <= 2.525_484_3
1151        );
1152    }
1153
1154    #[test]
1155    fn test_softmax() {
1156        crate::random::seed(853).unwrap();
1157        let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
1158        assert_eq!(a.shape(), &[2, 8, 16]);
1159        assert_eq!(a.dtype(), Dtype::Float32);
1160        assert_float_eq!(
1161            a.mean(None, None).unwrap().item::<f32>(),
1162            0.514_396_3,
1163            abs <= 0.010_287_926_5
1164        );
1165        assert_float_eq!(
1166            a.sum(None, None).unwrap().item::<f32>(),
1167            131.685_46,
1168            abs <= 2.633_709_2
1169        );
1170        let result = Softmax::new().forward(&a).unwrap();
1171        assert_eq!(result.shape(), &[2, 8, 16]);
1172        assert_eq!(result.dtype(), Dtype::Float32);
1173        assert_float_eq!(
1174            result.mean(None, None).unwrap().item::<f32>(),
1175            0.062_499_996,
1176            abs <= 0.001_25
1177        );
1178        assert_float_eq!(
1179            result.sum(None, None).unwrap().item::<f32>(),
1180            15.999_999,
1181            abs <= 0.32
1182        );
1183    }
1184
1185    #[test]
1186    fn test_softplus() {
1187        crate::random::seed(118).unwrap();
1188        let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
1189        assert_eq!(a.shape(), &[2, 8, 16]);
1190        assert_eq!(a.dtype(), Dtype::Float32);
1191        assert_float_eq!(
1192            a.mean(None, None).unwrap().item::<f32>(),
1193            0.498_981_42,
1194            abs <= 0.009_979_628
1195        );
1196        assert_float_eq!(
1197            a.sum(None, None).unwrap().item::<f32>(),
1198            127.739_24,
1199            abs <= 2.554_784_8
1200        );
1201        let result = Softplus.forward(&a).unwrap();
1202        assert_eq!(result.shape(), &[2, 8, 16]);
1203        assert_eq!(result.dtype(), Dtype::Float32);
1204        assert_float_eq!(
1205            result.mean(None, None).unwrap().item::<f32>(),
1206            0.982_857_76,
1207            abs <= 0.019_657_155
1208        );
1209        assert_float_eq!(
1210            result.sum(None, None).unwrap().item::<f32>(),
1211            251.611_59,
1212            abs <= 5.032_232
1213        );
1214    }
1215
1216    #[test]
1217    fn test_softsign() {
1218        crate::random::seed(37).unwrap();
1219        let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
1220        assert_eq!(a.shape(), &[2, 8, 16]);
1221        assert_eq!(a.dtype(), Dtype::Float32);
1222        assert_float_eq!(
1223            a.mean(None, None).unwrap().item::<f32>(),
1224            0.506_551_27,
1225            abs <= 0.010_131_026
1226        );
1227        assert_float_eq!(
1228            a.sum(None, None).unwrap().item::<f32>(),
1229            129.677_12,
1230            abs <= 2.593_542_6
1231        );
1232        let result = Softsign.forward(&a).unwrap();
1233        assert_eq!(result.shape(), &[2, 8, 16]);
1234        assert_eq!(result.dtype(), Dtype::Float32);
1235        assert_float_eq!(
1236            result.mean(None, None).unwrap().item::<f32>(),
1237            0.314_089_83,
1238            abs <= 0.006_281_797
1239        );
1240        assert_float_eq!(
1241            result.sum(None, None).unwrap().item::<f32>(),
1242            80.407,
1243            abs <= 1.608_14
1244        );
1245    }
1246
1247    // The unit test below is adapted from the python binding:
1248    // mlx/python/tests/test_nn.py
1249    #[test]
1250    fn test_celu() {
1251        let x = array!([1.0, -1.0, 0.0]);
1252        let y = Celu::new().forward(&x).unwrap();
1253        let epsilon = array!(1e-4);
1254        let expected_y = array!([1.0, -0.6321, 0.0]);
1255        assert!(y
1256            .subtract(&expected_y)
1257            .unwrap()
1258            .abs()
1259            .unwrap()
1260            .lt(&epsilon)
1261            .unwrap()
1262            .all(None, None)
1263            .unwrap()
1264            .item::<bool>());
1265        assert_eq!(y.shape(), &[3]);
1266        assert_eq!(y.dtype(), Dtype::Float32);
1267
1268        let y = CeluBuilder::new()
1269            .alpha(1.1)
1270            .build()
1271            .unwrap()
1272            .forward(&x)
1273            .unwrap();
1274        let expected_y = array!([1.0, -0.6568, 0.0]);
1275        assert!(y
1276            .subtract(&expected_y)
1277            .unwrap()
1278            .abs()
1279            .unwrap()
1280            .lt(&epsilon)
1281            .unwrap()
1282            .all(None, None)
1283            .unwrap()
1284            .item::<bool>());
1285        assert_eq!(y.shape(), &[3]);
1286        assert_eq!(y.dtype(), Dtype::Float32);
1287    }
1288
1289    #[test]
1290    fn test_silu() {
1291        crate::random::seed(22).unwrap();
1292        let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
1293        assert_eq!(a.shape(), &[2, 8, 16]);
1294        assert_eq!(a.dtype(), Dtype::Float32);
1295        assert_float_eq!(
1296            a.mean(None, None).unwrap().item::<f32>(),
1297            0.502_970_6,
1298            abs <= 0.010_059_412
1299        );
1300        assert_float_eq!(
1301            a.sum(None, None).unwrap().item::<f32>(),
1302            128.760_47,
1303            abs <= 2.575_209_4
1304        );
1305        let result = Silu.forward(&a).unwrap();
1306        assert_eq!(result.shape(), &[2, 8, 16]);
1307        assert_eq!(result.dtype(), Dtype::Float32);
1308        assert_float_eq!(
1309            result.mean(None, None).unwrap().item::<f32>(),
1310            0.331_970_93,
1311            abs <= 0.006_639_418_7
1312        );
1313        assert_float_eq!(
1314            result.sum(None, None).unwrap().item::<f32>(),
1315            84.984_56,
1316            abs <= 1.699_691_2
1317        );
1318    }
1319
1320    #[test]
1321    fn test_log_softmax() {
1322        crate::random::seed(199).unwrap();
1323        let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
1324        assert_eq!(a.shape(), &[2, 8, 16]);
1325        assert_eq!(a.dtype(), Dtype::Float32);
1326        assert_float_eq!(
1327            a.mean(None, None).unwrap().item::<f32>(),
1328            0.527_843_7,
1329            abs <= 0.010_556_874
1330        );
1331        assert_float_eq!(
1332            a.sum(None, None).unwrap().item::<f32>(),
1333            135.127_99,
1334            abs <= 2.702_559_7
1335        );
1336        let result = LogSoftmax::new().forward(&a).unwrap();
1337        assert_eq!(result.shape(), &[2, 8, 16]);
1338        assert_eq!(result.dtype(), Dtype::Float32);
1339        assert_float_eq!(
1340            result.mean(None, None).unwrap().item::<f32>(),
1341            -2.810_954_6,
1342            abs <= 0.056_219_09
1343        );
1344        assert_float_eq!(
1345            result.sum(None, None).unwrap().item::<f32>(),
1346            -719.604_4,
1347            abs <= 14.392_087
1348        );
1349    }
1350
1351    #[test]
1352    fn test_log_sigmoid() {
1353        crate::random::seed(984).unwrap();
1354        let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
1355        assert_eq!(a.shape(), &[2, 8, 16]);
1356        assert_eq!(a.dtype(), Dtype::Float32);
1357        assert_float_eq!(
1358            a.mean(None, None).unwrap().item::<f32>(),
1359            0.510_977_7,
1360            abs <= 0.010_219_553_5
1361        );
1362        assert_float_eq!(
1363            a.sum(None, None).unwrap().item::<f32>(),
1364            130.810_29,
1365            abs <= 2.616_205_7
1366        );
1367        let result = LogSigmoid.forward(&a).unwrap();
1368        assert_eq!(result.shape(), &[2, 8, 16]);
1369        assert_eq!(result.dtype(), Dtype::Float32);
1370        assert_float_eq!(
1371            result.mean(None, None).unwrap().item::<f32>(),
1372            -0.479_598_55,
1373            abs <= 0.009_591_971
1374        );
1375        assert_float_eq!(
1376            result.sum(None, None).unwrap().item::<f32>(),
1377            -122.777_23,
1378            abs <= 2.455_544_5
1379        );
1380    }
1381
1382    #[test]
1383    fn test_prelu() {
1384        crate::random::seed(993).unwrap();
1385        let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
1386        assert_eq!(a.shape(), &[2, 8, 16]);
1387        assert_eq!(a.dtype(), Dtype::Float32);
1388        assert_float_eq!(
1389            a.mean(None, None).unwrap().item::<f32>(),
1390            0.496_651_44,
1391            abs <= 0.009_933_028
1392        );
1393        assert_float_eq!(
1394            a.sum(None, None).unwrap().item::<f32>(),
1395            127.142_77,
1396            abs <= 2.542_855_3
1397        );
1398        let result = Prelu::new().forward(&a).unwrap();
1399        assert_eq!(result.shape(), &[2, 8, 16]);
1400        assert_eq!(result.dtype(), Dtype::Float32);
1401        assert_float_eq!(
1402            result.mean(None, None).unwrap().item::<f32>(),
1403            0.496_651_44,
1404            abs <= 0.009_933_028
1405        );
1406        assert_float_eq!(
1407            result.sum(None, None).unwrap().item::<f32>(),
1408            127.142_77,
1409            abs <= 2.542_855_3
1410        );
1411    }
1412
1413    #[test]
1414    fn test_gelu() {
1415        crate::random::seed(189).unwrap();
1416        let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
1417        assert_eq!(a.shape(), &[2, 8, 16]);
1418        assert_eq!(a.dtype(), Dtype::Float32);
1419        assert_float_eq!(
1420            a.mean(None, None).unwrap().item::<f32>(),
1421            0.492_950_32,
1422            abs <= 0.009_859_007
1423        );
1424        assert_float_eq!(
1425            a.sum(None, None).unwrap().item::<f32>(),
1426            126.195_28,
1427            abs <= 2.523_905_8
1428        );
1429        let result = Gelu::new().forward(&a).unwrap();
1430        assert_eq!(result.shape(), &[2, 8, 16]);
1431        assert_eq!(result.dtype(), Dtype::Float32);
1432        assert_float_eq!(
1433            result.mean(None, None).unwrap().item::<f32>(),
1434            0.365_638_38,
1435            abs <= 0.007_312_767_7
1436        );
1437        assert_float_eq!(
1438            result.sum(None, None).unwrap().item::<f32>(),
1439            93.603_424,
1440            abs <= 1.872_068_5
1441        );
1442    }
1443
1444    #[test]
1445    fn test_tanh() {
1446        crate::random::seed(735).unwrap();
1447        let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
1448        assert_eq!(a.shape(), &[2, 8, 16]);
1449        assert_eq!(a.dtype(), Dtype::Float32);
1450        assert_float_eq!(
1451            a.mean(None, None).unwrap().item::<f32>(),
1452            0.474_122_7,
1453            abs <= 0.009_482_454_5
1454        );
1455        assert_float_eq!(
1456            a.sum(None, None).unwrap().item::<f32>(),
1457            121.375_41,
1458            abs <= 2.427_508_4
1459        );
1460        let result = Tanh.forward(&a).unwrap();
1461        assert_eq!(result.shape(), &[2, 8, 16]);
1462        assert_eq!(result.dtype(), Dtype::Float32);
1463        assert_float_eq!(
1464            result.mean(None, None).unwrap().item::<f32>(),
1465            0.413_079_68,
1466            abs <= 0.008_261_594
1467        );
1468        assert_float_eq!(
1469            result.sum(None, None).unwrap().item::<f32>(),
1470            105.748_4,
1471            abs <= 2.114_968
1472        );
1473    }
1474
1475    #[test]
1476    fn test_hardswish() {
1477        crate::random::seed(126).unwrap();
1478        let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
1479        assert_eq!(a.shape(), &[2, 8, 16]);
1480        assert_eq!(a.dtype(), Dtype::Float32);
1481        assert_float_eq!(
1482            a.mean(None, None).unwrap().item::<f32>(),
1483            0.491_892_46,
1484            abs <= 0.009_837_849
1485        );
1486        assert_float_eq!(
1487            a.sum(None, None).unwrap().item::<f32>(),
1488            125.924_47,
1489            abs <= 2.518_489_4
1490        );
1491        let result = HardSwish.forward(&a).unwrap();
1492        assert_eq!(result.shape(), &[2, 8, 16]);
1493        assert_eq!(result.dtype(), Dtype::Float32);
1494        assert_float_eq!(
1495            result.mean(None, None).unwrap().item::<f32>(),
1496            0.299_602_24,
1497            abs <= 0.005_992_044_7
1498        );
1499        assert_float_eq!(
1500            result.sum(None, None).unwrap().item::<f32>(),
1501            76.698_17,
1502            abs <= 1.533_963_4
1503        );
1504    }
1505
1506    #[test]
1507    fn test_step() {
1508        crate::random::seed(490).unwrap();
1509        let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
1510        assert_eq!(a.shape(), &[2, 8, 16]);
1511        assert_eq!(a.dtype(), Dtype::Float32);
1512        assert_float_eq!(
1513            a.mean(None, None).unwrap().item::<f32>(),
1514            0.479_360_64,
1515            abs <= 0.009_587_212_5
1516        );
1517        assert_float_eq!(
1518            a.sum(None, None).unwrap().item::<f32>(),
1519            122.716_324,
1520            abs <= 2.454_326_4
1521        );
1522        let result = Step::new().forward(&a).unwrap();
1523        assert_eq!(result.shape(), &[2, 8, 16]);
1524        assert_eq!(result.dtype(), Dtype::Int32);
1525        assert_float_eq!(
1526            result.mean(None, None).unwrap().item::<f32>(),
1527            1.0,
1528            abs <= 0.02
1529        );
1530        assert_float_eq!(
1531            result.sum(None, None).unwrap().item::<f32>(),
1532            256.0,
1533            abs <= 5.12
1534        );
1535    }
1536
1537    #[test]
1538    fn test_selu() {
1539        crate::random::seed(215).unwrap();
1540        let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
1541        assert_eq!(a.shape(), &[2, 8, 16]);
1542        assert_eq!(a.dtype(), Dtype::Float32);
1543        assert_float_eq!(
1544            a.mean(None, None).unwrap().item::<f32>(),
1545            0.493_026_8,
1546            abs <= 0.009_860_536
1547        );
1548        assert_float_eq!(
1549            a.sum(None, None).unwrap().item::<f32>(),
1550            126.214_86,
1551            abs <= 2.524_297_2
1552        );
1553        let result = Selu.forward(&a).unwrap();
1554        assert_eq!(result.shape(), &[2, 8, 16]);
1555        assert_eq!(result.dtype(), Dtype::Float32);
1556        assert_float_eq!(
1557            result.mean(None, None).unwrap().item::<f32>(),
1558            0.518_023_2,
1559            abs <= 0.010_360_463_5
1560        );
1561        assert_float_eq!(
1562            result.sum(None, None).unwrap().item::<f32>(),
1563            132.613_94,
1564            abs <= 2.652_278_7
1565        );
1566    }
1567}