1use crate::{
4 array,
5 error::{CrossEntropyBuildError, Exception},
6 ops::{
7 abs, clip, exp, indexing::take_along_axis, log, log_add_exp, log_sum_exp, maximum, minimum,
8 multiply, power, r#where, sqrt, square, sum,
9 },
10 Array,
11};
12use mlx_internal_macros::{generate_builder, Buildable};
13
14#[inline]
15fn check_shape(
16 left: &Array,
17 right: &Array,
18 left_ident: &str,
19 right_ident: &str,
20) -> Result<(), Exception> {
21 if left.shape() != right.shape() {
22 return Err(Exception::custom(format!(
23 "The shape of the {} ({:?}) does not match the shape of the {} ({:?})",
24 left_ident,
25 left.shape(),
26 right_ident,
27 right.shape()
28 )));
29 }
30 Ok(())
31}
32
33#[derive(Debug, Clone, Copy)]
35pub enum LossReduction {
36 None,
38 Sum,
40 Mean,
42}
43
44impl LossReduction {
45 pub fn reduce(&self, loss: Array) -> Result<Array, Exception> {
47 match self {
48 LossReduction::None => Ok(loss),
49 LossReduction::Sum => Ok(loss.sum(None, None)?),
50 LossReduction::Mean => Ok(loss.mean(None, None)?),
51 }
52 }
53}
54
55pub type CrossEntropyBuilderWeights<'a> = &'a Array;
57
58generate_builder! {
59 #[derive(Debug, Clone, Buildable)]
61 #[buildable(root = crate)]
62 #[builder(
63 root = crate,
64 build_with = build_cross_entropy,
65 err = CrossEntropyBuildError
66 )]
67 pub struct CrossEntropy<'a> {
68 #[builder(optional, default = CrossEntropy::DEFAULT_WEIGHTS)]
70 pub weights: Option<&'a Array>,
71
72 #[builder(optional, default = CrossEntropy::DEFAULT_AXIS)]
74 pub axis: i32,
75
76 #[builder(optional, default = CrossEntropy::DEFAULT_LABEL_SMOOTHING)]
79 pub label_smoothing: f32,
80
81 #[builder(optional, default = CrossEntropy::DEFAULT_REDUCTION)]
83 pub reduction: LossReduction,
84 }
85}
86
87fn build_cross_entropy(
88 builder: CrossEntropyBuilder,
89) -> Result<CrossEntropy, CrossEntropyBuildError> {
90 let axis = builder.axis;
91 let label_smoothing = builder.label_smoothing;
92 let reduction = builder.reduction;
93
94 if !(0.0..1.0).contains(&label_smoothing) {
95 return Err(CrossEntropyBuildError::InvalidLabelSmoothingFactor);
96 }
97
98 Ok(CrossEntropy {
99 weights: builder.weights,
100 axis,
101 label_smoothing,
102 reduction,
103 })
104}
105
106impl<'a> CrossEntropy<'a> {
107 pub const DEFAULT_AXIS: i32 = -1;
109
110 pub const DEFAULT_LABEL_SMOOTHING: f32 = 0.0;
112
113 pub const DEFAULT_REDUCTION: LossReduction = LossReduction::None;
115
116 pub const DEFAULT_WEIGHTS: Option<&'a Array> = None;
118
119 pub fn apply(
126 &self,
127 logits: impl AsRef<Array>,
128 targets: impl AsRef<Array>,
129 ) -> Result<Array, Exception> {
130 let logits = logits.as_ref();
131 let targets = targets.as_ref();
132
133 let target_as_probs = targets.ndim() == logits.ndim();
134
135 let score = if target_as_probs {
136 sum(&logits.multiply(targets)?, &[self.axis], None)?
137 } else {
138 take_along_axis(logits, &targets.expand_dims(&[-1])?, self.axis)?.squeeze(&[-1])?
139 };
140 let log_sum_exp_logits = log_sum_exp(logits, &[self.axis], None)?;
141
142 let mut loss = if self.label_smoothing > 0.0 {
143 let adjusted_score = multiply(array!(1.0 - self.label_smoothing), score)?;
145
146 let mean_logits = logits.mean(&[self.axis], None)?;
148 let smoothed_loss = -multiply(mean_logits, array!(self.label_smoothing))?;
149
150 log_sum_exp_logits
152 .subtract(adjusted_score)?
153 .add(smoothed_loss)?
154 } else {
155 log_sum_exp_logits.subtract(score)?
156 };
157
158 if let Some(weights) = self.weights {
159 check_shape(weights, &loss, "weights", "loss")?;
160 loss = multiply(loss, weights)?;
161 }
162
163 self.reduction.reduce(loss)
164 }
165}
166
167generate_builder! {
168 #[derive(Debug, Clone, Buildable)]
175 #[buildable(root = crate)]
176 #[builder(root = crate)]
177 pub struct BinaryCrossEntropy<'a> {
178 #[builder(optional, default = BinaryCrossEntropy::DEFAULT_WEIGHTS)]
180 pub weights: Option<&'a Array>,
181
182 #[builder(optional, default = BinaryCrossEntropy::DEFAULT_INPUTS_ARE_LOGITS)]
185 pub inputs_are_logits: bool,
186
187 #[builder(optional, default = BinaryCrossEntropy::DEFAULT_REDUCTION)]
189 pub reduction: LossReduction,
190 }
191}
192
193impl<'a> BinaryCrossEntropy<'a> {
194 pub const DEFAULT_WEIGHTS: Option<&'a Array> = None;
196
197 pub const DEFAULT_INPUTS_ARE_LOGITS: bool = true;
199
200 pub const DEFAULT_REDUCTION: LossReduction = LossReduction::None;
202
203 pub fn apply(
210 &self,
211 logits: impl AsRef<Array>,
212 targets: impl AsRef<Array>,
213 ) -> Result<Array, Exception> {
214 let logits = logits.as_ref();
215 let targets = targets.as_ref();
216 let weights = self.weights;
217 let inputs_are_logits = self.inputs_are_logits;
218 let reduction = self.reduction;
219
220 let mut loss = if inputs_are_logits {
221 log_add_exp(array!(0.0), logits)?.subtract(targets.multiply(logits)?)?
222 } else {
223 let log_inputs_clip = clip(log(logits)?, (-100.0, ()))?;
224 let log_inputs_inverse_clip = clip(log(&array!(1.0).subtract(logits)?)?, (-100.0, ()))?;
225 -(targets.multiply(log_inputs_clip)?.add(
226 array!(1.0)
227 .subtract(targets)?
228 .multiply(log_inputs_inverse_clip)?,
229 )?)
230 };
231
232 if let Some(weights) = weights {
233 check_shape(weights, &loss, "weights", "loss")?;
234 loss = multiply(loss, weights)?;
235 }
236
237 reduction.reduce(loss)
238 }
239}
240
241generate_builder! {
242 #[derive(Debug, Clone, Buildable)]
244 #[buildable(root = crate)]
245 #[builder(root = crate)]
246 pub struct L1Loss {
247 #[builder(optional, default = L1Loss::DEFAULT_REDUCTION)]
249 pub reduction: LossReduction,
250 }
251}
252
253impl L1Loss {
254 pub const DEFAULT_REDUCTION: LossReduction = LossReduction::Mean;
256
257 pub fn apply(
264 &self,
265 predictions: impl AsRef<Array>,
266 targets: impl AsRef<Array>,
267 ) -> Result<Array, Exception> {
268 let predictions = predictions.as_ref();
269 let targets = targets.as_ref();
270 let reduction = self.reduction;
271
272 check_shape(predictions, targets, "predictions", "targets")?;
273 let loss = predictions.subtract(targets)?.abs()?;
274 reduction.reduce(loss)
275 }
276}
277
278generate_builder! {
279 #[derive(Debug, Clone, Buildable)]
281 #[buildable(root = crate)]
282 #[builder(root = crate)]
283 pub struct MseLoss {
284 #[builder(optional, default = MseLoss::DEFAULT_REDUCTION)]
286 pub reduction: LossReduction,
287 }
288}
289
290impl MseLoss {
291 pub const DEFAULT_REDUCTION: LossReduction = LossReduction::Mean;
293
294 pub fn apply(
301 &self,
302 predictions: impl AsRef<Array>,
303 targets: impl AsRef<Array>,
304 ) -> Result<Array, Exception> {
305 let predictions = predictions.as_ref();
306 let targets = targets.as_ref();
307 let reduction = self.reduction;
308
309 check_shape(predictions, targets, "predictions", "targets")?;
310 let loss = predictions.subtract(targets)?.square()?;
311 reduction.reduce(loss)
312 }
313}
314
315generate_builder! {
316 #[derive(Debug, Clone, Buildable)]
318 #[buildable(root = crate)]
319 #[builder(root = crate)]
320 pub struct NllLoss {
321 #[builder(optional, default = NllLoss::DEFAULT_AXIS)]
323 pub axis: i32,
324
325 #[builder(optional, default = NllLoss::DEFAULT_REDUCTION)]
327 pub reduction: LossReduction,
328 }
329}
330
331impl NllLoss {
332 pub const DEFAULT_AXIS: i32 = -1;
334
335 pub const DEFAULT_REDUCTION: LossReduction = LossReduction::None;
337
338 pub fn apply(
345 &self,
346 inputs: impl AsRef<Array>,
347 targets: impl AsRef<Array>,
348 ) -> Result<Array, Exception> {
349 let inputs = inputs.as_ref();
350 let targets = targets.as_ref();
351 let axis = self.axis;
352 let reduction = self.reduction;
353
354 let loss = -take_along_axis(inputs, &targets.expand_dims(&[-1])?, axis)?.squeeze(&[-1])?;
355 reduction.reduce(loss)
356 }
357}
358
359generate_builder! {
360 #[derive(Debug, Clone, Buildable)]
362 #[buildable(root = crate)]
363 #[builder(root = crate)]
364 pub struct GaussianNllLoss {
365 #[builder(optional, default = GaussianNllLoss::DEFAULT_FULL)]
368 pub full: bool,
369
370 #[builder(optional, default = GaussianNllLoss::DEFAULT_EPS)]
373 pub eps: f32,
374
375 #[builder(optional, default = GaussianNllLoss::DEFAULT_REDUCTION)]
377 pub reduction: LossReduction,
378 }
379}
380
381impl GaussianNllLoss {
382 pub const DEFAULT_FULL: bool = false;
384
385 pub const DEFAULT_EPS: f32 = 1e-6;
387
388 pub const DEFAULT_REDUCTION: LossReduction = LossReduction::None;
390
391 pub fn apply(
399 &self,
400 inputs: impl AsRef<Array>,
401 targets: impl AsRef<Array>,
402 vars: impl AsRef<Array>,
403 ) -> Result<Array, Exception> {
404 let inputs = inputs.as_ref();
405 let targets = targets.as_ref();
406 let vars = vars.as_ref();
407 let full = self.full;
408 let eps = self.eps;
409 let reduction = self.reduction;
410
411 check_shape(inputs, targets, "inputs", "targets")?;
412 check_shape(inputs, vars, "inputs", "vars")?;
413
414 let vars = maximum(vars, array!(eps))?;
415 let mut loss =
416 array!(0.5) * (log(&vars)?.add(square(&targets.subtract(inputs)?)?.divide(&vars)?)?);
417
418 if full {
419 let pi = array!(std::f32::consts::PI);
420 loss = loss.add(array!(0.5).multiply(log(&array!(2.0).multiply(pi)?)?)?)?;
421 }
422
423 reduction.reduce(loss)
424 }
425}
426
427generate_builder! {
428 #[derive(Debug, Clone, Buildable)]
436 #[buildable(root = crate)]
437 #[builder(root = crate)]
438 pub struct KlDivLoss {
439 #[builder(optional, default = KlDivLoss::DEFAULT_AXIS)]
441 pub axis: i32,
442
443 #[builder(optional, default = KlDivLoss::DEFAULT_REDUCTION)]
445 pub reduction: LossReduction,
446 }
447}
448
449impl KlDivLoss {
450 pub const DEFAULT_AXIS: i32 = -1;
452
453 pub const DEFAULT_REDUCTION: LossReduction = LossReduction::None;
455
456 pub fn apply(
463 &self,
464 inputs: impl AsRef<Array>,
465 targets: impl AsRef<Array>,
466 ) -> Result<Array, Exception> {
467 let inputs = inputs.as_ref();
468 let targets = targets.as_ref();
469 let axis = self.axis;
470 let reduction = self.reduction;
471
472 let loss = sum(
473 &exp(targets)?.multiply(targets.subtract(inputs)?)?,
474 &[axis],
475 None,
476 )?;
477 reduction.reduce(loss)
478 }
479}
480
481generate_builder! {
482 #[derive(Debug, Clone, Buildable)]
488 #[buildable(root = crate)]
489 #[builder(root = crate)]
490 pub struct SmoothL1Loss {
491 #[builder(optional, default = SmoothL1Loss::DEFAULT_BETA)]
494 pub beta: f32,
495
496 #[builder(optional, default = SmoothL1Loss::DEFAULT_REDUCTION)]
498 pub reduction: LossReduction,
499 }
500}
501
502impl SmoothL1Loss {
503 pub const DEFAULT_BETA: f32 = 1.0;
505
506 pub const DEFAULT_REDUCTION: LossReduction = LossReduction::Mean;
508
509 pub fn apply(
516 &self,
517 predictions: impl AsRef<Array>,
518 targets: impl AsRef<Array>,
519 ) -> Result<Array, Exception> {
520 let predictions = predictions.as_ref();
521 let targets = targets.as_ref();
522 let beta = self.beta;
523 let reduction = self.reduction;
524
525 check_shape(predictions, targets, "predictions", "targets")?;
526 let diff = predictions.subtract(targets)?.abs()?;
527 let beta = array!(beta);
528 let loss = r#where(
529 &diff.lt(&beta)?,
530 array!(0.5).multiply(square(&diff)?)?.divide(&beta)?,
531 diff.subtract(array!(0.5).multiply(beta)?)?,
532 )?;
533 reduction.reduce(loss)
534 }
535}
536
537generate_builder! {
538 #[derive(Debug, Clone, Buildable)]
541 #[buildable(root = crate)]
542 #[builder(root = crate)]
543 pub struct TripletLoss {
544 #[builder(optional, default = TripletLoss::DEFAULT_AXIS)]
546 pub axis: i32,
547
548 #[builder(optional, default = TripletLoss::DEFAULT_P)]
550 pub p: f32,
551
552 #[builder(optional, default = TripletLoss::DEFAULT_MARGIN)]
554 pub margin: f32,
555
556 #[builder(optional, default = TripletLoss::DEFAULT_EPS)]
558 pub eps: f32,
559
560 #[builder(optional, default = TripletLoss::DEFAULT_REDUCTION)]
562 pub reduction: LossReduction,
563 }
564}
565
566impl TripletLoss {
567 pub const DEFAULT_AXIS: i32 = -1;
569
570 pub const DEFAULT_P: f32 = 2.0;
572
573 pub const DEFAULT_MARGIN: f32 = 1.0;
575
576 pub const DEFAULT_EPS: f32 = 1e-6;
578
579 pub const DEFAULT_REDUCTION: LossReduction = LossReduction::None;
581
582 pub fn apply(
591 &self,
592 anchors: impl AsRef<Array>,
593 positives: impl AsRef<Array>,
594 negatives: impl AsRef<Array>,
595 ) -> Result<Array, Exception> {
596 let anchors = anchors.as_ref();
597 let positives = positives.as_ref();
598 let negatives = negatives.as_ref();
599 let axis = self.axis;
600 let p = self.p;
601 let margin = self.margin;
602 let eps = self.eps;
603 let reduction = self.reduction;
604
605 let eps = array!(eps);
606 let p = array!(p);
607 let margin = array!(margin);
608
609 let pos = sqrt(
610 &power(&anchors.subtract(positives)?, &p)?
611 .sum(&[axis], None)?
612 .add(&eps)?,
613 )?;
614 let neg = sqrt(
615 &power(&anchors.subtract(negatives)?, &p)?
616 .sum(&[axis], None)?
617 .add(&eps)?,
618 )?;
619 let loss = maximum(pos.subtract(neg)?.add(margin)?, array!(0.0))?;
620 reduction.reduce(loss)
621 }
622}
623
624generate_builder! {
625 #[derive(Debug, Clone, Buildable)]
627 #[buildable(root = crate)]
628 #[builder(root = crate)]
629 pub struct HingeLoss {
630 #[builder(optional, default = HingeLoss::DEFAULT_REDUCTION)]
632 pub reduction: LossReduction,
633 }
634}
635
636impl HingeLoss {
637 pub const DEFAULT_REDUCTION: LossReduction = LossReduction::None;
639
640 pub fn apply(
647 &self,
648 inputs: impl AsRef<Array>,
649 targets: impl AsRef<Array>,
650 ) -> Result<Array, Exception> {
651 let inputs = inputs.as_ref();
652 let targets = targets.as_ref();
653 let reduction = self.reduction;
654
655 let a = array!(1.0).subtract(inputs.multiply(targets)?)?;
656 let b = array!(0.0);
657 let loss = maximum(a, b)?;
658 reduction.reduce(loss)
659 }
660}
661
662generate_builder! {
663 #[derive(Debug, Clone, Buildable)]
665 #[buildable(root = crate)]
666 #[builder(root = crate)]
667 pub struct HuberLoss {
668 #[builder(optional, default = HuberLoss::DEFAULT_DELTA)]
671 pub delta: f32,
672
673 #[builder(optional, default = HuberLoss::DEFAULT_REDUCTION)]
675 pub reduction: LossReduction,
676 }
677}
678
679impl HuberLoss {
680 pub const DEFAULT_DELTA: f32 = 1.0;
682
683 pub const DEFAULT_REDUCTION: LossReduction = LossReduction::None;
685
686 pub fn apply(
693 &self,
694 inputs: impl AsRef<Array>,
695 targets: impl AsRef<Array>,
696 ) -> Result<Array, Exception> {
697 let inputs = inputs.as_ref();
698 let targets = targets.as_ref();
699 let delta = self.delta;
700 let reduction = self.reduction;
701
702 let errors = inputs.subtract(targets)?;
703 let abs_errors = errors.abs()?;
704 let quadratic = minimum(&abs_errors, array!(delta))?;
705 let linear = abs_errors.subtract(&quadratic)?;
706 let loss = array!(0.5)
707 .multiply(square(&quadratic)?)?
708 .add(array!(delta).multiply(linear)?)?;
709 reduction.reduce(loss)
710 }
711}
712
713generate_builder! {
714 #[derive(Debug, Clone, Buildable)]
720 #[buildable(root = crate)]
721 #[builder(root = crate)]
722 pub struct LogCoshLoss {
723 #[builder(optional, default = LogCoshLoss::DEFAULT_REDUCTION)]
725 pub reduction: LossReduction,
726 }
727}
728
729impl LogCoshLoss {
730 pub const DEFAULT_REDUCTION: LossReduction = LossReduction::None;
732
733 pub fn apply(
740 &self,
741 inputs: impl AsRef<Array>,
742 targets: impl AsRef<Array>,
743 ) -> Result<Array, Exception> {
744 let inputs = inputs.as_ref();
745 let targets = targets.as_ref();
746 let reduction = self.reduction;
747
748 let errors = inputs.subtract(targets)?;
749 let neg_errors = errors.negative()?;
750 let loss = log_add_exp(errors, neg_errors)?.subtract(log(&array!(2.0))?)?;
751 reduction.reduce(loss)
752 }
753}
754
755generate_builder! {
756 #[derive(Debug, Clone, Buildable)]
758 #[buildable(root = crate)]
759 #[builder(root = crate)]
760 pub struct CosineSimilarityLoss {
761 #[builder(optional, default = CosineSimilarityLoss::DEFAULT_AXIS)]
763 pub axis: i32,
764
765 #[builder(optional, default = CosineSimilarityLoss::DEFAULT_EPS)]
768 pub eps: f32,
769
770 #[builder(optional, default = CosineSimilarityLoss::DEFAULT_REDUCTION)]
772 pub reduction: LossReduction,
773 }
774}
775
776impl CosineSimilarityLoss {
777 pub const DEFAULT_AXIS: i32 = -1;
779
780 pub const DEFAULT_EPS: f32 = 1e-8;
782
783 pub const DEFAULT_REDUCTION: LossReduction = LossReduction::None;
785
786 pub fn apply(&self, x1: impl AsRef<Array>, x2: impl AsRef<Array>) -> Result<Array, Exception> {
793 let x1 = x1.as_ref();
794 let x2 = x2.as_ref();
795 let axis = self.axis;
796 let eps = self.eps;
797 let reduction = self.reduction;
798
799 fn l2_loss(a: &Array, axis: i32) -> Result<Array, Exception> {
800 if a.dtype().is_complex() {
801 Ok(sqrt(&sum(&abs(a)?.square()?, &[axis], None)?)?)
802 } else {
803 Ok(sqrt(&sum(&a.square()?, &[axis], None)?)?)
804 }
805 }
806
807 let x1_norm = l2_loss(x1, axis)?;
808 let x2_norm = l2_loss(x2, axis)?;
809
810 let num = sum(&x1.multiply(x2)?, &[axis], None)?;
811 let den = maximum(x1_norm.multiply(x2_norm)?, array!(eps))?;
812 let loss = num.divide(&den)?;
813
814 reduction.reduce(loss)
815 }
816}
817
818generate_builder! {
819 #[derive(Debug, Clone, Buildable)]
821 #[buildable(root = crate)]
822 #[builder(root = crate)]
823 pub struct MarginRankingLoss {
824 #[builder(optional, default = MarginRankingLoss::DEFAULT_MARGIN)]
827 pub margin: f32,
828
829 #[builder(optional, default = MarginRankingLoss::DEFAULT_REDUCTION)]
831 pub reduction: LossReduction,
832 }
833}
834
835impl MarginRankingLoss {
836 pub const DEFAULT_MARGIN: f32 = 0.0;
838
839 pub const DEFAULT_REDUCTION: LossReduction = LossReduction::None;
841
842 pub fn apply(
851 &self,
852 inputs1: impl AsRef<Array>,
853 inputs2: impl AsRef<Array>,
854 targets: impl AsRef<Array>,
855 ) -> Result<Array, Exception> {
856 let inputs1 = inputs1.as_ref();
857 let inputs2 = inputs2.as_ref();
858 let targets = targets.as_ref();
859 let margin = self.margin;
860 let reduction = self.reduction;
861
862 check_shape(inputs1, inputs2, "inputs1", "inputs2")?;
863 check_shape(inputs1, targets, "inputs1", "targets")?;
864
865 let margin = array!(margin);
866 let diff = inputs1.subtract(inputs2)?;
867 let loss = maximum(
868 array!(0.0),
869 targets.multiply(diff)?.negative()?.add(margin)?,
870 )?;
871 reduction.reduce(loss)
872 }
873}
874
875#[cfg(test)]
876#[allow(clippy::approx_constant)]
877mod tests {
878 use crate::{array, assert_array_eq, builder::Builder, ops::is_nan};
879 use float_eq::assert_float_eq;
880
881 use super::*;
882
883 #[test]
886 fn test_cross_entropy() {
887 let logits = array!([[0.0, f32::NEG_INFINITY], [f32::NEG_INFINITY, 0.0]]);
889 let indices = array!([0, 1]);
890 let expected = array!([0.0, 0.0]);
891 let loss = CrossEntropy::new()
892 .unwrap()
893 .apply(&logits, indices)
894 .unwrap();
895 assert_array_eq!(loss, expected);
896
897 let probs = array!([[1.0, 0.0], [0.0, 1.0]]);
898 let cross_entropy = CrossEntropyBuilder::new()
899 .reduction(LossReduction::None)
900 .build()
901 .unwrap();
902 let loss = cross_entropy.apply(logits, probs).unwrap();
903 assert!(is_nan(&loss)
904 .unwrap()
905 .all(None, None)
906 .unwrap()
907 .item::<bool>());
908
909 let logits = array!([[2.0, -1.0], [-1.0, 2.0]]);
911 let indices = array!([0, 1]);
912 let weights = array!([1.0, 2.0]);
913 let expected = array!([0.04858735, 0.0971747]);
914 let cross_entropy = CrossEntropyBuilder::new()
915 .weights(&weights)
916 .reduction(LossReduction::None)
917 .build()
918 .unwrap();
919 let loss = cross_entropy.apply(&logits, indices).unwrap();
920 assert_array_eq!(loss, expected);
921
922 let probs = array!([[1.0, 0.0], [0.0, 1.0]]);
923 let cross_entropy = CrossEntropyBuilder::new()
924 .weights(&weights)
925 .reduction(LossReduction::None)
926 .build()
927 .unwrap();
928 let loss = cross_entropy.apply(logits, probs).unwrap();
929 assert_array_eq!(loss, expected);
930
931 let logits = array!([[2.0, -1.0], [-1.0, 2.0]]);
933 let indices = array!([0, 1]);
934 let expected = array!([0.498587, 0.498587]);
935 let cross_entropy = CrossEntropyBuilder::new()
936 .label_smoothing(0.3)
937 .reduction(LossReduction::None)
938 .build()
939 .unwrap();
940 let loss = cross_entropy.apply(&logits, indices).unwrap();
941 assert_array_eq!(loss, expected);
942
943 let probs = array!([[1.0, 0.0], [0.0, 1.0]]);
944 let cross_entropy = CrossEntropyBuilder::new()
945 .label_smoothing(0.3)
946 .reduction(LossReduction::None)
947 .build()
948 .unwrap();
949 let loss = cross_entropy.apply(logits, probs).unwrap();
950 assert_array_eq!(loss, expected);
951
952 let logits = array!([[2.0, -1.0], [-1.0, 2.0]]);
954 let indices = array!([0, 1]);
955 let weights = array!([1.0, 2.0]);
956 let expected = array!([0.49858734, 0.9971747]);
957 let cross_entropy = CrossEntropyBuilder::new()
958 .weights(&weights)
959 .label_smoothing(0.3)
960 .reduction(LossReduction::None)
961 .build()
962 .unwrap();
963 let loss = cross_entropy.apply(&logits, indices).unwrap();
964 assert_array_eq!(loss, expected);
965
966 let probs = array!([[1.0, 0.0], [0.0, 1.0]]);
967 let cross_entropy = CrossEntropyBuilder::new()
968 .weights(&weights)
969 .label_smoothing(0.3)
970 .reduction(LossReduction::None)
971 .build()
972 .unwrap();
973 let loss = cross_entropy.apply(logits, probs).unwrap();
974 assert_array_eq!(loss, expected);
975 }
976
977 #[test]
978 fn test_binary_cross_entropy_with_logits_as_inputs() {
979 let logits = array!([0.105361, 0.223144, 1.20397, 0.916291]);
980 let targets = array!([0.0, 0.0, 1.0, 1.0]);
981
982 let binary_cross_entropy = BinaryCrossEntropyBuilder::new()
984 .reduction(LossReduction::None)
985 .build()
986 .unwrap();
987 let loss_none = binary_cross_entropy.apply(&logits, &targets).unwrap();
988 let expected_none = array!([0.747215, 0.810930, 0.262365, 0.336472]);
989 assert_array_eq!(loss_none, expected_none);
990
991 let binary_cross_entropy = BinaryCrossEntropyBuilder::new()
993 .reduction(LossReduction::Mean)
994 .build()
995 .unwrap();
996 let loss_mean = binary_cross_entropy.apply(&logits, &targets).unwrap();
997 let expected_mean = expected_none.mean(None, None).unwrap();
998 assert_array_eq!(loss_mean, expected_mean);
999
1000 let binary_cross_entropy = BinaryCrossEntropyBuilder::new()
1002 .reduction(LossReduction::Sum)
1003 .build()
1004 .unwrap();
1005 let loss = binary_cross_entropy.apply(&logits, &targets).unwrap();
1006 let expected = expected_none.sum(None, None).unwrap();
1007 assert_array_eq!(loss, expected);
1008
1009 let weights = array!([1.0, 2.0, 1.0, 2.0]);
1011 let expected = array!([0.747215, 1.62186, 0.262365, 0.672944]);
1012 let binary_cross_entropy = BinaryCrossEntropyBuilder::new()
1013 .weights(&weights)
1014 .reduction(LossReduction::None)
1015 .build()
1016 .unwrap();
1017 let loss = binary_cross_entropy.apply(&logits, &targets).unwrap();
1018 assert_array_eq!(loss, expected);
1019 }
1020
1021 #[test]
1022 fn test_binary_cross_entropy_with_probs_as_inputs() {
1023 let probs = array!([0.5, 0.6, 0.7, 0.8]);
1024 let targets = array!([0.0, 0.0, 1.0, 1.0]);
1025
1026 let binary_cross_entropy = BinaryCrossEntropyBuilder::new()
1028 .inputs_are_logits(false)
1029 .reduction(LossReduction::None)
1030 .build()
1031 .unwrap();
1032 let loss_none = binary_cross_entropy.apply(&probs, &targets).unwrap();
1033 let expected_none = array!([0.693147, 0.916291, 0.356675, 0.223144]);
1034 assert_array_eq!(loss_none, expected_none);
1035
1036 let binary_cross_entropy = BinaryCrossEntropyBuilder::new()
1038 .inputs_are_logits(false)
1039 .reduction(LossReduction::Mean)
1040 .build()
1041 .unwrap();
1042 let loss_mean = binary_cross_entropy.apply(&probs, &targets).unwrap();
1043 let expected_mean = expected_none.mean(None, None).unwrap();
1044 assert_array_eq!(loss_mean, expected_mean);
1045
1046 let binary_cross_entropy = BinaryCrossEntropyBuilder::new()
1048 .inputs_are_logits(false)
1049 .reduction(LossReduction::Sum)
1050 .build()
1051 .unwrap();
1052 let loss = binary_cross_entropy.apply(&probs, &targets).unwrap();
1053 let expected = expected_none.sum(None, None).unwrap();
1054 assert_array_eq!(loss, expected);
1055 }
1056
1057 #[test]
1058 fn test_binary_cross_entropy_with_tiny_probs_as_inputs() {
1059 let tiny_prob = 1e-59;
1060 let probs = array!([0.0, tiny_prob, 1.0 - tiny_prob, 1.0]);
1061 let targets = array!([0.0, 0.0, 1.0, 1.0]);
1062
1063 let binary_cross_entropy = BinaryCrossEntropyBuilder::new()
1065 .inputs_are_logits(false)
1066 .reduction(LossReduction::None)
1067 .build()
1068 .unwrap();
1069 let loss_none = binary_cross_entropy.apply(&probs, &targets).unwrap();
1070 let expected_none = array!([0.0, tiny_prob, tiny_prob, 0.0]);
1071 assert_array_eq!(loss_none, expected_none);
1072
1073 let binary_cross_entropy = BinaryCrossEntropyBuilder::new()
1075 .inputs_are_logits(false)
1076 .reduction(LossReduction::Mean)
1077 .build()
1078 .unwrap();
1079 let loss_mean = binary_cross_entropy.apply(&probs, &targets).unwrap();
1080 let expected_mean = expected_none.mean(None, None).unwrap();
1081 assert_array_eq!(loss_mean, expected_mean);
1082
1083 let binary_cross_entropy = BinaryCrossEntropyBuilder::new()
1085 .inputs_are_logits(false)
1086 .reduction(LossReduction::Sum)
1087 .build()
1088 .unwrap();
1089 let loss = binary_cross_entropy.apply(&probs, &targets).unwrap();
1090 let expected = expected_none.sum(None, None).unwrap();
1091 assert_array_eq!(loss, expected);
1092 }
1093
1094 #[test]
1095 fn test_l1_loss() {
1096 let predictions = array!([0.5, 0.2, 0.9, 0.0]);
1097 let targets = array!([0.5, 0.2, 0.9, 0.0]);
1098
1099 let expected_none = array!([0.0, 0.0, 0.0, 0.0]);
1100 let expected_sum = expected_none.sum(None, None).unwrap();
1101 let expected_mean = expected_none.mean(None, None).unwrap();
1102
1103 let l1_loss = L1LossBuilder::new()
1104 .reduction(LossReduction::None)
1105 .build()
1106 .unwrap();
1107 let loss_none = l1_loss.apply(&predictions, &targets).unwrap();
1108 assert_array_eq!(loss_none, expected_none);
1109
1110 let l1_loss = L1LossBuilder::new()
1111 .reduction(LossReduction::Sum)
1112 .build()
1113 .unwrap();
1114 let loss_sum = l1_loss.apply(&predictions, &targets).unwrap();
1115 assert_array_eq!(loss_sum, expected_sum);
1116
1117 let l1_loss = L1LossBuilder::new()
1118 .reduction(LossReduction::Mean)
1119 .build()
1120 .unwrap();
1121 let loss_mean = l1_loss.apply(&predictions, &targets).unwrap();
1122 assert_array_eq!(loss_mean, expected_mean);
1123 }
1124
1125 #[test]
1126 fn test_mse_loss() {
1127 let predictions = array!([0.5, 0.2, 0.9, 0.0]);
1128 let targets = array!([0.7, 0.1, 0.8, 0.2]);
1129
1130 let expected_none = array!([0.04, 0.01, 0.01, 0.04]);
1131 let expected_mean = expected_none.mean(None, None).unwrap();
1132 let expected_sum = expected_none.sum(None, None).unwrap();
1133
1134 let mse_loss = MseLossBuilder::new()
1135 .reduction(LossReduction::None)
1136 .build()
1137 .unwrap();
1138 let loss_none = mse_loss.apply(&predictions, &targets).unwrap();
1139 assert_array_eq!(loss_none, expected_none);
1140
1141 let mse_loss = MseLossBuilder::new()
1142 .reduction(LossReduction::Mean)
1143 .build()
1144 .unwrap();
1145 let loss_mean = mse_loss.apply(&predictions, &targets).unwrap();
1146 assert_array_eq!(loss_mean, expected_mean);
1147
1148 let mse_loss = MseLossBuilder::new()
1149 .reduction(LossReduction::Sum)
1150 .build()
1151 .unwrap();
1152 let loss_sum = mse_loss.apply(&predictions, &targets).unwrap();
1153 assert_array_eq!(loss_sum, expected_sum);
1154 }
1155
1156 #[test]
1157 fn test_smooth_l1_loss() {
1158 let predictions = array!([1.5, 2.5, 0.5, 3.5]);
1159 let targets = array!([1.0, 2.0, 0.5, 2.5]);
1160 let beta = 1.0;
1161
1162 let expected_none = array!([0.125, 0.125, 0.0, 0.5]);
1163 let expected_sum = expected_none.sum(None, None).unwrap();
1164 let expected_mean = expected_none.mean(None, None).unwrap();
1165
1166 let smooth_l1_loss = SmoothL1LossBuilder::new()
1167 .beta(beta)
1168 .reduction(LossReduction::None)
1169 .build()
1170 .unwrap();
1171 let loss_none = smooth_l1_loss.apply(&predictions, &targets).unwrap();
1172 assert_array_eq!(loss_none, expected_none);
1173
1174 let smooth_l1_loss = SmoothL1LossBuilder::new()
1175 .beta(beta)
1176 .reduction(LossReduction::Sum)
1177 .build()
1178 .unwrap();
1179 let loss_sum = smooth_l1_loss.apply(&predictions, &targets).unwrap();
1180 assert_array_eq!(loss_sum, expected_sum);
1181
1182 let smooth_l1_loss = SmoothL1LossBuilder::new()
1183 .beta(beta)
1184 .reduction(LossReduction::Mean)
1185 .build()
1186 .unwrap();
1187 let loss_mean = smooth_l1_loss.apply(&predictions, &targets).unwrap();
1188 assert_array_eq!(loss_mean, expected_mean);
1189 }
1190
1191 #[test]
1192 fn test_smooth_l1_loss_negative_diff() {
1193 let a = array!([1.5, 6.0, 0.5, 2.5]);
1194 let b = array!([1.0, 2.0, 0.5, 3.5]);
1195
1196 let loss = SmoothL1Loss::new();
1197
1198 let ab = loss.apply(&a, &b).unwrap();
1199 let ba = loss.apply(&b, &a).unwrap();
1200 assert_array_eq!(ab, ba);
1201 }
1202
1203 #[test]
1204 fn test_nll_loss() {
1205 let logits = array!([[0.0, f32::NEG_INFINITY], [f32::NEG_INFINITY, 0.0]]);
1206 let targets = array!([0, 1]);
1207
1208 let expected_none = array!([0.0, 0.0]);
1209 let expected_sum = expected_none.sum(None, None).unwrap();
1210 let expected_mean = expected_none.mean(None, None).unwrap();
1211
1212 let nll_loss = NllLossBuilder::new()
1213 .reduction(LossReduction::None)
1214 .build()
1215 .unwrap();
1216 let loss_none = nll_loss.apply(&logits, &targets).unwrap();
1217 assert_array_eq!(loss_none, expected_none);
1218
1219 let nll_loss = NllLossBuilder::new()
1220 .reduction(LossReduction::Mean)
1221 .build()
1222 .unwrap();
1223 let loss_mean = nll_loss.apply(&logits, &targets).unwrap();
1224 assert_array_eq!(loss_mean, expected_mean);
1225
1226 let nll_loss = NllLossBuilder::new()
1227 .reduction(LossReduction::Sum)
1228 .build()
1229 .unwrap();
1230 let loss_sum = nll_loss.apply(&logits, &targets).unwrap();
1231 assert_array_eq!(loss_sum, expected_sum);
1232 }
1233
1234 #[test]
1235 fn test_gaussian_nll_loss() {
1236 let inputs = array!([[0.1, 0.2], [0.3, 0.4]]);
1237 let targets = array!([[0.2, 0.1], [0.1, 0.2]]);
1238 let vars = array!([[0.1, 0.2], [0.3, 0.4]]);
1239
1240 let gaussian_nll_loss = GaussianNllLossBuilder::new()
1242 .full(false)
1243 .reduction(LossReduction::None)
1244 .build()
1245 .unwrap();
1246 let loss_none = gaussian_nll_loss.apply(&inputs, &targets, &vars).unwrap();
1247 let expected_none = array!([[-1.101293, -0.779719], [-0.535320, -0.408145]]);
1248 assert_array_eq!(loss_none, expected_none);
1249
1250 let gaussian_nll_loss = GaussianNllLossBuilder::new()
1252 .full(false)
1253 .reduction(LossReduction::Mean)
1254 .build()
1255 .unwrap();
1256 let loss_mean = gaussian_nll_loss.apply(&inputs, &targets, &vars).unwrap();
1257 let expected_mean = expected_none.mean(None, None).unwrap();
1258 assert_array_eq!(loss_mean, expected_mean);
1259
1260 let gaussian_nll_loss = GaussianNllLossBuilder::new()
1262 .full(false)
1263 .reduction(LossReduction::Sum)
1264 .build()
1265 .unwrap();
1266 let loss_sum = gaussian_nll_loss.apply(&inputs, &targets, &vars).unwrap();
1267 let expected_sum = expected_none.sum(None, None).unwrap();
1268 assert_array_eq!(loss_sum, expected_sum);
1269
1270 let gaussian_nll_loss = GaussianNllLossBuilder::new()
1272 .full(true)
1273 .reduction(LossReduction::None)
1274 .build()
1275 .unwrap();
1276 let loss_none_full = gaussian_nll_loss.apply(&inputs, &targets, &vars).unwrap();
1277 let expected_none_full = array!([[-0.182354, 0.139220], [0.383619, 0.510793]]);
1278 assert_array_eq!(loss_none_full, expected_none_full);
1279
1280 let gaussian_nll_loss = GaussianNllLossBuilder::new()
1282 .full(true)
1283 .reduction(LossReduction::Mean)
1284 .build()
1285 .unwrap();
1286 let loss_mean_full = gaussian_nll_loss.apply(&inputs, &targets, &vars).unwrap();
1287 let expected_mean_full = expected_none_full.mean(None, None).unwrap();
1288 assert_array_eq!(loss_mean_full, expected_mean_full);
1289
1290 let gaussian_nll_loss = GaussianNllLossBuilder::new()
1292 .full(true)
1293 .reduction(LossReduction::Sum)
1294 .build()
1295 .unwrap();
1296 let loss_sum_full = gaussian_nll_loss.apply(&inputs, &targets, &vars).unwrap();
1297 let expected_sum_full = expected_none_full.sum(None, None).unwrap();
1298 assert_array_eq!(loss_sum_full, expected_sum_full);
1299 }
1300
1301 #[test]
1302 fn test_kl_div_loss() {
1303 let p_logits = array!([[0.5, 0.5], [0.8, 0.2]]).log().unwrap();
1304 let q_logits = array!([[0.5, 0.5], [0.2, 0.8]]).log().unwrap();
1305
1306 let kl_div_loss = KlDivLossBuilder::new()
1308 .reduction(LossReduction::None)
1309 .build()
1310 .unwrap();
1311 let loss_none = kl_div_loss.apply(&p_logits, &q_logits).unwrap();
1312 let expected_none = array!([0.0, 0.831777]);
1313 assert_array_eq!(loss_none, expected_none);
1314
1315 let kl_div_loss = KlDivLossBuilder::new()
1317 .reduction(LossReduction::Mean)
1318 .build()
1319 .unwrap();
1320 let loss_mean = kl_div_loss.apply(&p_logits, &q_logits).unwrap();
1321 let expected_mean = expected_none.mean(None, None).unwrap();
1322 assert_array_eq!(loss_mean, expected_mean);
1323
1324 let kl_div_loss = KlDivLossBuilder::new()
1326 .reduction(LossReduction::Sum)
1327 .build()
1328 .unwrap();
1329 let loss_sum = kl_div_loss.apply(&p_logits, &q_logits).unwrap();
1330 let expected_sum = expected_none.sum(None, None).unwrap();
1331 assert_array_eq!(loss_sum, expected_sum);
1332 }
1333
1334 #[test]
1335 fn test_triplet_loss() {
1336 let anchors = array!([[1, 2, 3], [1, 2, 3]]);
1337 let positives = array!([[4, 5, 6], [0, -1, 2]]);
1338 let negatives = array!([[7, 8, 9], [3, 2, 3]]);
1339
1340 let triplet_loss = TripletLossBuilder::new()
1342 .reduction(LossReduction::None)
1343 .build()
1344 .unwrap();
1345 let loss_none = triplet_loss
1346 .apply(&anchors, &positives, &negatives)
1347 .unwrap();
1348 let expected_none = array!([0.0, 2.31662]);
1349 assert_array_eq!(loss_none, expected_none);
1350
1351 let triplet_loss = TripletLossBuilder::new()
1353 .reduction(LossReduction::Mean)
1354 .build()
1355 .unwrap();
1356 let loss_mean = triplet_loss
1357 .apply(&anchors, &positives, &negatives)
1358 .unwrap();
1359 let expected_mean = expected_none.mean(None, None).unwrap();
1360 assert_array_eq!(loss_mean, expected_mean);
1361
1362 let triplet_loss = TripletLossBuilder::new()
1364 .reduction(LossReduction::Sum)
1365 .build()
1366 .unwrap();
1367 let loss_sum = triplet_loss
1368 .apply(&anchors, &positives, &negatives)
1369 .unwrap();
1370 let expected_sum = expected_none.sum(None, None).unwrap();
1371 assert_array_eq!(loss_sum, expected_sum);
1372 }
1373
1374 #[test]
1375 fn test_hinge_loss() {
1376 let inputs = array!([[1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]]);
1377 let targets = array!([[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]);
1378 let hinge_loss = HingeLossBuilder::new()
1379 .reduction(LossReduction::Mean)
1380 .build()
1381 .unwrap();
1382 let loss = hinge_loss.apply(&inputs, &targets).unwrap();
1383 assert_eq!(loss.item::<f32>(), 1.0);
1384 }
1385
1386 #[test]
1387 fn test_huber_loss() {
1388 let inputs = array!([[1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]]);
1389 let targets = array!([[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]);
1390 let huber_loss = HuberLossBuilder::new()
1391 .reduction(LossReduction::Mean)
1392 .build()
1393 .unwrap();
1394 let loss = huber_loss.apply(&inputs, &targets).unwrap();
1395 assert_eq!(loss.item::<f32>(), 0.5);
1396 }
1397
1398 #[test]
1399 fn test_log_cosh_loss() {
1400 let inputs = array!([[1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]]);
1401 let targets = array!([[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]);
1402 let log_cosh_loss = LogCoshLossBuilder::new()
1403 .reduction(LossReduction::Mean)
1404 .build()
1405 .unwrap();
1406 let loss = log_cosh_loss.apply(&inputs, &targets).unwrap();
1407 assert_float_eq!(loss.item::<f32>(), 0.433781, abs <= 1e-6);
1408 }
1409
1410 #[test]
1411 fn test_cosine_similarity_loss() {
1412 let embeddings1 = array!([[0.5, 0.5, 0.2, 0.9], [0.1, 0.3, 0.5, 0.5]]);
1413 let embeddings2 = array!([[0.6, 0.4, 0.3, 0.8], [0.2, 0.5, 0.6, 0.4]]);
1414
1415 let cosine_similarity_loss = CosineSimilarityLossBuilder::new()
1417 .reduction(LossReduction::None)
1418 .build()
1419 .unwrap();
1420 let loss_none = cosine_similarity_loss
1421 .apply(&embeddings1, &embeddings2)
1422 .unwrap();
1423 let expected_none = array!([0.985344, 0.961074]);
1424 assert_array_eq!(loss_none, expected_none);
1425
1426 let cosine_similarity_loss = CosineSimilarityLossBuilder::new()
1428 .reduction(LossReduction::Mean)
1429 .build()
1430 .unwrap();
1431 let loss_mean = cosine_similarity_loss
1432 .apply(&embeddings1, &embeddings2)
1433 .unwrap();
1434 let expected_mean = expected_none.mean(None, None).unwrap();
1435 assert_array_eq!(loss_mean, expected_mean);
1436
1437 let cosine_similarity_loss = CosineSimilarityLossBuilder::new()
1439 .reduction(LossReduction::Sum)
1440 .build()
1441 .unwrap();
1442 let loss_sum = cosine_similarity_loss
1443 .apply(&embeddings1, &embeddings2)
1444 .unwrap();
1445 let expected_sum = expected_none.sum(None, None).unwrap();
1446 assert_array_eq!(loss_sum, expected_sum);
1447 }
1448
1449 #[test]
1450 fn test_margin_ranking_loss() {
1451 let inputs1 = array!([-0.573409, -0.765166, -0.0638]);
1452 let inputs2 = array!([0.75596, 0.225763, 0.256995]);
1453 let targets = array!([1, 1, -1]);
1454
1455 let margin_ranking_loss = MarginRankingLossBuilder::new()
1457 .reduction(LossReduction::None)
1458 .build()
1459 .unwrap();
1460 let loss = margin_ranking_loss
1461 .apply(&inputs1, &inputs2, &targets)
1462 .unwrap();
1463 let expected = array!([1.329369, 0.990929, 0.0]);
1464 assert_array_eq!(loss, expected);
1465
1466 let margin_ranking_loss = MarginRankingLossBuilder::new()
1468 .margin(0.5)
1469 .reduction(LossReduction::None)
1470 .build()
1471 .unwrap();
1472 let loss = margin_ranking_loss
1473 .apply(&inputs1, &inputs2, &targets)
1474 .unwrap();
1475 let expected = array!([1.829369, 1.490929, 0.179205]);
1476 assert_array_eq!(loss, expected);
1477 }
1478}