mlx_rs/nn/
normalization.rs

1use std::borrow::Cow;
2
3use crate::{
4    array,
5    error::Exception,
6    module::{Module, Param},
7    ops::{ones, rsqrt, zeros},
8    Array,
9};
10use mlx_internal_macros::{Buildable, Builder};
11use mlx_macros::ModuleParameters;
12
13fn instance_norm(x: &Array, axes: &[i32], eps: &Array) -> Result<Array, Exception> {
14    // Compute stats
15    let mean = x.mean_axes(axes, true)?;
16    let variance = x.var_axes(axes, true, None)?;
17
18    // Normalize
19    let x = x.subtract(&mean)?.multiply(rsqrt(&variance.add(eps)?)?)?;
20
21    Ok(x)
22}
23
24/// Builder for [`InstanceNorm`].
25#[derive(Debug, Clone, Builder)]
26#[builder(
27    root = crate,
28    build_with = build_instance_norm,
29    err = Exception,
30)]
31pub struct InstanceNormBuilder {
32    /// Number of features in the input
33    pub dimensions: i32,
34
35    /// Value added to the denominator for numerical stability. Default to
36    /// [`InstanceNorm::DEFAULT_EPS`].
37    #[builder(optional, default = InstanceNorm::DEFAULT_EPS)]
38    pub eps: f32,
39
40    /// If `true`, addes a trainable `weight` and `bias`. Default to
41    /// [`InstanceNorm::DEFAULT_AFFINE`].
42    #[builder(optional, default = InstanceNorm::DEFAULT_AFFINE)]
43    pub affine: bool,
44}
45
46fn build_instance_norm(builder: InstanceNormBuilder) -> Result<InstanceNorm, Exception> {
47    let eps = builder.eps;
48    let affine = builder.affine;
49
50    let (weight, bias) = if affine {
51        (
52            Some(ones::<f32>(&[builder.dimensions])?),
53            Some(zeros::<f32>(&[builder.dimensions])?),
54        )
55    } else {
56        (None, None)
57    };
58
59    Ok(InstanceNorm {
60        dimensions: builder.dimensions,
61        eps: array!(eps),
62        weight: Param::new(weight),
63        bias: Param::new(bias),
64    })
65}
66
67/// Applies instance normalization [1] on the inputs.
68///
69/// ### References
70///
71/// 1. [https://arxiv.org/abs/1607.08022](https://arxiv.org/abs/1607.08022)
72#[derive(Debug, Clone, ModuleParameters, Buildable)]
73#[module(root = crate)]
74#[buildable(root = crate)]
75pub struct InstanceNorm {
76    /// Number of features in the input
77    pub dimensions: i32,
78
79    /// Value added to the denominator for numerical stability.
80    pub eps: Array,
81
82    /// An optional trainable weight
83    #[param]
84    pub weight: Param<Option<Array>>,
85
86    /// An optional trainable bias
87    #[param]
88    pub bias: Param<Option<Array>>,
89}
90
91impl InstanceNorm {
92    /// Default value for `eps`.
93    pub const DEFAULT_EPS: f32 = 1e-5;
94
95    /// Disable trainable `weight` and `bias` by default.
96    pub const DEFAULT_AFFINE: bool = false;
97}
98
99impl Module<&Array> for InstanceNorm {
100    type Error = Exception;
101    type Output = Array;
102
103    fn forward(&mut self, x: &Array) -> Result<Array, Self::Error> {
104        let reduction_axes = (1..x.ndim() as i32 - 1).collect::<Vec<_>>();
105
106        let x = instance_norm(x, &reduction_axes, &self.eps)?;
107
108        if let (Some(weight), Some(bias)) = (self.weight.as_ref(), self.bias.as_ref()) {
109            weight.multiply(x)?.add(bias)
110        } else {
111            Ok(x)
112        }
113    }
114
115    fn training_mode(&mut self, _mode: bool) {}
116}
117
118/// Builder for [`LayerNorm`].
119#[derive(Debug, Clone, Builder)]
120#[builder(
121    root = crate,
122    build_with = build_layer_norm,
123    err = Exception,
124)]
125pub struct LayerNormBuilder {
126    /// Number of features in the input
127    pub dimensions: i32,
128
129    /// Value added to the denominator for numerical stability. Default to
130    /// [`LayerNorm::DEFAULT_EPS`].
131    #[builder(optional, default = LayerNorm::DEFAULT_EPS)]
132    pub eps: f32,
133
134    /// If `true`, addes a trainable `weight` and `bias`. Default to
135    /// [`LayerNorm::DEFAULT_AFFINE`].
136    #[builder(optional, default = LayerNorm::DEFAULT_AFFINE)]
137    pub affine: bool,
138}
139
140fn build_layer_norm(builder: LayerNormBuilder) -> Result<LayerNorm, Exception> {
141    let eps = builder.eps;
142    let affine = builder.affine;
143
144    let (weight, bias) = if affine {
145        (
146            Some(ones::<f32>(&[builder.dimensions])?),
147            Some(zeros::<f32>(&[builder.dimensions])?),
148        )
149    } else {
150        (None, None)
151    };
152
153    Ok(LayerNorm {
154        dimensions: builder.dimensions,
155        eps,
156        weight: Param::new(weight),
157        bias: Param::new(bias),
158    })
159}
160
161/// Applies layer normalization [1] on the inputs.
162///
163/// ### References
164///
165/// 1. [https://arxiv.org/abs/1607.06450](https://arxiv.org/abs/1607.06450)
166#[derive(Debug, Clone, ModuleParameters, Buildable)]
167#[module(root = crate)]
168#[buildable(root = crate)]
169pub struct LayerNorm {
170    /// Number of features in the input
171    pub dimensions: i32,
172
173    /// Value added to the denominator for numerical stability.
174    pub eps: f32,
175
176    /// An optional trainable weight
177    #[param]
178    pub weight: Param<Option<Array>>,
179
180    /// An optional trainable bias
181    #[param]
182    pub bias: Param<Option<Array>>,
183}
184
185impl LayerNorm {
186    /// Default value for `eps`.
187    pub const DEFAULT_EPS: f32 = 1e-5;
188
189    /// Enable trainable `weight` and `bias` by default.
190    pub const DEFAULT_AFFINE: bool = true;
191}
192
193impl Module<&Array> for LayerNorm {
194    type Error = Exception;
195    type Output = Array;
196
197    fn forward(&mut self, x: &Array) -> Result<Array, Self::Error> {
198        let weight = self.weight.as_ref();
199        let bias = self.bias.as_ref();
200        let eps = self.eps;
201        crate::fast::layer_norm(x, weight, bias, eps)
202    }
203
204    fn training_mode(&mut self, _mode: bool) {}
205}
206
207/// Builder for [`RmsNorm`].
208#[derive(Debug, Clone, Builder)]
209#[builder(
210    root = crate,
211    build_with = build_rms_norm,
212    err = Exception,
213)]
214pub struct RmsNormBuilder {
215    /// Number of features in the input
216    pub dimensions: i32,
217
218    /// Value added to the denominator for numerical stability. Default to
219    /// [`RmsNorm::DEFAULT_EPS`].
220    #[builder(optional, default = RmsNorm::DEFAULT_EPS)]
221    pub eps: f32,
222}
223
224fn build_rms_norm(builder: RmsNormBuilder) -> Result<RmsNorm, Exception> {
225    let weight = ones::<f32>(&[builder.dimensions])?;
226    let eps = builder.eps;
227    Ok(RmsNorm {
228        weight: Param::new(weight),
229        eps,
230    })
231}
232
233/// Applies Root Mean Square normalization [1] to the inputs.
234///
235/// Concretely:
236///
237/// ```swift
238/// weight * x * MLX.rsqrt(x.square().mean() + eps)
239/// ```
240///
241/// where `weight` is initialized with ones and `eps` is a small float to
242/// ensure the numerical stability of inverse square root.
243///
244/// ### References
245///
246/// 1. [https://arxiv.org/abs/1910.07467](https://arxiv.org/abs/1910.07467)
247#[derive(Debug, Clone, ModuleParameters, Buildable)]
248#[module(root = crate)]
249#[buildable(root = crate)]
250pub struct RmsNorm {
251    /// Weight
252    #[param]
253    pub weight: Param<Array>,
254
255    /// A small float to ensure the numerical stability
256    pub eps: f32,
257}
258
259impl RmsNorm {
260    /// Default value for `eps`.
261    pub const DEFAULT_EPS: f32 = 1e-5;
262}
263
264impl Module<&Array> for RmsNorm {
265    type Error = Exception;
266    type Output = Array;
267
268    fn forward(&mut self, x: &Array) -> Result<Array, Self::Error> {
269        let weight = self.weight.as_ref();
270        let eps = self.eps;
271        crate::fast::rms_norm(x, weight, eps)
272    }
273
274    fn training_mode(&mut self, _mode: bool) {}
275}
276
277/// Builder for [`GroupNorm`].
278#[derive(Debug, Clone, Builder)]
279#[builder(
280    root = crate,
281    build_with = build_group_norm,
282    err = Exception,
283)]
284pub struct GroupNormBuilder {
285    /// Number of groups to separate the features into
286    pub group_count: i32,
287
288    /// Number of features in the input
289    pub dimensions: i32,
290
291    /// Value added to the denominator for numerical stability. Default to
292    /// [`GroupNorm::DEFAULT_EPS`].
293    #[builder(optional, default = GroupNorm::DEFAULT_EPS)]
294    pub eps: f32,
295
296    /// If `true`, add a trainable `weight` and `bias`. Default to
297    /// [`GroupNorm::DEFAULT_AFFINE`].
298    #[builder(optional, default = GroupNorm::DEFAULT_AFFINE)]
299    pub affine: bool,
300
301    /// If `true`, perform the group normalization in the same order/grouping as PyTorch.
302    /// Default to [`GroupNorm::DEFAULT_PYTORCH_COMPATIBLE`].
303    #[builder(optional, default = GroupNorm::DEFAULT_PYTORCH_COMPATIBLE)]
304    pub pytorch_compatible: bool,
305}
306
307fn build_group_norm(builder: GroupNormBuilder) -> Result<GroupNorm, Exception> {
308    let eps = builder.eps;
309    let affine = builder.affine;
310    let pytorch_compatible = builder.pytorch_compatible;
311
312    let (weight, bias) = if affine {
313        (
314            Some(ones::<f32>(&[builder.dimensions])?),
315            Some(zeros::<f32>(&[builder.dimensions])?),
316        )
317    } else {
318        (None, None)
319    };
320
321    Ok(GroupNorm {
322        group_count: builder.group_count,
323        dimensions: builder.dimensions,
324        eps: array!(eps),
325        pytorch_compatible,
326        weight: Param::new(weight),
327        bias: Param::new(bias),
328    })
329}
330
331/// Applies Group Normalization [1] on the inputs.
332///
333/// ### References
334///
335/// 1. [https://arxiv.org/abs/1803.08494](https://arxiv.org/abs/1803.08494)
336#[derive(Debug, Clone, ModuleParameters, Buildable)]
337#[module(root = crate)]
338#[buildable(root = crate)]
339pub struct GroupNorm {
340    /// Number of groups to separate the features into
341    pub group_count: i32,
342
343    /// Number of features in the input
344    pub dimensions: i32,
345
346    /// Value added to the denominator for numerical stability.
347    pub eps: Array,
348
349    /// If `true`, perform the group normalization in the same order/grouping as PyTorch.
350    pub pytorch_compatible: bool,
351
352    /// An optional trainable weight
353    #[param]
354    pub weight: Param<Option<Array>>,
355
356    /// An optional trainable bias
357    #[param]
358    pub bias: Param<Option<Array>>,
359}
360
361impl GroupNorm {
362    /// Default value for `eps`.
363    pub const DEFAULT_EPS: f32 = 1e-5;
364
365    /// Enable trainable `weight` and `bias` by default.
366    pub const DEFAULT_AFFINE: bool = true;
367
368    /// Default value for `pytorch_compatible`.
369    pub const DEFAULT_PYTORCH_COMPATIBLE: bool = false;
370
371    fn pytorch_group_norm(&self, x: &Array) -> Result<Array, Exception> {
372        let batch = x.dim(0);
373        let dims = x.dim(-1);
374        let rest = &x.shape()[1..x.ndim() - 1];
375        let group_size = dims / self.group_count;
376
377        // Split into groups
378        let x = x.reshape(&[batch, -1, self.group_count, group_size])?;
379        let x = x
380            .transpose_axes(&[0, 2, 1, 3])?
381            .reshape(&[batch, self.group_count, -1])?;
382
383        // Normalize
384        let x = crate::fast::layer_norm(x, None, None, self.eps.item::<f32>())?;
385
386        let x = x.reshape(&[batch, self.group_count, -1, group_size])?;
387
388        let new_shape: Vec<_> = [batch]
389            .into_iter()
390            .chain(rest.iter().copied())
391            .chain([dims])
392            .collect();
393        x.transpose_axes(&[0, 2, 1, 3])?.reshape(&new_shape[..])
394    }
395
396    fn group_norm(&self, x: &Array) -> Result<Array, Exception> {
397        let batch = x.dim(0);
398        let dims = x.dim(-1);
399        let rest = &x.shape()[1..x.ndim() - 1];
400
401        // Split into groups
402        let x = x.reshape(&[batch, -1, self.group_count])?;
403
404        // Normalize
405        let x = instance_norm(&x, &[1], &self.eps)?;
406
407        let new_shape: Vec<_> = [batch]
408            .into_iter()
409            .chain(rest.iter().copied())
410            .chain([dims])
411            .collect();
412        x.reshape(&new_shape[..])
413    }
414}
415
416impl Module<&Array> for GroupNorm {
417    type Error = Exception;
418    type Output = Array;
419
420    fn forward(&mut self, x: &Array) -> Result<Array, Self::Error> {
421        let x = if self.pytorch_compatible {
422            self.pytorch_group_norm(x)?
423        } else {
424            self.group_norm(x)?
425        };
426
427        if let (Some(weight), Some(bias)) = (self.weight.as_ref(), self.bias.as_ref()) {
428            weight.multiply(&x)?.add(bias)
429        } else {
430            Ok(x)
431        }
432    }
433
434    fn training_mode(&mut self, _mode: bool) {}
435}
436
437/// Builder for [`BatchNorm`].
438#[derive(Debug, Clone, Builder)]
439#[builder(
440    root = crate,
441    build_with = build_batch_norm,
442    err = Exception,
443)]
444pub struct BatchNormBuilder {
445    /// Number of features in the input
446    pub feature_count: i32,
447
448    /// Value added to the denominator for numerical stability. Default to
449    /// [`BatchNorm::DEFAULT_EPS`].
450    #[builder(optional, default = BatchNorm::DEFAULT_EPS)]
451    pub eps: f32,
452
453    /// Momentum for updating the running mean and variance. Default to
454    /// [`BatchNorm::DEFAULT_MOMENTUM`].
455    #[builder(optional, default = BatchNorm::DEFAULT_MOMENTUM)]
456    pub momentum: f32,
457
458    /// If `true`, addes a trainable `weight` and `bias`. Default to
459    /// [`BatchNorm::DEFAULT_AFFINE`].
460    #[builder(optional, default = BatchNorm::DEFAULT_AFFINE)]
461    pub affine: bool,
462
463    /// If `true`, track the running mean and variance. Default to
464    /// [`BatchNorm::DEFAULT_TRACK_RUNNING_STATS`].
465    #[builder(optional, default = BatchNorm::DEFAULT_TRACK_RUNNING_STATS)]
466    pub track_running_stats: bool,
467}
468
469fn build_batch_norm(builder: BatchNormBuilder) -> Result<BatchNorm, Exception> {
470    let eps = builder.eps;
471    let momentum = builder.momentum;
472    let affine = builder.affine;
473    let track_running_stats = builder.track_running_stats;
474
475    let (weight, bias) = if affine {
476        (
477            Some(ones::<f32>(&[builder.feature_count])?),
478            Some(zeros::<f32>(&[builder.feature_count])?),
479        )
480    } else {
481        (None, None)
482    };
483
484    let (running_mean, running_var) = if track_running_stats {
485        (
486            Some(zeros::<f32>(&[builder.feature_count])?),
487            Some(ones::<f32>(&[builder.feature_count])?),
488        )
489    } else {
490        (None, None)
491    };
492
493    Ok(BatchNorm {
494        feature_count: builder.feature_count,
495        eps: array!(eps),
496        momentum: array!(momentum),
497        weight: Param::new(weight),
498        bias: Param::new(bias),
499        running_mean: Param::new(running_mean),
500        running_var: Param::new(running_var),
501        training: BatchNorm::DEFAULT_TRAINING,
502    })
503}
504
505/// Applies batch normalization [1] on the inputs.
506///
507/// ### References
508///
509/// 1. [https://arxiv.org/abs/1502.03167](https://arxiv.org/abs/1502.03167)
510#[derive(Debug, Clone, ModuleParameters, Buildable)]
511#[module(root = crate)]
512#[buildable(root = crate)]
513pub struct BatchNorm {
514    /// Number of features in the input
515    pub feature_count: i32,
516
517    /// Value added to the denominator for numerical stability.
518    pub eps: Array,
519
520    /// Momentum for updating the running mean and variance.
521    pub momentum: Array,
522
523    /// An optional trainable weight
524    #[param]
525    pub weight: Param<Option<Array>>,
526
527    /// An optional trainable bias
528    #[param]
529    pub bias: Param<Option<Array>>,
530
531    /// Tracked running mean
532    #[param]
533    pub running_mean: Param<Option<Array>>,
534
535    /// Tracked running variance
536    #[param]
537    pub running_var: Param<Option<Array>>,
538
539    /// If `true`, the module is in training mode.
540    pub training: bool,
541}
542
543impl BatchNorm {
544    /// Default value for `eps`.
545    pub const DEFAULT_EPS: f32 = 1e-5;
546
547    /// Default value for `momentum`.
548    pub const DEFAULT_MOMENTUM: f32 = 0.1;
549
550    /// Enable trainable `weight` and `bias` by default.
551    pub const DEFAULT_AFFINE: bool = true;
552
553    /// Enable tracking of running mean and variance by default.
554    pub const DEFAULT_TRACK_RUNNING_STATS: bool = true;
555
556    /// Enable training mode by default.
557    pub const DEFAULT_TRAINING: bool = true;
558
559    fn stats(x: &Array) -> Result<(Array, Array), Exception> {
560        let reduction_axes = (0..x.ndim() as i32 - 1).collect::<Vec<_>>();
561
562        let mean = x.mean_axes(&reduction_axes, None)?;
563        let variance = x.var_axes(&reduction_axes, None, None)?;
564
565        Ok((mean, variance))
566    }
567}
568
569impl Module<&Array> for BatchNorm {
570    type Error = Exception;
571    type Output = Array;
572
573    fn forward(&mut self, x: &Array) -> Result<Array, Self::Error> {
574        let ndim = x.ndim();
575        if !(2..=4).contains(&ndim) {
576            return Err(Exception::custom(
577                "Input tensor must be at least 2 dimensions and at most 4 dimensions",
578            ));
579        }
580
581        let (mean, variance) = Self::stats(x)?;
582        let mut mean = Cow::Owned(mean);
583        let mut variance = Cow::Owned(variance);
584
585        if let (Some(running_mean), Some(running_var)) =
586            (self.running_mean.as_mut(), self.running_var.as_mut())
587        {
588            if self.training {
589                let mu = &self.momentum;
590                // SAFETY: momentum is a single element array
591                let one_minus_mu = array!(1.0) - mu;
592
593                *running_mean = one_minus_mu
594                    .multiply(&running_mean)?
595                    .add(mu.multiply(&mean)?)?;
596                *running_var = one_minus_mu
597                    .multiply(&running_var)?
598                    .add(mu.multiply(&variance)?)?;
599            } else {
600                mean = Cow::Borrowed(&*running_mean);
601                variance = Cow::Borrowed(&*running_var);
602            }
603        }
604
605        let x = x
606            .subtract(&mean)?
607            .multiply(rsqrt(&variance.add(&self.eps)?)?)?;
608
609        if let (Some(weight), Some(bias)) = (self.weight.as_ref(), self.bias.as_ref()) {
610            weight.multiply(&x)?.add(bias)
611        } else {
612            Ok(x)
613        }
614    }
615
616    fn training_mode(&mut self, mode: bool) {
617        self.training = mode;
618    }
619}
620
621#[cfg(test)]
622mod tests {
623    use crate::{
624        ops::indexing::{Ellipsis, IndexOp},
625        Dtype,
626    };
627    use float_eq::assert_float_eq;
628
629    use super::*;
630
631    #[test]
632    fn test_instance_norm() {
633        crate::random::seed(435).unwrap();
634        let a = crate::random::uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
635        assert_eq!(a.shape(), &[2, 8, 16]);
636        assert_eq!(a.dtype(), Dtype::Float32);
637        assert_float_eq!(
638            a.mean(None).unwrap().item::<f32>(),
639            0.500_064_6,
640            abs <= 0.010_001_292
641        );
642        assert_float_eq!(
643            a.sum(None).unwrap().item::<f32>(),
644            128.016_54,
645            abs <= 2.560_330_9
646        );
647
648        let result = InstanceNorm::new(8)
649            .unwrap()
650            .forward(&a)
651            .unwrap()
652            .index((0, 0));
653        assert_eq!(result.shape(), &[16]);
654        assert_eq!(result.dtype(), Dtype::Float32);
655        assert_float_eq!(
656            result.mean(None).unwrap().item::<f32>(),
657            0.106_454_11,
658            abs <= 0.002_129_082_3
659        );
660        assert_float_eq!(
661            result.sum(None).unwrap().item::<f32>(),
662            1.703_265_8,
663            abs <= 0.034_065_317
664        );
665    }
666
667    #[test]
668    fn test_layer_norm() {
669        crate::random::seed(635).unwrap();
670        let a = crate::random::uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
671        assert_eq!(a.shape(), &[2, 8, 16]);
672        assert_eq!(a.dtype(), Dtype::Float32);
673        assert_float_eq!(
674            a.mean(None).unwrap().item::<f32>(),
675            0.492_690_32,
676            abs <= 0.009_853_806
677        );
678        assert_float_eq!(
679            a.sum(None).unwrap().item::<f32>(),
680            126.128_72,
681            abs <= 2.522_574_4
682        );
683
684        let result = LayerNorm::new(16)
685            .unwrap()
686            .forward(&a)
687            .unwrap()
688            .index((Ellipsis, 0));
689        assert_eq!(result.shape(), &[2, 8]);
690        assert_eq!(result.dtype(), Dtype::Float32);
691        assert_float_eq!(
692            result.mean(None).unwrap().item::<f32>(),
693            0.290_990_38,
694            abs <= 0.005_819_807_8
695        );
696        assert_float_eq!(
697            result.sum(None).unwrap().item::<f32>(),
698            4.655_846,
699            abs <= 0.093_116_924
700        );
701    }
702
703    #[test]
704    fn test_rms_norm() {
705        crate::random::seed(103).unwrap();
706        let a = crate::random::uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
707        assert_eq!(a.shape(), &[2, 8, 16]);
708        assert_eq!(a.dtype(), Dtype::Float32);
709        assert_float_eq!(
710            a.mean(None).unwrap().item::<f32>(),
711            0.505_476_36,
712            abs <= 0.010_109_527
713        );
714        assert_float_eq!(
715            a.sum(None).unwrap().item::<f32>(),
716            129.401_95,
717            abs <= 2.588_039
718        );
719
720        let result = RmsNorm::new(16).unwrap().forward(&a).unwrap();
721        assert_eq!(result.shape(), &[2, 8, 16]);
722        assert_eq!(result.dtype(), Dtype::Float32);
723        assert_float_eq!(
724            result.mean(None).unwrap().item::<f32>(),
725            0.872_938_75,
726            abs <= 0.017_458_774
727        );
728        assert_float_eq!(
729            result.sum(None).unwrap().item::<f32>(),
730            223.472_32,
731            abs <= 4.469_446
732        );
733    }
734
735    #[test]
736    fn test_group_norm() {
737        crate::random::seed(855).unwrap();
738        let a = crate::random::uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
739        assert_eq!(a.shape(), &[2, 8, 16]);
740        assert_eq!(a.dtype(), Dtype::Float32);
741        assert_float_eq!(
742            a.mean(None).unwrap().item::<f32>(),
743            0.486_665_87,
744            abs <= 0.009_733_317
745        );
746        assert_float_eq!(
747            a.sum(None).unwrap().item::<f32>(),
748            124.586_464,
749            abs <= 2.491_729_3
750        );
751
752        let result = GroupNorm::new(4, 16)
753            .unwrap()
754            .forward(&a)
755            .unwrap()
756            .index((0, 0));
757        assert_eq!(result.shape(), &[16]);
758        assert_eq!(result.dtype(), Dtype::Float32);
759        assert_float_eq!(
760            result.mean(None).unwrap().item::<f32>(),
761            -0.054_606_52,
762            abs <= 0.001_092_130_4
763        );
764        assert_float_eq!(
765            result.sum(None).unwrap().item::<f32>(),
766            -0.873_704_3,
767            abs <= 0.017_474_087
768        );
769    }
770
771    #[test]
772    fn test_batch_norm() {
773        crate::random::seed(266).unwrap();
774        let a = crate::random::uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
775        assert_eq!(a.shape(), &[2, 8, 16]);
776        assert_eq!(a.dtype(), Dtype::Float32);
777        assert_float_eq!(
778            a.mean(None).unwrap().item::<f32>(),
779            0.505_814_7,
780            abs <= 0.010_116_293
781        );
782        assert_float_eq!(
783            a.sum(None).unwrap().item::<f32>(),
784            129.488_56,
785            abs <= 2.589_771
786        );
787
788        let result = BatchNorm::new(16)
789            .unwrap()
790            .forward(&a)
791            .unwrap()
792            .index((0, 0));
793        assert_eq!(result.shape(), &[16]);
794        assert_eq!(result.dtype(), Dtype::Float32);
795        assert_float_eq!(
796            result.mean(None).unwrap().item::<f32>(),
797            0.439_785_24,
798            abs <= 0.008_795_705
799        );
800        assert_float_eq!(
801            result.sum(None).unwrap().item::<f32>(),
802            7.036_564,
803            abs <= 0.140_731_28
804        );
805    }
806}