1use std::borrow::Cow;
2
3use crate::{
4 array,
5 error::Exception,
6 module::{Module, Param},
7 ops::{ones, rsqrt, zeros},
8 Array,
9};
10use mlx_internal_macros::{Buildable, Builder};
11use mlx_macros::ModuleParameters;
12
13fn instance_norm(x: &Array, axes: &[i32], eps: &Array) -> Result<Array, Exception> {
14 let mean = x.mean(axes, true)?;
16 let variance = x.variance(axes, true, None)?;
17
18 let x = x.subtract(&mean)?.multiply(rsqrt(&variance.add(eps)?)?)?;
20
21 Ok(x)
22}
23
24#[derive(Debug, Clone, Builder)]
26#[builder(
27 root = crate,
28 build_with = build_instance_norm,
29 err = Exception,
30)]
31pub struct InstanceNormBuilder {
32 pub dimensions: i32,
34
35 #[builder(optional, default = InstanceNorm::DEFAULT_EPS)]
38 pub eps: f32,
39
40 #[builder(optional, default = InstanceNorm::DEFAULT_AFFINE)]
43 pub affine: bool,
44}
45
46fn build_instance_norm(builder: InstanceNormBuilder) -> Result<InstanceNorm, Exception> {
47 let eps = builder.eps;
48 let affine = builder.affine;
49
50 let (weight, bias) = if affine {
51 (
52 Some(ones::<f32>(&[builder.dimensions])?),
53 Some(zeros::<f32>(&[builder.dimensions])?),
54 )
55 } else {
56 (None, None)
57 };
58
59 Ok(InstanceNorm {
60 dimensions: builder.dimensions,
61 eps: array!(eps),
62 weight: Param::new(weight),
63 bias: Param::new(bias),
64 })
65}
66
67#[derive(Debug, Clone, ModuleParameters, Buildable)]
73#[module(root = crate)]
74#[buildable(root = crate)]
75pub struct InstanceNorm {
76 pub dimensions: i32,
78
79 pub eps: Array,
81
82 pub weight: Param<Option<Array>>,
84
85 pub bias: Param<Option<Array>>,
87}
88
89impl InstanceNorm {
90 pub const DEFAULT_EPS: f32 = 1e-5;
92
93 pub const DEFAULT_AFFINE: bool = false;
95}
96
97impl Module<&Array> for InstanceNorm {
98 type Error = Exception;
99 type Output = Array;
100
101 fn forward(&mut self, x: &Array) -> Result<Array, Self::Error> {
102 let reduction_axes = (1..x.ndim() as i32 - 1).collect::<Vec<_>>();
103
104 let x = instance_norm(x, &reduction_axes, &self.eps)?;
105
106 if let (Some(weight), Some(bias)) = (self.weight.as_ref(), self.bias.as_ref()) {
107 weight.multiply(x)?.add(bias)
108 } else {
109 Ok(x)
110 }
111 }
112
113 fn training_mode(&mut self, _mode: bool) {}
114}
115
116#[derive(Debug, Clone, Builder)]
118#[builder(
119 root = crate,
120 build_with = build_layer_norm,
121 err = Exception,
122)]
123pub struct LayerNormBuilder {
124 pub dimensions: i32,
126
127 #[builder(optional, default = LayerNorm::DEFAULT_EPS)]
130 pub eps: f32,
131
132 #[builder(optional, default = LayerNorm::DEFAULT_AFFINE)]
135 pub affine: bool,
136}
137
138fn build_layer_norm(builder: LayerNormBuilder) -> Result<LayerNorm, Exception> {
139 let eps = builder.eps;
140 let affine = builder.affine;
141
142 let (weight, bias) = if affine {
143 (
144 Some(ones::<f32>(&[builder.dimensions])?),
145 Some(zeros::<f32>(&[builder.dimensions])?),
146 )
147 } else {
148 (None, None)
149 };
150
151 Ok(LayerNorm {
152 dimensions: builder.dimensions,
153 eps,
154 weight: Param::new(weight),
155 bias: Param::new(bias),
156 })
157}
158
159#[derive(Debug, Clone, ModuleParameters, Buildable)]
165#[module(root = crate)]
166#[buildable(root = crate)]
167pub struct LayerNorm {
168 pub dimensions: i32,
170
171 pub eps: f32,
173
174 #[param]
176 pub weight: Param<Option<Array>>,
177
178 #[param]
180 pub bias: Param<Option<Array>>,
181}
182
183impl LayerNorm {
184 pub const DEFAULT_EPS: f32 = 1e-5;
186
187 pub const DEFAULT_AFFINE: bool = true;
189}
190
191impl Module<&Array> for LayerNorm {
192 type Error = Exception;
193 type Output = Array;
194
195 fn forward(&mut self, x: &Array) -> Result<Array, Self::Error> {
196 let weight = self.weight.as_ref();
197 let bias = self.bias.as_ref();
198 let eps = self.eps;
199 crate::fast::layer_norm(x, weight, bias, eps)
200 }
201
202 fn training_mode(&mut self, _mode: bool) {}
203}
204
205#[derive(Debug, Clone, Builder)]
207#[builder(
208 root = crate,
209 build_with = build_rms_norm,
210 err = Exception,
211)]
212pub struct RmsNormBuilder {
213 pub dimensions: i32,
215
216 #[builder(optional, default = RmsNorm::DEFAULT_EPS)]
219 pub eps: f32,
220}
221
222fn build_rms_norm(builder: RmsNormBuilder) -> Result<RmsNorm, Exception> {
223 let weight = ones::<f32>(&[builder.dimensions])?;
224 let eps = builder.eps;
225 Ok(RmsNorm {
226 weight: Param::new(weight),
227 eps,
228 })
229}
230
231#[derive(Debug, Clone, ModuleParameters, Buildable)]
246#[module(root = crate)]
247#[buildable(root = crate)]
248pub struct RmsNorm {
249 #[param]
251 pub weight: Param<Array>,
252
253 pub eps: f32,
255}
256
257impl RmsNorm {
258 pub const DEFAULT_EPS: f32 = 1e-5;
260}
261
262impl Module<&Array> for RmsNorm {
263 type Error = Exception;
264 type Output = Array;
265
266 fn forward(&mut self, x: &Array) -> Result<Array, Self::Error> {
267 let weight = self.weight.as_ref();
268 let eps = self.eps;
269 crate::fast::rms_norm(x, weight, eps)
270 }
271
272 fn training_mode(&mut self, _mode: bool) {}
273}
274
275#[derive(Debug, Clone, Builder)]
277#[builder(
278 root = crate,
279 build_with = build_group_norm,
280 err = Exception,
281)]
282pub struct GroupNormBuilder {
283 pub group_count: i32,
285
286 pub dimensions: i32,
288
289 #[builder(optional, default = GroupNorm::DEFAULT_EPS)]
292 pub eps: f32,
293
294 #[builder(optional, default = GroupNorm::DEFAULT_AFFINE)]
297 pub affine: bool,
298
299 #[builder(optional, default = GroupNorm::DEFAULT_PYTORCH_COMPATIBLE)]
302 pub pytorch_compatible: bool,
303}
304
305fn build_group_norm(builder: GroupNormBuilder) -> Result<GroupNorm, Exception> {
306 let eps = builder.eps;
307 let affine = builder.affine;
308 let pytorch_compatible = builder.pytorch_compatible;
309
310 let (weight, bias) = if affine {
311 (
312 Some(ones::<f32>(&[builder.dimensions])?),
313 Some(zeros::<f32>(&[builder.dimensions])?),
314 )
315 } else {
316 (None, None)
317 };
318
319 Ok(GroupNorm {
320 group_count: builder.group_count,
321 dimensions: builder.dimensions,
322 eps: array!(eps),
323 pytorch_compatible,
324 weight: Param::new(weight),
325 bias: Param::new(bias),
326 })
327}
328
329#[derive(Debug, Clone, ModuleParameters, Buildable)]
335#[module(root = crate)]
336#[buildable(root = crate)]
337pub struct GroupNorm {
338 pub group_count: i32,
340
341 pub dimensions: i32,
343
344 pub eps: Array,
346
347 pub pytorch_compatible: bool,
349
350 #[param]
352 pub weight: Param<Option<Array>>,
353
354 #[param]
356 pub bias: Param<Option<Array>>,
357}
358
359impl GroupNorm {
360 pub const DEFAULT_EPS: f32 = 1e-5;
362
363 pub const DEFAULT_AFFINE: bool = true;
365
366 pub const DEFAULT_PYTORCH_COMPATIBLE: bool = false;
368
369 fn pytorch_group_norm(&self, x: &Array) -> Result<Array, Exception> {
370 let batch = x.dim(0);
371 let dims = x.dim(-1);
372 let rest = &x.shape()[1..x.ndim() - 1];
373 let group_size = dims / self.group_count;
374
375 let x = x.reshape(&[batch, -1, self.group_count, group_size])?;
377 let x = x
378 .transpose(&[0, 2, 1, 3])?
379 .reshape(&[batch, self.group_count, -1])?;
380
381 let x = crate::fast::layer_norm(x, None, None, self.eps.item::<f32>())?;
383
384 let x = x.reshape(&[batch, self.group_count, -1, group_size])?;
385
386 let new_shape: Vec<_> = [batch]
387 .into_iter()
388 .chain(rest.iter().copied())
389 .chain([dims])
390 .collect();
391 x.transpose(&[0, 2, 1, 3])?.reshape(&new_shape[..])
392 }
393
394 fn group_norm(&self, x: &Array) -> Result<Array, Exception> {
395 let batch = x.dim(0);
396 let dims = x.dim(-1);
397 let rest = &x.shape()[1..x.ndim() - 1];
398
399 let x = x.reshape(&[batch, -1, self.group_count])?;
401
402 let x = instance_norm(&x, &[1], &self.eps)?;
404
405 let new_shape: Vec<_> = [batch]
406 .into_iter()
407 .chain(rest.iter().copied())
408 .chain([dims])
409 .collect();
410 x.reshape(&new_shape[..])
411 }
412}
413
414impl Module<&Array> for GroupNorm {
415 type Error = Exception;
416 type Output = Array;
417
418 fn forward(&mut self, x: &Array) -> Result<Array, Self::Error> {
419 let x = if self.pytorch_compatible {
420 self.pytorch_group_norm(x)?
421 } else {
422 self.group_norm(x)?
423 };
424
425 if let (Some(weight), Some(bias)) = (self.weight.as_ref(), self.bias.as_ref()) {
426 weight.multiply(&x)?.add(bias)
427 } else {
428 Ok(x)
429 }
430 }
431
432 fn training_mode(&mut self, _mode: bool) {}
433}
434
435#[derive(Debug, Clone, Builder)]
437#[builder(
438 root = crate,
439 build_with = build_batch_norm,
440 err = Exception,
441)]
442pub struct BatchNormBuilder {
443 pub feature_count: i32,
445
446 #[builder(optional, default = BatchNorm::DEFAULT_EPS)]
449 pub eps: f32,
450
451 #[builder(optional, default = BatchNorm::DEFAULT_MOMENTUM)]
454 pub momentum: f32,
455
456 #[builder(optional, default = BatchNorm::DEFAULT_AFFINE)]
459 pub affine: bool,
460
461 #[builder(optional, default = BatchNorm::DEFAULT_TRACK_RUNNING_STATS)]
464 pub track_running_stats: bool,
465}
466
467fn build_batch_norm(builder: BatchNormBuilder) -> Result<BatchNorm, Exception> {
468 let eps = builder.eps;
469 let momentum = builder.momentum;
470 let affine = builder.affine;
471 let track_running_stats = builder.track_running_stats;
472
473 let (weight, bias) = if affine {
474 (
475 Some(ones::<f32>(&[builder.feature_count])?),
476 Some(zeros::<f32>(&[builder.feature_count])?),
477 )
478 } else {
479 (None, None)
480 };
481
482 let (running_mean, running_var) = if track_running_stats {
483 (
484 Some(zeros::<f32>(&[builder.feature_count])?),
485 Some(ones::<f32>(&[builder.feature_count])?),
486 )
487 } else {
488 (None, None)
489 };
490
491 Ok(BatchNorm {
492 feature_count: builder.feature_count,
493 eps: array!(eps),
494 momentum: array!(momentum),
495 weight: Param::new(weight),
496 bias: Param::new(bias),
497 running_mean: Param::new(running_mean),
498 running_var: Param::new(running_var),
499 training: BatchNorm::DEFAULT_TRAINING,
500 })
501}
502
503#[derive(Debug, Clone, ModuleParameters, Buildable)]
509#[module(root = crate)]
510#[buildable(root = crate)]
511pub struct BatchNorm {
512 pub feature_count: i32,
514
515 pub eps: Array,
517
518 pub momentum: Array,
520
521 #[param]
523 pub weight: Param<Option<Array>>,
524
525 #[param]
527 pub bias: Param<Option<Array>>,
528
529 #[param]
531 pub running_mean: Param<Option<Array>>,
532
533 #[param]
535 pub running_var: Param<Option<Array>>,
536
537 pub training: bool,
539}
540
541impl BatchNorm {
542 pub const DEFAULT_EPS: f32 = 1e-5;
544
545 pub const DEFAULT_MOMENTUM: f32 = 0.1;
547
548 pub const DEFAULT_AFFINE: bool = true;
550
551 pub const DEFAULT_TRACK_RUNNING_STATS: bool = true;
553
554 pub const DEFAULT_TRAINING: bool = true;
556
557 fn stats(x: &Array) -> Result<(Array, Array), Exception> {
558 let reduction_axes = (0..x.ndim() as i32 - 1).collect::<Vec<_>>();
559
560 let mean = x.mean(&reduction_axes, None)?;
561 let variance = x.variance(&reduction_axes, None, None)?;
562
563 Ok((mean, variance))
564 }
565}
566
567impl Module<&Array> for BatchNorm {
568 type Error = Exception;
569 type Output = Array;
570
571 fn forward(&mut self, x: &Array) -> Result<Array, Self::Error> {
572 let ndim = x.ndim();
573 if !(2..=4).contains(&ndim) {
574 return Err(Exception::custom(
575 "Input tensor must be at least 2 dimensions and at most 4 dimensions",
576 ));
577 }
578
579 let (mean, variance) = Self::stats(x)?;
580 let mut mean = Cow::Owned(mean);
581 let mut variance = Cow::Owned(variance);
582
583 if let (Some(running_mean), Some(running_var)) =
584 (self.running_mean.as_mut(), self.running_var.as_mut())
585 {
586 if self.training {
587 let mu = &self.momentum;
588 let one_minus_mu = array!(1.0) - mu;
590
591 *running_mean = one_minus_mu
592 .multiply(&running_mean)?
593 .add(mu.multiply(&mean)?)?;
594 *running_var = one_minus_mu
595 .multiply(&running_var)?
596 .add(mu.multiply(&variance)?)?;
597 } else {
598 mean = Cow::Borrowed(&*running_mean);
599 variance = Cow::Borrowed(&*running_var);
600 }
601 }
602
603 let x = x
604 .subtract(&mean)?
605 .multiply(rsqrt(&variance.add(&self.eps)?)?)?;
606
607 if let (Some(weight), Some(bias)) = (self.weight.as_ref(), self.bias.as_ref()) {
608 weight.multiply(&x)?.add(bias)
609 } else {
610 Ok(x)
611 }
612 }
613
614 fn training_mode(&mut self, mode: bool) {
615 self.training = mode;
616 }
617}
618
619#[cfg(test)]
620mod tests {
621 use crate::{
622 ops::indexing::{Ellipsis, IndexOp},
623 Dtype,
624 };
625 use float_eq::assert_float_eq;
626
627 use super::*;
628
629 #[test]
630 fn test_instance_norm() {
631 crate::random::seed(435).unwrap();
632 let a = crate::random::uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
633 assert_eq!(a.shape(), &[2, 8, 16]);
634 assert_eq!(a.dtype(), Dtype::Float32);
635 assert_float_eq!(
636 a.mean(None, None).unwrap().item::<f32>(),
637 0.500_064_6,
638 abs <= 0.010_001_292
639 );
640 assert_float_eq!(
641 a.sum(None, None).unwrap().item::<f32>(),
642 128.016_54,
643 abs <= 2.560_330_9
644 );
645
646 let result = InstanceNorm::new(8)
647 .unwrap()
648 .forward(&a)
649 .unwrap()
650 .index((0, 0));
651 assert_eq!(result.shape(), &[16]);
652 assert_eq!(result.dtype(), Dtype::Float32);
653 assert_float_eq!(
654 result.mean(None, None).unwrap().item::<f32>(),
655 0.106_454_11,
656 abs <= 0.002_129_082_3
657 );
658 assert_float_eq!(
659 result.sum(None, None).unwrap().item::<f32>(),
660 1.703_265_8,
661 abs <= 0.034_065_317
662 );
663 }
664
665 #[test]
666 fn test_layer_norm() {
667 crate::random::seed(635).unwrap();
668 let a = crate::random::uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
669 assert_eq!(a.shape(), &[2, 8, 16]);
670 assert_eq!(a.dtype(), Dtype::Float32);
671 assert_float_eq!(
672 a.mean(None, None).unwrap().item::<f32>(),
673 0.492_690_32,
674 abs <= 0.009_853_806
675 );
676 assert_float_eq!(
677 a.sum(None, None).unwrap().item::<f32>(),
678 126.128_72,
679 abs <= 2.522_574_4
680 );
681
682 let result = LayerNorm::new(16)
683 .unwrap()
684 .forward(&a)
685 .unwrap()
686 .index((Ellipsis, 0));
687 assert_eq!(result.shape(), &[2, 8]);
688 assert_eq!(result.dtype(), Dtype::Float32);
689 assert_float_eq!(
690 result.mean(None, None).unwrap().item::<f32>(),
691 0.290_990_38,
692 abs <= 0.005_819_807_8
693 );
694 assert_float_eq!(
695 result.sum(None, None).unwrap().item::<f32>(),
696 4.655_846,
697 abs <= 0.093_116_924
698 );
699 }
700
701 #[test]
702 fn test_rms_norm() {
703 crate::random::seed(103).unwrap();
704 let a = crate::random::uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
705 assert_eq!(a.shape(), &[2, 8, 16]);
706 assert_eq!(a.dtype(), Dtype::Float32);
707 assert_float_eq!(
708 a.mean(None, None).unwrap().item::<f32>(),
709 0.505_476_36,
710 abs <= 0.010_109_527
711 );
712 assert_float_eq!(
713 a.sum(None, None).unwrap().item::<f32>(),
714 129.401_95,
715 abs <= 2.588_039
716 );
717
718 let result = RmsNorm::new(16).unwrap().forward(&a).unwrap();
719 assert_eq!(result.shape(), &[2, 8, 16]);
720 assert_eq!(result.dtype(), Dtype::Float32);
721 assert_float_eq!(
722 result.mean(None, None).unwrap().item::<f32>(),
723 0.872_938_75,
724 abs <= 0.017_458_774
725 );
726 assert_float_eq!(
727 result.sum(None, None).unwrap().item::<f32>(),
728 223.472_32,
729 abs <= 4.469_446
730 );
731 }
732
733 #[test]
734 fn test_group_norm() {
735 crate::random::seed(855).unwrap();
736 let a = crate::random::uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
737 assert_eq!(a.shape(), &[2, 8, 16]);
738 assert_eq!(a.dtype(), Dtype::Float32);
739 assert_float_eq!(
740 a.mean(None, None).unwrap().item::<f32>(),
741 0.486_665_87,
742 abs <= 0.009_733_317
743 );
744 assert_float_eq!(
745 a.sum(None, None).unwrap().item::<f32>(),
746 124.586_464,
747 abs <= 2.491_729_3
748 );
749
750 let result = GroupNorm::new(4, 16)
751 .unwrap()
752 .forward(&a)
753 .unwrap()
754 .index((0, 0));
755 assert_eq!(result.shape(), &[16]);
756 assert_eq!(result.dtype(), Dtype::Float32);
757 assert_float_eq!(
758 result.mean(None, None).unwrap().item::<f32>(),
759 -0.054_606_52,
760 abs <= 0.001_092_130_4
761 );
762 assert_float_eq!(
763 result.sum(None, None).unwrap().item::<f32>(),
764 -0.873_704_3,
765 abs <= 0.017_474_087
766 );
767 }
768
769 #[test]
770 fn test_batch_norm() {
771 crate::random::seed(266).unwrap();
772 let a = crate::random::uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
773 assert_eq!(a.shape(), &[2, 8, 16]);
774 assert_eq!(a.dtype(), Dtype::Float32);
775 assert_float_eq!(
776 a.mean(None, None).unwrap().item::<f32>(),
777 0.505_814_7,
778 abs <= 0.010_116_293
779 );
780 assert_float_eq!(
781 a.sum(None, None).unwrap().item::<f32>(),
782 129.488_56,
783 abs <= 2.589_771
784 );
785
786 let result = BatchNorm::new(16)
787 .unwrap()
788 .forward(&a)
789 .unwrap()
790 .index((0, 0));
791 assert_eq!(result.shape(), &[16]);
792 assert_eq!(result.dtype(), Dtype::Float32);
793 assert_float_eq!(
794 result.mean(None, None).unwrap().item::<f32>(),
795 0.439_785_24,
796 abs <= 0.008_795_705
797 );
798 assert_float_eq!(
799 result.sum(None, None).unwrap().item::<f32>(),
800 7.036_564,
801 abs <= 0.140_731_28
802 );
803 }
804}