1use crate::{
4 array,
5 error::{CrossEntropyBuildError, Exception},
6 ops::{
7 abs, clip, exp, indexing::take_along_axis, log, logaddexp, logsumexp_axes, maximum,
8 minimum, multiply, power, r#where, sqrt, square, sum_axes, sum_axis,
9 },
10 Array,
11};
12use mlx_internal_macros::{generate_builder, Buildable};
13
14#[inline]
15fn check_shape(
16 left: &Array,
17 right: &Array,
18 left_ident: &str,
19 right_ident: &str,
20) -> Result<(), Exception> {
21 if left.shape() != right.shape() {
22 return Err(Exception::custom(format!(
23 "The shape of the {} ({:?}) does not match the shape of the {} ({:?})",
24 left_ident,
25 left.shape(),
26 right_ident,
27 right.shape()
28 )));
29 }
30 Ok(())
31}
32
33#[derive(Debug, Clone, Copy)]
35pub enum LossReduction {
36 None,
38 Sum,
40 Mean,
42}
43
44impl LossReduction {
45 pub fn reduce(&self, loss: Array) -> Result<Array, Exception> {
47 match self {
48 LossReduction::None => Ok(loss),
49 LossReduction::Sum => Ok(loss.sum(None)?),
50 LossReduction::Mean => Ok(loss.mean(None)?),
51 }
52 }
53}
54
55pub type CrossEntropyBuilderWeights<'a> = &'a Array;
57
58generate_builder! {
59 #[derive(Debug, Clone, Buildable)]
61 #[buildable(root = crate)]
62 #[builder(
63 root = crate,
64 build_with = build_cross_entropy,
65 err = CrossEntropyBuildError
66 )]
67 pub struct CrossEntropy<'a> {
68 #[builder(optional, default = CrossEntropy::DEFAULT_WEIGHTS)]
70 pub weights: Option<&'a Array>,
71
72 #[builder(optional, default = CrossEntropy::DEFAULT_AXIS)]
74 pub axis: i32,
75
76 #[builder(optional, default = CrossEntropy::DEFAULT_LABEL_SMOOTHING)]
79 pub label_smoothing: f32,
80
81 #[builder(optional, default = CrossEntropy::DEFAULT_REDUCTION)]
83 pub reduction: LossReduction,
84 }
85}
86
87fn build_cross_entropy(
88 builder: CrossEntropyBuilder,
89) -> Result<CrossEntropy, CrossEntropyBuildError> {
90 let axis = builder.axis;
91 let label_smoothing = builder.label_smoothing;
92 let reduction = builder.reduction;
93
94 if !(0.0..1.0).contains(&label_smoothing) {
95 return Err(CrossEntropyBuildError::InvalidLabelSmoothingFactor);
96 }
97
98 Ok(CrossEntropy {
99 weights: builder.weights,
100 axis,
101 label_smoothing,
102 reduction,
103 })
104}
105
106impl<'a> CrossEntropy<'a> {
107 pub const DEFAULT_AXIS: i32 = -1;
109
110 pub const DEFAULT_LABEL_SMOOTHING: f32 = 0.0;
112
113 pub const DEFAULT_REDUCTION: LossReduction = LossReduction::None;
115
116 pub const DEFAULT_WEIGHTS: Option<&'a Array> = None;
118
119 pub fn apply(
126 &self,
127 logits: impl AsRef<Array>,
128 targets: impl AsRef<Array>,
129 ) -> Result<Array, Exception> {
130 let logits = logits.as_ref();
131 let targets = targets.as_ref();
132
133 let target_as_probs = targets.ndim() == logits.ndim();
134
135 let score = if target_as_probs {
136 sum_axes(&logits.multiply(targets)?, &[self.axis], None)?
137 } else {
138 take_along_axis(logits, &targets.expand_dims_axes(&[-1])?, self.axis)?
139 .squeeze_axes(&[-1])?
140 };
141 let log_sum_exp_logits = logsumexp_axes(logits, &[self.axis], None)?;
142
143 let mut loss = if self.label_smoothing > 0.0 {
144 let adjusted_score = multiply(array!(1.0 - self.label_smoothing), score)?;
146
147 let mean_logits = logits.mean_axis(self.axis, None)?;
149 let smoothed_loss = -multiply(mean_logits, array!(self.label_smoothing))?;
150
151 log_sum_exp_logits
153 .subtract(adjusted_score)?
154 .add(smoothed_loss)?
155 } else {
156 log_sum_exp_logits.subtract(score)?
157 };
158
159 if let Some(weights) = self.weights {
160 check_shape(weights, &loss, "weights", "loss")?;
161 loss = multiply(loss, weights)?;
162 }
163
164 self.reduction.reduce(loss)
165 }
166}
167
168generate_builder! {
169 #[derive(Debug, Clone, Buildable)]
176 #[buildable(root = crate)]
177 #[builder(root = crate)]
178 pub struct BinaryCrossEntropy<'a> {
179 #[builder(optional, default = BinaryCrossEntropy::DEFAULT_WEIGHTS)]
181 pub weights: Option<&'a Array>,
182
183 #[builder(optional, default = BinaryCrossEntropy::DEFAULT_INPUTS_ARE_LOGITS)]
186 pub inputs_are_logits: bool,
187
188 #[builder(optional, default = BinaryCrossEntropy::DEFAULT_REDUCTION)]
190 pub reduction: LossReduction,
191 }
192}
193
194impl<'a> BinaryCrossEntropy<'a> {
195 pub const DEFAULT_WEIGHTS: Option<&'a Array> = None;
197
198 pub const DEFAULT_INPUTS_ARE_LOGITS: bool = true;
200
201 pub const DEFAULT_REDUCTION: LossReduction = LossReduction::None;
203
204 pub fn apply(
211 &self,
212 logits: impl AsRef<Array>,
213 targets: impl AsRef<Array>,
214 ) -> Result<Array, Exception> {
215 let logits = logits.as_ref();
216 let targets = targets.as_ref();
217 let weights = self.weights;
218 let inputs_are_logits = self.inputs_are_logits;
219 let reduction = self.reduction;
220
221 let mut loss = if inputs_are_logits {
222 logaddexp(array!(0.0), logits)?.subtract(targets.multiply(logits)?)?
223 } else {
224 let log_inputs_clip = clip(log(logits)?, (-100.0, ()))?;
225 let log_inputs_inverse_clip = clip(log(&array!(1.0).subtract(logits)?)?, (-100.0, ()))?;
226 -(targets.multiply(log_inputs_clip)?.add(
227 array!(1.0)
228 .subtract(targets)?
229 .multiply(log_inputs_inverse_clip)?,
230 )?)
231 };
232
233 if let Some(weights) = weights {
234 check_shape(weights, &loss, "weights", "loss")?;
235 loss = multiply(loss, weights)?;
236 }
237
238 reduction.reduce(loss)
239 }
240}
241
242generate_builder! {
243 #[derive(Debug, Clone, Buildable)]
245 #[buildable(root = crate)]
246 #[builder(root = crate)]
247 pub struct L1Loss {
248 #[builder(optional, default = L1Loss::DEFAULT_REDUCTION)]
250 pub reduction: LossReduction,
251 }
252}
253
254impl L1Loss {
255 pub const DEFAULT_REDUCTION: LossReduction = LossReduction::Mean;
257
258 pub fn apply(
265 &self,
266 predictions: impl AsRef<Array>,
267 targets: impl AsRef<Array>,
268 ) -> Result<Array, Exception> {
269 let predictions = predictions.as_ref();
270 let targets = targets.as_ref();
271 let reduction = self.reduction;
272
273 check_shape(predictions, targets, "predictions", "targets")?;
274 let loss = predictions.subtract(targets)?.abs()?;
275 reduction.reduce(loss)
276 }
277}
278
279generate_builder! {
280 #[derive(Debug, Clone, Buildable)]
282 #[buildable(root = crate)]
283 #[builder(root = crate)]
284 pub struct MseLoss {
285 #[builder(optional, default = MseLoss::DEFAULT_REDUCTION)]
287 pub reduction: LossReduction,
288 }
289}
290
291impl MseLoss {
292 pub const DEFAULT_REDUCTION: LossReduction = LossReduction::Mean;
294
295 pub fn apply(
302 &self,
303 predictions: impl AsRef<Array>,
304 targets: impl AsRef<Array>,
305 ) -> Result<Array, Exception> {
306 let predictions = predictions.as_ref();
307 let targets = targets.as_ref();
308 let reduction = self.reduction;
309
310 check_shape(predictions, targets, "predictions", "targets")?;
311 let loss = predictions.subtract(targets)?.square()?;
312 reduction.reduce(loss)
313 }
314}
315
316generate_builder! {
317 #[derive(Debug, Clone, Buildable)]
319 #[buildable(root = crate)]
320 #[builder(root = crate)]
321 pub struct NllLoss {
322 #[builder(optional, default = NllLoss::DEFAULT_AXIS)]
324 pub axis: i32,
325
326 #[builder(optional, default = NllLoss::DEFAULT_REDUCTION)]
328 pub reduction: LossReduction,
329 }
330}
331
332impl NllLoss {
333 pub const DEFAULT_AXIS: i32 = -1;
335
336 pub const DEFAULT_REDUCTION: LossReduction = LossReduction::None;
338
339 pub fn apply(
346 &self,
347 inputs: impl AsRef<Array>,
348 targets: impl AsRef<Array>,
349 ) -> Result<Array, Exception> {
350 let inputs = inputs.as_ref();
351 let targets = targets.as_ref();
352 let axis = self.axis;
353 let reduction = self.reduction;
354
355 let loss = -take_along_axis(inputs, &targets.expand_dims_axes(&[-1])?, axis)?
356 .squeeze_axes(&[-1])?;
357 reduction.reduce(loss)
358 }
359}
360
361generate_builder! {
362 #[derive(Debug, Clone, Buildable)]
364 #[buildable(root = crate)]
365 #[builder(root = crate)]
366 pub struct GaussianNllLoss {
367 #[builder(optional, default = GaussianNllLoss::DEFAULT_FULL)]
370 pub full: bool,
371
372 #[builder(optional, default = GaussianNllLoss::DEFAULT_EPS)]
375 pub eps: f32,
376
377 #[builder(optional, default = GaussianNllLoss::DEFAULT_REDUCTION)]
379 pub reduction: LossReduction,
380 }
381}
382
383impl GaussianNllLoss {
384 pub const DEFAULT_FULL: bool = false;
386
387 pub const DEFAULT_EPS: f32 = 1e-6;
389
390 pub const DEFAULT_REDUCTION: LossReduction = LossReduction::None;
392
393 pub fn apply(
401 &self,
402 inputs: impl AsRef<Array>,
403 targets: impl AsRef<Array>,
404 vars: impl AsRef<Array>,
405 ) -> Result<Array, Exception> {
406 let inputs = inputs.as_ref();
407 let targets = targets.as_ref();
408 let vars = vars.as_ref();
409 let full = self.full;
410 let eps = self.eps;
411 let reduction = self.reduction;
412
413 check_shape(inputs, targets, "inputs", "targets")?;
414 check_shape(inputs, vars, "inputs", "vars")?;
415
416 let vars = maximum(vars, array!(eps))?;
417 let mut loss =
418 array!(0.5) * (log(&vars)?.add(square(&targets.subtract(inputs)?)?.divide(&vars)?)?);
419
420 if full {
421 let pi = array!(std::f32::consts::PI);
422 loss = loss.add(array!(0.5).multiply(log(&array!(2.0).multiply(pi)?)?)?)?;
423 }
424
425 reduction.reduce(loss)
426 }
427}
428
429generate_builder! {
430 #[derive(Debug, Clone, Buildable)]
438 #[buildable(root = crate)]
439 #[builder(root = crate)]
440 pub struct KlDivLoss {
441 #[builder(optional, default = KlDivLoss::DEFAULT_AXIS)]
443 pub axis: i32,
444
445 #[builder(optional, default = KlDivLoss::DEFAULT_REDUCTION)]
447 pub reduction: LossReduction,
448 }
449}
450
451impl KlDivLoss {
452 pub const DEFAULT_AXIS: i32 = -1;
454
455 pub const DEFAULT_REDUCTION: LossReduction = LossReduction::None;
457
458 pub fn apply(
465 &self,
466 inputs: impl AsRef<Array>,
467 targets: impl AsRef<Array>,
468 ) -> Result<Array, Exception> {
469 let inputs = inputs.as_ref();
470 let targets = targets.as_ref();
471 let axis = self.axis;
472 let reduction = self.reduction;
473
474 let loss = sum_axis(
475 &exp(targets)?.multiply(targets.subtract(inputs)?)?,
476 axis,
477 None,
478 )?;
479 reduction.reduce(loss)
480 }
481}
482
483generate_builder! {
484 #[derive(Debug, Clone, Buildable)]
490 #[buildable(root = crate)]
491 #[builder(root = crate)]
492 pub struct SmoothL1Loss {
493 #[builder(optional, default = SmoothL1Loss::DEFAULT_BETA)]
496 pub beta: f32,
497
498 #[builder(optional, default = SmoothL1Loss::DEFAULT_REDUCTION)]
500 pub reduction: LossReduction,
501 }
502}
503
504impl SmoothL1Loss {
505 pub const DEFAULT_BETA: f32 = 1.0;
507
508 pub const DEFAULT_REDUCTION: LossReduction = LossReduction::Mean;
510
511 pub fn apply(
518 &self,
519 predictions: impl AsRef<Array>,
520 targets: impl AsRef<Array>,
521 ) -> Result<Array, Exception> {
522 let predictions = predictions.as_ref();
523 let targets = targets.as_ref();
524 let beta = self.beta;
525 let reduction = self.reduction;
526
527 check_shape(predictions, targets, "predictions", "targets")?;
528 let diff = predictions.subtract(targets)?.abs()?;
529 let beta = array!(beta);
530 let loss = r#where(
531 &diff.lt(&beta)?,
532 array!(0.5).multiply(square(&diff)?)?.divide(&beta)?,
533 diff.subtract(array!(0.5).multiply(beta)?)?,
534 )?;
535 reduction.reduce(loss)
536 }
537}
538
539generate_builder! {
540 #[derive(Debug, Clone, Buildable)]
543 #[buildable(root = crate)]
544 #[builder(root = crate)]
545 pub struct TripletLoss {
546 #[builder(optional, default = TripletLoss::DEFAULT_AXIS)]
548 pub axis: i32,
549
550 #[builder(optional, default = TripletLoss::DEFAULT_P)]
552 pub p: f32,
553
554 #[builder(optional, default = TripletLoss::DEFAULT_MARGIN)]
556 pub margin: f32,
557
558 #[builder(optional, default = TripletLoss::DEFAULT_EPS)]
560 pub eps: f32,
561
562 #[builder(optional, default = TripletLoss::DEFAULT_REDUCTION)]
564 pub reduction: LossReduction,
565 }
566}
567
568impl TripletLoss {
569 pub const DEFAULT_AXIS: i32 = -1;
571
572 pub const DEFAULT_P: f32 = 2.0;
574
575 pub const DEFAULT_MARGIN: f32 = 1.0;
577
578 pub const DEFAULT_EPS: f32 = 1e-6;
580
581 pub const DEFAULT_REDUCTION: LossReduction = LossReduction::None;
583
584 pub fn apply(
593 &self,
594 anchors: impl AsRef<Array>,
595 positives: impl AsRef<Array>,
596 negatives: impl AsRef<Array>,
597 ) -> Result<Array, Exception> {
598 let anchors = anchors.as_ref();
599 let positives = positives.as_ref();
600 let negatives = negatives.as_ref();
601 let axis = self.axis;
602 let p = self.p;
603 let margin = self.margin;
604 let eps = self.eps;
605 let reduction = self.reduction;
606
607 let eps = array!(eps);
608 let p = array!(p);
609 let margin = array!(margin);
610
611 let pos = sqrt(
612 &power(&anchors.subtract(positives)?, &p)?
613 .sum_axis(axis, None)?
614 .add(&eps)?,
615 )?;
616 let neg = sqrt(
617 &power(&anchors.subtract(negatives)?, &p)?
618 .sum_axis(axis, None)?
619 .add(&eps)?,
620 )?;
621 let loss = maximum(pos.subtract(neg)?.add(margin)?, array!(0.0))?;
622 reduction.reduce(loss)
623 }
624}
625
626generate_builder! {
627 #[derive(Debug, Clone, Buildable)]
629 #[buildable(root = crate)]
630 #[builder(root = crate)]
631 pub struct HingeLoss {
632 #[builder(optional, default = HingeLoss::DEFAULT_REDUCTION)]
634 pub reduction: LossReduction,
635 }
636}
637
638impl HingeLoss {
639 pub const DEFAULT_REDUCTION: LossReduction = LossReduction::None;
641
642 pub fn apply(
649 &self,
650 inputs: impl AsRef<Array>,
651 targets: impl AsRef<Array>,
652 ) -> Result<Array, Exception> {
653 let inputs = inputs.as_ref();
654 let targets = targets.as_ref();
655 let reduction = self.reduction;
656
657 let a = array!(1.0).subtract(inputs.multiply(targets)?)?;
658 let b = array!(0.0);
659 let loss = maximum(a, b)?;
660 reduction.reduce(loss)
661 }
662}
663
664generate_builder! {
665 #[derive(Debug, Clone, Buildable)]
667 #[buildable(root = crate)]
668 #[builder(root = crate)]
669 pub struct HuberLoss {
670 #[builder(optional, default = HuberLoss::DEFAULT_DELTA)]
673 pub delta: f32,
674
675 #[builder(optional, default = HuberLoss::DEFAULT_REDUCTION)]
677 pub reduction: LossReduction,
678 }
679}
680
681impl HuberLoss {
682 pub const DEFAULT_DELTA: f32 = 1.0;
684
685 pub const DEFAULT_REDUCTION: LossReduction = LossReduction::None;
687
688 pub fn apply(
695 &self,
696 inputs: impl AsRef<Array>,
697 targets: impl AsRef<Array>,
698 ) -> Result<Array, Exception> {
699 let inputs = inputs.as_ref();
700 let targets = targets.as_ref();
701 let delta = self.delta;
702 let reduction = self.reduction;
703
704 let errors = inputs.subtract(targets)?;
705 let abs_errors = errors.abs()?;
706 let quadratic = minimum(&abs_errors, array!(delta))?;
707 let linear = abs_errors.subtract(&quadratic)?;
708 let loss = array!(0.5)
709 .multiply(square(&quadratic)?)?
710 .add(array!(delta).multiply(linear)?)?;
711 reduction.reduce(loss)
712 }
713}
714
715generate_builder! {
716 #[derive(Debug, Clone, Buildable)]
722 #[buildable(root = crate)]
723 #[builder(root = crate)]
724 pub struct LogCoshLoss {
725 #[builder(optional, default = LogCoshLoss::DEFAULT_REDUCTION)]
727 pub reduction: LossReduction,
728 }
729}
730
731impl LogCoshLoss {
732 pub const DEFAULT_REDUCTION: LossReduction = LossReduction::None;
734
735 pub fn apply(
742 &self,
743 inputs: impl AsRef<Array>,
744 targets: impl AsRef<Array>,
745 ) -> Result<Array, Exception> {
746 let inputs = inputs.as_ref();
747 let targets = targets.as_ref();
748 let reduction = self.reduction;
749
750 let errors = inputs.subtract(targets)?;
751 let neg_errors = errors.negative()?;
752 let loss = logaddexp(errors, neg_errors)?.subtract(log(&array!(2.0))?)?;
753 reduction.reduce(loss)
754 }
755}
756
757generate_builder! {
758 #[derive(Debug, Clone, Buildable)]
760 #[buildable(root = crate)]
761 #[builder(root = crate)]
762 pub struct CosineSimilarityLoss {
763 #[builder(optional, default = CosineSimilarityLoss::DEFAULT_AXIS)]
765 pub axis: i32,
766
767 #[builder(optional, default = CosineSimilarityLoss::DEFAULT_EPS)]
770 pub eps: f32,
771
772 #[builder(optional, default = CosineSimilarityLoss::DEFAULT_REDUCTION)]
774 pub reduction: LossReduction,
775 }
776}
777
778impl CosineSimilarityLoss {
779 pub const DEFAULT_AXIS: i32 = -1;
781
782 pub const DEFAULT_EPS: f32 = 1e-8;
784
785 pub const DEFAULT_REDUCTION: LossReduction = LossReduction::None;
787
788 pub fn apply(&self, x1: impl AsRef<Array>, x2: impl AsRef<Array>) -> Result<Array, Exception> {
795 let x1 = x1.as_ref();
796 let x2 = x2.as_ref();
797 let axis = self.axis;
798 let eps = self.eps;
799 let reduction = self.reduction;
800
801 fn l2_loss(a: &Array, axis: i32) -> Result<Array, Exception> {
802 if a.dtype().is_complex() {
803 Ok(sqrt(&sum_axis(&abs(a)?.square()?, axis, None)?)?)
804 } else {
805 Ok(sqrt(&sum_axis(&a.square()?, axis, None)?)?)
806 }
807 }
808
809 let x1_norm = l2_loss(x1, axis)?;
810 let x2_norm = l2_loss(x2, axis)?;
811
812 let num = sum_axis(&x1.multiply(x2)?, axis, None)?;
813 let den = maximum(x1_norm.multiply(x2_norm)?, array!(eps))?;
814 let loss = num.divide(&den)?;
815
816 reduction.reduce(loss)
817 }
818}
819
820generate_builder! {
821 #[derive(Debug, Clone, Buildable)]
823 #[buildable(root = crate)]
824 #[builder(root = crate)]
825 pub struct MarginRankingLoss {
826 #[builder(optional, default = MarginRankingLoss::DEFAULT_MARGIN)]
829 pub margin: f32,
830
831 #[builder(optional, default = MarginRankingLoss::DEFAULT_REDUCTION)]
833 pub reduction: LossReduction,
834 }
835}
836
837impl MarginRankingLoss {
838 pub const DEFAULT_MARGIN: f32 = 0.0;
840
841 pub const DEFAULT_REDUCTION: LossReduction = LossReduction::None;
843
844 pub fn apply(
853 &self,
854 inputs1: impl AsRef<Array>,
855 inputs2: impl AsRef<Array>,
856 targets: impl AsRef<Array>,
857 ) -> Result<Array, Exception> {
858 let inputs1 = inputs1.as_ref();
859 let inputs2 = inputs2.as_ref();
860 let targets = targets.as_ref();
861 let margin = self.margin;
862 let reduction = self.reduction;
863
864 check_shape(inputs1, inputs2, "inputs1", "inputs2")?;
865 check_shape(inputs1, targets, "inputs1", "targets")?;
866
867 let margin = array!(margin);
868 let diff = inputs1.subtract(inputs2)?;
869 let loss = maximum(
870 array!(0.0),
871 targets.multiply(diff)?.negative()?.add(margin)?,
872 )?;
873 reduction.reduce(loss)
874 }
875}
876
877#[cfg(test)]
878#[allow(clippy::approx_constant)]
879mod tests {
880 use crate::{array, assert_array_eq, builder::Builder, ops::is_nan};
881 use float_eq::assert_float_eq;
882
883 use super::*;
884
885 #[test]
888 fn test_cross_entropy() {
889 let logits = array!([[0.0, f32::NEG_INFINITY], [f32::NEG_INFINITY, 0.0]]);
891 let indices = array!([0, 1]);
892 let expected = array!([0.0, 0.0]);
893 let loss = CrossEntropy::new()
894 .unwrap()
895 .apply(&logits, indices)
896 .unwrap();
897 assert_array_eq!(loss, expected);
898
899 let probs = array!([[1.0, 0.0], [0.0, 1.0]]);
900 let cross_entropy = CrossEntropyBuilder::new()
901 .reduction(LossReduction::None)
902 .build()
903 .unwrap();
904 let loss = cross_entropy.apply(logits, probs).unwrap();
905 assert!(is_nan(&loss).unwrap().all(None).unwrap().item::<bool>());
906
907 let logits = array!([[2.0, -1.0], [-1.0, 2.0]]);
909 let indices = array!([0, 1]);
910 let weights = array!([1.0, 2.0]);
911 let expected = array!([0.04858735, 0.0971747]);
912 let cross_entropy = CrossEntropyBuilder::new()
913 .weights(&weights)
914 .reduction(LossReduction::None)
915 .build()
916 .unwrap();
917 let loss = cross_entropy.apply(&logits, indices).unwrap();
918 assert_array_eq!(loss, expected);
919
920 let probs = array!([[1.0, 0.0], [0.0, 1.0]]);
921 let cross_entropy = CrossEntropyBuilder::new()
922 .weights(&weights)
923 .reduction(LossReduction::None)
924 .build()
925 .unwrap();
926 let loss = cross_entropy.apply(logits, probs).unwrap();
927 assert_array_eq!(loss, expected);
928
929 let logits = array!([[2.0, -1.0], [-1.0, 2.0]]);
931 let indices = array!([0, 1]);
932 let expected = array!([0.498587, 0.498587]);
933 let cross_entropy = CrossEntropyBuilder::new()
934 .label_smoothing(0.3)
935 .reduction(LossReduction::None)
936 .build()
937 .unwrap();
938 let loss = cross_entropy.apply(&logits, indices).unwrap();
939 assert_array_eq!(loss, expected);
940
941 let probs = array!([[1.0, 0.0], [0.0, 1.0]]);
942 let cross_entropy = CrossEntropyBuilder::new()
943 .label_smoothing(0.3)
944 .reduction(LossReduction::None)
945 .build()
946 .unwrap();
947 let loss = cross_entropy.apply(logits, probs).unwrap();
948 assert_array_eq!(loss, expected);
949
950 let logits = array!([[2.0, -1.0], [-1.0, 2.0]]);
952 let indices = array!([0, 1]);
953 let weights = array!([1.0, 2.0]);
954 let expected = array!([0.49858734, 0.9971747]);
955 let cross_entropy = CrossEntropyBuilder::new()
956 .weights(&weights)
957 .label_smoothing(0.3)
958 .reduction(LossReduction::None)
959 .build()
960 .unwrap();
961 let loss = cross_entropy.apply(&logits, indices).unwrap();
962 assert_array_eq!(loss, expected);
963
964 let probs = array!([[1.0, 0.0], [0.0, 1.0]]);
965 let cross_entropy = CrossEntropyBuilder::new()
966 .weights(&weights)
967 .label_smoothing(0.3)
968 .reduction(LossReduction::None)
969 .build()
970 .unwrap();
971 let loss = cross_entropy.apply(logits, probs).unwrap();
972 assert_array_eq!(loss, expected);
973 }
974
975 #[test]
976 fn test_binary_cross_entropy_with_logits_as_inputs() {
977 let logits = array!([0.105361, 0.223144, 1.20397, 0.916291]);
978 let targets = array!([0.0, 0.0, 1.0, 1.0]);
979
980 let binary_cross_entropy = BinaryCrossEntropyBuilder::new()
982 .reduction(LossReduction::None)
983 .build()
984 .unwrap();
985 let loss_none = binary_cross_entropy.apply(&logits, &targets).unwrap();
986 let expected_none = array!([0.747215, 0.810930, 0.262365, 0.336472]);
987 assert_array_eq!(loss_none, expected_none);
988
989 let binary_cross_entropy = BinaryCrossEntropyBuilder::new()
991 .reduction(LossReduction::Mean)
992 .build()
993 .unwrap();
994 let loss_mean = binary_cross_entropy.apply(&logits, &targets).unwrap();
995 let expected_mean = expected_none.mean(None).unwrap();
996 assert_array_eq!(loss_mean, expected_mean);
997
998 let binary_cross_entropy = BinaryCrossEntropyBuilder::new()
1000 .reduction(LossReduction::Sum)
1001 .build()
1002 .unwrap();
1003 let loss = binary_cross_entropy.apply(&logits, &targets).unwrap();
1004 let expected = expected_none.sum(None).unwrap();
1005 assert_array_eq!(loss, expected);
1006
1007 let weights = array!([1.0, 2.0, 1.0, 2.0]);
1009 let expected = array!([0.747215, 1.62186, 0.262365, 0.672944]);
1010 let binary_cross_entropy = BinaryCrossEntropyBuilder::new()
1011 .weights(&weights)
1012 .reduction(LossReduction::None)
1013 .build()
1014 .unwrap();
1015 let loss = binary_cross_entropy.apply(&logits, &targets).unwrap();
1016 assert_array_eq!(loss, expected);
1017 }
1018
1019 #[test]
1020 fn test_binary_cross_entropy_with_probs_as_inputs() {
1021 let probs = array!([0.5, 0.6, 0.7, 0.8]);
1022 let targets = array!([0.0, 0.0, 1.0, 1.0]);
1023
1024 let binary_cross_entropy = BinaryCrossEntropyBuilder::new()
1026 .inputs_are_logits(false)
1027 .reduction(LossReduction::None)
1028 .build()
1029 .unwrap();
1030 let loss_none = binary_cross_entropy.apply(&probs, &targets).unwrap();
1031 let expected_none = array!([0.693147, 0.916291, 0.356675, 0.223144]);
1032 assert_array_eq!(loss_none, expected_none);
1033
1034 let binary_cross_entropy = BinaryCrossEntropyBuilder::new()
1036 .inputs_are_logits(false)
1037 .reduction(LossReduction::Mean)
1038 .build()
1039 .unwrap();
1040 let loss_mean = binary_cross_entropy.apply(&probs, &targets).unwrap();
1041 let expected_mean = expected_none.mean(None).unwrap();
1042 assert_array_eq!(loss_mean, expected_mean);
1043
1044 let binary_cross_entropy = BinaryCrossEntropyBuilder::new()
1046 .inputs_are_logits(false)
1047 .reduction(LossReduction::Sum)
1048 .build()
1049 .unwrap();
1050 let loss = binary_cross_entropy.apply(&probs, &targets).unwrap();
1051 let expected = expected_none.sum(None).unwrap();
1052 assert_array_eq!(loss, expected);
1053 }
1054
1055 #[test]
1056 fn test_binary_cross_entropy_with_tiny_probs_as_inputs() {
1057 let tiny_prob = 1e-59;
1058 let probs = array!([0.0, tiny_prob, 1.0 - tiny_prob, 1.0]);
1059 let targets = array!([0.0, 0.0, 1.0, 1.0]);
1060
1061 let binary_cross_entropy = BinaryCrossEntropyBuilder::new()
1063 .inputs_are_logits(false)
1064 .reduction(LossReduction::None)
1065 .build()
1066 .unwrap();
1067 let loss_none = binary_cross_entropy.apply(&probs, &targets).unwrap();
1068 let expected_none = array!([0.0, tiny_prob, tiny_prob, 0.0]);
1069 assert_array_eq!(loss_none, expected_none);
1070
1071 let binary_cross_entropy = BinaryCrossEntropyBuilder::new()
1073 .inputs_are_logits(false)
1074 .reduction(LossReduction::Mean)
1075 .build()
1076 .unwrap();
1077 let loss_mean = binary_cross_entropy.apply(&probs, &targets).unwrap();
1078 let expected_mean = expected_none.mean(None).unwrap();
1079 assert_array_eq!(loss_mean, expected_mean);
1080
1081 let binary_cross_entropy = BinaryCrossEntropyBuilder::new()
1083 .inputs_are_logits(false)
1084 .reduction(LossReduction::Sum)
1085 .build()
1086 .unwrap();
1087 let loss = binary_cross_entropy.apply(&probs, &targets).unwrap();
1088 let expected = expected_none.sum(None).unwrap();
1089 assert_array_eq!(loss, expected);
1090 }
1091
1092 #[test]
1093 fn test_l1_loss() {
1094 let predictions = array!([0.5, 0.2, 0.9, 0.0]);
1095 let targets = array!([0.5, 0.2, 0.9, 0.0]);
1096
1097 let expected_none = array!([0.0, 0.0, 0.0, 0.0]);
1098 let expected_sum = expected_none.sum(None).unwrap();
1099 let expected_mean = expected_none.mean(None).unwrap();
1100
1101 let l1_loss = L1LossBuilder::new()
1102 .reduction(LossReduction::None)
1103 .build()
1104 .unwrap();
1105 let loss_none = l1_loss.apply(&predictions, &targets).unwrap();
1106 assert_array_eq!(loss_none, expected_none);
1107
1108 let l1_loss = L1LossBuilder::new()
1109 .reduction(LossReduction::Sum)
1110 .build()
1111 .unwrap();
1112 let loss_sum = l1_loss.apply(&predictions, &targets).unwrap();
1113 assert_array_eq!(loss_sum, expected_sum);
1114
1115 let l1_loss = L1LossBuilder::new()
1116 .reduction(LossReduction::Mean)
1117 .build()
1118 .unwrap();
1119 let loss_mean = l1_loss.apply(&predictions, &targets).unwrap();
1120 assert_array_eq!(loss_mean, expected_mean);
1121 }
1122
1123 #[test]
1124 fn test_mse_loss() {
1125 let predictions = array!([0.5, 0.2, 0.9, 0.0]);
1126 let targets = array!([0.7, 0.1, 0.8, 0.2]);
1127
1128 let expected_none = array!([0.04, 0.01, 0.01, 0.04]);
1129 let expected_mean = expected_none.mean(None).unwrap();
1130 let expected_sum = expected_none.sum(None).unwrap();
1131
1132 let mse_loss = MseLossBuilder::new()
1133 .reduction(LossReduction::None)
1134 .build()
1135 .unwrap();
1136 let loss_none = mse_loss.apply(&predictions, &targets).unwrap();
1137 assert_array_eq!(loss_none, expected_none);
1138
1139 let mse_loss = MseLossBuilder::new()
1140 .reduction(LossReduction::Mean)
1141 .build()
1142 .unwrap();
1143 let loss_mean = mse_loss.apply(&predictions, &targets).unwrap();
1144 assert_array_eq!(loss_mean, expected_mean);
1145
1146 let mse_loss = MseLossBuilder::new()
1147 .reduction(LossReduction::Sum)
1148 .build()
1149 .unwrap();
1150 let loss_sum = mse_loss.apply(&predictions, &targets).unwrap();
1151 assert_array_eq!(loss_sum, expected_sum);
1152 }
1153
1154 #[test]
1155 fn test_smooth_l1_loss() {
1156 let predictions = array!([1.5, 2.5, 0.5, 3.5]);
1157 let targets = array!([1.0, 2.0, 0.5, 2.5]);
1158 let beta = 1.0;
1159
1160 let expected_none = array!([0.125, 0.125, 0.0, 0.5]);
1161 let expected_sum = expected_none.sum(None).unwrap();
1162 let expected_mean = expected_none.mean(None).unwrap();
1163
1164 let smooth_l1_loss = SmoothL1LossBuilder::new()
1165 .beta(beta)
1166 .reduction(LossReduction::None)
1167 .build()
1168 .unwrap();
1169 let loss_none = smooth_l1_loss.apply(&predictions, &targets).unwrap();
1170 assert_array_eq!(loss_none, expected_none);
1171
1172 let smooth_l1_loss = SmoothL1LossBuilder::new()
1173 .beta(beta)
1174 .reduction(LossReduction::Sum)
1175 .build()
1176 .unwrap();
1177 let loss_sum = smooth_l1_loss.apply(&predictions, &targets).unwrap();
1178 assert_array_eq!(loss_sum, expected_sum);
1179
1180 let smooth_l1_loss = SmoothL1LossBuilder::new()
1181 .beta(beta)
1182 .reduction(LossReduction::Mean)
1183 .build()
1184 .unwrap();
1185 let loss_mean = smooth_l1_loss.apply(&predictions, &targets).unwrap();
1186 assert_array_eq!(loss_mean, expected_mean);
1187 }
1188
1189 #[test]
1190 fn test_smooth_l1_loss_negative_diff() {
1191 let a = array!([1.5, 6.0, 0.5, 2.5]);
1192 let b = array!([1.0, 2.0, 0.5, 3.5]);
1193
1194 let loss = SmoothL1Loss::new();
1195
1196 let ab = loss.apply(&a, &b).unwrap();
1197 let ba = loss.apply(&b, &a).unwrap();
1198 assert_array_eq!(ab, ba);
1199 }
1200
1201 #[test]
1202 fn test_nll_loss() {
1203 let logits = array!([[0.0, f32::NEG_INFINITY], [f32::NEG_INFINITY, 0.0]]);
1204 let targets = array!([0, 1]);
1205
1206 let expected_none = array!([0.0, 0.0]);
1207 let expected_sum = expected_none.sum(None).unwrap();
1208 let expected_mean = expected_none.mean(None).unwrap();
1209
1210 let nll_loss = NllLossBuilder::new()
1211 .reduction(LossReduction::None)
1212 .build()
1213 .unwrap();
1214 let loss_none = nll_loss.apply(&logits, &targets).unwrap();
1215 assert_array_eq!(loss_none, expected_none);
1216
1217 let nll_loss = NllLossBuilder::new()
1218 .reduction(LossReduction::Mean)
1219 .build()
1220 .unwrap();
1221 let loss_mean = nll_loss.apply(&logits, &targets).unwrap();
1222 assert_array_eq!(loss_mean, expected_mean);
1223
1224 let nll_loss = NllLossBuilder::new()
1225 .reduction(LossReduction::Sum)
1226 .build()
1227 .unwrap();
1228 let loss_sum = nll_loss.apply(&logits, &targets).unwrap();
1229 assert_array_eq!(loss_sum, expected_sum);
1230 }
1231
1232 #[test]
1233 fn test_gaussian_nll_loss() {
1234 let inputs = array!([[0.1, 0.2], [0.3, 0.4]]);
1235 let targets = array!([[0.2, 0.1], [0.1, 0.2]]);
1236 let vars = array!([[0.1, 0.2], [0.3, 0.4]]);
1237
1238 let gaussian_nll_loss = GaussianNllLossBuilder::new()
1240 .full(false)
1241 .reduction(LossReduction::None)
1242 .build()
1243 .unwrap();
1244 let loss_none = gaussian_nll_loss.apply(&inputs, &targets, &vars).unwrap();
1245 let expected_none = array!([[-1.101293, -0.779719], [-0.535320, -0.408145]]);
1246 assert_array_eq!(loss_none, expected_none);
1247
1248 let gaussian_nll_loss = GaussianNllLossBuilder::new()
1250 .full(false)
1251 .reduction(LossReduction::Mean)
1252 .build()
1253 .unwrap();
1254 let loss_mean = gaussian_nll_loss.apply(&inputs, &targets, &vars).unwrap();
1255 let expected_mean = expected_none.mean(None).unwrap();
1256 assert_array_eq!(loss_mean, expected_mean);
1257
1258 let gaussian_nll_loss = GaussianNllLossBuilder::new()
1260 .full(false)
1261 .reduction(LossReduction::Sum)
1262 .build()
1263 .unwrap();
1264 let loss_sum = gaussian_nll_loss.apply(&inputs, &targets, &vars).unwrap();
1265 let expected_sum = expected_none.sum(None).unwrap();
1266 assert_array_eq!(loss_sum, expected_sum);
1267
1268 let gaussian_nll_loss = GaussianNllLossBuilder::new()
1270 .full(true)
1271 .reduction(LossReduction::None)
1272 .build()
1273 .unwrap();
1274 let loss_none_full = gaussian_nll_loss.apply(&inputs, &targets, &vars).unwrap();
1275 let expected_none_full = array!([[-0.182354, 0.139220], [0.383619, 0.510793]]);
1276 assert_array_eq!(loss_none_full, expected_none_full);
1277
1278 let gaussian_nll_loss = GaussianNllLossBuilder::new()
1280 .full(true)
1281 .reduction(LossReduction::Mean)
1282 .build()
1283 .unwrap();
1284 let loss_mean_full = gaussian_nll_loss.apply(&inputs, &targets, &vars).unwrap();
1285 let expected_mean_full = expected_none_full.mean(None).unwrap();
1286 assert_array_eq!(loss_mean_full, expected_mean_full);
1287
1288 let gaussian_nll_loss = GaussianNllLossBuilder::new()
1290 .full(true)
1291 .reduction(LossReduction::Sum)
1292 .build()
1293 .unwrap();
1294 let loss_sum_full = gaussian_nll_loss.apply(&inputs, &targets, &vars).unwrap();
1295 let expected_sum_full = expected_none_full.sum(None).unwrap();
1296 assert_array_eq!(loss_sum_full, expected_sum_full);
1297 }
1298
1299 #[test]
1300 fn test_kl_div_loss() {
1301 let p_logits = array!([[0.5, 0.5], [0.8, 0.2]]).log().unwrap();
1302 let q_logits = array!([[0.5, 0.5], [0.2, 0.8]]).log().unwrap();
1303
1304 let kl_div_loss = KlDivLossBuilder::new()
1306 .reduction(LossReduction::None)
1307 .build()
1308 .unwrap();
1309 let loss_none = kl_div_loss.apply(&p_logits, &q_logits).unwrap();
1310 let expected_none = array!([0.0, 0.831777]);
1311 assert_array_eq!(loss_none, expected_none);
1312
1313 let kl_div_loss = KlDivLossBuilder::new()
1315 .reduction(LossReduction::Mean)
1316 .build()
1317 .unwrap();
1318 let loss_mean = kl_div_loss.apply(&p_logits, &q_logits).unwrap();
1319 let expected_mean = expected_none.mean(None).unwrap();
1320 assert_array_eq!(loss_mean, expected_mean);
1321
1322 let kl_div_loss = KlDivLossBuilder::new()
1324 .reduction(LossReduction::Sum)
1325 .build()
1326 .unwrap();
1327 let loss_sum = kl_div_loss.apply(&p_logits, &q_logits).unwrap();
1328 let expected_sum = expected_none.sum(None).unwrap();
1329 assert_array_eq!(loss_sum, expected_sum);
1330 }
1331
1332 #[test]
1333 fn test_triplet_loss() {
1334 let anchors = array!([[1, 2, 3], [1, 2, 3]]);
1335 let positives = array!([[4, 5, 6], [0, -1, 2]]);
1336 let negatives = array!([[7, 8, 9], [3, 2, 3]]);
1337
1338 let triplet_loss = TripletLossBuilder::new()
1340 .reduction(LossReduction::None)
1341 .build()
1342 .unwrap();
1343 let loss_none = triplet_loss
1344 .apply(&anchors, &positives, &negatives)
1345 .unwrap();
1346 let expected_none = array!([0.0, 2.31662]);
1347 assert_array_eq!(loss_none, expected_none);
1348
1349 let triplet_loss = TripletLossBuilder::new()
1351 .reduction(LossReduction::Mean)
1352 .build()
1353 .unwrap();
1354 let loss_mean = triplet_loss
1355 .apply(&anchors, &positives, &negatives)
1356 .unwrap();
1357 let expected_mean = expected_none.mean(None).unwrap();
1358 assert_array_eq!(loss_mean, expected_mean);
1359
1360 let triplet_loss = TripletLossBuilder::new()
1362 .reduction(LossReduction::Sum)
1363 .build()
1364 .unwrap();
1365 let loss_sum = triplet_loss
1366 .apply(&anchors, &positives, &negatives)
1367 .unwrap();
1368 let expected_sum = expected_none.sum(None).unwrap();
1369 assert_array_eq!(loss_sum, expected_sum);
1370 }
1371
1372 #[test]
1373 fn test_hinge_loss() {
1374 let inputs = array!([[1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]]);
1375 let targets = array!([[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]);
1376 let hinge_loss = HingeLossBuilder::new()
1377 .reduction(LossReduction::Mean)
1378 .build()
1379 .unwrap();
1380 let loss = hinge_loss.apply(&inputs, &targets).unwrap();
1381 assert_eq!(loss.item::<f32>(), 1.0);
1382 }
1383
1384 #[test]
1385 fn test_huber_loss() {
1386 let inputs = array!([[1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]]);
1387 let targets = array!([[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]);
1388 let huber_loss = HuberLossBuilder::new()
1389 .reduction(LossReduction::Mean)
1390 .build()
1391 .unwrap();
1392 let loss = huber_loss.apply(&inputs, &targets).unwrap();
1393 assert_eq!(loss.item::<f32>(), 0.5);
1394 }
1395
1396 #[test]
1397 fn test_log_cosh_loss() {
1398 let inputs = array!([[1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]]);
1399 let targets = array!([[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]);
1400 let log_cosh_loss = LogCoshLossBuilder::new()
1401 .reduction(LossReduction::Mean)
1402 .build()
1403 .unwrap();
1404 let loss = log_cosh_loss.apply(&inputs, &targets).unwrap();
1405 assert_float_eq!(loss.item::<f32>(), 0.433781, abs <= 1e-6);
1406 }
1407
1408 #[test]
1409 fn test_cosine_similarity_loss() {
1410 let embeddings1 = array!([[0.5, 0.5, 0.2, 0.9], [0.1, 0.3, 0.5, 0.5]]);
1411 let embeddings2 = array!([[0.6, 0.4, 0.3, 0.8], [0.2, 0.5, 0.6, 0.4]]);
1412
1413 let cosine_similarity_loss = CosineSimilarityLossBuilder::new()
1415 .reduction(LossReduction::None)
1416 .build()
1417 .unwrap();
1418 let loss_none = cosine_similarity_loss
1419 .apply(&embeddings1, &embeddings2)
1420 .unwrap();
1421 let expected_none = array!([0.985344, 0.961074]);
1422 assert_array_eq!(loss_none, expected_none);
1423
1424 let cosine_similarity_loss = CosineSimilarityLossBuilder::new()
1426 .reduction(LossReduction::Mean)
1427 .build()
1428 .unwrap();
1429 let loss_mean = cosine_similarity_loss
1430 .apply(&embeddings1, &embeddings2)
1431 .unwrap();
1432 let expected_mean = expected_none.mean(None).unwrap();
1433 assert_array_eq!(loss_mean, expected_mean);
1434
1435 let cosine_similarity_loss = CosineSimilarityLossBuilder::new()
1437 .reduction(LossReduction::Sum)
1438 .build()
1439 .unwrap();
1440 let loss_sum = cosine_similarity_loss
1441 .apply(&embeddings1, &embeddings2)
1442 .unwrap();
1443 let expected_sum = expected_none.sum(None).unwrap();
1444 assert_array_eq!(loss_sum, expected_sum);
1445 }
1446
1447 #[test]
1448 fn test_margin_ranking_loss() {
1449 let inputs1 = array!([-0.573409, -0.765166, -0.0638]);
1450 let inputs2 = array!([0.75596, 0.225763, 0.256995]);
1451 let targets = array!([1, 1, -1]);
1452
1453 let margin_ranking_loss = MarginRankingLossBuilder::new()
1455 .reduction(LossReduction::None)
1456 .build()
1457 .unwrap();
1458 let loss = margin_ranking_loss
1459 .apply(&inputs1, &inputs2, &targets)
1460 .unwrap();
1461 let expected = array!([1.329369, 0.990929, 0.0]);
1462 assert_array_eq!(loss, expected);
1463
1464 let margin_ranking_loss = MarginRankingLossBuilder::new()
1466 .margin(0.5)
1467 .reduction(LossReduction::None)
1468 .build()
1469 .unwrap();
1470 let loss = margin_ranking_loss
1471 .apply(&inputs1, &inputs2, &targets)
1472 .unwrap();
1473 let expected = array!([1.829369, 1.490929, 0.179205]);
1474 assert_array_eq!(loss, expected);
1475 }
1476}