1use std::borrow::Cow;
2
3use crate::{
4 array,
5 error::Exception,
6 module::{Module, Param},
7 ops::{ones, rsqrt, zeros},
8 Array,
9};
10use mlx_internal_macros::{Buildable, Builder};
11use mlx_macros::ModuleParameters;
12
13fn instance_norm(x: &Array, axes: &[i32], eps: &Array) -> Result<Array, Exception> {
14 let mean = x.mean_axes(axes, true)?;
16 let variance = x.var_axes(axes, true, None)?;
17
18 let x = x.subtract(&mean)?.multiply(rsqrt(&variance.add(eps)?)?)?;
20
21 Ok(x)
22}
23
24#[derive(Debug, Clone, Builder)]
26#[builder(
27 root = crate,
28 build_with = build_instance_norm,
29 err = Exception,
30)]
31pub struct InstanceNormBuilder {
32 pub dimensions: i32,
34
35 #[builder(optional, default = InstanceNorm::DEFAULT_EPS)]
38 pub eps: f32,
39
40 #[builder(optional, default = InstanceNorm::DEFAULT_AFFINE)]
43 pub affine: bool,
44}
45
46fn build_instance_norm(builder: InstanceNormBuilder) -> Result<InstanceNorm, Exception> {
47 let eps = builder.eps;
48 let affine = builder.affine;
49
50 let (weight, bias) = if affine {
51 (
52 Some(ones::<f32>(&[builder.dimensions])?),
53 Some(zeros::<f32>(&[builder.dimensions])?),
54 )
55 } else {
56 (None, None)
57 };
58
59 Ok(InstanceNorm {
60 dimensions: builder.dimensions,
61 eps: array!(eps),
62 weight: Param::new(weight),
63 bias: Param::new(bias),
64 })
65}
66
67#[derive(Debug, Clone, ModuleParameters, Buildable)]
73#[module(root = crate)]
74#[buildable(root = crate)]
75pub struct InstanceNorm {
76 pub dimensions: i32,
78
79 pub eps: Array,
81
82 #[param]
84 pub weight: Param<Option<Array>>,
85
86 #[param]
88 pub bias: Param<Option<Array>>,
89}
90
91impl InstanceNorm {
92 pub const DEFAULT_EPS: f32 = 1e-5;
94
95 pub const DEFAULT_AFFINE: bool = false;
97}
98
99impl Module<&Array> for InstanceNorm {
100 type Error = Exception;
101 type Output = Array;
102
103 fn forward(&mut self, x: &Array) -> Result<Array, Self::Error> {
104 let reduction_axes = (1..x.ndim() as i32 - 1).collect::<Vec<_>>();
105
106 let x = instance_norm(x, &reduction_axes, &self.eps)?;
107
108 if let (Some(weight), Some(bias)) = (self.weight.as_ref(), self.bias.as_ref()) {
109 weight.multiply(x)?.add(bias)
110 } else {
111 Ok(x)
112 }
113 }
114
115 fn training_mode(&mut self, _mode: bool) {}
116}
117
118#[derive(Debug, Clone, Builder)]
120#[builder(
121 root = crate,
122 build_with = build_layer_norm,
123 err = Exception,
124)]
125pub struct LayerNormBuilder {
126 pub dimensions: i32,
128
129 #[builder(optional, default = LayerNorm::DEFAULT_EPS)]
132 pub eps: f32,
133
134 #[builder(optional, default = LayerNorm::DEFAULT_AFFINE)]
137 pub affine: bool,
138}
139
140fn build_layer_norm(builder: LayerNormBuilder) -> Result<LayerNorm, Exception> {
141 let eps = builder.eps;
142 let affine = builder.affine;
143
144 let (weight, bias) = if affine {
145 (
146 Some(ones::<f32>(&[builder.dimensions])?),
147 Some(zeros::<f32>(&[builder.dimensions])?),
148 )
149 } else {
150 (None, None)
151 };
152
153 Ok(LayerNorm {
154 dimensions: builder.dimensions,
155 eps,
156 weight: Param::new(weight),
157 bias: Param::new(bias),
158 })
159}
160
161#[derive(Debug, Clone, ModuleParameters, Buildable)]
167#[module(root = crate)]
168#[buildable(root = crate)]
169pub struct LayerNorm {
170 pub dimensions: i32,
172
173 pub eps: f32,
175
176 #[param]
178 pub weight: Param<Option<Array>>,
179
180 #[param]
182 pub bias: Param<Option<Array>>,
183}
184
185impl LayerNorm {
186 pub const DEFAULT_EPS: f32 = 1e-5;
188
189 pub const DEFAULT_AFFINE: bool = true;
191}
192
193impl Module<&Array> for LayerNorm {
194 type Error = Exception;
195 type Output = Array;
196
197 fn forward(&mut self, x: &Array) -> Result<Array, Self::Error> {
198 let weight = self.weight.as_ref();
199 let bias = self.bias.as_ref();
200 let eps = self.eps;
201 crate::fast::layer_norm(x, weight, bias, eps)
202 }
203
204 fn training_mode(&mut self, _mode: bool) {}
205}
206
207#[derive(Debug, Clone, Builder)]
209#[builder(
210 root = crate,
211 build_with = build_rms_norm,
212 err = Exception,
213)]
214pub struct RmsNormBuilder {
215 pub dimensions: i32,
217
218 #[builder(optional, default = RmsNorm::DEFAULT_EPS)]
221 pub eps: f32,
222}
223
224fn build_rms_norm(builder: RmsNormBuilder) -> Result<RmsNorm, Exception> {
225 let weight = ones::<f32>(&[builder.dimensions])?;
226 let eps = builder.eps;
227 Ok(RmsNorm {
228 weight: Param::new(weight),
229 eps,
230 })
231}
232
233#[derive(Debug, Clone, ModuleParameters, Buildable)]
248#[module(root = crate)]
249#[buildable(root = crate)]
250pub struct RmsNorm {
251 #[param]
253 pub weight: Param<Array>,
254
255 pub eps: f32,
257}
258
259impl RmsNorm {
260 pub const DEFAULT_EPS: f32 = 1e-5;
262}
263
264impl Module<&Array> for RmsNorm {
265 type Error = Exception;
266 type Output = Array;
267
268 fn forward(&mut self, x: &Array) -> Result<Array, Self::Error> {
269 let weight = self.weight.as_ref();
270 let eps = self.eps;
271 crate::fast::rms_norm(x, weight, eps)
272 }
273
274 fn training_mode(&mut self, _mode: bool) {}
275}
276
277#[derive(Debug, Clone, Builder)]
279#[builder(
280 root = crate,
281 build_with = build_group_norm,
282 err = Exception,
283)]
284pub struct GroupNormBuilder {
285 pub group_count: i32,
287
288 pub dimensions: i32,
290
291 #[builder(optional, default = GroupNorm::DEFAULT_EPS)]
294 pub eps: f32,
295
296 #[builder(optional, default = GroupNorm::DEFAULT_AFFINE)]
299 pub affine: bool,
300
301 #[builder(optional, default = GroupNorm::DEFAULT_PYTORCH_COMPATIBLE)]
304 pub pytorch_compatible: bool,
305}
306
307fn build_group_norm(builder: GroupNormBuilder) -> Result<GroupNorm, Exception> {
308 let eps = builder.eps;
309 let affine = builder.affine;
310 let pytorch_compatible = builder.pytorch_compatible;
311
312 let (weight, bias) = if affine {
313 (
314 Some(ones::<f32>(&[builder.dimensions])?),
315 Some(zeros::<f32>(&[builder.dimensions])?),
316 )
317 } else {
318 (None, None)
319 };
320
321 Ok(GroupNorm {
322 group_count: builder.group_count,
323 dimensions: builder.dimensions,
324 eps: array!(eps),
325 pytorch_compatible,
326 weight: Param::new(weight),
327 bias: Param::new(bias),
328 })
329}
330
331#[derive(Debug, Clone, ModuleParameters, Buildable)]
337#[module(root = crate)]
338#[buildable(root = crate)]
339pub struct GroupNorm {
340 pub group_count: i32,
342
343 pub dimensions: i32,
345
346 pub eps: Array,
348
349 pub pytorch_compatible: bool,
351
352 #[param]
354 pub weight: Param<Option<Array>>,
355
356 #[param]
358 pub bias: Param<Option<Array>>,
359}
360
361impl GroupNorm {
362 pub const DEFAULT_EPS: f32 = 1e-5;
364
365 pub const DEFAULT_AFFINE: bool = true;
367
368 pub const DEFAULT_PYTORCH_COMPATIBLE: bool = false;
370
371 fn pytorch_group_norm(&self, x: &Array) -> Result<Array, Exception> {
372 let batch = x.dim(0);
373 let dims = x.dim(-1);
374 let rest = &x.shape()[1..x.ndim() - 1];
375 let group_size = dims / self.group_count;
376
377 let x = x.reshape(&[batch, -1, self.group_count, group_size])?;
379 let x = x
380 .transpose_axes(&[0, 2, 1, 3])?
381 .reshape(&[batch, self.group_count, -1])?;
382
383 let x = crate::fast::layer_norm(x, None, None, self.eps.item::<f32>())?;
385
386 let x = x.reshape(&[batch, self.group_count, -1, group_size])?;
387
388 let new_shape: Vec<_> = [batch]
389 .into_iter()
390 .chain(rest.iter().copied())
391 .chain([dims])
392 .collect();
393 x.transpose_axes(&[0, 2, 1, 3])?.reshape(&new_shape[..])
394 }
395
396 fn group_norm(&self, x: &Array) -> Result<Array, Exception> {
397 let batch = x.dim(0);
398 let dims = x.dim(-1);
399 let rest = &x.shape()[1..x.ndim() - 1];
400
401 let x = x.reshape(&[batch, -1, self.group_count])?;
403
404 let x = instance_norm(&x, &[1], &self.eps)?;
406
407 let new_shape: Vec<_> = [batch]
408 .into_iter()
409 .chain(rest.iter().copied())
410 .chain([dims])
411 .collect();
412 x.reshape(&new_shape[..])
413 }
414}
415
416impl Module<&Array> for GroupNorm {
417 type Error = Exception;
418 type Output = Array;
419
420 fn forward(&mut self, x: &Array) -> Result<Array, Self::Error> {
421 let x = if self.pytorch_compatible {
422 self.pytorch_group_norm(x)?
423 } else {
424 self.group_norm(x)?
425 };
426
427 if let (Some(weight), Some(bias)) = (self.weight.as_ref(), self.bias.as_ref()) {
428 weight.multiply(&x)?.add(bias)
429 } else {
430 Ok(x)
431 }
432 }
433
434 fn training_mode(&mut self, _mode: bool) {}
435}
436
437#[derive(Debug, Clone, Builder)]
439#[builder(
440 root = crate,
441 build_with = build_batch_norm,
442 err = Exception,
443)]
444pub struct BatchNormBuilder {
445 pub feature_count: i32,
447
448 #[builder(optional, default = BatchNorm::DEFAULT_EPS)]
451 pub eps: f32,
452
453 #[builder(optional, default = BatchNorm::DEFAULT_MOMENTUM)]
456 pub momentum: f32,
457
458 #[builder(optional, default = BatchNorm::DEFAULT_AFFINE)]
461 pub affine: bool,
462
463 #[builder(optional, default = BatchNorm::DEFAULT_TRACK_RUNNING_STATS)]
466 pub track_running_stats: bool,
467}
468
469fn build_batch_norm(builder: BatchNormBuilder) -> Result<BatchNorm, Exception> {
470 let eps = builder.eps;
471 let momentum = builder.momentum;
472 let affine = builder.affine;
473 let track_running_stats = builder.track_running_stats;
474
475 let (weight, bias) = if affine {
476 (
477 Some(ones::<f32>(&[builder.feature_count])?),
478 Some(zeros::<f32>(&[builder.feature_count])?),
479 )
480 } else {
481 (None, None)
482 };
483
484 let (running_mean, running_var) = if track_running_stats {
485 (
486 Some(zeros::<f32>(&[builder.feature_count])?),
487 Some(ones::<f32>(&[builder.feature_count])?),
488 )
489 } else {
490 (None, None)
491 };
492
493 Ok(BatchNorm {
494 feature_count: builder.feature_count,
495 eps: array!(eps),
496 momentum: array!(momentum),
497 weight: Param::new(weight),
498 bias: Param::new(bias),
499 running_mean: Param::new(running_mean),
500 running_var: Param::new(running_var),
501 training: BatchNorm::DEFAULT_TRAINING,
502 })
503}
504
505#[derive(Debug, Clone, ModuleParameters, Buildable)]
511#[module(root = crate)]
512#[buildable(root = crate)]
513pub struct BatchNorm {
514 pub feature_count: i32,
516
517 pub eps: Array,
519
520 pub momentum: Array,
522
523 #[param]
525 pub weight: Param<Option<Array>>,
526
527 #[param]
529 pub bias: Param<Option<Array>>,
530
531 #[param]
533 pub running_mean: Param<Option<Array>>,
534
535 #[param]
537 pub running_var: Param<Option<Array>>,
538
539 pub training: bool,
541}
542
543impl BatchNorm {
544 pub const DEFAULT_EPS: f32 = 1e-5;
546
547 pub const DEFAULT_MOMENTUM: f32 = 0.1;
549
550 pub const DEFAULT_AFFINE: bool = true;
552
553 pub const DEFAULT_TRACK_RUNNING_STATS: bool = true;
555
556 pub const DEFAULT_TRAINING: bool = true;
558
559 fn stats(x: &Array) -> Result<(Array, Array), Exception> {
560 let reduction_axes = (0..x.ndim() as i32 - 1).collect::<Vec<_>>();
561
562 let mean = x.mean_axes(&reduction_axes, None)?;
563 let variance = x.var_axes(&reduction_axes, None, None)?;
564
565 Ok((mean, variance))
566 }
567}
568
569impl Module<&Array> for BatchNorm {
570 type Error = Exception;
571 type Output = Array;
572
573 fn forward(&mut self, x: &Array) -> Result<Array, Self::Error> {
574 let ndim = x.ndim();
575 if !(2..=4).contains(&ndim) {
576 return Err(Exception::custom(
577 "Input tensor must be at least 2 dimensions and at most 4 dimensions",
578 ));
579 }
580
581 let (mean, variance) = Self::stats(x)?;
582 let mut mean = Cow::Owned(mean);
583 let mut variance = Cow::Owned(variance);
584
585 if let (Some(running_mean), Some(running_var)) =
586 (self.running_mean.as_mut(), self.running_var.as_mut())
587 {
588 if self.training {
589 let mu = &self.momentum;
590 let one_minus_mu = array!(1.0) - mu;
592
593 *running_mean = one_minus_mu
594 .multiply(&running_mean)?
595 .add(mu.multiply(&mean)?)?;
596 *running_var = one_minus_mu
597 .multiply(&running_var)?
598 .add(mu.multiply(&variance)?)?;
599 } else {
600 mean = Cow::Borrowed(&*running_mean);
601 variance = Cow::Borrowed(&*running_var);
602 }
603 }
604
605 let x = x
606 .subtract(&mean)?
607 .multiply(rsqrt(&variance.add(&self.eps)?)?)?;
608
609 if let (Some(weight), Some(bias)) = (self.weight.as_ref(), self.bias.as_ref()) {
610 weight.multiply(&x)?.add(bias)
611 } else {
612 Ok(x)
613 }
614 }
615
616 fn training_mode(&mut self, mode: bool) {
617 self.training = mode;
618 }
619}
620
621#[cfg(test)]
622mod tests {
623 use crate::{
624 ops::indexing::{Ellipsis, IndexOp},
625 Dtype,
626 };
627 use float_eq::assert_float_eq;
628
629 use super::*;
630
631 #[test]
632 fn test_instance_norm() {
633 crate::random::seed(435).unwrap();
634 let a = crate::random::uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
635 assert_eq!(a.shape(), &[2, 8, 16]);
636 assert_eq!(a.dtype(), Dtype::Float32);
637 assert_float_eq!(
638 a.mean(None).unwrap().item::<f32>(),
639 0.500_064_6,
640 abs <= 0.010_001_292
641 );
642 assert_float_eq!(
643 a.sum(None).unwrap().item::<f32>(),
644 128.016_54,
645 abs <= 2.560_330_9
646 );
647
648 let result = InstanceNorm::new(8)
649 .unwrap()
650 .forward(&a)
651 .unwrap()
652 .index((0, 0));
653 assert_eq!(result.shape(), &[16]);
654 assert_eq!(result.dtype(), Dtype::Float32);
655 assert_float_eq!(
656 result.mean(None).unwrap().item::<f32>(),
657 0.106_454_11,
658 abs <= 0.002_129_082_3
659 );
660 assert_float_eq!(
661 result.sum(None).unwrap().item::<f32>(),
662 1.703_265_8,
663 abs <= 0.034_065_317
664 );
665 }
666
667 #[test]
668 fn test_layer_norm() {
669 crate::random::seed(635).unwrap();
670 let a = crate::random::uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
671 assert_eq!(a.shape(), &[2, 8, 16]);
672 assert_eq!(a.dtype(), Dtype::Float32);
673 assert_float_eq!(
674 a.mean(None).unwrap().item::<f32>(),
675 0.492_690_32,
676 abs <= 0.009_853_806
677 );
678 assert_float_eq!(
679 a.sum(None).unwrap().item::<f32>(),
680 126.128_72,
681 abs <= 2.522_574_4
682 );
683
684 let result = LayerNorm::new(16)
685 .unwrap()
686 .forward(&a)
687 .unwrap()
688 .index((Ellipsis, 0));
689 assert_eq!(result.shape(), &[2, 8]);
690 assert_eq!(result.dtype(), Dtype::Float32);
691 assert_float_eq!(
692 result.mean(None).unwrap().item::<f32>(),
693 0.290_990_38,
694 abs <= 0.005_819_807_8
695 );
696 assert_float_eq!(
697 result.sum(None).unwrap().item::<f32>(),
698 4.655_846,
699 abs <= 0.093_116_924
700 );
701 }
702
703 #[test]
704 fn test_rms_norm() {
705 crate::random::seed(103).unwrap();
706 let a = crate::random::uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
707 assert_eq!(a.shape(), &[2, 8, 16]);
708 assert_eq!(a.dtype(), Dtype::Float32);
709 assert_float_eq!(
710 a.mean(None).unwrap().item::<f32>(),
711 0.505_476_36,
712 abs <= 0.010_109_527
713 );
714 assert_float_eq!(
715 a.sum(None).unwrap().item::<f32>(),
716 129.401_95,
717 abs <= 2.588_039
718 );
719
720 let result = RmsNorm::new(16).unwrap().forward(&a).unwrap();
721 assert_eq!(result.shape(), &[2, 8, 16]);
722 assert_eq!(result.dtype(), Dtype::Float32);
723 assert_float_eq!(
724 result.mean(None).unwrap().item::<f32>(),
725 0.872_938_75,
726 abs <= 0.017_458_774
727 );
728 assert_float_eq!(
729 result.sum(None).unwrap().item::<f32>(),
730 223.472_32,
731 abs <= 4.469_446
732 );
733 }
734
735 #[test]
736 fn test_group_norm() {
737 crate::random::seed(855).unwrap();
738 let a = crate::random::uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
739 assert_eq!(a.shape(), &[2, 8, 16]);
740 assert_eq!(a.dtype(), Dtype::Float32);
741 assert_float_eq!(
742 a.mean(None).unwrap().item::<f32>(),
743 0.486_665_87,
744 abs <= 0.009_733_317
745 );
746 assert_float_eq!(
747 a.sum(None).unwrap().item::<f32>(),
748 124.586_464,
749 abs <= 2.491_729_3
750 );
751
752 let result = GroupNorm::new(4, 16)
753 .unwrap()
754 .forward(&a)
755 .unwrap()
756 .index((0, 0));
757 assert_eq!(result.shape(), &[16]);
758 assert_eq!(result.dtype(), Dtype::Float32);
759 assert_float_eq!(
760 result.mean(None).unwrap().item::<f32>(),
761 -0.054_606_52,
762 abs <= 0.001_092_130_4
763 );
764 assert_float_eq!(
765 result.sum(None).unwrap().item::<f32>(),
766 -0.873_704_3,
767 abs <= 0.017_474_087
768 );
769 }
770
771 #[test]
772 fn test_batch_norm() {
773 crate::random::seed(266).unwrap();
774 let a = crate::random::uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
775 assert_eq!(a.shape(), &[2, 8, 16]);
776 assert_eq!(a.dtype(), Dtype::Float32);
777 assert_float_eq!(
778 a.mean(None).unwrap().item::<f32>(),
779 0.505_814_7,
780 abs <= 0.010_116_293
781 );
782 assert_float_eq!(
783 a.sum(None).unwrap().item::<f32>(),
784 129.488_56,
785 abs <= 2.589_771
786 );
787
788 let result = BatchNorm::new(16)
789 .unwrap()
790 .forward(&a)
791 .unwrap()
792 .index((0, 0));
793 assert_eq!(result.shape(), &[16]);
794 assert_eq!(result.dtype(), Dtype::Float32);
795 assert_float_eq!(
796 result.mean(None).unwrap().item::<f32>(),
797 0.439_785_24,
798 abs <= 0.008_795_705
799 );
800 assert_float_eq!(
801 result.sum(None).unwrap().item::<f32>(),
802 7.036_564,
803 abs <= 0.140_731_28
804 );
805 }
806}