mlx_rs/
losses.rs

1//! Loss functions
2
3use crate::{
4    array,
5    error::{CrossEntropyBuildError, Exception},
6    ops::{
7        abs, clip, exp, indexing::take_along_axis, log, logaddexp, logsumexp_axes, maximum,
8        minimum, multiply, power, r#where, sqrt, square, sum_axes, sum_axis,
9    },
10    Array,
11};
12use mlx_internal_macros::{generate_builder, Buildable};
13
14#[inline]
15fn check_shape(
16    left: &Array,
17    right: &Array,
18    left_ident: &str,
19    right_ident: &str,
20) -> Result<(), Exception> {
21    if left.shape() != right.shape() {
22        return Err(Exception::custom(format!(
23            "The shape of the {} ({:?}) does not match the shape of the {} ({:?})",
24            left_ident,
25            left.shape(),
26            right_ident,
27            right.shape()
28        )));
29    }
30    Ok(())
31}
32
33/// Different types of loss reductions
34#[derive(Debug, Clone, Copy)]
35pub enum LossReduction {
36    /// No reduction is applied.
37    None,
38    /// The sum of the output will be computed.
39    Sum,
40    /// The mean of the output will be computed.
41    Mean,
42}
43
44impl LossReduction {
45    /// Reduces the loss according to the reduction type.
46    pub fn reduce(&self, loss: Array) -> Result<Array, Exception> {
47        match self {
48            LossReduction::None => Ok(loss),
49            LossReduction::Sum => Ok(loss.sum(None)?),
50            LossReduction::Mean => Ok(loss.mean(None)?),
51        }
52    }
53}
54
55/// Helper type alias for CrossEntropyBuilder weights.
56pub type CrossEntropyBuilderWeights<'a> = &'a Array;
57
58generate_builder! {
59    /// Cross entropy loss function.
60    #[derive(Debug, Clone, Buildable)]
61    #[buildable(root = crate)]
62    #[builder(
63        root = crate,
64        build_with = build_cross_entropy,
65        err = CrossEntropyBuildError
66    )]
67    pub struct CrossEntropy<'a> {
68        /// Weights for each target
69        #[builder(optional, default = CrossEntropy::DEFAULT_WEIGHTS)]
70        pub weights: Option<&'a Array>,
71
72        /// The axis over which to compute softmax. Default to [`CrossEntropy::DEFAULT_AXIS`]
73        #[builder(optional, default = CrossEntropy::DEFAULT_AXIS)]
74        pub axis: i32,
75
76        /// The label smoothing factor, range [0, 1). Default to
77        /// [`CrossEntropy::DEFAULT_LABEL_SMOOTHING`]
78        #[builder(optional, default = CrossEntropy::DEFAULT_LABEL_SMOOTHING)]
79        pub label_smoothing: f32,
80
81        /// Reduction type. Default to [`CrossEntropy::DEFAULT_REDUCTION`]
82        #[builder(optional, default = CrossEntropy::DEFAULT_REDUCTION)]
83        pub reduction: LossReduction,
84    }
85}
86
87fn build_cross_entropy(
88    builder: CrossEntropyBuilder,
89) -> Result<CrossEntropy, CrossEntropyBuildError> {
90    let axis = builder.axis;
91    let label_smoothing = builder.label_smoothing;
92    let reduction = builder.reduction;
93
94    if !(0.0..1.0).contains(&label_smoothing) {
95        return Err(CrossEntropyBuildError::InvalidLabelSmoothingFactor);
96    }
97
98    Ok(CrossEntropy {
99        weights: builder.weights,
100        axis,
101        label_smoothing,
102        reduction,
103    })
104}
105
106impl<'a> CrossEntropy<'a> {
107    /// Default value for the `axis` parameter.
108    pub const DEFAULT_AXIS: i32 = -1;
109
110    /// Default value for the `label_smoothing` parameter.
111    pub const DEFAULT_LABEL_SMOOTHING: f32 = 0.0;
112
113    /// Default value for the `reduction` parameter.
114    pub const DEFAULT_REDUCTION: LossReduction = LossReduction::None;
115
116    /// Default value for the `weights` parameter.
117    pub const DEFAULT_WEIGHTS: Option<&'a Array> = None;
118
119    /// Apply the cross entropy loss function on the given logits and targets.
120    ///
121    /// # Params
122    ///
123    /// - `logits`: unnormalized predicted logits
124    /// - `targets`: target values, as class indices
125    pub fn apply(
126        &self,
127        logits: impl AsRef<Array>,
128        targets: impl AsRef<Array>,
129    ) -> Result<Array, Exception> {
130        let logits = logits.as_ref();
131        let targets = targets.as_ref();
132
133        let target_as_probs = targets.ndim() == logits.ndim();
134
135        let score = if target_as_probs {
136            sum_axes(&logits.multiply(targets)?, &[self.axis], None)?
137        } else {
138            take_along_axis(logits, &targets.expand_dims_axes(&[-1])?, self.axis)?
139                .squeeze_axes(&[-1])?
140        };
141        let log_sum_exp_logits = logsumexp_axes(logits, &[self.axis], None)?;
142
143        let mut loss = if self.label_smoothing > 0.0 {
144            // adjust the true class score with label smoothing
145            let adjusted_score = multiply(array!(1.0 - self.label_smoothing), score)?;
146
147            // calculate the mean logit across the classes for smoothed loss
148            let mean_logits = logits.mean_axis(self.axis, None)?;
149            let smoothed_loss = -multiply(mean_logits, array!(self.label_smoothing))?;
150
151            // combine the adjusted score and smoothed loss with the logsumexp logits
152            log_sum_exp_logits
153                .subtract(adjusted_score)?
154                .add(smoothed_loss)?
155        } else {
156            log_sum_exp_logits.subtract(score)?
157        };
158
159        if let Some(weights) = self.weights {
160            check_shape(weights, &loss, "weights", "loss")?;
161            loss = multiply(loss, weights)?;
162        }
163
164        self.reduction.reduce(loss)
165    }
166}
167
168generate_builder! {
169    /// Binary cross entropy loss.
170    ///
171    /// By default, this function takes the pre-sigmoid logits, which results in a faster
172    /// and more precise loss. For improved numerical stability when `inputs_are_logits` is true,
173    /// the loss calculation clips the input probabilities (in log-space) to a minimum value
174    /// of `-100`.
175    #[derive(Debug, Clone, Buildable)]
176    #[buildable(root = crate)]
177    #[builder(root = crate)]
178    pub struct BinaryCrossEntropy<'a> {
179        /// Optional weights for each target
180        #[builder(optional, default = BinaryCrossEntropy::DEFAULT_WEIGHTS)]
181        pub weights: Option<&'a Array>,
182
183        /// Whether the inputs are logits. Default to
184        /// [`BinaryCrossEntropy::DEFAULT_INPUTS_ARE_LOGITS`]
185        #[builder(optional, default = BinaryCrossEntropy::DEFAULT_INPUTS_ARE_LOGITS)]
186        pub inputs_are_logits: bool,
187
188        /// Reduction type. Default to [`BinaryCrossEntropy::DEFAULT_REDUCTION`]
189        #[builder(optional, default = BinaryCrossEntropy::DEFAULT_REDUCTION)]
190        pub reduction: LossReduction,
191    }
192}
193
194impl<'a> BinaryCrossEntropy<'a> {
195    /// Default value for the `weights` parameter.
196    pub const DEFAULT_WEIGHTS: Option<&'a Array> = None;
197
198    /// Default value for the `with_logits` parameter.
199    pub const DEFAULT_INPUTS_ARE_LOGITS: bool = true;
200
201    /// Default value for the `reduction` parameter.
202    pub const DEFAULT_REDUCTION: LossReduction = LossReduction::None;
203
204    /// Apply the binary cross entropy loss function on the given logits and targets.
205    ///
206    /// # Params
207    ///
208    /// - `logits`: unnormalized predicted logits
209    /// - `targets`: binary target values in {0, 1}
210    pub fn apply(
211        &self,
212        logits: impl AsRef<Array>,
213        targets: impl AsRef<Array>,
214    ) -> Result<Array, Exception> {
215        let logits = logits.as_ref();
216        let targets = targets.as_ref();
217        let weights = self.weights;
218        let inputs_are_logits = self.inputs_are_logits;
219        let reduction = self.reduction;
220
221        let mut loss = if inputs_are_logits {
222            logaddexp(array!(0.0), logits)?.subtract(targets.multiply(logits)?)?
223        } else {
224            let log_inputs_clip = clip(log(logits)?, (-100.0, ()))?;
225            let log_inputs_inverse_clip = clip(log(&array!(1.0).subtract(logits)?)?, (-100.0, ()))?;
226            -(targets.multiply(log_inputs_clip)?.add(
227                array!(1.0)
228                    .subtract(targets)?
229                    .multiply(log_inputs_inverse_clip)?,
230            )?)
231        };
232
233        if let Some(weights) = weights {
234            check_shape(weights, &loss, "weights", "loss")?;
235            loss = multiply(loss, weights)?;
236        }
237
238        reduction.reduce(loss)
239    }
240}
241
242generate_builder! {
243    /// Computes the L1 loss
244    #[derive(Debug, Clone, Buildable)]
245    #[buildable(root = crate)]
246    #[builder(root = crate)]
247    pub struct L1Loss {
248        /// Reduction type. Default to [`L1loss::DEFAULT_REDUCTION`]
249        #[builder(optional, default = L1Loss::DEFAULT_REDUCTION)]
250        pub reduction: LossReduction,
251    }
252}
253
254impl L1Loss {
255    /// Default value for the `reduction` parameter.
256    pub const DEFAULT_REDUCTION: LossReduction = LossReduction::Mean;
257
258    /// Compute the L1 loss.
259    ///
260    /// # Params
261    ///
262    /// - `predictions`: predicted values
263    /// - `targets`: target values
264    pub fn apply(
265        &self,
266        predictions: impl AsRef<Array>,
267        targets: impl AsRef<Array>,
268    ) -> Result<Array, Exception> {
269        let predictions = predictions.as_ref();
270        let targets = targets.as_ref();
271        let reduction = self.reduction;
272
273        check_shape(predictions, targets, "predictions", "targets")?;
274        let loss = predictions.subtract(targets)?.abs()?;
275        reduction.reduce(loss)
276    }
277}
278
279generate_builder! {
280    /// Computes the mean squared error loss.
281    #[derive(Debug, Clone, Buildable)]
282    #[buildable(root = crate)]
283    #[builder(root = crate)]
284    pub struct MseLoss {
285        /// Reduction type. Default to [`MseLoss::DEFAULT_REDUCTION`]
286        #[builder(optional, default = MseLoss::DEFAULT_REDUCTION)]
287        pub reduction: LossReduction,
288    }
289}
290
291impl MseLoss {
292    /// Default value for the reduction parameter.
293    pub const DEFAULT_REDUCTION: LossReduction = LossReduction::Mean;
294
295    /// Compute the mean squared error loss.
296    ///
297    /// # Params
298    ///
299    /// - `predictions`: predicted values
300    /// - `targets`: target values
301    pub fn apply(
302        &self,
303        predictions: impl AsRef<Array>,
304        targets: impl AsRef<Array>,
305    ) -> Result<Array, Exception> {
306        let predictions = predictions.as_ref();
307        let targets = targets.as_ref();
308        let reduction = self.reduction;
309
310        check_shape(predictions, targets, "predictions", "targets")?;
311        let loss = predictions.subtract(targets)?.square()?;
312        reduction.reduce(loss)
313    }
314}
315
316generate_builder! {
317    /// Computes the negative log likelihood loss.
318    #[derive(Debug, Clone, Buildable)]
319    #[buildable(root = crate)]
320    #[builder(root = crate)]
321    pub struct NllLoss {
322        /// distribution axis. Default to [`NllLoss::DEFAULT_AXIS`]
323        #[builder(optional, default = NllLoss::DEFAULT_AXIS)]
324        pub axis: i32,
325
326        /// Reduction type. Default to [`NllLoss::DEFAULT_REDUCTION`]
327        #[builder(optional, default = NllLoss::DEFAULT_REDUCTION)]
328        pub reduction: LossReduction,
329    }
330}
331
332impl NllLoss {
333    /// Default value for the `axis` parameter.
334    pub const DEFAULT_AXIS: i32 = -1;
335
336    /// Default value for the `reduction` parameter.
337    pub const DEFAULT_REDUCTION: LossReduction = LossReduction::None;
338
339    /// Compute the negative log likelihood loss.
340    ///
341    /// # Params
342    ///
343    /// - `inputs`: predicted distribution in log space
344    /// - `targets`: target values
345    pub fn apply(
346        &self,
347        inputs: impl AsRef<Array>,
348        targets: impl AsRef<Array>,
349    ) -> Result<Array, Exception> {
350        let inputs = inputs.as_ref();
351        let targets = targets.as_ref();
352        let axis = self.axis;
353        let reduction = self.reduction;
354
355        let loss = -take_along_axis(inputs, &targets.expand_dims_axes(&[-1])?, axis)?
356            .squeeze_axes(&[-1])?;
357        reduction.reduce(loss)
358    }
359}
360
361generate_builder! {
362    /// Compute the negative log likelihood loss for a Gaussian distribution.
363    #[derive(Debug, Clone, Buildable)]
364    #[buildable(root = crate)]
365    #[builder(root = crate)]
366    pub struct GaussianNllLoss {
367        /// Whether to include the constant term in the loss calculation. Default to
368        /// [`GaussianNllLoss::DEFAULT_FULL`]
369        #[builder(optional, default = GaussianNllLoss::DEFAULT_FULL)]
370        pub full: bool,
371
372        /// Small positive constant for numerical stability. Default to
373        /// [`GaussianNllLoss::DEFAULT_EPS`]
374        #[builder(optional, default = GaussianNllLoss::DEFAULT_EPS)]
375        pub eps: f32,
376
377        /// Reduction type. Default to [`GaussianNllLoss::DEFAULT_REDUCTION`]
378        #[builder(optional, default = GaussianNllLoss::DEFAULT_REDUCTION)]
379        pub reduction: LossReduction,
380    }
381}
382
383impl GaussianNllLoss {
384    /// Default value for the `full` parameter.
385    pub const DEFAULT_FULL: bool = false;
386
387    /// Default value for the `eps` parameter.
388    pub const DEFAULT_EPS: f32 = 1e-6;
389
390    /// Default value for the `reduction` parameter.
391    pub const DEFAULT_REDUCTION: LossReduction = LossReduction::None;
392
393    /// Compute the negative log likelihood loss for a Gaussian distribution.
394    ///
395    /// # Params
396    ///
397    /// - `inputs`: The predicted expectation of the Gaussian distribution.
398    /// - `targets`: The target values (samples from the Gaussian distribution).
399    /// - `vars`: The predicted variance of the Gaussian distribution.
400    pub fn apply(
401        &self,
402        inputs: impl AsRef<Array>,
403        targets: impl AsRef<Array>,
404        vars: impl AsRef<Array>,
405    ) -> Result<Array, Exception> {
406        let inputs = inputs.as_ref();
407        let targets = targets.as_ref();
408        let vars = vars.as_ref();
409        let full = self.full;
410        let eps = self.eps;
411        let reduction = self.reduction;
412
413        check_shape(inputs, targets, "inputs", "targets")?;
414        check_shape(inputs, vars, "inputs", "vars")?;
415
416        let vars = maximum(vars, array!(eps))?;
417        let mut loss =
418            array!(0.5) * (log(&vars)?.add(square(&targets.subtract(inputs)?)?.divide(&vars)?)?);
419
420        if full {
421            let pi = array!(std::f32::consts::PI);
422            loss = loss.add(array!(0.5).multiply(log(&array!(2.0).multiply(pi)?)?)?)?;
423        }
424
425        reduction.reduce(loss)
426    }
427}
428
429generate_builder! {
430    /// Compute the Kullback-Leibler divergence loss.
431    ///
432    /// Computes the following when the `reduction` is `LossReduction::None`:
433    ///
434    /// ```rust, ignore
435    /// sum(exp(targets) * (targets - inputs), axis, None)
436    /// ```
437    #[derive(Debug, Clone, Buildable)]
438    #[buildable(root = crate)]
439    #[builder(root = crate)]
440    pub struct KlDivLoss {
441        /// The distribution axis. Default to [`KlDivLoss::DEFAULT_AXIS`]
442        #[builder(optional, default = KlDivLoss::DEFAULT_AXIS)]
443        pub axis: i32,
444
445        /// Reduction type. Default to [`KlDivLoss::DEFAULT_REDUCTION`]
446        #[builder(optional, default = KlDivLoss::DEFAULT_REDUCTION)]
447        pub reduction: LossReduction,
448    }
449}
450
451impl KlDivLoss {
452    /// Default value for the `axis` parameter.
453    pub const DEFAULT_AXIS: i32 = -1;
454
455    /// Default value for the `reduction` parameter.
456    pub const DEFAULT_REDUCTION: LossReduction = LossReduction::None;
457
458    /// Compute the Kullback-Leibler divergence loss.
459    ///
460    /// # Params
461    ///
462    /// - `inputs`: Log probabilities for the predicted distribution.
463    /// - `targets`: Log probabilities for the target distribution.
464    pub fn apply(
465        &self,
466        inputs: impl AsRef<Array>,
467        targets: impl AsRef<Array>,
468    ) -> Result<Array, Exception> {
469        let inputs = inputs.as_ref();
470        let targets = targets.as_ref();
471        let axis = self.axis;
472        let reduction = self.reduction;
473
474        let loss = sum_axis(
475            &exp(targets)?.multiply(targets.subtract(inputs)?)?,
476            axis,
477            None,
478        )?;
479        reduction.reduce(loss)
480    }
481}
482
483generate_builder! {
484    /// Computes the smooth L1 loss.
485    ///
486    /// The smooth L1 loss is a variant of the L1 loss which replaces the absolute
487    /// difference with a squared difference when the absolute difference is less
488    /// than `beta`.
489    #[derive(Debug, Clone, Buildable)]
490    #[buildable(root = crate)]
491    #[builder(root = crate)]
492    pub struct SmoothL1Loss {
493        /// The threshold after which the loss changes from the squared to the absolute difference.
494        /// Default to [`SmoothL1Loss::DEFAULT_BETA`]
495        #[builder(optional, default = SmoothL1Loss::DEFAULT_BETA)]
496        pub beta: f32,
497
498        /// Reduction type. Default to [`SmoothL1Loss::DEFAULT_REDUCTION`]
499        #[builder(optional, default = SmoothL1Loss::DEFAULT_REDUCTION)]
500        pub reduction: LossReduction,
501    }
502}
503
504impl SmoothL1Loss {
505    /// Default value for the `beta` parameter.
506    pub const DEFAULT_BETA: f32 = 1.0;
507
508    /// Default value for the `reduction` parameter.
509    pub const DEFAULT_REDUCTION: LossReduction = LossReduction::Mean;
510
511    /// Compute the smooth L1 loss.
512    ///
513    /// # Params
514    ///
515    /// - `predictions`: predicted values
516    /// - `targets`: target values
517    pub fn apply(
518        &self,
519        predictions: impl AsRef<Array>,
520        targets: impl AsRef<Array>,
521    ) -> Result<Array, Exception> {
522        let predictions = predictions.as_ref();
523        let targets = targets.as_ref();
524        let beta = self.beta;
525        let reduction = self.reduction;
526
527        check_shape(predictions, targets, "predictions", "targets")?;
528        let diff = predictions.subtract(targets)?.abs()?;
529        let beta = array!(beta);
530        let loss = r#where(
531            &diff.lt(&beta)?,
532            array!(0.5).multiply(square(&diff)?)?.divide(&beta)?,
533            diff.subtract(array!(0.5).multiply(beta)?)?,
534        )?;
535        reduction.reduce(loss)
536    }
537}
538
539generate_builder! {
540    /// Computes the triplet loss for a set of anchor, positive, and negative samples. Margin is
541    /// represented with alpha in the math section.
542    #[derive(Debug, Clone, Buildable)]
543    #[buildable(root = crate)]
544    #[builder(root = crate)]
545    pub struct TripletLoss {
546        /// Distribution axis. Default to [`TripletLoss::DEFAULT_AXIS`]
547        #[builder(optional, default = TripletLoss::DEFAULT_AXIS)]
548        pub axis: i32,
549
550        /// The norm degree for pairwise distance. Default to [`TripletLoss::DEFAULT_P`]
551        #[builder(optional, default = TripletLoss::DEFAULT_P)]
552        pub p: f32,
553
554        /// Margin for the triplet loss. Default to [`TripletLoss::DEFAULT_MARGIN`]
555        #[builder(optional, default = TripletLoss::DEFAULT_MARGIN)]
556        pub margin: f32,
557
558        /// Small positive constant for numerical stability. Default to [`TripletLoss::DEFAULT_EPS`]
559        #[builder(optional, default = TripletLoss::DEFAULT_EPS)]
560        pub eps: f32,
561
562        /// Reduction type. Default to [`TripletLoss::DEFAULT_REDUCTION`]
563        #[builder(optional, default = TripletLoss::DEFAULT_REDUCTION)]
564        pub reduction: LossReduction,
565    }
566}
567
568impl TripletLoss {
569    /// Default value for the `axis` parameter.
570    pub const DEFAULT_AXIS: i32 = -1;
571
572    /// Default value for the `p` parameter.
573    pub const DEFAULT_P: f32 = 2.0;
574
575    /// Default value for the `margin` parameter.
576    pub const DEFAULT_MARGIN: f32 = 1.0;
577
578    /// Default value for the `eps` parameter.
579    pub const DEFAULT_EPS: f32 = 1e-6;
580
581    /// Default value for the `reduction` parameter.
582    pub const DEFAULT_REDUCTION: LossReduction = LossReduction::None;
583
584    /// Computes the triplet loss for a set of anchor, positive, and negative samples. Margin is
585    /// represented with alpha in the math section.
586    ///
587    /// # Params
588    ///
589    /// - `anchors`: The anchor samples
590    /// - `positives`: The positive samples
591    /// - `neonatives`: The negative samples
592    pub fn apply(
593        &self,
594        anchors: impl AsRef<Array>,
595        positives: impl AsRef<Array>,
596        negatives: impl AsRef<Array>,
597    ) -> Result<Array, Exception> {
598        let anchors = anchors.as_ref();
599        let positives = positives.as_ref();
600        let negatives = negatives.as_ref();
601        let axis = self.axis;
602        let p = self.p;
603        let margin = self.margin;
604        let eps = self.eps;
605        let reduction = self.reduction;
606
607        let eps = array!(eps);
608        let p = array!(p);
609        let margin = array!(margin);
610
611        let pos = sqrt(
612            &power(&anchors.subtract(positives)?, &p)?
613                .sum_axis(axis, None)?
614                .add(&eps)?,
615        )?;
616        let neg = sqrt(
617            &power(&anchors.subtract(negatives)?, &p)?
618                .sum_axis(axis, None)?
619                .add(&eps)?,
620        )?;
621        let loss = maximum(pos.subtract(neg)?.add(margin)?, array!(0.0))?;
622        reduction.reduce(loss)
623    }
624}
625
626generate_builder! {
627    /// Compute the hinge loss.
628    #[derive(Debug, Clone, Buildable)]
629    #[buildable(root = crate)]
630    #[builder(root = crate)]
631    pub struct HingeLoss {
632        /// Reduction type. Default to [`HingeLoss::DEFAULT_REDUCTION`]
633        #[builder(optional, default = HingeLoss::DEFAULT_REDUCTION)]
634        pub reduction: LossReduction,
635    }
636}
637
638impl HingeLoss {
639    /// Default value for the `reduction` parameter.
640    pub const DEFAULT_REDUCTION: LossReduction = LossReduction::None;
641
642    /// Compute the hinge loss.
643    ///
644    /// # Params
645    ///
646    /// - `inputs`: predicted values
647    /// - `targets`: target values, -1 or 1
648    pub fn apply(
649        &self,
650        inputs: impl AsRef<Array>,
651        targets: impl AsRef<Array>,
652    ) -> Result<Array, Exception> {
653        let inputs = inputs.as_ref();
654        let targets = targets.as_ref();
655        let reduction = self.reduction;
656
657        let a = array!(1.0).subtract(inputs.multiply(targets)?)?;
658        let b = array!(0.0);
659        let loss = maximum(a, b)?;
660        reduction.reduce(loss)
661    }
662}
663
664generate_builder! {
665    /// Compute the Huber loss.
666    #[derive(Debug, Clone, Buildable)]
667    #[buildable(root = crate)]
668    #[builder(root = crate)]
669    pub struct HuberLoss {
670        /// The threshold at which to change between L1 and L2 loss. Default to
671        /// [`HuberLoss::DEFAULT_DELTA`]
672        #[builder(optional, default = HuberLoss::DEFAULT_DELTA)]
673        pub delta: f32,
674
675        /// Reduction type. Default to [`HuberLoss::DEFAULT_REDUCTION`]
676        #[builder(optional, default = HuberLoss::DEFAULT_REDUCTION)]
677        pub reduction: LossReduction,
678    }
679}
680
681impl HuberLoss {
682    /// Default value for the `delta` parameter.
683    pub const DEFAULT_DELTA: f32 = 1.0;
684
685    /// Default value for the `reduction` parameter.
686    pub const DEFAULT_REDUCTION: LossReduction = LossReduction::None;
687
688    /// Compute the Huber loss.
689    ///
690    /// # Params
691    ///
692    /// - `inputs`: predicted values
693    /// - `targets`: target values
694    pub fn apply(
695        &self,
696        inputs: impl AsRef<Array>,
697        targets: impl AsRef<Array>,
698    ) -> Result<Array, Exception> {
699        let inputs = inputs.as_ref();
700        let targets = targets.as_ref();
701        let delta = self.delta;
702        let reduction = self.reduction;
703
704        let errors = inputs.subtract(targets)?;
705        let abs_errors = errors.abs()?;
706        let quadratic = minimum(&abs_errors, array!(delta))?;
707        let linear = abs_errors.subtract(&quadratic)?;
708        let loss = array!(0.5)
709            .multiply(square(&quadratic)?)?
710            .add(array!(delta).multiply(linear)?)?;
711        reduction.reduce(loss)
712    }
713}
714
715generate_builder! {
716    /// Computes the log cosh loss between inputs and targets.
717    ///
718    /// Logcosh acts like L2 loss for small errors, ensuring stable gradients,
719    /// and like the L1 loss for large errors, reducing sensitivity to outliers. This
720    /// dual behavior offers a balanced, robust approach for regression tasks.
721    #[derive(Debug, Clone, Buildable)]
722    #[buildable(root = crate)]
723    #[builder(root = crate)]
724    pub struct LogCoshLoss {
725        /// Reduction type. Default to [`LogCoshLoss::DEFAULT_REDUCTION`]
726        #[builder(optional, default = LogCoshLoss::DEFAULT_REDUCTION)]
727        pub reduction: LossReduction,
728    }
729}
730
731impl LogCoshLoss {
732    /// Default value for the `reduction` parameter.
733    pub const DEFAULT_REDUCTION: LossReduction = LossReduction::None;
734
735    /// Computes the log cosh loss between inputs and targets.
736    ///
737    /// # Params
738    ///
739    /// - `inputs`: predicted values
740    /// - `targets`: target values
741    pub fn apply(
742        &self,
743        inputs: impl AsRef<Array>,
744        targets: impl AsRef<Array>,
745    ) -> Result<Array, Exception> {
746        let inputs = inputs.as_ref();
747        let targets = targets.as_ref();
748        let reduction = self.reduction;
749
750        let errors = inputs.subtract(targets)?;
751        let neg_errors = errors.negative()?;
752        let loss = logaddexp(errors, neg_errors)?.subtract(log(&array!(2.0))?)?;
753        reduction.reduce(loss)
754    }
755}
756
757generate_builder! {
758    /// Computes the cosine similarity loss.
759    #[derive(Debug, Clone, Buildable)]
760    #[buildable(root = crate)]
761    #[builder(root = crate)]
762    pub struct CosineSimilarityLoss {
763        /// Embedding axis. Default to [`CosineSimilarityLoss::DEFAULT_AXIS`]
764        #[builder(optional, default = CosineSimilarityLoss::DEFAULT_AXIS)]
765        pub axis: i32,
766
767        /// minimum value of the denominator used for numerical stability. Default to
768        /// [`CosineSimilarityLoss::DEFAULT_EPS`]
769        #[builder(optional, default = CosineSimilarityLoss::DEFAULT_EPS)]
770        pub eps: f32,
771
772        /// Reduction type. Default to [`CosineSimilarityLoss::DEFAULT_REDUCTION`]
773        #[builder(optional, default = CosineSimilarityLoss::DEFAULT_REDUCTION)]
774        pub reduction: LossReduction,
775    }
776}
777
778impl CosineSimilarityLoss {
779    /// Default value for the `axis` parameter.
780    pub const DEFAULT_AXIS: i32 = -1;
781
782    /// Default value for the `eps` parameter.
783    pub const DEFAULT_EPS: f32 = 1e-8;
784
785    /// Default value for the `reduction` parameter.
786    pub const DEFAULT_REDUCTION: LossReduction = LossReduction::None;
787
788    /// Computes the cosine similarity loss.
789    ///
790    /// # Params
791    ///
792    /// - `x1`: first array
793    /// - `x2`: second array
794    pub fn apply(&self, x1: impl AsRef<Array>, x2: impl AsRef<Array>) -> Result<Array, Exception> {
795        let x1 = x1.as_ref();
796        let x2 = x2.as_ref();
797        let axis = self.axis;
798        let eps = self.eps;
799        let reduction = self.reduction;
800
801        fn l2_loss(a: &Array, axis: i32) -> Result<Array, Exception> {
802            if a.dtype().is_complex() {
803                Ok(sqrt(&sum_axis(&abs(a)?.square()?, axis, None)?)?)
804            } else {
805                Ok(sqrt(&sum_axis(&a.square()?, axis, None)?)?)
806            }
807        }
808
809        let x1_norm = l2_loss(x1, axis)?;
810        let x2_norm = l2_loss(x2, axis)?;
811
812        let num = sum_axis(&x1.multiply(x2)?, axis, None)?;
813        let den = maximum(x1_norm.multiply(x2_norm)?, array!(eps))?;
814        let loss = num.divide(&den)?;
815
816        reduction.reduce(loss)
817    }
818}
819
820generate_builder! {
821    /// Computes the margin ranking loss.
822    #[derive(Debug, Clone, Buildable)]
823    #[buildable(root = crate)]
824    #[builder(root = crate)]
825    pub struct MarginRankingLoss {
826        /// The margin by which the scores should be separated. Default to
827        /// [`MarginRankingLoss::DEFAULT_MARGIN`]
828        #[builder(optional, default = MarginRankingLoss::DEFAULT_MARGIN)]
829        pub margin: f32,
830
831        /// Reduction type. Default to [`MarginRankingLoss::DEFAULT_REDUCTION`]
832        #[builder(optional, default = MarginRankingLoss::DEFAULT_REDUCTION)]
833        pub reduction: LossReduction,
834    }
835}
836
837impl MarginRankingLoss {
838    /// Default value for the `margin` parameter.
839    pub const DEFAULT_MARGIN: f32 = 0.0;
840
841    /// Default value for the `reduction` parameter.
842    pub const DEFAULT_REDUCTION: LossReduction = LossReduction::None;
843
844    /// Computes the margin ranking loss.
845    ///
846    /// # Params
847    ///
848    /// - `inputs1`: Scores for the first input.
849    /// - `inputs2`: Scores for the second input.
850    /// - `targets`: Labels indicating whether samples in `inputs1` should be ranked higher than samples
851    ///   in `inputs2`. Values should be 1 or -1.
852    pub fn apply(
853        &self,
854        inputs1: impl AsRef<Array>,
855        inputs2: impl AsRef<Array>,
856        targets: impl AsRef<Array>,
857    ) -> Result<Array, Exception> {
858        let inputs1 = inputs1.as_ref();
859        let inputs2 = inputs2.as_ref();
860        let targets = targets.as_ref();
861        let margin = self.margin;
862        let reduction = self.reduction;
863
864        check_shape(inputs1, inputs2, "inputs1", "inputs2")?;
865        check_shape(inputs1, targets, "inputs1", "targets")?;
866
867        let margin = array!(margin);
868        let diff = inputs1.subtract(inputs2)?;
869        let loss = maximum(
870            array!(0.0),
871            targets.multiply(diff)?.negative()?.add(margin)?,
872        )?;
873        reduction.reduce(loss)
874    }
875}
876
877#[cfg(test)]
878#[allow(clippy::approx_constant)]
879mod tests {
880    use crate::{array, assert_array_eq, builder::Builder, ops::is_nan};
881    use float_eq::assert_float_eq;
882
883    use super::*;
884
885    // The following unit tests are adapted from the python API at: mlx/python/tests/test_losses.py
886
887    #[test]
888    fn test_cross_entropy() {
889        // No weights, no label smoothing
890        let logits = array!([[0.0, f32::NEG_INFINITY], [f32::NEG_INFINITY, 0.0]]);
891        let indices = array!([0, 1]);
892        let expected = array!([0.0, 0.0]);
893        let loss = CrossEntropy::new()
894            .unwrap()
895            .apply(&logits, indices)
896            .unwrap();
897        assert_array_eq!(loss, expected);
898
899        let probs = array!([[1.0, 0.0], [0.0, 1.0]]);
900        let cross_entropy = CrossEntropyBuilder::new()
901            .reduction(LossReduction::None)
902            .build()
903            .unwrap();
904        let loss = cross_entropy.apply(logits, probs).unwrap();
905        assert!(is_nan(&loss).unwrap().all(None).unwrap().item::<bool>());
906
907        // With weights, no label smoothing
908        let logits = array!([[2.0, -1.0], [-1.0, 2.0]]);
909        let indices = array!([0, 1]);
910        let weights = array!([1.0, 2.0]);
911        let expected = array!([0.04858735, 0.0971747]);
912        let cross_entropy = CrossEntropyBuilder::new()
913            .weights(&weights)
914            .reduction(LossReduction::None)
915            .build()
916            .unwrap();
917        let loss = cross_entropy.apply(&logits, indices).unwrap();
918        assert_array_eq!(loss, expected);
919
920        let probs = array!([[1.0, 0.0], [0.0, 1.0]]);
921        let cross_entropy = CrossEntropyBuilder::new()
922            .weights(&weights)
923            .reduction(LossReduction::None)
924            .build()
925            .unwrap();
926        let loss = cross_entropy.apply(logits, probs).unwrap();
927        assert_array_eq!(loss, expected);
928
929        // No weights, with label smoothing
930        let logits = array!([[2.0, -1.0], [-1.0, 2.0]]);
931        let indices = array!([0, 1]);
932        let expected = array!([0.498587, 0.498587]);
933        let cross_entropy = CrossEntropyBuilder::new()
934            .label_smoothing(0.3)
935            .reduction(LossReduction::None)
936            .build()
937            .unwrap();
938        let loss = cross_entropy.apply(&logits, indices).unwrap();
939        assert_array_eq!(loss, expected);
940
941        let probs = array!([[1.0, 0.0], [0.0, 1.0]]);
942        let cross_entropy = CrossEntropyBuilder::new()
943            .label_smoothing(0.3)
944            .reduction(LossReduction::None)
945            .build()
946            .unwrap();
947        let loss = cross_entropy.apply(logits, probs).unwrap();
948        assert_array_eq!(loss, expected);
949
950        // With weights and label smoothing
951        let logits = array!([[2.0, -1.0], [-1.0, 2.0]]);
952        let indices = array!([0, 1]);
953        let weights = array!([1.0, 2.0]);
954        let expected = array!([0.49858734, 0.9971747]);
955        let cross_entropy = CrossEntropyBuilder::new()
956            .weights(&weights)
957            .label_smoothing(0.3)
958            .reduction(LossReduction::None)
959            .build()
960            .unwrap();
961        let loss = cross_entropy.apply(&logits, indices).unwrap();
962        assert_array_eq!(loss, expected);
963
964        let probs = array!([[1.0, 0.0], [0.0, 1.0]]);
965        let cross_entropy = CrossEntropyBuilder::new()
966            .weights(&weights)
967            .label_smoothing(0.3)
968            .reduction(LossReduction::None)
969            .build()
970            .unwrap();
971        let loss = cross_entropy.apply(logits, probs).unwrap();
972        assert_array_eq!(loss, expected);
973    }
974
975    #[test]
976    fn test_binary_cross_entropy_with_logits_as_inputs() {
977        let logits = array!([0.105361, 0.223144, 1.20397, 0.916291]);
978        let targets = array!([0.0, 0.0, 1.0, 1.0]);
979
980        // Test with reduction 'none'
981        let binary_cross_entropy = BinaryCrossEntropyBuilder::new()
982            .reduction(LossReduction::None)
983            .build()
984            .unwrap();
985        let loss_none = binary_cross_entropy.apply(&logits, &targets).unwrap();
986        let expected_none = array!([0.747215, 0.810930, 0.262365, 0.336472]);
987        assert_array_eq!(loss_none, expected_none);
988
989        // Test with reduction 'mean'
990        let binary_cross_entropy = BinaryCrossEntropyBuilder::new()
991            .reduction(LossReduction::Mean)
992            .build()
993            .unwrap();
994        let loss_mean = binary_cross_entropy.apply(&logits, &targets).unwrap();
995        let expected_mean = expected_none.mean(None).unwrap();
996        assert_array_eq!(loss_mean, expected_mean);
997
998        // Test with reduction 'sum'
999        let binary_cross_entropy = BinaryCrossEntropyBuilder::new()
1000            .reduction(LossReduction::Sum)
1001            .build()
1002            .unwrap();
1003        let loss = binary_cross_entropy.apply(&logits, &targets).unwrap();
1004        let expected = expected_none.sum(None).unwrap();
1005        assert_array_eq!(loss, expected);
1006
1007        // With weights, no label smoothing
1008        let weights = array!([1.0, 2.0, 1.0, 2.0]);
1009        let expected = array!([0.747215, 1.62186, 0.262365, 0.672944]);
1010        let binary_cross_entropy = BinaryCrossEntropyBuilder::new()
1011            .weights(&weights)
1012            .reduction(LossReduction::None)
1013            .build()
1014            .unwrap();
1015        let loss = binary_cross_entropy.apply(&logits, &targets).unwrap();
1016        assert_array_eq!(loss, expected);
1017    }
1018
1019    #[test]
1020    fn test_binary_cross_entropy_with_probs_as_inputs() {
1021        let probs = array!([0.5, 0.6, 0.7, 0.8]);
1022        let targets = array!([0.0, 0.0, 1.0, 1.0]);
1023
1024        // Test with reduction 'none'
1025        let binary_cross_entropy = BinaryCrossEntropyBuilder::new()
1026            .inputs_are_logits(false)
1027            .reduction(LossReduction::None)
1028            .build()
1029            .unwrap();
1030        let loss_none = binary_cross_entropy.apply(&probs, &targets).unwrap();
1031        let expected_none = array!([0.693147, 0.916291, 0.356675, 0.223144]);
1032        assert_array_eq!(loss_none, expected_none);
1033
1034        // Test with reduction 'mean'
1035        let binary_cross_entropy = BinaryCrossEntropyBuilder::new()
1036            .inputs_are_logits(false)
1037            .reduction(LossReduction::Mean)
1038            .build()
1039            .unwrap();
1040        let loss_mean = binary_cross_entropy.apply(&probs, &targets).unwrap();
1041        let expected_mean = expected_none.mean(None).unwrap();
1042        assert_array_eq!(loss_mean, expected_mean);
1043
1044        // Test with reduction 'sum'
1045        let binary_cross_entropy = BinaryCrossEntropyBuilder::new()
1046            .inputs_are_logits(false)
1047            .reduction(LossReduction::Sum)
1048            .build()
1049            .unwrap();
1050        let loss = binary_cross_entropy.apply(&probs, &targets).unwrap();
1051        let expected = expected_none.sum(None).unwrap();
1052        assert_array_eq!(loss, expected);
1053    }
1054
1055    #[test]
1056    fn test_binary_cross_entropy_with_tiny_probs_as_inputs() {
1057        let tiny_prob = 1e-59;
1058        let probs = array!([0.0, tiny_prob, 1.0 - tiny_prob, 1.0]);
1059        let targets = array!([0.0, 0.0, 1.0, 1.0]);
1060
1061        // Test with reduction 'none'
1062        let binary_cross_entropy = BinaryCrossEntropyBuilder::new()
1063            .inputs_are_logits(false)
1064            .reduction(LossReduction::None)
1065            .build()
1066            .unwrap();
1067        let loss_none = binary_cross_entropy.apply(&probs, &targets).unwrap();
1068        let expected_none = array!([0.0, tiny_prob, tiny_prob, 0.0]);
1069        assert_array_eq!(loss_none, expected_none);
1070
1071        // Test with reduction 'mean'
1072        let binary_cross_entropy = BinaryCrossEntropyBuilder::new()
1073            .inputs_are_logits(false)
1074            .reduction(LossReduction::Mean)
1075            .build()
1076            .unwrap();
1077        let loss_mean = binary_cross_entropy.apply(&probs, &targets).unwrap();
1078        let expected_mean = expected_none.mean(None).unwrap();
1079        assert_array_eq!(loss_mean, expected_mean);
1080
1081        // Test with reduction 'sum'
1082        let binary_cross_entropy = BinaryCrossEntropyBuilder::new()
1083            .inputs_are_logits(false)
1084            .reduction(LossReduction::Sum)
1085            .build()
1086            .unwrap();
1087        let loss = binary_cross_entropy.apply(&probs, &targets).unwrap();
1088        let expected = expected_none.sum(None).unwrap();
1089        assert_array_eq!(loss, expected);
1090    }
1091
1092    #[test]
1093    fn test_l1_loss() {
1094        let predictions = array!([0.5, 0.2, 0.9, 0.0]);
1095        let targets = array!([0.5, 0.2, 0.9, 0.0]);
1096
1097        let expected_none = array!([0.0, 0.0, 0.0, 0.0]);
1098        let expected_sum = expected_none.sum(None).unwrap();
1099        let expected_mean = expected_none.mean(None).unwrap();
1100
1101        let l1_loss = L1LossBuilder::new()
1102            .reduction(LossReduction::None)
1103            .build()
1104            .unwrap();
1105        let loss_none = l1_loss.apply(&predictions, &targets).unwrap();
1106        assert_array_eq!(loss_none, expected_none);
1107
1108        let l1_loss = L1LossBuilder::new()
1109            .reduction(LossReduction::Sum)
1110            .build()
1111            .unwrap();
1112        let loss_sum = l1_loss.apply(&predictions, &targets).unwrap();
1113        assert_array_eq!(loss_sum, expected_sum);
1114
1115        let l1_loss = L1LossBuilder::new()
1116            .reduction(LossReduction::Mean)
1117            .build()
1118            .unwrap();
1119        let loss_mean = l1_loss.apply(&predictions, &targets).unwrap();
1120        assert_array_eq!(loss_mean, expected_mean);
1121    }
1122
1123    #[test]
1124    fn test_mse_loss() {
1125        let predictions = array!([0.5, 0.2, 0.9, 0.0]);
1126        let targets = array!([0.7, 0.1, 0.8, 0.2]);
1127
1128        let expected_none = array!([0.04, 0.01, 0.01, 0.04]);
1129        let expected_mean = expected_none.mean(None).unwrap();
1130        let expected_sum = expected_none.sum(None).unwrap();
1131
1132        let mse_loss = MseLossBuilder::new()
1133            .reduction(LossReduction::None)
1134            .build()
1135            .unwrap();
1136        let loss_none = mse_loss.apply(&predictions, &targets).unwrap();
1137        assert_array_eq!(loss_none, expected_none);
1138
1139        let mse_loss = MseLossBuilder::new()
1140            .reduction(LossReduction::Mean)
1141            .build()
1142            .unwrap();
1143        let loss_mean = mse_loss.apply(&predictions, &targets).unwrap();
1144        assert_array_eq!(loss_mean, expected_mean);
1145
1146        let mse_loss = MseLossBuilder::new()
1147            .reduction(LossReduction::Sum)
1148            .build()
1149            .unwrap();
1150        let loss_sum = mse_loss.apply(&predictions, &targets).unwrap();
1151        assert_array_eq!(loss_sum, expected_sum);
1152    }
1153
1154    #[test]
1155    fn test_smooth_l1_loss() {
1156        let predictions = array!([1.5, 2.5, 0.5, 3.5]);
1157        let targets = array!([1.0, 2.0, 0.5, 2.5]);
1158        let beta = 1.0;
1159
1160        let expected_none = array!([0.125, 0.125, 0.0, 0.5]);
1161        let expected_sum = expected_none.sum(None).unwrap();
1162        let expected_mean = expected_none.mean(None).unwrap();
1163
1164        let smooth_l1_loss = SmoothL1LossBuilder::new()
1165            .beta(beta)
1166            .reduction(LossReduction::None)
1167            .build()
1168            .unwrap();
1169        let loss_none = smooth_l1_loss.apply(&predictions, &targets).unwrap();
1170        assert_array_eq!(loss_none, expected_none);
1171
1172        let smooth_l1_loss = SmoothL1LossBuilder::new()
1173            .beta(beta)
1174            .reduction(LossReduction::Sum)
1175            .build()
1176            .unwrap();
1177        let loss_sum = smooth_l1_loss.apply(&predictions, &targets).unwrap();
1178        assert_array_eq!(loss_sum, expected_sum);
1179
1180        let smooth_l1_loss = SmoothL1LossBuilder::new()
1181            .beta(beta)
1182            .reduction(LossReduction::Mean)
1183            .build()
1184            .unwrap();
1185        let loss_mean = smooth_l1_loss.apply(&predictions, &targets).unwrap();
1186        assert_array_eq!(loss_mean, expected_mean);
1187    }
1188
1189    #[test]
1190    fn test_smooth_l1_loss_negative_diff() {
1191        let a = array!([1.5, 6.0, 0.5, 2.5]);
1192        let b = array!([1.0, 2.0, 0.5, 3.5]);
1193
1194        let loss = SmoothL1Loss::new();
1195
1196        let ab = loss.apply(&a, &b).unwrap();
1197        let ba = loss.apply(&b, &a).unwrap();
1198        assert_array_eq!(ab, ba);
1199    }
1200
1201    #[test]
1202    fn test_nll_loss() {
1203        let logits = array!([[0.0, f32::NEG_INFINITY], [f32::NEG_INFINITY, 0.0]]);
1204        let targets = array!([0, 1]);
1205
1206        let expected_none = array!([0.0, 0.0]);
1207        let expected_sum = expected_none.sum(None).unwrap();
1208        let expected_mean = expected_none.mean(None).unwrap();
1209
1210        let nll_loss = NllLossBuilder::new()
1211            .reduction(LossReduction::None)
1212            .build()
1213            .unwrap();
1214        let loss_none = nll_loss.apply(&logits, &targets).unwrap();
1215        assert_array_eq!(loss_none, expected_none);
1216
1217        let nll_loss = NllLossBuilder::new()
1218            .reduction(LossReduction::Mean)
1219            .build()
1220            .unwrap();
1221        let loss_mean = nll_loss.apply(&logits, &targets).unwrap();
1222        assert_array_eq!(loss_mean, expected_mean);
1223
1224        let nll_loss = NllLossBuilder::new()
1225            .reduction(LossReduction::Sum)
1226            .build()
1227            .unwrap();
1228        let loss_sum = nll_loss.apply(&logits, &targets).unwrap();
1229        assert_array_eq!(loss_sum, expected_sum);
1230    }
1231
1232    #[test]
1233    fn test_gaussian_nll_loss() {
1234        let inputs = array!([[0.1, 0.2], [0.3, 0.4]]);
1235        let targets = array!([[0.2, 0.1], [0.1, 0.2]]);
1236        let vars = array!([[0.1, 0.2], [0.3, 0.4]]);
1237
1238        // Test with reduction 'none', full=False
1239        let gaussian_nll_loss = GaussianNllLossBuilder::new()
1240            .full(false)
1241            .reduction(LossReduction::None)
1242            .build()
1243            .unwrap();
1244        let loss_none = gaussian_nll_loss.apply(&inputs, &targets, &vars).unwrap();
1245        let expected_none = array!([[-1.101293, -0.779719], [-0.535320, -0.408145]]);
1246        assert_array_eq!(loss_none, expected_none);
1247
1248        // Test with reduction 'mean', full=False
1249        let gaussian_nll_loss = GaussianNllLossBuilder::new()
1250            .full(false)
1251            .reduction(LossReduction::Mean)
1252            .build()
1253            .unwrap();
1254        let loss_mean = gaussian_nll_loss.apply(&inputs, &targets, &vars).unwrap();
1255        let expected_mean = expected_none.mean(None).unwrap();
1256        assert_array_eq!(loss_mean, expected_mean);
1257
1258        // Test with reduction 'sum', full=False
1259        let gaussian_nll_loss = GaussianNllLossBuilder::new()
1260            .full(false)
1261            .reduction(LossReduction::Sum)
1262            .build()
1263            .unwrap();
1264        let loss_sum = gaussian_nll_loss.apply(&inputs, &targets, &vars).unwrap();
1265        let expected_sum = expected_none.sum(None).unwrap();
1266        assert_array_eq!(loss_sum, expected_sum);
1267
1268        // Test with reduction='none', full=True
1269        let gaussian_nll_loss = GaussianNllLossBuilder::new()
1270            .full(true)
1271            .reduction(LossReduction::None)
1272            .build()
1273            .unwrap();
1274        let loss_none_full = gaussian_nll_loss.apply(&inputs, &targets, &vars).unwrap();
1275        let expected_none_full = array!([[-0.182354, 0.139220], [0.383619, 0.510793]]);
1276        assert_array_eq!(loss_none_full, expected_none_full);
1277
1278        // Test with reduction='mean', full=True
1279        let gaussian_nll_loss = GaussianNllLossBuilder::new()
1280            .full(true)
1281            .reduction(LossReduction::Mean)
1282            .build()
1283            .unwrap();
1284        let loss_mean_full = gaussian_nll_loss.apply(&inputs, &targets, &vars).unwrap();
1285        let expected_mean_full = expected_none_full.mean(None).unwrap();
1286        assert_array_eq!(loss_mean_full, expected_mean_full);
1287
1288        // Test with reduction='sum', full=True
1289        let gaussian_nll_loss = GaussianNllLossBuilder::new()
1290            .full(true)
1291            .reduction(LossReduction::Sum)
1292            .build()
1293            .unwrap();
1294        let loss_sum_full = gaussian_nll_loss.apply(&inputs, &targets, &vars).unwrap();
1295        let expected_sum_full = expected_none_full.sum(None).unwrap();
1296        assert_array_eq!(loss_sum_full, expected_sum_full);
1297    }
1298
1299    #[test]
1300    fn test_kl_div_loss() {
1301        let p_logits = array!([[0.5, 0.5], [0.8, 0.2]]).log().unwrap();
1302        let q_logits = array!([[0.5, 0.5], [0.2, 0.8]]).log().unwrap();
1303
1304        // Test with reduction 'none'
1305        let kl_div_loss = KlDivLossBuilder::new()
1306            .reduction(LossReduction::None)
1307            .build()
1308            .unwrap();
1309        let loss_none = kl_div_loss.apply(&p_logits, &q_logits).unwrap();
1310        let expected_none = array!([0.0, 0.831777]);
1311        assert_array_eq!(loss_none, expected_none);
1312
1313        // Test with reduction 'mean'
1314        let kl_div_loss = KlDivLossBuilder::new()
1315            .reduction(LossReduction::Mean)
1316            .build()
1317            .unwrap();
1318        let loss_mean = kl_div_loss.apply(&p_logits, &q_logits).unwrap();
1319        let expected_mean = expected_none.mean(None).unwrap();
1320        assert_array_eq!(loss_mean, expected_mean);
1321
1322        // Test with reduction 'sum'
1323        let kl_div_loss = KlDivLossBuilder::new()
1324            .reduction(LossReduction::Sum)
1325            .build()
1326            .unwrap();
1327        let loss_sum = kl_div_loss.apply(&p_logits, &q_logits).unwrap();
1328        let expected_sum = expected_none.sum(None).unwrap();
1329        assert_array_eq!(loss_sum, expected_sum);
1330    }
1331
1332    #[test]
1333    fn test_triplet_loss() {
1334        let anchors = array!([[1, 2, 3], [1, 2, 3]]);
1335        let positives = array!([[4, 5, 6], [0, -1, 2]]);
1336        let negatives = array!([[7, 8, 9], [3, 2, 3]]);
1337
1338        // Test with reduction 'none'
1339        let triplet_loss = TripletLossBuilder::new()
1340            .reduction(LossReduction::None)
1341            .build()
1342            .unwrap();
1343        let loss_none = triplet_loss
1344            .apply(&anchors, &positives, &negatives)
1345            .unwrap();
1346        let expected_none = array!([0.0, 2.31662]);
1347        assert_array_eq!(loss_none, expected_none);
1348
1349        // Test with reduction 'mean'
1350        let triplet_loss = TripletLossBuilder::new()
1351            .reduction(LossReduction::Mean)
1352            .build()
1353            .unwrap();
1354        let loss_mean = triplet_loss
1355            .apply(&anchors, &positives, &negatives)
1356            .unwrap();
1357        let expected_mean = expected_none.mean(None).unwrap();
1358        assert_array_eq!(loss_mean, expected_mean);
1359
1360        // Test with reduction 'sum'
1361        let triplet_loss = TripletLossBuilder::new()
1362            .reduction(LossReduction::Sum)
1363            .build()
1364            .unwrap();
1365        let loss_sum = triplet_loss
1366            .apply(&anchors, &positives, &negatives)
1367            .unwrap();
1368        let expected_sum = expected_none.sum(None).unwrap();
1369        assert_array_eq!(loss_sum, expected_sum);
1370    }
1371
1372    #[test]
1373    fn test_hinge_loss() {
1374        let inputs = array!([[1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]]);
1375        let targets = array!([[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]);
1376        let hinge_loss = HingeLossBuilder::new()
1377            .reduction(LossReduction::Mean)
1378            .build()
1379            .unwrap();
1380        let loss = hinge_loss.apply(&inputs, &targets).unwrap();
1381        assert_eq!(loss.item::<f32>(), 1.0);
1382    }
1383
1384    #[test]
1385    fn test_huber_loss() {
1386        let inputs = array!([[1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]]);
1387        let targets = array!([[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]);
1388        let huber_loss = HuberLossBuilder::new()
1389            .reduction(LossReduction::Mean)
1390            .build()
1391            .unwrap();
1392        let loss = huber_loss.apply(&inputs, &targets).unwrap();
1393        assert_eq!(loss.item::<f32>(), 0.5);
1394    }
1395
1396    #[test]
1397    fn test_log_cosh_loss() {
1398        let inputs = array!([[1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]]);
1399        let targets = array!([[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]);
1400        let log_cosh_loss = LogCoshLossBuilder::new()
1401            .reduction(LossReduction::Mean)
1402            .build()
1403            .unwrap();
1404        let loss = log_cosh_loss.apply(&inputs, &targets).unwrap();
1405        assert_float_eq!(loss.item::<f32>(), 0.433781, abs <= 1e-6);
1406    }
1407
1408    #[test]
1409    fn test_cosine_similarity_loss() {
1410        let embeddings1 = array!([[0.5, 0.5, 0.2, 0.9], [0.1, 0.3, 0.5, 0.5]]);
1411        let embeddings2 = array!([[0.6, 0.4, 0.3, 0.8], [0.2, 0.5, 0.6, 0.4]]);
1412
1413        // Test with reduction 'none'
1414        let cosine_similarity_loss = CosineSimilarityLossBuilder::new()
1415            .reduction(LossReduction::None)
1416            .build()
1417            .unwrap();
1418        let loss_none = cosine_similarity_loss
1419            .apply(&embeddings1, &embeddings2)
1420            .unwrap();
1421        let expected_none = array!([0.985344, 0.961074]);
1422        assert_array_eq!(loss_none, expected_none);
1423
1424        // Test with reduction 'mean'
1425        let cosine_similarity_loss = CosineSimilarityLossBuilder::new()
1426            .reduction(LossReduction::Mean)
1427            .build()
1428            .unwrap();
1429        let loss_mean = cosine_similarity_loss
1430            .apply(&embeddings1, &embeddings2)
1431            .unwrap();
1432        let expected_mean = expected_none.mean(None).unwrap();
1433        assert_array_eq!(loss_mean, expected_mean);
1434
1435        // Test with reduction 'sum'
1436        let cosine_similarity_loss = CosineSimilarityLossBuilder::new()
1437            .reduction(LossReduction::Sum)
1438            .build()
1439            .unwrap();
1440        let loss_sum = cosine_similarity_loss
1441            .apply(&embeddings1, &embeddings2)
1442            .unwrap();
1443        let expected_sum = expected_none.sum(None).unwrap();
1444        assert_array_eq!(loss_sum, expected_sum);
1445    }
1446
1447    #[test]
1448    fn test_margin_ranking_loss() {
1449        let inputs1 = array!([-0.573409, -0.765166, -0.0638]);
1450        let inputs2 = array!([0.75596, 0.225763, 0.256995]);
1451        let targets = array!([1, 1, -1]);
1452
1453        // Test with no margin
1454        let margin_ranking_loss = MarginRankingLossBuilder::new()
1455            .reduction(LossReduction::None)
1456            .build()
1457            .unwrap();
1458        let loss = margin_ranking_loss
1459            .apply(&inputs1, &inputs2, &targets)
1460            .unwrap();
1461        let expected = array!([1.329369, 0.990929, 0.0]);
1462        assert_array_eq!(loss, expected);
1463
1464        // Test with margin
1465        let margin_ranking_loss = MarginRankingLossBuilder::new()
1466            .margin(0.5)
1467            .reduction(LossReduction::None)
1468            .build()
1469            .unwrap();
1470        let loss = margin_ranking_loss
1471            .apply(&inputs1, &inputs2, &targets)
1472            .unwrap();
1473        let expected = array!([1.829369, 1.490929, 0.179205]);
1474        assert_array_eq!(loss, expected);
1475    }
1476}