mlx_rs/nn/
activation.rs

1use std::f32::consts::PI;
2
3use crate::module::{Module, Param};
4use crate::ops::logsumexp_axis;
5use crate::{
6    array,
7    error::{Exception, Result},
8    ops::{abs, exp, maximum, minimum, multiply, which},
9    transforms::compile::compile,
10    Array,
11};
12use mlx_internal_macros::{generate_builder, Buildable, Builder};
13use mlx_macros::ModuleParameters;
14
15/// Applies the element-wise sigmoid logistic sigmoid.
16///
17/// For details, please see
18/// [this documentation](https://ml-explore.github.io/mlx/build/html/python/_autosummary/mlx.core.sigmoid.html)
19///
20/// This is:
21///
22/// ```rust, ignore
23/// sigmoid(x)
24/// ```
25pub fn sigmoid(x: impl AsRef<Array>) -> Result<Array> {
26    crate::ops::sigmoid(x.as_ref())
27}
28
29/// Applies the Rectified Linear Unit.
30///
31/// This is:
32///
33/// ```rust, ignore
34/// maximum(x, 0)
35/// ```
36pub fn relu(x: impl AsRef<Array>) -> Result<Array> {
37    crate::ops::maximum(x.as_ref(), &array!(0))
38}
39
40/// Applies the Leaky Rectified Linear Unit.
41///
42/// `neg_slope` is default to 0.01 if not provided.
43///
44/// This is:
45///
46/// ```rust, ignore
47/// maximum(neg_slope * x, x)
48/// ```
49pub fn leaky_relu(x: impl AsRef<Array>, neg_slope: impl Into<Option<f32>>) -> Result<Array> {
50    let neg_slope = array!(neg_slope.into().unwrap_or(0.01));
51    // We have to use this indirection, otherwise the compiler cannot
52    // infer the lifetime of the value returned by the closure properly
53    compiled_leaky_relu(x.as_ref(), &neg_slope)
54}
55
56/// Applies the Log Softmax function.
57///
58/// This is:
59///
60/// ```rust, ignore
61/// x - logsumexp_axis(x, axis, true)
62/// ```
63pub fn log_softmax(x: impl AsRef<Array>, axis: impl Into<Option<i32>>) -> Result<Array> {
64    let x = x.as_ref();
65    let axis = axis.into().unwrap_or(-1);
66    x.subtract(logsumexp_axis(x, axis, true)?)
67}
68
69/// Applies the Exponential Linear Unit.
70///
71/// This is:
72///
73/// ```rust, ignore
74/// which(x.gt(0), x, alpha * (exp(x) - 1))
75/// ```
76///
77/// # Params
78///
79/// - `x`: The input array
80/// - `alpha`: Default to 1.0 if not provided
81pub fn elu(x: impl AsRef<Array>, alpha: impl Into<Option<f32>>) -> Result<Array> {
82    let alpha = array!(alpha.into().unwrap_or(1.0));
83    // We have to use this indirection, otherwise the compiler cannot
84    // infer the lifetime of the value returned by the closure properly
85    compiled_elu(x.as_ref(), &alpha)
86}
87
88/// Applies the Rectified Linear Unit 6.
89///
90/// This is:
91///
92/// ```rust, ignore
93/// minimum(maximum(x, 0), 6)
94/// ```
95pub fn relu6(x: impl AsRef<Array>) -> Result<Array> {
96    compiled_relu6(x.as_ref())
97}
98
99/// Applies the Exponential Linear Unit.
100///
101/// This is:
102///
103/// ```rust, ignore
104/// logaddexp(x, 0)
105/// ```
106pub fn softplus(x: impl AsRef<Array>) -> Result<Array> {
107    crate::ops::logaddexp(x.as_ref(), &array!(0))
108}
109
110/// Applies the Softsign function.
111///
112/// This is:
113///
114/// ```rust, ignore
115/// x / (1 + abs(x))
116/// ```
117pub fn softsign(x: impl AsRef<Array>) -> Result<Array> {
118    compiled_softsign(x.as_ref())
119}
120
121/// Applies the Continuously Differentiable Exponential Linear Unit.
122///
123/// This is:
124///
125/// ```rust, ignore
126/// maximum(x, 0) + alpha * (exp(minimum(x, 0) / alpha) - 1)
127/// ```
128pub fn celu(x: impl AsRef<Array>, alpha: impl Into<Option<f32>>) -> Result<Array> {
129    let alpha = array!(alpha.into().unwrap_or(1.0));
130    // We have to use this indirection, otherwise the compiler cannot
131    // infer the lifetime of the value returned by the closure properly
132    compiled_celu(x.as_ref(), &alpha)
133}
134
135/// Applies the Sigmoid Linear Unit. Also known as Swish.
136///
137/// This is:
138///
139/// ```rust, ignore
140/// x * sigmoid(x)
141/// ```
142pub fn silu(x: impl AsRef<Array>) -> Result<Array> {
143    compiled_silu(x.as_ref())
144}
145
146/// Applies the Log Sigmoid function.
147///
148/// This is:
149///
150/// ```rust, ignore
151/// -softplus(-x)
152/// ```
153pub fn log_sigmoid(x: impl AsRef<Array>) -> Result<Array> {
154    compiled_log_sigmoid(x.as_ref())
155}
156
157/// Applies the Gaussian Error Linear Units function.
158///
159/// This is:
160///
161/// ```rust, ignore
162/// x * (1 + erf(x / 2.sqrt())) / 2
163/// ```
164pub fn gelu(x: impl AsRef<Array>) -> Result<Array> {
165    compiled_gelu(x.as_ref())
166}
167
168/// An approximation to Gaussian Error Linear Unit.
169///
170/// This is:
171///
172/// ```rust, ignore
173/// 0.5 * x * (1 + tanh(sqrt(2 / PI) * (x + 0.044715 * x ** 3)))
174/// ```
175pub fn gelu_approximate(x: impl AsRef<Array>) -> Result<Array> {
176    compiled_gelu_approximate(x.as_ref())
177}
178
179/// A fast approximation to Gaussian Error Linear Unit.
180///
181/// This is:
182///
183/// ```rust, ignore
184/// x * sigmoid(1.773 * x)
185/// ```
186pub fn gelu_fast_approximate(x: impl AsRef<Array>) -> Result<Array> {
187    compiled_gelu_fast_approximate(x.as_ref())
188}
189
190/// Applies the gated linear unit function.
191///
192/// This function splits the `axis` dimension of the input into two halves
193/// (`a` and `b`) and applies `a * sigmoid(b)`.
194pub fn glu(x: impl AsRef<Array>, axis: impl Into<Option<i32>>) -> Result<Array> {
195    let split = x.as_ref().split(2, axis)?;
196    let (a, b) = (&split[0], &split[1]);
197    Ok(a * sigmoid(b)?)
198}
199
200/// Applies the Step Activation Function.
201///
202/// This function implements a binary step activation, where the output is set
203/// to 1 if the input is greater than a specified threshold, and 0 otherwise.
204///
205/// This is:
206///
207/// ```rust, ignore
208/// r#where(x.gt(threshold), 1, 0)
209/// ```
210pub fn step(x: impl AsRef<Array>, threshold: impl Into<Option<f32>>) -> Result<Array> {
211    let threshold = array!(threshold.into().unwrap_or(0.0));
212    crate::ops::r#where(&x.as_ref().gt(threshold)?, &array!(1), &array!(0))
213}
214
215/// Applies the Scaled Exponential Linear Unit.
216///
217/// This is:
218///
219/// ```rust, ignore
220/// elu(x, 1.67326) * 1.0507
221/// ```
222pub fn selu(x: impl AsRef<Array>) -> Result<Array> {
223    compiled_selu(x.as_ref())
224}
225
226/// Applies the element-wise parametric ReLU.
227///
228/// This is:
229///
230/// ```rust, ignore
231/// maximum(0, x) + alpha * minimum(0, x)
232/// ```
233pub fn prelu(x: impl AsRef<Array>, alpha: impl AsRef<Array>) -> Result<Array> {
234    compiled_prelu(x.as_ref(), alpha.as_ref())
235}
236
237/// Applies the Mish function, element-wise.
238///
239/// Mish: A Self Regularized Non-Monotonic Neural Activation Function.
240///
241/// Reference: [https://arxiv.org/abs/1908.08681](https://arxiv.org/abs/1908.08681)
242///
243/// This is:
244///
245/// ```rust, ignore
246/// x * tanh(softplus(x))
247/// ```
248pub fn mish(x: impl AsRef<Array>) -> Result<Array> {
249    compiled_mish(x.as_ref())
250}
251
252/// Applies the hardswish function, element-wise.
253///
254/// This is:
255///
256/// ```rust, ignore
257/// x * minimum(maximum(x + 3, 0), 6) / 6
258/// ```
259pub fn hard_swish(x: impl AsRef<Array>) -> Result<Array> {
260    compiled_hard_swish(x.as_ref())
261}
262
263generate_builder! {
264    /// Applies the gated linear unit function.
265    ///
266    /// This splits the `axis` dimension of the input into two halves
267    /// (`a` and `b`) and applies `a * sigmoid(b)`.
268    #[derive(Debug, Clone, ModuleParameters, Buildable)]
269    #[module(root = crate)]
270    #[buildable(root = crate)]
271    #[builder(root = crate)]
272    pub struct Glu {
273        /// The axis to split the input tensor. Default to [`Glu::DEFAULT_AXIS`] if not provided.
274        #[builder(optional, default = Glu::DEFAULT_AXIS)]
275        pub axis: i32,
276    }
277}
278
279impl Glu {
280    /// The default axis value.
281    pub const DEFAULT_AXIS: i32 = -1;
282}
283
284impl Module<&Array> for Glu {
285    type Error = Exception;
286    type Output = Array;
287
288    fn forward(&mut self, x: &Array) -> Result<Array> {
289        glu(x, self.axis)
290    }
291
292    fn training_mode(&mut self, _: bool) {}
293}
294
295/// Applies the element-wise logistic sigmoid.
296///
297/// For details, please see
298/// [this documentation](https://ml-explore.github.io/mlx/build/html/python/_autosummary/mlx.core.sigmoid.html)
299///
300/// This is:
301///
302/// ```rust, ignore
303/// sigmoid(x)
304/// ```
305#[derive(Debug, Clone, ModuleParameters)]
306#[module(root = crate)]
307pub struct Sigmoid;
308
309impl Module<&Array> for Sigmoid {
310    type Error = Exception;
311    type Output = Array;
312
313    fn forward(&mut self, x: &Array) -> Result<Array> {
314        sigmoid(x)
315    }
316
317    fn training_mode(&mut self, _: bool) {}
318}
319
320/// Applies the Mish function, element-wise.
321///
322/// Mish: A Self Regularized Non-Monotonic Neural Activation Function.
323///
324/// Reference: [https://arxiv.org/abs/1908.08681](https://arxiv.org/abs/1908.08681)
325///
326/// This is:
327///
328/// ```rust, ignore
329/// x * tanh(softplus(x))
330/// ```
331#[derive(Debug, Clone, ModuleParameters)]
332#[module(root = crate)]
333pub struct Mish;
334
335impl Module<&Array> for Mish {
336    type Error = Exception;
337    type Output = Array;
338
339    fn forward(&mut self, x: &Array) -> Result<Array> {
340        mish(x)
341    }
342
343    fn training_mode(&mut self, _: bool) {}
344}
345
346/// Applies the Rectified Linear Unit.
347///
348/// This is:
349///
350/// ```rust, ignore
351/// maximum(x, 0)
352/// ```
353#[derive(Debug, Clone, ModuleParameters)]
354#[module(root = crate)]
355pub struct Relu;
356
357impl Module<&Array> for Relu {
358    type Error = Exception;
359    type Output = Array;
360
361    fn forward(&mut self, x: &Array) -> Result<Array> {
362        relu(x)
363    }
364
365    fn training_mode(&mut self, _: bool) {}
366}
367
368generate_builder! {
369    /// Applies the Leaky Rectified Linear Unit.
370    ///
371    /// This is:
372    ///
373    /// ```rust, ignore
374    /// maximum(neg_slope * x, x)
375    /// ```
376    #[derive(Debug, Clone, ModuleParameters, Buildable)]
377    #[module(root = crate)]
378    #[buildable(root = crate)]
379    #[builder(root = crate)]
380    pub struct LeakyRelu {
381        /// The negative slope. Default to [`LeakyReLU::DEFAULT_NEG_SLOPE`] if not provided.
382        #[builder(optional, default = LeakyRelu::DEFAULT_NEG_SLOPE)]
383        pub neg_slope: f32,
384    }
385}
386
387impl LeakyRelu {
388    /// The default negative slope value.
389    pub const DEFAULT_NEG_SLOPE: f32 = 0.01;
390}
391
392impl Module<&Array> for LeakyRelu {
393    type Error = Exception;
394    type Output = Array;
395
396    fn forward(&mut self, x: &Array) -> Result<Array> {
397        leaky_relu(x, self.neg_slope)
398    }
399
400    fn training_mode(&mut self, _: bool) {}
401}
402
403/// Applies the Rectified Linear Unit 6.
404///
405/// This is:
406///
407/// ```rust, ignore
408/// minimum(&maximum(x, 0).unwrap(), 6).unwrap()
409/// ```
410#[derive(Debug, Clone, ModuleParameters)]
411#[module(root = crate)]
412pub struct Relu6;
413
414impl Module<&Array> for Relu6 {
415    type Error = Exception;
416    type Output = Array;
417
418    fn forward(&mut self, x: &Array) -> Result<Array> {
419        relu6(x)
420    }
421
422    fn training_mode(&mut self, _: bool) {}
423}
424
425generate_builder! {
426    /// Applies the Softmax function.
427    ///
428    /// This is:
429    ///
430    /// ```rust, ignore
431    /// softmax(&x, None, None)
432    /// ```
433    #[derive(Debug, Clone, ModuleParameters, Buildable)]
434    #[module(root = crate)]
435    #[buildable(root = crate)]
436    #[builder(root = crate)]
437    pub struct Softmax {
438        /// The axis to apply the softmax.
439        #[builder(optional, default = Softmax::DEFAULT_AXIS)]
440        pub axis: i32,
441    }
442}
443
444impl Softmax {
445    /// The default axis value.
446    pub const DEFAULT_AXIS: i32 = -1;
447}
448
449impl Module<&Array> for Softmax {
450    type Error = Exception;
451    type Output = Array;
452
453    fn forward(&mut self, x: &Array) -> Result<Array> {
454        crate::ops::softmax_axis(x, self.axis, None)
455    }
456
457    fn training_mode(&mut self, _: bool) {}
458}
459
460/// Applies the Softplus function.
461///
462/// This is:
463///
464/// ```rust, ignore
465/// logaddexp(x, 0)
466/// ```
467#[derive(Debug, Clone, ModuleParameters)]
468#[module(root = crate)]
469pub struct Softplus;
470
471impl Module<&Array> for Softplus {
472    type Error = Exception;
473    type Output = Array;
474
475    fn forward(&mut self, x: &Array) -> Result<Array> {
476        softplus(x)
477    }
478
479    fn training_mode(&mut self, _: bool) {}
480}
481
482/// Applies the Softsign function.
483///
484/// This is:
485///
486/// ```rust, ignore
487/// x / (array!(1) + abs(x)
488/// ```
489#[derive(Debug, Clone, ModuleParameters)]
490#[module(root = crate)]
491pub struct Softsign;
492
493impl Module<&Array> for Softsign {
494    type Error = Exception;
495    type Output = Array;
496
497    fn forward(&mut self, x: &Array) -> Result<Array> {
498        softsign(x)
499    }
500
501    fn training_mode(&mut self, _: bool) {}
502}
503
504generate_builder! {
505    /// Applies the Continuously Differentiable Exponential Linear Unit.
506    ///
507    /// This is:
508    ///
509    /// ```rust, ignore
510    /// maximum(x, 0.0).unwrap()
511    ///     + alpha * (exp(&(minimum(x, 0.0).unwrap() / alpha)) - 1)
512    /// ```
513    #[derive(Debug, Clone, ModuleParameters, Buildable)]
514    #[module(root = crate)]
515    #[buildable(root = crate)]
516    #[builder(root = crate)]
517    pub struct Celu {
518        /// The alpha value. Default to [`Celu::DEFAULT_ALPHA`] if not provided.
519        #[builder(optional, default = Celu::DEFAULT_ALPHA)]
520        pub alpha: f32,
521    }
522}
523
524impl Celu {
525    /// The default alpha value.
526    pub const DEFAULT_ALPHA: f32 = 1.0;
527}
528
529impl Module<&Array> for Celu {
530    type Error = Exception;
531    type Output = Array;
532
533    fn forward(&mut self, x: &Array) -> Result<Array> {
534        celu(x, self.alpha)
535    }
536
537    fn training_mode(&mut self, _: bool) {}
538}
539
540/// Applies the Sigmoid Linear Unit. Also known as Swish.
541///
542/// This is:
543///
544/// ```rust, ignore
545/// x * sigmoid(x)
546/// ```
547#[derive(Debug, Clone, ModuleParameters)]
548#[module(root = crate)]
549pub struct Silu;
550
551impl Module<&Array> for Silu {
552    type Error = Exception;
553    type Output = Array;
554
555    fn forward(&mut self, x: &Array) -> Result<Array> {
556        silu(x)
557    }
558
559    fn training_mode(&mut self, _: bool) {}
560}
561
562generate_builder! {
563    /// Applies the Log Softmax function.
564    ///
565    /// This is:
566    ///
567    /// ```rust, ignore
568    /// x - logsumexp(x, axis, true)
569    /// ```
570    #[derive(Debug, Clone, ModuleParameters, Buildable)]
571    #[module(root = crate)]
572    #[buildable(root = crate)]
573    #[builder(root = crate)]
574    pub struct LogSoftmax {
575        /// The axis value. Default to [`LogSoftmax::DEFAULT_AXIS`] if not provided.
576        #[builder(optional, default = LogSoftmax::DEFAULT_AXIS)]
577        pub axis: i32,
578    }
579}
580
581impl LogSoftmax {
582    /// The default axis value.
583    pub const DEFAULT_AXIS: i32 = -1;
584}
585
586impl Module<&Array> for LogSoftmax {
587    type Error = Exception;
588    type Output = Array;
589
590    fn forward(&mut self, x: &Array) -> Result<Array> {
591        log_softmax(x, self.axis)
592    }
593
594    fn training_mode(&mut self, _: bool) {}
595}
596
597/// Applies the Log Sigmoid function.
598///
599/// This is:
600///
601/// ```rust, ignore
602/// -softplus(-x)
603/// ```
604#[derive(Debug, Clone, ModuleParameters)]
605#[module(root = crate)]
606pub struct LogSigmoid;
607
608impl Module<&Array> for LogSigmoid {
609    type Error = Exception;
610    type Output = Array;
611
612    fn forward(&mut self, x: &Array) -> Result<Array> {
613        log_sigmoid(x)
614    }
615
616    fn training_mode(&mut self, _: bool) {}
617}
618
619/// Applies the element-wise parametric ReLU.
620///
621/// This is:
622///
623/// ```rust, ignore
624/// maximum(0, x) + alpha * minimum(0, x)
625/// ```
626#[derive(Debug, Clone, ModuleParameters, Buildable)]
627#[module(root = crate)]
628#[buildable(root = crate)]
629pub struct Prelu {
630    /// The alpha value. See [`prelu`] for more details.
631    #[param]
632    #[builder(ignore)]
633    pub weight: Param<Array>, // TODO: double check if this is trainable
634}
635
636/// The builder for the Prelu module.
637#[derive(Debug, Clone, Builder)]
638#[builder(
639    root = crate,
640    build_with = build_prelu,
641    default_infallible,
642    err = Exception,
643)]
644pub struct PreluBuilder {
645    /// The count. Default to [`Prelu::DEFAULT_COUNT`] if not provided.
646    #[builder(optional, default = Prelu::DEFAULT_COUNT)]
647    pub count: i32,
648
649    /// The value. Default to [`Prelu::DEFAULT_VALUE`] if not provided.
650    #[builder(optional, default = Prelu::DEFAULT_VALUE)]
651    pub value: f32,
652}
653
654/// Builds the Prelu module.
655fn build_prelu(builder: PreluBuilder) -> Result<Prelu> {
656    let count = builder.count;
657    let value = builder.value;
658    let weight = Param::new(crate::ops::full::<f32>(&[count], &array!(value))?);
659    Ok(Prelu { weight })
660}
661
662impl Prelu {
663    /// The default count value.
664    pub const DEFAULT_COUNT: i32 = 1;
665
666    /// The default value.
667    pub const DEFAULT_VALUE: f32 = 0.25;
668}
669
670impl Module<&Array> for Prelu {
671    type Error = Exception;
672    type Output = Array;
673
674    fn forward(&mut self, x: &Array) -> Result<Array> {
675        prelu(x, &self.weight)
676    }
677
678    fn training_mode(&mut self, _: bool) {}
679}
680
681/// Variants of Gaussian Error Linear Units function.
682#[derive(Debug, Clone, Copy, Default)]
683pub enum GeluApprox {
684    /// Uses [`gelu`]
685    #[default]
686    None,
687
688    /// Uses [`gelu_approximate`]
689    Precise,
690
691    /// Uses [`gelu_fast_approximate`]
692    Fast,
693}
694
695generate_builder! {
696    /// Applies the Gaussian Error Linear Units function.
697    ///
698    /// There are three variants:
699    ///
700    /// - `GeluApprox::None`: Uses [`gelu`]. This is the default.
701    /// - `GeluApprox::Precise`: Uses [`gelu_approximate`]
702    /// - `GeluApprox::Fast`: Uses [`gelu_fast_approximate`]
703    #[derive(Debug, Clone, ModuleParameters, Buildable)]
704    #[module(root = crate)]
705    #[buildable(root = crate)]
706    #[builder(root = crate)]
707    pub struct Gelu {
708        /// The approximation to use. Default to `GeluApprox::None` if not provided.
709        #[builder(optional, default = GeluApprox::None)]
710        pub approximate: GeluApprox,
711    }
712}
713
714impl Module<&Array> for Gelu {
715    type Error = Exception;
716    type Output = Array;
717
718    fn forward(&mut self, x: &Array) -> Result<Array> {
719        match self.approximate {
720            GeluApprox::None => gelu(x),
721            GeluApprox::Precise => gelu_approximate(x),
722            GeluApprox::Fast => gelu_fast_approximate(x),
723        }
724    }
725
726    fn training_mode(&mut self, _: bool) {}
727}
728
729/// Applies the hyperbolic tangent function
730#[derive(Debug, Clone, ModuleParameters)]
731#[module(root = crate)]
732pub struct Tanh;
733
734impl Module<&Array> for Tanh {
735    type Error = Exception;
736    type Output = Array;
737
738    fn forward(&mut self, x: &Array) -> Result<Array> {
739        crate::ops::tanh(x)
740    }
741
742    fn training_mode(&mut self, _: bool) {}
743}
744
745/// Applies the hardswish function, element-wise
746///
747/// This is:
748///
749/// ```rust, ignore
750/// x * minimum(maximum(x + 3, 0), 6) / 6
751/// ```
752#[derive(Debug, Clone, ModuleParameters)]
753#[module(root = crate)]
754pub struct HardSwish;
755
756impl Module<&Array> for HardSwish {
757    type Error = Exception;
758    type Output = Array;
759
760    fn forward(&mut self, x: &Array) -> Result<Array> {
761        hard_swish(x)
762    }
763
764    fn training_mode(&mut self, _: bool) {}
765}
766
767generate_builder! {
768    /// Applies the Step Activation Function.
769    ///
770    /// This function implements a binary step activation, where the output is set
771    /// to 1 if the input is greater than a specified threshold, and 0 otherwise.
772    ///
773    /// This is:
774    ///
775    /// ```rust, ignore
776    /// r#where(x.gt(threshold), 1, 0)
777    /// ```
778    #[derive(Debug, Clone, ModuleParameters, Buildable)]
779    #[module(root = crate)]
780    #[buildable(root = crate)]
781    #[builder(root = crate)]
782    pub struct Step {
783        /// The threshold value. Default to [`Step::DEFAULT_THRESHOLD`] if not provided.
784        #[builder(optional, default = Step::DEFAULT_THRESHOLD)]
785        pub threshold: f32,
786    }
787}
788
789impl Step {
790    /// The default threshold value.
791    pub const DEFAULT_THRESHOLD: f32 = 0.0;
792}
793
794impl Module<&Array> for Step {
795    type Error = Exception;
796    type Output = Array;
797
798    fn forward(&mut self, x: &Array) -> Result<Array> {
799        step(x, self.threshold)
800    }
801
802    fn training_mode(&mut self, _: bool) {}
803}
804
805/// Applies the Scaled Exponential Linear Unit.
806///
807/// This is:
808///
809/// ```rust, ignore
810/// elu(x, 1.67326) * 1.0507
811/// ```
812#[derive(Debug, Clone, ModuleParameters)]
813#[module(root = crate)]
814pub struct Selu;
815
816impl Module<&Array> for Selu {
817    type Error = Exception;
818    type Output = Array;
819
820    fn forward(&mut self, x: &Array) -> Result<Array> {
821        selu(x)
822    }
823
824    fn training_mode(&mut self, _: bool) {}
825}
826
827/* -------------------------------------------------------------------------- */
828/*                        Compiled activation functions                       */
829/* -------------------------------------------------------------------------- */
830
831#[inline]
832fn compiled_leaky_relu(x: &Array, neg_slope: &Array) -> Result<Array> {
833    let f = |(x_, neg_slope_): (&Array, &Array)| {
834        // This will not panic because a scalar can always be broadcasted to any shape
835        let a = multiply(neg_slope_, x_)?;
836        maximum(&a, x_)
837    };
838    let mut compiled = compile(f, true);
839    compiled((x, neg_slope))
840}
841
842#[inline]
843fn compiled_elu(x: &Array, alpha: &Array) -> Result<Array> {
844    let f = |(x_, alpha_): (&Array, &Array)| {
845        which(&x_.gt(&array!(0.0))?, x_, alpha_ * (exp(x_)? - array!(1.0)))
846    };
847    let mut compiled = compile(f, true);
848    compiled((x, alpha))
849}
850
851#[inline]
852fn compiled_relu6(x: &Array) -> Result<Array> {
853    let f = |x_: &Array| minimum(maximum(x_, &array!(0.0))?, &array!(6.0));
854    let mut compiled = compile(f, true);
855    compiled(x)
856}
857
858#[inline]
859fn compiled_softsign(x: &Array) -> Result<Array> {
860    let f = |x_: &Array| x_.divide(array!(1.0) + abs(x_)?);
861    let mut compiled = compile(f, true);
862    compiled(x)
863}
864
865#[inline]
866fn compiled_celu(x: &Array, alpha: &Array) -> Result<Array> {
867    let f = |(x_, alpha_): (&Array, &Array)| {
868        maximum(x_, &array!(0.0))?
869            .add(alpha_.multiply(exp(&(minimum(x_, &array!(0.0))? / alpha_))? - array!(1.0))?)
870    };
871    let mut compiled = compile(f, true);
872    compiled((x, alpha))
873}
874
875#[inline]
876fn compiled_silu(x: &Array) -> Result<Array> {
877    let f = |x_: &Array| x_.multiply(sigmoid(x_)?);
878    let mut compiled = compile(f, true);
879    compiled(x)
880}
881
882#[inline]
883fn compiled_log_sigmoid(x: &Array) -> Result<Array> {
884    let f = |x_: &Array| Ok(-softplus(&(-x_))?);
885    let mut compiled = compile(f, true);
886    compiled(x)
887}
888
889#[inline]
890fn compiled_gelu(x: &Array) -> Result<Array> {
891    use crate::ops::erf;
892    let f = |x_: &Array| {
893        x_.multiply(array!(1) + erf(&(x_ / array!(2f32.sqrt())))?)?
894            .divide(array!(2.0))
895    };
896    let mut compiled = compile(f, true);
897    compiled(x)
898}
899
900#[inline]
901fn compiled_gelu_approximate(x: &Array) -> Result<Array> {
902    use crate::ops::{sqrt, tanh};
903
904    let f = move |x_: &Array| {
905        // 0.5 * x * (1 + tanh(sqrt(2 / Float.pi) * (x + 0.044715 * x ** 3)))
906        array!(0.5).multiply(x_)?.multiply(
907            array!(1.0).add(tanh(
908                &(sqrt(&array!(2.0 / PI))?
909                    .multiply(x_ + array!(0.044715).multiply(x_.power(&array!(3))?)?)?),
910            )?)?,
911        )
912    };
913    let mut compiled = compile(f, true);
914    compiled(x)
915}
916
917#[inline]
918fn compiled_gelu_fast_approximate(x: &Array) -> Result<Array> {
919    let f = |x_: &Array| x_.multiply(sigmoid(&(array!(1.773) * x_))?);
920    let mut compiled = compile(f, true);
921    compiled(x)
922}
923
924#[inline]
925fn compiled_selu(x: &Array) -> Result<Array> {
926    let f = |x_: &Array| elu(x_, 1.67326)?.multiply(array!(1.0507));
927    let mut compiled = compile(f, true);
928    compiled(x)
929}
930
931#[inline]
932fn compiled_prelu(x: &Array, alpha: &Array) -> Result<Array> {
933    let f = |(x_, alpha_): (&Array, &Array)| {
934        maximum(&array!(0.0), x_)?.add(alpha_ * minimum(&array!(0.0), x_)?)
935    };
936    let mut compiled = compile(f, true);
937    compiled((x, alpha))
938}
939
940#[inline]
941fn compiled_mish(x: &Array) -> Result<Array> {
942    use crate::ops::tanh;
943
944    let f = |x_: &Array| x_.multiply(tanh(&softplus(x_)?)?);
945    let mut compiled = compile(f, true);
946    compiled(x)
947}
948
949#[inline]
950fn compiled_hard_swish(x: &Array) -> Result<Array> {
951    let f = |x_: &Array| {
952        let max_x_plus_3 = maximum(&(x_ + array!(3.0)), &array!(0.0))?;
953        x_.multiply(minimum(&max_x_plus_3, &array!(6.0))?)?
954            .divide(&array!(6.0))
955    };
956    let mut compiled = compile(f, true);
957    compiled(x)
958}
959
960// The following tests are ported from the swift binding:
961// mlx-swift/Tests/MLXTests/IntegrationTests.swift
962#[cfg(test)]
963mod tests {
964    use crate::{builder::Builder, random::uniform, Dtype};
965    use float_eq::assert_float_eq;
966
967    use super::*;
968
969    #[test]
970    fn test_glu() {
971        crate::random::seed(850).unwrap();
972        let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
973        assert_eq!(a.shape(), &[2, 8, 16]);
974        assert_eq!(a.dtype(), Dtype::Float32);
975        assert_float_eq!(
976            a.mean(None).unwrap().item::<f32>(),
977            0.547_252_66,
978            abs <= 0.010_945_053
979        );
980        assert_float_eq!(
981            a.sum(None).unwrap().item::<f32>(),
982            140.096_68,
983            abs <= 2.801_933_5
984        );
985        let result = Glu::new().forward(&a).unwrap();
986        assert_eq!(result.shape(), &[2, 8, 8]);
987        assert_eq!(result.dtype(), Dtype::Float32);
988        assert_float_eq!(
989            result.mean(None).unwrap().item::<f32>(),
990            0.333_276_75,
991            abs <= 0.006_665_535
992        );
993        assert_float_eq!(
994            result.sum(None).unwrap().item::<f32>(),
995            42.659_424,
996            abs <= 0.853_188_46
997        );
998    }
999
1000    #[test]
1001    fn test_sigmoid() {
1002        crate::random::seed(589).unwrap();
1003        let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
1004        assert_eq!(a.shape(), &[2, 8, 16]);
1005        assert_eq!(a.dtype(), Dtype::Float32);
1006        assert_float_eq!(
1007            a.mean(None).unwrap().item::<f32>(),
1008            0.529_697_9,
1009            abs <= 0.010_593_958
1010        );
1011        assert_float_eq!(
1012            a.sum(None).unwrap().item::<f32>(),
1013            135.602_66,
1014            abs <= 2.712_053_3
1015        );
1016        let result = Sigmoid.forward(&a).unwrap();
1017        assert_eq!(result.shape(), &[2, 8, 16]);
1018        assert_eq!(result.dtype(), Dtype::Float32);
1019        assert_float_eq!(
1020            result.mean(None).unwrap().item::<f32>(),
1021            0.627_014,
1022            abs <= 0.012_540_28
1023        );
1024        assert_float_eq!(
1025            result.sum(None).unwrap().item::<f32>(),
1026            160.515_58,
1027            abs <= 3.210_311_7
1028        );
1029    }
1030
1031    #[test]
1032    fn test_mish() {
1033        crate::random::seed(122).unwrap();
1034        let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
1035        assert_eq!(a.shape(), &[2, 8, 16]);
1036        assert_eq!(a.dtype(), Dtype::Float32);
1037        assert_float_eq!(
1038            a.mean(None).unwrap().item::<f32>(),
1039            0.501_719_8,
1040            abs <= 0.010_034_395
1041        );
1042        assert_float_eq!(
1043            a.sum(None).unwrap().item::<f32>(),
1044            128.440_26,
1045            abs <= 2.568_805_2
1046        );
1047        let result = Mish.forward(&a).unwrap();
1048        assert_eq!(result.shape(), &[2, 8, 16]);
1049        assert_eq!(result.dtype(), Dtype::Float32);
1050        assert_float_eq!(
1051            result.mean(None).unwrap().item::<f32>(),
1052            0.395_375_73,
1053            abs <= 0.007_907_514
1054        );
1055        assert_float_eq!(
1056            result.sum(None).unwrap().item::<f32>(),
1057            101.216_19,
1058            abs <= 2.024_323_7
1059        );
1060    }
1061
1062    #[test]
1063    fn test_relu() {
1064        crate::random::seed(400).unwrap();
1065        let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
1066        assert_eq!(a.shape(), &[2, 8, 16]);
1067        assert_eq!(a.dtype(), Dtype::Float32);
1068        assert_float_eq!(
1069            a.mean(None).unwrap().item::<f32>(),
1070            0.478_322_74,
1071            abs <= 0.009_566_455
1072        );
1073        assert_float_eq!(
1074            a.sum(None).unwrap().item::<f32>(),
1075            122.450_62,
1076            abs <= 2.449_012_5
1077        );
1078        let result = Relu.forward(&a).unwrap();
1079        assert_eq!(result.shape(), &[2, 8, 16]);
1080        assert_eq!(result.dtype(), Dtype::Float32);
1081        assert_float_eq!(
1082            result.mean(None).unwrap().item::<f32>(),
1083            0.478_322_74,
1084            abs <= 0.009_566_455
1085        );
1086        assert_float_eq!(
1087            result.sum(None).unwrap().item::<f32>(),
1088            122.450_62,
1089            abs <= 2.449_012_5
1090        );
1091    }
1092
1093    #[test]
1094    fn test_leaky_relu() {
1095        crate::random::seed(93).unwrap();
1096        let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
1097        assert_eq!(a.shape(), &[2, 8, 16]);
1098        assert_eq!(a.dtype(), Dtype::Float32);
1099        assert_float_eq!(
1100            a.mean(None).unwrap().item::<f32>(),
1101            0.499_930_68,
1102            abs <= 0.009_998_614
1103        );
1104        assert_float_eq!(
1105            a.sum(None).unwrap().item::<f32>(),
1106            127.982_254,
1107            abs <= 2.559_645_2
1108        );
1109        let result = LeakyRelu::new().forward(&a).unwrap();
1110        assert_eq!(result.shape(), &[2, 8, 16]);
1111        assert_eq!(result.dtype(), Dtype::Float32);
1112        assert_float_eq!(
1113            result.mean(None).unwrap().item::<f32>(),
1114            0.499_930_68,
1115            abs <= 0.009_998_614
1116        );
1117        assert_float_eq!(
1118            result.sum(None).unwrap().item::<f32>(),
1119            127.982_254,
1120            abs <= 2.559_645_2
1121        );
1122    }
1123
1124    #[test]
1125    fn test_relu6() {
1126        crate::random::seed(379).unwrap();
1127        let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
1128        assert_eq!(a.shape(), &[2, 8, 16]);
1129        assert_eq!(a.dtype(), Dtype::Float32);
1130        assert_float_eq!(
1131            a.mean(None).unwrap().item::<f32>(),
1132            0.493_258_66,
1133            abs <= 0.009_865_173
1134        );
1135        assert_float_eq!(
1136            a.sum(None).unwrap().item::<f32>(),
1137            126.274_216,
1138            abs <= 2.525_484_3
1139        );
1140        let result = Relu6.forward(&a).unwrap();
1141        assert_eq!(result.shape(), &[2, 8, 16]);
1142        assert_eq!(result.dtype(), Dtype::Float32);
1143        assert_float_eq!(
1144            result.mean(None).unwrap().item::<f32>(),
1145            0.493_258_66,
1146            abs <= 0.009_865_173
1147        );
1148        assert_float_eq!(
1149            result.sum(None).unwrap().item::<f32>(),
1150            126.274_216,
1151            abs <= 2.525_484_3
1152        );
1153    }
1154
1155    #[test]
1156    fn test_softmax() {
1157        crate::random::seed(853).unwrap();
1158        let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
1159        assert_eq!(a.shape(), &[2, 8, 16]);
1160        assert_eq!(a.dtype(), Dtype::Float32);
1161        assert_float_eq!(
1162            a.mean(None).unwrap().item::<f32>(),
1163            0.514_396_3,
1164            abs <= 0.010_287_926_5
1165        );
1166        assert_float_eq!(
1167            a.sum(None).unwrap().item::<f32>(),
1168            131.685_46,
1169            abs <= 2.633_709_2
1170        );
1171        let result = Softmax::new().forward(&a).unwrap();
1172        assert_eq!(result.shape(), &[2, 8, 16]);
1173        assert_eq!(result.dtype(), Dtype::Float32);
1174        assert_float_eq!(
1175            result.mean(None).unwrap().item::<f32>(),
1176            0.062_499_996,
1177            abs <= 0.001_25
1178        );
1179        assert_float_eq!(
1180            result.sum(None).unwrap().item::<f32>(),
1181            15.999_999,
1182            abs <= 0.32
1183        );
1184    }
1185
1186    #[test]
1187    fn test_softplus() {
1188        crate::random::seed(118).unwrap();
1189        let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
1190        assert_eq!(a.shape(), &[2, 8, 16]);
1191        assert_eq!(a.dtype(), Dtype::Float32);
1192        assert_float_eq!(
1193            a.mean(None).unwrap().item::<f32>(),
1194            0.498_981_42,
1195            abs <= 0.009_979_628
1196        );
1197        assert_float_eq!(
1198            a.sum(None).unwrap().item::<f32>(),
1199            127.739_24,
1200            abs <= 2.554_784_8
1201        );
1202        let result = Softplus.forward(&a).unwrap();
1203        assert_eq!(result.shape(), &[2, 8, 16]);
1204        assert_eq!(result.dtype(), Dtype::Float32);
1205        assert_float_eq!(
1206            result.mean(None).unwrap().item::<f32>(),
1207            0.982_857_76,
1208            abs <= 0.019_657_155
1209        );
1210        assert_float_eq!(
1211            result.sum(None).unwrap().item::<f32>(),
1212            251.611_59,
1213            abs <= 5.032_232
1214        );
1215    }
1216
1217    #[test]
1218    fn test_softsign() {
1219        crate::random::seed(37).unwrap();
1220        let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
1221        assert_eq!(a.shape(), &[2, 8, 16]);
1222        assert_eq!(a.dtype(), Dtype::Float32);
1223        assert_float_eq!(
1224            a.mean(None).unwrap().item::<f32>(),
1225            0.506_551_27,
1226            abs <= 0.010_131_026
1227        );
1228        assert_float_eq!(
1229            a.sum(None).unwrap().item::<f32>(),
1230            129.677_12,
1231            abs <= 2.593_542_6
1232        );
1233        let result = Softsign.forward(&a).unwrap();
1234        assert_eq!(result.shape(), &[2, 8, 16]);
1235        assert_eq!(result.dtype(), Dtype::Float32);
1236        assert_float_eq!(
1237            result.mean(None).unwrap().item::<f32>(),
1238            0.314_089_83,
1239            abs <= 0.006_281_797
1240        );
1241        assert_float_eq!(
1242            result.sum(None).unwrap().item::<f32>(),
1243            80.407,
1244            abs <= 1.608_14
1245        );
1246    }
1247
1248    // The unit test below is adapted from the python binding:
1249    // mlx/python/tests/test_nn.py
1250    #[test]
1251    fn test_celu() {
1252        let x = array!([1.0, -1.0, 0.0]);
1253        let y = Celu::new().forward(&x).unwrap();
1254        let epsilon = array!(1e-4);
1255        let expected_y = array!([1.0, -0.6321, 0.0]);
1256        assert!(y
1257            .subtract(&expected_y)
1258            .unwrap()
1259            .abs()
1260            .unwrap()
1261            .lt(&epsilon)
1262            .unwrap()
1263            .all(None)
1264            .unwrap()
1265            .item::<bool>());
1266        assert_eq!(y.shape(), &[3]);
1267        assert_eq!(y.dtype(), Dtype::Float32);
1268
1269        let y = CeluBuilder::new()
1270            .alpha(1.1)
1271            .build()
1272            .unwrap()
1273            .forward(&x)
1274            .unwrap();
1275        let expected_y = array!([1.0, -0.6568, 0.0]);
1276        assert!(y
1277            .subtract(&expected_y)
1278            .unwrap()
1279            .abs()
1280            .unwrap()
1281            .lt(&epsilon)
1282            .unwrap()
1283            .all(None)
1284            .unwrap()
1285            .item::<bool>());
1286        assert_eq!(y.shape(), &[3]);
1287        assert_eq!(y.dtype(), Dtype::Float32);
1288    }
1289
1290    #[test]
1291    fn test_silu() {
1292        crate::random::seed(22).unwrap();
1293        let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
1294        assert_eq!(a.shape(), &[2, 8, 16]);
1295        assert_eq!(a.dtype(), Dtype::Float32);
1296        assert_float_eq!(
1297            a.mean(None).unwrap().item::<f32>(),
1298            0.502_970_6,
1299            abs <= 0.010_059_412
1300        );
1301        assert_float_eq!(
1302            a.sum(None).unwrap().item::<f32>(),
1303            128.760_47,
1304            abs <= 2.575_209_4
1305        );
1306        let result = Silu.forward(&a).unwrap();
1307        assert_eq!(result.shape(), &[2, 8, 16]);
1308        assert_eq!(result.dtype(), Dtype::Float32);
1309        assert_float_eq!(
1310            result.mean(None).unwrap().item::<f32>(),
1311            0.331_970_93,
1312            abs <= 0.006_639_418_7
1313        );
1314        assert_float_eq!(
1315            result.sum(None).unwrap().item::<f32>(),
1316            84.984_56,
1317            abs <= 1.699_691_2
1318        );
1319    }
1320
1321    #[test]
1322    fn test_log_softmax() {
1323        crate::random::seed(199).unwrap();
1324        let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
1325        assert_eq!(a.shape(), &[2, 8, 16]);
1326        assert_eq!(a.dtype(), Dtype::Float32);
1327        assert_float_eq!(
1328            a.mean(None).unwrap().item::<f32>(),
1329            0.527_843_7,
1330            abs <= 0.010_556_874
1331        );
1332        assert_float_eq!(
1333            a.sum(None).unwrap().item::<f32>(),
1334            135.127_99,
1335            abs <= 2.702_559_7
1336        );
1337        let result = LogSoftmax::new().forward(&a).unwrap();
1338        assert_eq!(result.shape(), &[2, 8, 16]);
1339        assert_eq!(result.dtype(), Dtype::Float32);
1340        assert_float_eq!(
1341            result.mean(None).unwrap().item::<f32>(),
1342            -2.810_954_6,
1343            abs <= 0.056_219_09
1344        );
1345        assert_float_eq!(
1346            result.sum(None).unwrap().item::<f32>(),
1347            -719.604_4,
1348            abs <= 14.392_087
1349        );
1350    }
1351
1352    #[test]
1353    fn test_log_sigmoid() {
1354        crate::random::seed(984).unwrap();
1355        let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
1356        assert_eq!(a.shape(), &[2, 8, 16]);
1357        assert_eq!(a.dtype(), Dtype::Float32);
1358        assert_float_eq!(
1359            a.mean(None).unwrap().item::<f32>(),
1360            0.510_977_7,
1361            abs <= 0.010_219_553_5
1362        );
1363        assert_float_eq!(
1364            a.sum(None).unwrap().item::<f32>(),
1365            130.810_29,
1366            abs <= 2.616_205_7
1367        );
1368        let result = LogSigmoid.forward(&a).unwrap();
1369        assert_eq!(result.shape(), &[2, 8, 16]);
1370        assert_eq!(result.dtype(), Dtype::Float32);
1371        assert_float_eq!(
1372            result.mean(None).unwrap().item::<f32>(),
1373            -0.479_598_55,
1374            abs <= 0.009_591_971
1375        );
1376        assert_float_eq!(
1377            result.sum(None).unwrap().item::<f32>(),
1378            -122.777_23,
1379            abs <= 2.455_544_5
1380        );
1381    }
1382
1383    #[test]
1384    fn test_prelu() {
1385        crate::random::seed(993).unwrap();
1386        let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
1387        assert_eq!(a.shape(), &[2, 8, 16]);
1388        assert_eq!(a.dtype(), Dtype::Float32);
1389        assert_float_eq!(
1390            a.mean(None).unwrap().item::<f32>(),
1391            0.496_651_44,
1392            abs <= 0.009_933_028
1393        );
1394        assert_float_eq!(
1395            a.sum(None).unwrap().item::<f32>(),
1396            127.142_77,
1397            abs <= 2.542_855_3
1398        );
1399        let result = Prelu::new().forward(&a).unwrap();
1400        assert_eq!(result.shape(), &[2, 8, 16]);
1401        assert_eq!(result.dtype(), Dtype::Float32);
1402        assert_float_eq!(
1403            result.mean(None).unwrap().item::<f32>(),
1404            0.496_651_44,
1405            abs <= 0.009_933_028
1406        );
1407        assert_float_eq!(
1408            result.sum(None).unwrap().item::<f32>(),
1409            127.142_77,
1410            abs <= 2.542_855_3
1411        );
1412    }
1413
1414    #[test]
1415    fn test_gelu() {
1416        crate::random::seed(189).unwrap();
1417        let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
1418        assert_eq!(a.shape(), &[2, 8, 16]);
1419        assert_eq!(a.dtype(), Dtype::Float32);
1420        assert_float_eq!(
1421            a.mean(None).unwrap().item::<f32>(),
1422            0.492_950_32,
1423            abs <= 0.009_859_007
1424        );
1425        assert_float_eq!(
1426            a.sum(None).unwrap().item::<f32>(),
1427            126.195_28,
1428            abs <= 2.523_905_8
1429        );
1430        let result = Gelu::new().forward(&a).unwrap();
1431        assert_eq!(result.shape(), &[2, 8, 16]);
1432        assert_eq!(result.dtype(), Dtype::Float32);
1433        assert_float_eq!(
1434            result.mean(None).unwrap().item::<f32>(),
1435            0.365_638_38,
1436            abs <= 0.007_312_767_7
1437        );
1438        assert_float_eq!(
1439            result.sum(None).unwrap().item::<f32>(),
1440            93.603_424,
1441            abs <= 1.872_068_5
1442        );
1443    }
1444
1445    #[test]
1446    fn test_tanh() {
1447        crate::random::seed(735).unwrap();
1448        let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
1449        assert_eq!(a.shape(), &[2, 8, 16]);
1450        assert_eq!(a.dtype(), Dtype::Float32);
1451        assert_float_eq!(
1452            a.mean(None).unwrap().item::<f32>(),
1453            0.474_122_7,
1454            abs <= 0.009_482_454_5
1455        );
1456        assert_float_eq!(
1457            a.sum(None).unwrap().item::<f32>(),
1458            121.375_41,
1459            abs <= 2.427_508_4
1460        );
1461        let result = Tanh.forward(&a).unwrap();
1462        assert_eq!(result.shape(), &[2, 8, 16]);
1463        assert_eq!(result.dtype(), Dtype::Float32);
1464        assert_float_eq!(
1465            result.mean(None).unwrap().item::<f32>(),
1466            0.413_079_68,
1467            abs <= 0.008_261_594
1468        );
1469        assert_float_eq!(
1470            result.sum(None).unwrap().item::<f32>(),
1471            105.748_4,
1472            abs <= 2.114_968
1473        );
1474    }
1475
1476    #[test]
1477    fn test_hardswish() {
1478        crate::random::seed(126).unwrap();
1479        let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
1480        assert_eq!(a.shape(), &[2, 8, 16]);
1481        assert_eq!(a.dtype(), Dtype::Float32);
1482        assert_float_eq!(
1483            a.mean(None).unwrap().item::<f32>(),
1484            0.491_892_46,
1485            abs <= 0.009_837_849
1486        );
1487        assert_float_eq!(
1488            a.sum(None).unwrap().item::<f32>(),
1489            125.924_47,
1490            abs <= 2.518_489_4
1491        );
1492        let result = HardSwish.forward(&a).unwrap();
1493        assert_eq!(result.shape(), &[2, 8, 16]);
1494        assert_eq!(result.dtype(), Dtype::Float32);
1495        assert_float_eq!(
1496            result.mean(None).unwrap().item::<f32>(),
1497            0.299_602_24,
1498            abs <= 0.005_992_044_7
1499        );
1500        assert_float_eq!(
1501            result.sum(None).unwrap().item::<f32>(),
1502            76.698_17,
1503            abs <= 1.533_963_4
1504        );
1505    }
1506
1507    #[test]
1508    fn test_step() {
1509        crate::random::seed(490).unwrap();
1510        let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
1511        assert_eq!(a.shape(), &[2, 8, 16]);
1512        assert_eq!(a.dtype(), Dtype::Float32);
1513        assert_float_eq!(
1514            a.mean(None).unwrap().item::<f32>(),
1515            0.479_360_64,
1516            abs <= 0.009_587_212_5
1517        );
1518        assert_float_eq!(
1519            a.sum(None).unwrap().item::<f32>(),
1520            122.716_324,
1521            abs <= 2.454_326_4
1522        );
1523        let result = Step::new().forward(&a).unwrap();
1524        assert_eq!(result.shape(), &[2, 8, 16]);
1525        assert_eq!(result.dtype(), Dtype::Int32);
1526        assert_float_eq!(result.mean(None).unwrap().item::<f32>(), 1.0, abs <= 0.02);
1527        assert_float_eq!(result.sum(None).unwrap().item::<f32>(), 256.0, abs <= 5.12);
1528    }
1529
1530    #[test]
1531    fn test_selu() {
1532        crate::random::seed(215).unwrap();
1533        let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
1534        assert_eq!(a.shape(), &[2, 8, 16]);
1535        assert_eq!(a.dtype(), Dtype::Float32);
1536        assert_float_eq!(
1537            a.mean(None).unwrap().item::<f32>(),
1538            0.493_026_8,
1539            abs <= 0.009_860_536
1540        );
1541        assert_float_eq!(
1542            a.sum(None).unwrap().item::<f32>(),
1543            126.214_86,
1544            abs <= 2.524_297_2
1545        );
1546        let result = Selu.forward(&a).unwrap();
1547        assert_eq!(result.shape(), &[2, 8, 16]);
1548        assert_eq!(result.dtype(), Dtype::Float32);
1549        assert_float_eq!(
1550            result.mean(None).unwrap().item::<f32>(),
1551            0.518_023_2,
1552            abs <= 0.010_360_463_5
1553        );
1554        assert_float_eq!(
1555            result.sum(None).unwrap().item::<f32>(),
1556            132.613_94,
1557            abs <= 2.652_278_7
1558        );
1559    }
1560}