mlx_rs/nn/
normalization.rs

1use std::borrow::Cow;
2
3use crate::{
4    array,
5    error::Exception,
6    module::{Module, Param},
7    ops::{ones, rsqrt, zeros},
8    Array,
9};
10use mlx_internal_macros::{Buildable, Builder};
11use mlx_macros::ModuleParameters;
12
13fn instance_norm(x: &Array, axes: &[i32], eps: &Array) -> Result<Array, Exception> {
14    // Compute stats
15    let mean = x.mean(axes, true)?;
16    let variance = x.variance(axes, true, None)?;
17
18    // Normalize
19    let x = x.subtract(&mean)?.multiply(rsqrt(&variance.add(eps)?)?)?;
20
21    Ok(x)
22}
23
24/// Builder for [`InstanceNorm`].
25#[derive(Debug, Clone, Builder)]
26#[builder(
27    root = crate,
28    build_with = build_instance_norm,
29    err = Exception,
30)]
31pub struct InstanceNormBuilder {
32    /// Number of features in the input
33    pub dimensions: i32,
34
35    /// Value added to the denominator for numerical stability. Default to
36    /// [`InstanceNorm::DEFAULT_EPS`].
37    #[builder(optional, default = InstanceNorm::DEFAULT_EPS)]
38    pub eps: f32,
39
40    /// If `true`, addes a trainable `weight` and `bias`. Default to
41    /// [`InstanceNorm::DEFAULT_AFFINE`].
42    #[builder(optional, default = InstanceNorm::DEFAULT_AFFINE)]
43    pub affine: bool,
44}
45
46fn build_instance_norm(builder: InstanceNormBuilder) -> Result<InstanceNorm, Exception> {
47    let eps = builder.eps;
48    let affine = builder.affine;
49
50    let (weight, bias) = if affine {
51        (
52            Some(ones::<f32>(&[builder.dimensions])?),
53            Some(zeros::<f32>(&[builder.dimensions])?),
54        )
55    } else {
56        (None, None)
57    };
58
59    Ok(InstanceNorm {
60        dimensions: builder.dimensions,
61        eps: array!(eps),
62        weight: Param::new(weight),
63        bias: Param::new(bias),
64    })
65}
66
67/// Applies instance normalization [1] on the inputs.
68///
69/// ### References
70///
71/// 1. [https://arxiv.org/abs/1607.08022](https://arxiv.org/abs/1607.08022)
72#[derive(Debug, Clone, ModuleParameters, Buildable)]
73#[module(root = crate)]
74#[buildable(root = crate)]
75pub struct InstanceNorm {
76    /// Number of features in the input
77    pub dimensions: i32,
78
79    /// Value added to the denominator for numerical stability.
80    pub eps: Array,
81
82    /// An optional trainable weight
83    pub weight: Param<Option<Array>>,
84
85    /// An optional trainable bias
86    pub bias: Param<Option<Array>>,
87}
88
89impl InstanceNorm {
90    /// Default value for `eps`.
91    pub const DEFAULT_EPS: f32 = 1e-5;
92
93    /// Disable trainable `weight` and `bias` by default.
94    pub const DEFAULT_AFFINE: bool = false;
95}
96
97impl Module<&Array> for InstanceNorm {
98    type Error = Exception;
99    type Output = Array;
100
101    fn forward(&mut self, x: &Array) -> Result<Array, Self::Error> {
102        let reduction_axes = (1..x.ndim() as i32 - 1).collect::<Vec<_>>();
103
104        let x = instance_norm(x, &reduction_axes, &self.eps)?;
105
106        if let (Some(weight), Some(bias)) = (self.weight.as_ref(), self.bias.as_ref()) {
107            weight.multiply(x)?.add(bias)
108        } else {
109            Ok(x)
110        }
111    }
112
113    fn training_mode(&mut self, _mode: bool) {}
114}
115
116/// Builder for [`LayerNorm`].
117#[derive(Debug, Clone, Builder)]
118#[builder(
119    root = crate,
120    build_with = build_layer_norm,
121    err = Exception,
122)]
123pub struct LayerNormBuilder {
124    /// Number of features in the input
125    pub dimensions: i32,
126
127    /// Value added to the denominator for numerical stability. Default to
128    /// [`LayerNorm::DEFAULT_EPS`].
129    #[builder(optional, default = LayerNorm::DEFAULT_EPS)]
130    pub eps: f32,
131
132    /// If `true`, addes a trainable `weight` and `bias`. Default to
133    /// [`LayerNorm::DEFAULT_AFFINE`].
134    #[builder(optional, default = LayerNorm::DEFAULT_AFFINE)]
135    pub affine: bool,
136}
137
138fn build_layer_norm(builder: LayerNormBuilder) -> Result<LayerNorm, Exception> {
139    let eps = builder.eps;
140    let affine = builder.affine;
141
142    let (weight, bias) = if affine {
143        (
144            Some(ones::<f32>(&[builder.dimensions])?),
145            Some(zeros::<f32>(&[builder.dimensions])?),
146        )
147    } else {
148        (None, None)
149    };
150
151    Ok(LayerNorm {
152        dimensions: builder.dimensions,
153        eps,
154        weight: Param::new(weight),
155        bias: Param::new(bias),
156    })
157}
158
159/// Applies layer normalization [1] on the inputs.
160///
161/// ### References
162///
163/// 1. [https://arxiv.org/abs/1607.06450](https://arxiv.org/abs/1607.06450)
164#[derive(Debug, Clone, ModuleParameters, Buildable)]
165#[module(root = crate)]
166#[buildable(root = crate)]
167pub struct LayerNorm {
168    /// Number of features in the input
169    pub dimensions: i32,
170
171    /// Value added to the denominator for numerical stability.
172    pub eps: f32,
173
174    /// An optional trainable weight
175    #[param]
176    pub weight: Param<Option<Array>>,
177
178    /// An optional trainable bias
179    #[param]
180    pub bias: Param<Option<Array>>,
181}
182
183impl LayerNorm {
184    /// Default value for `eps`.
185    pub const DEFAULT_EPS: f32 = 1e-5;
186
187    /// Enable trainable `weight` and `bias` by default.
188    pub const DEFAULT_AFFINE: bool = true;
189}
190
191impl Module<&Array> for LayerNorm {
192    type Error = Exception;
193    type Output = Array;
194
195    fn forward(&mut self, x: &Array) -> Result<Array, Self::Error> {
196        let weight = self.weight.as_ref();
197        let bias = self.bias.as_ref();
198        let eps = self.eps;
199        crate::fast::layer_norm(x, weight, bias, eps)
200    }
201
202    fn training_mode(&mut self, _mode: bool) {}
203}
204
205/// Builder for [`RmsNorm`].
206#[derive(Debug, Clone, Builder)]
207#[builder(
208    root = crate,
209    build_with = build_rms_norm,
210    err = Exception,
211)]
212pub struct RmsNormBuilder {
213    /// Number of features in the input
214    pub dimensions: i32,
215
216    /// Value added to the denominator for numerical stability. Default to
217    /// [`RmsNorm::DEFAULT_EPS`].
218    #[builder(optional, default = RmsNorm::DEFAULT_EPS)]
219    pub eps: f32,
220}
221
222fn build_rms_norm(builder: RmsNormBuilder) -> Result<RmsNorm, Exception> {
223    let weight = ones::<f32>(&[builder.dimensions])?;
224    let eps = builder.eps;
225    Ok(RmsNorm {
226        weight: Param::new(weight),
227        eps,
228    })
229}
230
231/// Applies Root Mean Square normalization [1] to the inputs.
232///
233/// Concretely:
234///
235/// ```swift
236/// weight * x * MLX.rsqrt(x.square().mean() + eps)
237/// ```
238///
239/// where `weight` is initialized with ones and `eps` is a small float to
240/// ensure the numerical stability of inverse square root.
241///
242/// ### References
243///
244/// 1. [https://arxiv.org/abs/1910.07467](https://arxiv.org/abs/1910.07467)
245#[derive(Debug, Clone, ModuleParameters, Buildable)]
246#[module(root = crate)]
247#[buildable(root = crate)]
248pub struct RmsNorm {
249    /// Weight
250    #[param]
251    pub weight: Param<Array>,
252
253    /// A small float to ensure the numerical stability
254    pub eps: f32,
255}
256
257impl RmsNorm {
258    /// Default value for `eps`.
259    pub const DEFAULT_EPS: f32 = 1e-5;
260}
261
262impl Module<&Array> for RmsNorm {
263    type Error = Exception;
264    type Output = Array;
265
266    fn forward(&mut self, x: &Array) -> Result<Array, Self::Error> {
267        let weight = self.weight.as_ref();
268        let eps = self.eps;
269        crate::fast::rms_norm(x, weight, eps)
270    }
271
272    fn training_mode(&mut self, _mode: bool) {}
273}
274
275/// Builder for [`GroupNorm`].
276#[derive(Debug, Clone, Builder)]
277#[builder(
278    root = crate,
279    build_with = build_group_norm,
280    err = Exception,
281)]
282pub struct GroupNormBuilder {
283    /// Number of groups to separate the features into
284    pub group_count: i32,
285
286    /// Number of features in the input
287    pub dimensions: i32,
288
289    /// Value added to the denominator for numerical stability. Default to
290    /// [`GroupNorm::DEFAULT_EPS`].
291    #[builder(optional, default = GroupNorm::DEFAULT_EPS)]
292    pub eps: f32,
293
294    /// If `true`, add a trainable `weight` and `bias`. Default to
295    /// [`GroupNorm::DEFAULT_AFFINE`].
296    #[builder(optional, default = GroupNorm::DEFAULT_AFFINE)]
297    pub affine: bool,
298
299    /// If `true`, perform the group normalization in the same order/grouping as PyTorch.
300    /// Default to [`GroupNorm::DEFAULT_PYTORCH_COMPATIBLE`].
301    #[builder(optional, default = GroupNorm::DEFAULT_PYTORCH_COMPATIBLE)]
302    pub pytorch_compatible: bool,
303}
304
305fn build_group_norm(builder: GroupNormBuilder) -> Result<GroupNorm, Exception> {
306    let eps = builder.eps;
307    let affine = builder.affine;
308    let pytorch_compatible = builder.pytorch_compatible;
309
310    let (weight, bias) = if affine {
311        (
312            Some(ones::<f32>(&[builder.dimensions])?),
313            Some(zeros::<f32>(&[builder.dimensions])?),
314        )
315    } else {
316        (None, None)
317    };
318
319    Ok(GroupNorm {
320        group_count: builder.group_count,
321        dimensions: builder.dimensions,
322        eps: array!(eps),
323        pytorch_compatible,
324        weight: Param::new(weight),
325        bias: Param::new(bias),
326    })
327}
328
329/// Applies Group Normalization [1] on the inputs.
330///
331/// ### References
332///
333/// 1. [https://arxiv.org/abs/1803.08494](https://arxiv.org/abs/1803.08494)
334#[derive(Debug, Clone, ModuleParameters, Buildable)]
335#[module(root = crate)]
336#[buildable(root = crate)]
337pub struct GroupNorm {
338    /// Number of groups to separate the features into
339    pub group_count: i32,
340
341    /// Number of features in the input
342    pub dimensions: i32,
343
344    /// Value added to the denominator for numerical stability.
345    pub eps: Array,
346
347    /// If `true`, perform the group normalization in the same order/grouping as PyTorch.
348    pub pytorch_compatible: bool,
349
350    /// An optional trainable weight
351    #[param]
352    pub weight: Param<Option<Array>>,
353
354    /// An optional trainable bias
355    #[param]
356    pub bias: Param<Option<Array>>,
357}
358
359impl GroupNorm {
360    /// Default value for `eps`.
361    pub const DEFAULT_EPS: f32 = 1e-5;
362
363    /// Enable trainable `weight` and `bias` by default.
364    pub const DEFAULT_AFFINE: bool = true;
365
366    /// Default value for `pytorch_compatible`.
367    pub const DEFAULT_PYTORCH_COMPATIBLE: bool = false;
368
369    fn pytorch_group_norm(&self, x: &Array) -> Result<Array, Exception> {
370        let batch = x.dim(0);
371        let dims = x.dim(-1);
372        let rest = &x.shape()[1..x.ndim() - 1];
373        let group_size = dims / self.group_count;
374
375        // Split into groups
376        let x = x.reshape(&[batch, -1, self.group_count, group_size])?;
377        let x = x
378            .transpose(&[0, 2, 1, 3])?
379            .reshape(&[batch, self.group_count, -1])?;
380
381        // Normalize
382        let x = crate::fast::layer_norm(x, None, None, self.eps.item::<f32>())?;
383
384        let x = x.reshape(&[batch, self.group_count, -1, group_size])?;
385
386        let new_shape: Vec<_> = [batch]
387            .into_iter()
388            .chain(rest.iter().copied())
389            .chain([dims])
390            .collect();
391        x.transpose(&[0, 2, 1, 3])?.reshape(&new_shape[..])
392    }
393
394    fn group_norm(&self, x: &Array) -> Result<Array, Exception> {
395        let batch = x.dim(0);
396        let dims = x.dim(-1);
397        let rest = &x.shape()[1..x.ndim() - 1];
398
399        // Split into groups
400        let x = x.reshape(&[batch, -1, self.group_count])?;
401
402        // Normalize
403        let x = instance_norm(&x, &[1], &self.eps)?;
404
405        let new_shape: Vec<_> = [batch]
406            .into_iter()
407            .chain(rest.iter().copied())
408            .chain([dims])
409            .collect();
410        x.reshape(&new_shape[..])
411    }
412}
413
414impl Module<&Array> for GroupNorm {
415    type Error = Exception;
416    type Output = Array;
417
418    fn forward(&mut self, x: &Array) -> Result<Array, Self::Error> {
419        let x = if self.pytorch_compatible {
420            self.pytorch_group_norm(x)?
421        } else {
422            self.group_norm(x)?
423        };
424
425        if let (Some(weight), Some(bias)) = (self.weight.as_ref(), self.bias.as_ref()) {
426            weight.multiply(&x)?.add(bias)
427        } else {
428            Ok(x)
429        }
430    }
431
432    fn training_mode(&mut self, _mode: bool) {}
433}
434
435/// Builder for [`BatchNorm`].
436#[derive(Debug, Clone, Builder)]
437#[builder(
438    root = crate,
439    build_with = build_batch_norm,
440    err = Exception,
441)]
442pub struct BatchNormBuilder {
443    /// Number of features in the input
444    pub feature_count: i32,
445
446    /// Value added to the denominator for numerical stability. Default to
447    /// [`BatchNorm::DEFAULT_EPS`].
448    #[builder(optional, default = BatchNorm::DEFAULT_EPS)]
449    pub eps: f32,
450
451    /// Momentum for updating the running mean and variance. Default to
452    /// [`BatchNorm::DEFAULT_MOMENTUM`].
453    #[builder(optional, default = BatchNorm::DEFAULT_MOMENTUM)]
454    pub momentum: f32,
455
456    /// If `true`, addes a trainable `weight` and `bias`. Default to
457    /// [`BatchNorm::DEFAULT_AFFINE`].
458    #[builder(optional, default = BatchNorm::DEFAULT_AFFINE)]
459    pub affine: bool,
460
461    /// If `true`, track the running mean and variance. Default to
462    /// [`BatchNorm::DEFAULT_TRACK_RUNNING_STATS`].
463    #[builder(optional, default = BatchNorm::DEFAULT_TRACK_RUNNING_STATS)]
464    pub track_running_stats: bool,
465}
466
467fn build_batch_norm(builder: BatchNormBuilder) -> Result<BatchNorm, Exception> {
468    let eps = builder.eps;
469    let momentum = builder.momentum;
470    let affine = builder.affine;
471    let track_running_stats = builder.track_running_stats;
472
473    let (weight, bias) = if affine {
474        (
475            Some(ones::<f32>(&[builder.feature_count])?),
476            Some(zeros::<f32>(&[builder.feature_count])?),
477        )
478    } else {
479        (None, None)
480    };
481
482    let (running_mean, running_var) = if track_running_stats {
483        (
484            Some(zeros::<f32>(&[builder.feature_count])?),
485            Some(ones::<f32>(&[builder.feature_count])?),
486        )
487    } else {
488        (None, None)
489    };
490
491    Ok(BatchNorm {
492        feature_count: builder.feature_count,
493        eps: array!(eps),
494        momentum: array!(momentum),
495        weight: Param::new(weight),
496        bias: Param::new(bias),
497        running_mean: Param::new(running_mean),
498        running_var: Param::new(running_var),
499        training: BatchNorm::DEFAULT_TRAINING,
500    })
501}
502
503/// Applies batch normalization [1] on the inputs.
504///
505/// ### References
506///
507/// 1. [https://arxiv.org/abs/1502.03167](https://arxiv.org/abs/1502.03167)
508#[derive(Debug, Clone, ModuleParameters, Buildable)]
509#[module(root = crate)]
510#[buildable(root = crate)]
511pub struct BatchNorm {
512    /// Number of features in the input
513    pub feature_count: i32,
514
515    /// Value added to the denominator for numerical stability.
516    pub eps: Array,
517
518    /// Momentum for updating the running mean and variance.
519    pub momentum: Array,
520
521    /// An optional trainable weight
522    #[param]
523    pub weight: Param<Option<Array>>,
524
525    /// An optional trainable bias
526    #[param]
527    pub bias: Param<Option<Array>>,
528
529    /// Tracked running mean
530    #[param]
531    pub running_mean: Param<Option<Array>>,
532
533    /// Tracked running variance
534    #[param]
535    pub running_var: Param<Option<Array>>,
536
537    /// If `true`, the module is in training mode.
538    pub training: bool,
539}
540
541impl BatchNorm {
542    /// Default value for `eps`.
543    pub const DEFAULT_EPS: f32 = 1e-5;
544
545    /// Default value for `momentum`.
546    pub const DEFAULT_MOMENTUM: f32 = 0.1;
547
548    /// Enable trainable `weight` and `bias` by default.
549    pub const DEFAULT_AFFINE: bool = true;
550
551    /// Enable tracking of running mean and variance by default.
552    pub const DEFAULT_TRACK_RUNNING_STATS: bool = true;
553
554    /// Enable training mode by default.
555    pub const DEFAULT_TRAINING: bool = true;
556
557    fn stats(x: &Array) -> Result<(Array, Array), Exception> {
558        let reduction_axes = (0..x.ndim() as i32 - 1).collect::<Vec<_>>();
559
560        let mean = x.mean(&reduction_axes, None)?;
561        let variance = x.variance(&reduction_axes, None, None)?;
562
563        Ok((mean, variance))
564    }
565}
566
567impl Module<&Array> for BatchNorm {
568    type Error = Exception;
569    type Output = Array;
570
571    fn forward(&mut self, x: &Array) -> Result<Array, Self::Error> {
572        let ndim = x.ndim();
573        if !(2..=4).contains(&ndim) {
574            return Err(Exception::custom(
575                "Input tensor must be at least 2 dimensions and at most 4 dimensions",
576            ));
577        }
578
579        let (mean, variance) = Self::stats(x)?;
580        let mut mean = Cow::Owned(mean);
581        let mut variance = Cow::Owned(variance);
582
583        if let (Some(running_mean), Some(running_var)) =
584            (self.running_mean.as_mut(), self.running_var.as_mut())
585        {
586            if self.training {
587                let mu = &self.momentum;
588                // SAFETY: momentum is a single element array
589                let one_minus_mu = array!(1.0) - mu;
590
591                *running_mean = one_minus_mu
592                    .multiply(&running_mean)?
593                    .add(mu.multiply(&mean)?)?;
594                *running_var = one_minus_mu
595                    .multiply(&running_var)?
596                    .add(mu.multiply(&variance)?)?;
597            } else {
598                mean = Cow::Borrowed(&*running_mean);
599                variance = Cow::Borrowed(&*running_var);
600            }
601        }
602
603        let x = x
604            .subtract(&mean)?
605            .multiply(rsqrt(&variance.add(&self.eps)?)?)?;
606
607        if let (Some(weight), Some(bias)) = (self.weight.as_ref(), self.bias.as_ref()) {
608            weight.multiply(&x)?.add(bias)
609        } else {
610            Ok(x)
611        }
612    }
613
614    fn training_mode(&mut self, mode: bool) {
615        self.training = mode;
616    }
617}
618
619#[cfg(test)]
620mod tests {
621    use crate::{
622        ops::indexing::{Ellipsis, IndexOp},
623        Dtype,
624    };
625    use float_eq::assert_float_eq;
626
627    use super::*;
628
629    #[test]
630    fn test_instance_norm() {
631        crate::random::seed(435).unwrap();
632        let a = crate::random::uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
633        assert_eq!(a.shape(), &[2, 8, 16]);
634        assert_eq!(a.dtype(), Dtype::Float32);
635        assert_float_eq!(
636            a.mean(None, None).unwrap().item::<f32>(),
637            0.500_064_6,
638            abs <= 0.010_001_292
639        );
640        assert_float_eq!(
641            a.sum(None, None).unwrap().item::<f32>(),
642            128.016_54,
643            abs <= 2.560_330_9
644        );
645
646        let result = InstanceNorm::new(8)
647            .unwrap()
648            .forward(&a)
649            .unwrap()
650            .index((0, 0));
651        assert_eq!(result.shape(), &[16]);
652        assert_eq!(result.dtype(), Dtype::Float32);
653        assert_float_eq!(
654            result.mean(None, None).unwrap().item::<f32>(),
655            0.106_454_11,
656            abs <= 0.002_129_082_3
657        );
658        assert_float_eq!(
659            result.sum(None, None).unwrap().item::<f32>(),
660            1.703_265_8,
661            abs <= 0.034_065_317
662        );
663    }
664
665    #[test]
666    fn test_layer_norm() {
667        crate::random::seed(635).unwrap();
668        let a = crate::random::uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
669        assert_eq!(a.shape(), &[2, 8, 16]);
670        assert_eq!(a.dtype(), Dtype::Float32);
671        assert_float_eq!(
672            a.mean(None, None).unwrap().item::<f32>(),
673            0.492_690_32,
674            abs <= 0.009_853_806
675        );
676        assert_float_eq!(
677            a.sum(None, None).unwrap().item::<f32>(),
678            126.128_72,
679            abs <= 2.522_574_4
680        );
681
682        let result = LayerNorm::new(16)
683            .unwrap()
684            .forward(&a)
685            .unwrap()
686            .index((Ellipsis, 0));
687        assert_eq!(result.shape(), &[2, 8]);
688        assert_eq!(result.dtype(), Dtype::Float32);
689        assert_float_eq!(
690            result.mean(None, None).unwrap().item::<f32>(),
691            0.290_990_38,
692            abs <= 0.005_819_807_8
693        );
694        assert_float_eq!(
695            result.sum(None, None).unwrap().item::<f32>(),
696            4.655_846,
697            abs <= 0.093_116_924
698        );
699    }
700
701    #[test]
702    fn test_rms_norm() {
703        crate::random::seed(103).unwrap();
704        let a = crate::random::uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
705        assert_eq!(a.shape(), &[2, 8, 16]);
706        assert_eq!(a.dtype(), Dtype::Float32);
707        assert_float_eq!(
708            a.mean(None, None).unwrap().item::<f32>(),
709            0.505_476_36,
710            abs <= 0.010_109_527
711        );
712        assert_float_eq!(
713            a.sum(None, None).unwrap().item::<f32>(),
714            129.401_95,
715            abs <= 2.588_039
716        );
717
718        let result = RmsNorm::new(16).unwrap().forward(&a).unwrap();
719        assert_eq!(result.shape(), &[2, 8, 16]);
720        assert_eq!(result.dtype(), Dtype::Float32);
721        assert_float_eq!(
722            result.mean(None, None).unwrap().item::<f32>(),
723            0.872_938_75,
724            abs <= 0.017_458_774
725        );
726        assert_float_eq!(
727            result.sum(None, None).unwrap().item::<f32>(),
728            223.472_32,
729            abs <= 4.469_446
730        );
731    }
732
733    #[test]
734    fn test_group_norm() {
735        crate::random::seed(855).unwrap();
736        let a = crate::random::uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
737        assert_eq!(a.shape(), &[2, 8, 16]);
738        assert_eq!(a.dtype(), Dtype::Float32);
739        assert_float_eq!(
740            a.mean(None, None).unwrap().item::<f32>(),
741            0.486_665_87,
742            abs <= 0.009_733_317
743        );
744        assert_float_eq!(
745            a.sum(None, None).unwrap().item::<f32>(),
746            124.586_464,
747            abs <= 2.491_729_3
748        );
749
750        let result = GroupNorm::new(4, 16)
751            .unwrap()
752            .forward(&a)
753            .unwrap()
754            .index((0, 0));
755        assert_eq!(result.shape(), &[16]);
756        assert_eq!(result.dtype(), Dtype::Float32);
757        assert_float_eq!(
758            result.mean(None, None).unwrap().item::<f32>(),
759            -0.054_606_52,
760            abs <= 0.001_092_130_4
761        );
762        assert_float_eq!(
763            result.sum(None, None).unwrap().item::<f32>(),
764            -0.873_704_3,
765            abs <= 0.017_474_087
766        );
767    }
768
769    #[test]
770    fn test_batch_norm() {
771        crate::random::seed(266).unwrap();
772        let a = crate::random::uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
773        assert_eq!(a.shape(), &[2, 8, 16]);
774        assert_eq!(a.dtype(), Dtype::Float32);
775        assert_float_eq!(
776            a.mean(None, None).unwrap().item::<f32>(),
777            0.505_814_7,
778            abs <= 0.010_116_293
779        );
780        assert_float_eq!(
781            a.sum(None, None).unwrap().item::<f32>(),
782            129.488_56,
783            abs <= 2.589_771
784        );
785
786        let result = BatchNorm::new(16)
787            .unwrap()
788            .forward(&a)
789            .unwrap()
790            .index((0, 0));
791        assert_eq!(result.shape(), &[16]);
792        assert_eq!(result.dtype(), Dtype::Float32);
793        assert_float_eq!(
794            result.mean(None, None).unwrap().item::<f32>(),
795            0.439_785_24,
796            abs <= 0.008_795_705
797        );
798        assert_float_eq!(
799            result.sum(None, None).unwrap().item::<f32>(),
800            7.036_564,
801            abs <= 0.140_731_28
802        );
803    }
804}