mlx_rs/
losses.rs

1//! Loss functions
2
3use crate::{
4    array,
5    error::{CrossEntropyBuildError, Exception},
6    ops::{
7        abs, clip, exp, indexing::take_along_axis, log, log_add_exp, log_sum_exp, maximum, minimum,
8        multiply, power, r#where, sqrt, square, sum,
9    },
10    Array,
11};
12use mlx_internal_macros::{generate_builder, Buildable};
13
14#[inline]
15fn check_shape(
16    left: &Array,
17    right: &Array,
18    left_ident: &str,
19    right_ident: &str,
20) -> Result<(), Exception> {
21    if left.shape() != right.shape() {
22        return Err(Exception::custom(format!(
23            "The shape of the {} ({:?}) does not match the shape of the {} ({:?})",
24            left_ident,
25            left.shape(),
26            right_ident,
27            right.shape()
28        )));
29    }
30    Ok(())
31}
32
33/// Different types of loss reductions
34#[derive(Debug, Clone, Copy)]
35pub enum LossReduction {
36    /// No reduction is applied.
37    None,
38    /// The sum of the output will be computed.
39    Sum,
40    /// The mean of the output will be computed.
41    Mean,
42}
43
44impl LossReduction {
45    /// Reduces the loss according to the reduction type.
46    pub fn reduce(&self, loss: Array) -> Result<Array, Exception> {
47        match self {
48            LossReduction::None => Ok(loss),
49            LossReduction::Sum => Ok(loss.sum(None, None)?),
50            LossReduction::Mean => Ok(loss.mean(None, None)?),
51        }
52    }
53}
54
55/// Helper type alias for CrossEntropyBuilder weights.
56pub type CrossEntropyBuilderWeights<'a> = &'a Array;
57
58generate_builder! {
59    /// Cross entropy loss function.
60    #[derive(Debug, Clone, Buildable)]
61    #[buildable(root = crate)]
62    #[builder(
63        root = crate,
64        build_with = build_cross_entropy,
65        err = CrossEntropyBuildError
66    )]
67    pub struct CrossEntropy<'a> {
68        /// Weights for each target
69        #[builder(optional, default = CrossEntropy::DEFAULT_WEIGHTS)]
70        pub weights: Option<&'a Array>,
71
72        /// The axis over which to compute softmax. Default to [`CrossEntropy::DEFAULT_AXIS`]
73        #[builder(optional, default = CrossEntropy::DEFAULT_AXIS)]
74        pub axis: i32,
75
76        /// The label smoothing factor, range [0, 1). Default to
77        /// [`CrossEntropy::DEFAULT_LABEL_SMOOTHING`]
78        #[builder(optional, default = CrossEntropy::DEFAULT_LABEL_SMOOTHING)]
79        pub label_smoothing: f32,
80
81        /// Reduction type. Default to [`CrossEntropy::DEFAULT_REDUCTION`]
82        #[builder(optional, default = CrossEntropy::DEFAULT_REDUCTION)]
83        pub reduction: LossReduction,
84    }
85}
86
87fn build_cross_entropy(
88    builder: CrossEntropyBuilder,
89) -> Result<CrossEntropy, CrossEntropyBuildError> {
90    let axis = builder.axis;
91    let label_smoothing = builder.label_smoothing;
92    let reduction = builder.reduction;
93
94    if !(0.0..1.0).contains(&label_smoothing) {
95        return Err(CrossEntropyBuildError::InvalidLabelSmoothingFactor);
96    }
97
98    Ok(CrossEntropy {
99        weights: builder.weights,
100        axis,
101        label_smoothing,
102        reduction,
103    })
104}
105
106impl<'a> CrossEntropy<'a> {
107    /// Default value for the `axis` parameter.
108    pub const DEFAULT_AXIS: i32 = -1;
109
110    /// Default value for the `label_smoothing` parameter.
111    pub const DEFAULT_LABEL_SMOOTHING: f32 = 0.0;
112
113    /// Default value for the `reduction` parameter.
114    pub const DEFAULT_REDUCTION: LossReduction = LossReduction::None;
115
116    /// Default value for the `weights` parameter.
117    pub const DEFAULT_WEIGHTS: Option<&'a Array> = None;
118
119    /// Apply the cross entropy loss function on the given logits and targets.
120    ///
121    /// # Params
122    ///
123    /// - `logits`: unnormalized predicted logits
124    /// - `targets`: target values, as class indices
125    pub fn apply(
126        &self,
127        logits: impl AsRef<Array>,
128        targets: impl AsRef<Array>,
129    ) -> Result<Array, Exception> {
130        let logits = logits.as_ref();
131        let targets = targets.as_ref();
132
133        let target_as_probs = targets.ndim() == logits.ndim();
134
135        let score = if target_as_probs {
136            sum(&logits.multiply(targets)?, &[self.axis], None)?
137        } else {
138            take_along_axis(logits, &targets.expand_dims(&[-1])?, self.axis)?.squeeze(&[-1])?
139        };
140        let log_sum_exp_logits = log_sum_exp(logits, &[self.axis], None)?;
141
142        let mut loss = if self.label_smoothing > 0.0 {
143            // adjust the true class score with label smoothing
144            let adjusted_score = multiply(array!(1.0 - self.label_smoothing), score)?;
145
146            // calculate the mean logit across the classes for smoothed loss
147            let mean_logits = logits.mean(&[self.axis], None)?;
148            let smoothed_loss = -multiply(mean_logits, array!(self.label_smoothing))?;
149
150            // combine the adjusted score and smoothed loss with the logsumexp logits
151            log_sum_exp_logits
152                .subtract(adjusted_score)?
153                .add(smoothed_loss)?
154        } else {
155            log_sum_exp_logits.subtract(score)?
156        };
157
158        if let Some(weights) = self.weights {
159            check_shape(weights, &loss, "weights", "loss")?;
160            loss = multiply(loss, weights)?;
161        }
162
163        self.reduction.reduce(loss)
164    }
165}
166
167generate_builder! {
168    /// Binary cross entropy loss.
169    ///
170    /// By default, this function takes the pre-sigmoid logits, which results in a faster
171    /// and more precise loss. For improved numerical stability when `inputs_are_logits` is true,
172    /// the loss calculation clips the input probabilities (in log-space) to a minimum value
173    /// of `-100`.
174    #[derive(Debug, Clone, Buildable)]
175    #[buildable(root = crate)]
176    #[builder(root = crate)]
177    pub struct BinaryCrossEntropy<'a> {
178        /// Optional weights for each target
179        #[builder(optional, default = BinaryCrossEntropy::DEFAULT_WEIGHTS)]
180        pub weights: Option<&'a Array>,
181
182        /// Whether the inputs are logits. Default to
183        /// [`BinaryCrossEntropy::DEFAULT_INPUTS_ARE_LOGITS`]
184        #[builder(optional, default = BinaryCrossEntropy::DEFAULT_INPUTS_ARE_LOGITS)]
185        pub inputs_are_logits: bool,
186
187        /// Reduction type. Default to [`BinaryCrossEntropy::DEFAULT_REDUCTION`]
188        #[builder(optional, default = BinaryCrossEntropy::DEFAULT_REDUCTION)]
189        pub reduction: LossReduction,
190    }
191}
192
193impl<'a> BinaryCrossEntropy<'a> {
194    /// Default value for the `weights` parameter.
195    pub const DEFAULT_WEIGHTS: Option<&'a Array> = None;
196
197    /// Default value for the `with_logits` parameter.
198    pub const DEFAULT_INPUTS_ARE_LOGITS: bool = true;
199
200    /// Default value for the `reduction` parameter.
201    pub const DEFAULT_REDUCTION: LossReduction = LossReduction::None;
202
203    /// Apply the binary cross entropy loss function on the given logits and targets.
204    ///
205    /// # Params
206    ///
207    /// - `logits`: unnormalized predicted logits
208    /// - `targets`: binary target values in {0, 1}
209    pub fn apply(
210        &self,
211        logits: impl AsRef<Array>,
212        targets: impl AsRef<Array>,
213    ) -> Result<Array, Exception> {
214        let logits = logits.as_ref();
215        let targets = targets.as_ref();
216        let weights = self.weights;
217        let inputs_are_logits = self.inputs_are_logits;
218        let reduction = self.reduction;
219
220        let mut loss = if inputs_are_logits {
221            log_add_exp(array!(0.0), logits)?.subtract(targets.multiply(logits)?)?
222        } else {
223            let log_inputs_clip = clip(log(logits)?, (-100.0, ()))?;
224            let log_inputs_inverse_clip = clip(log(&array!(1.0).subtract(logits)?)?, (-100.0, ()))?;
225            -(targets.multiply(log_inputs_clip)?.add(
226                array!(1.0)
227                    .subtract(targets)?
228                    .multiply(log_inputs_inverse_clip)?,
229            )?)
230        };
231
232        if let Some(weights) = weights {
233            check_shape(weights, &loss, "weights", "loss")?;
234            loss = multiply(loss, weights)?;
235        }
236
237        reduction.reduce(loss)
238    }
239}
240
241generate_builder! {
242    /// Computes the L1 loss
243    #[derive(Debug, Clone, Buildable)]
244    #[buildable(root = crate)]
245    #[builder(root = crate)]
246    pub struct L1Loss {
247        /// Reduction type. Default to [`L1loss::DEFAULT_REDUCTION`]
248        #[builder(optional, default = L1Loss::DEFAULT_REDUCTION)]
249        pub reduction: LossReduction,
250    }
251}
252
253impl L1Loss {
254    /// Default value for the `reduction` parameter.
255    pub const DEFAULT_REDUCTION: LossReduction = LossReduction::Mean;
256
257    /// Compute the L1 loss.
258    ///
259    /// # Params
260    ///
261    /// - `predictions`: predicted values
262    /// - `targets`: target values
263    pub fn apply(
264        &self,
265        predictions: impl AsRef<Array>,
266        targets: impl AsRef<Array>,
267    ) -> Result<Array, Exception> {
268        let predictions = predictions.as_ref();
269        let targets = targets.as_ref();
270        let reduction = self.reduction;
271
272        check_shape(predictions, targets, "predictions", "targets")?;
273        let loss = predictions.subtract(targets)?.abs()?;
274        reduction.reduce(loss)
275    }
276}
277
278generate_builder! {
279    /// Computes the mean squared error loss.
280    #[derive(Debug, Clone, Buildable)]
281    #[buildable(root = crate)]
282    #[builder(root = crate)]
283    pub struct MseLoss {
284        /// Reduction type. Default to [`MseLoss::DEFAULT_REDUCTION`]
285        #[builder(optional, default = MseLoss::DEFAULT_REDUCTION)]
286        pub reduction: LossReduction,
287    }
288}
289
290impl MseLoss {
291    /// Default value for the reduction parameter.
292    pub const DEFAULT_REDUCTION: LossReduction = LossReduction::Mean;
293
294    /// Compute the mean squared error loss.
295    ///
296    /// # Params
297    ///
298    /// - `predictions`: predicted values
299    /// - `targets`: target values
300    pub fn apply(
301        &self,
302        predictions: impl AsRef<Array>,
303        targets: impl AsRef<Array>,
304    ) -> Result<Array, Exception> {
305        let predictions = predictions.as_ref();
306        let targets = targets.as_ref();
307        let reduction = self.reduction;
308
309        check_shape(predictions, targets, "predictions", "targets")?;
310        let loss = predictions.subtract(targets)?.square()?;
311        reduction.reduce(loss)
312    }
313}
314
315generate_builder! {
316    /// Computes the negative log likelihood loss.
317    #[derive(Debug, Clone, Buildable)]
318    #[buildable(root = crate)]
319    #[builder(root = crate)]
320    pub struct NllLoss {
321        /// distribution axis. Default to [`NllLoss::DEFAULT_AXIS`]
322        #[builder(optional, default = NllLoss::DEFAULT_AXIS)]
323        pub axis: i32,
324
325        /// Reduction type. Default to [`NllLoss::DEFAULT_REDUCTION`]
326        #[builder(optional, default = NllLoss::DEFAULT_REDUCTION)]
327        pub reduction: LossReduction,
328    }
329}
330
331impl NllLoss {
332    /// Default value for the `axis` parameter.
333    pub const DEFAULT_AXIS: i32 = -1;
334
335    /// Default value for the `reduction` parameter.
336    pub const DEFAULT_REDUCTION: LossReduction = LossReduction::None;
337
338    /// Compute the negative log likelihood loss.
339    ///
340    /// # Params
341    ///
342    /// - `inputs`: predicted distribution in log space
343    /// - `targets`: target values
344    pub fn apply(
345        &self,
346        inputs: impl AsRef<Array>,
347        targets: impl AsRef<Array>,
348    ) -> Result<Array, Exception> {
349        let inputs = inputs.as_ref();
350        let targets = targets.as_ref();
351        let axis = self.axis;
352        let reduction = self.reduction;
353
354        let loss = -take_along_axis(inputs, &targets.expand_dims(&[-1])?, axis)?.squeeze(&[-1])?;
355        reduction.reduce(loss)
356    }
357}
358
359generate_builder! {
360    /// Compute the negative log likelihood loss for a Gaussian distribution.
361    #[derive(Debug, Clone, Buildable)]
362    #[buildable(root = crate)]
363    #[builder(root = crate)]
364    pub struct GaussianNllLoss {
365        /// Whether to include the constant term in the loss calculation. Default to
366        /// [`GaussianNllLoss::DEFAULT_FULL`]
367        #[builder(optional, default = GaussianNllLoss::DEFAULT_FULL)]
368        pub full: bool,
369
370        /// Small positive constant for numerical stability. Default to
371        /// [`GaussianNllLoss::DEFAULT_EPS`]
372        #[builder(optional, default = GaussianNllLoss::DEFAULT_EPS)]
373        pub eps: f32,
374
375        /// Reduction type. Default to [`GaussianNllLoss::DEFAULT_REDUCTION`]
376        #[builder(optional, default = GaussianNllLoss::DEFAULT_REDUCTION)]
377        pub reduction: LossReduction,
378    }
379}
380
381impl GaussianNllLoss {
382    /// Default value for the `full` parameter.
383    pub const DEFAULT_FULL: bool = false;
384
385    /// Default value for the `eps` parameter.
386    pub const DEFAULT_EPS: f32 = 1e-6;
387
388    /// Default value for the `reduction` parameter.
389    pub const DEFAULT_REDUCTION: LossReduction = LossReduction::None;
390
391    /// Compute the negative log likelihood loss for a Gaussian distribution.
392    ///
393    /// # Params
394    ///
395    /// - `inputs`: The predicted expectation of the Gaussian distribution.
396    /// - `targets`: The target values (samples from the Gaussian distribution).
397    /// - `vars`: The predicted variance of the Gaussian distribution.
398    pub fn apply(
399        &self,
400        inputs: impl AsRef<Array>,
401        targets: impl AsRef<Array>,
402        vars: impl AsRef<Array>,
403    ) -> Result<Array, Exception> {
404        let inputs = inputs.as_ref();
405        let targets = targets.as_ref();
406        let vars = vars.as_ref();
407        let full = self.full;
408        let eps = self.eps;
409        let reduction = self.reduction;
410
411        check_shape(inputs, targets, "inputs", "targets")?;
412        check_shape(inputs, vars, "inputs", "vars")?;
413
414        let vars = maximum(vars, array!(eps))?;
415        let mut loss =
416            array!(0.5) * (log(&vars)?.add(square(&targets.subtract(inputs)?)?.divide(&vars)?)?);
417
418        if full {
419            let pi = array!(std::f32::consts::PI);
420            loss = loss.add(array!(0.5).multiply(log(&array!(2.0).multiply(pi)?)?)?)?;
421        }
422
423        reduction.reduce(loss)
424    }
425}
426
427generate_builder! {
428    /// Compute the Kullback-Leibler divergence loss.
429    ///
430    /// Computes the following when the `reduction` is `LossReduction::None`:
431    ///
432    /// ```rust, ignore
433    /// sum(exp(targets) * (targets - inputs), axis, None)
434    /// ```
435    #[derive(Debug, Clone, Buildable)]
436    #[buildable(root = crate)]
437    #[builder(root = crate)]
438    pub struct KlDivLoss {
439        /// The distribution axis. Default to [`KlDivLoss::DEFAULT_AXIS`]
440        #[builder(optional, default = KlDivLoss::DEFAULT_AXIS)]
441        pub axis: i32,
442
443        /// Reduction type. Default to [`KlDivLoss::DEFAULT_REDUCTION`]
444        #[builder(optional, default = KlDivLoss::DEFAULT_REDUCTION)]
445        pub reduction: LossReduction,
446    }
447}
448
449impl KlDivLoss {
450    /// Default value for the `axis` parameter.
451    pub const DEFAULT_AXIS: i32 = -1;
452
453    /// Default value for the `reduction` parameter.
454    pub const DEFAULT_REDUCTION: LossReduction = LossReduction::None;
455
456    /// Compute the Kullback-Leibler divergence loss.
457    ///
458    /// # Params
459    ///
460    /// - `inputs`: Log probabilities for the predicted distribution.
461    /// - `targets`: Log probabilities for the target distribution.
462    pub fn apply(
463        &self,
464        inputs: impl AsRef<Array>,
465        targets: impl AsRef<Array>,
466    ) -> Result<Array, Exception> {
467        let inputs = inputs.as_ref();
468        let targets = targets.as_ref();
469        let axis = self.axis;
470        let reduction = self.reduction;
471
472        let loss = sum(
473            &exp(targets)?.multiply(targets.subtract(inputs)?)?,
474            &[axis],
475            None,
476        )?;
477        reduction.reduce(loss)
478    }
479}
480
481generate_builder! {
482    /// Computes the smooth L1 loss.
483    ///
484    /// The smooth L1 loss is a variant of the L1 loss which replaces the absolute
485    /// difference with a squared difference when the absolute difference is less
486    /// than `beta`.
487    #[derive(Debug, Clone, Buildable)]
488    #[buildable(root = crate)]
489    #[builder(root = crate)]
490    pub struct SmoothL1Loss {
491        /// The threshold after which the loss changes from the squared to the absolute difference.
492        /// Default to [`SmoothL1Loss::DEFAULT_BETA`]
493        #[builder(optional, default = SmoothL1Loss::DEFAULT_BETA)]
494        pub beta: f32,
495
496        /// Reduction type. Default to [`SmoothL1Loss::DEFAULT_REDUCTION`]
497        #[builder(optional, default = SmoothL1Loss::DEFAULT_REDUCTION)]
498        pub reduction: LossReduction,
499    }
500}
501
502impl SmoothL1Loss {
503    /// Default value for the `beta` parameter.
504    pub const DEFAULT_BETA: f32 = 1.0;
505
506    /// Default value for the `reduction` parameter.
507    pub const DEFAULT_REDUCTION: LossReduction = LossReduction::Mean;
508
509    /// Compute the smooth L1 loss.
510    ///
511    /// # Params
512    ///
513    /// - `predictions`: predicted values
514    /// - `targets`: target values
515    pub fn apply(
516        &self,
517        predictions: impl AsRef<Array>,
518        targets: impl AsRef<Array>,
519    ) -> Result<Array, Exception> {
520        let predictions = predictions.as_ref();
521        let targets = targets.as_ref();
522        let beta = self.beta;
523        let reduction = self.reduction;
524
525        check_shape(predictions, targets, "predictions", "targets")?;
526        let diff = predictions.subtract(targets)?.abs()?;
527        let beta = array!(beta);
528        let loss = r#where(
529            &diff.lt(&beta)?,
530            array!(0.5).multiply(square(&diff)?)?.divide(&beta)?,
531            diff.subtract(array!(0.5).multiply(beta)?)?,
532        )?;
533        reduction.reduce(loss)
534    }
535}
536
537generate_builder! {
538    /// Computes the triplet loss for a set of anchor, positive, and negative samples. Margin is
539    /// represented with alpha in the math section.
540    #[derive(Debug, Clone, Buildable)]
541    #[buildable(root = crate)]
542    #[builder(root = crate)]
543    pub struct TripletLoss {
544        /// Distribution axis. Default to [`TripletLoss::DEFAULT_AXIS`]
545        #[builder(optional, default = TripletLoss::DEFAULT_AXIS)]
546        pub axis: i32,
547
548        /// The norm degree for pairwise distance. Default to [`TripletLoss::DEFAULT_P`]
549        #[builder(optional, default = TripletLoss::DEFAULT_P)]
550        pub p: f32,
551
552        /// Margin for the triplet loss. Default to [`TripletLoss::DEFAULT_MARGIN`]
553        #[builder(optional, default = TripletLoss::DEFAULT_MARGIN)]
554        pub margin: f32,
555
556        /// Small positive constant for numerical stability. Default to [`TripletLoss::DEFAULT_EPS`]
557        #[builder(optional, default = TripletLoss::DEFAULT_EPS)]
558        pub eps: f32,
559
560        /// Reduction type. Default to [`TripletLoss::DEFAULT_REDUCTION`]
561        #[builder(optional, default = TripletLoss::DEFAULT_REDUCTION)]
562        pub reduction: LossReduction,
563    }
564}
565
566impl TripletLoss {
567    /// Default value for the `axis` parameter.
568    pub const DEFAULT_AXIS: i32 = -1;
569
570    /// Default value for the `p` parameter.
571    pub const DEFAULT_P: f32 = 2.0;
572
573    /// Default value for the `margin` parameter.
574    pub const DEFAULT_MARGIN: f32 = 1.0;
575
576    /// Default value for the `eps` parameter.
577    pub const DEFAULT_EPS: f32 = 1e-6;
578
579    /// Default value for the `reduction` parameter.
580    pub const DEFAULT_REDUCTION: LossReduction = LossReduction::None;
581
582    /// Computes the triplet loss for a set of anchor, positive, and negative samples. Margin is
583    /// represented with alpha in the math section.
584    ///
585    /// # Params
586    ///
587    /// - `anchors`: The anchor samples
588    /// - `positives`: The positive samples
589    /// - `neonatives`: The negative samples
590    pub fn apply(
591        &self,
592        anchors: impl AsRef<Array>,
593        positives: impl AsRef<Array>,
594        negatives: impl AsRef<Array>,
595    ) -> Result<Array, Exception> {
596        let anchors = anchors.as_ref();
597        let positives = positives.as_ref();
598        let negatives = negatives.as_ref();
599        let axis = self.axis;
600        let p = self.p;
601        let margin = self.margin;
602        let eps = self.eps;
603        let reduction = self.reduction;
604
605        let eps = array!(eps);
606        let p = array!(p);
607        let margin = array!(margin);
608
609        let pos = sqrt(
610            &power(&anchors.subtract(positives)?, &p)?
611                .sum(&[axis], None)?
612                .add(&eps)?,
613        )?;
614        let neg = sqrt(
615            &power(&anchors.subtract(negatives)?, &p)?
616                .sum(&[axis], None)?
617                .add(&eps)?,
618        )?;
619        let loss = maximum(pos.subtract(neg)?.add(margin)?, array!(0.0))?;
620        reduction.reduce(loss)
621    }
622}
623
624generate_builder! {
625    /// Compute the hinge loss.
626    #[derive(Debug, Clone, Buildable)]
627    #[buildable(root = crate)]
628    #[builder(root = crate)]
629    pub struct HingeLoss {
630        /// Reduction type. Default to [`HingeLoss::DEFAULT_REDUCTION`]
631        #[builder(optional, default = HingeLoss::DEFAULT_REDUCTION)]
632        pub reduction: LossReduction,
633    }
634}
635
636impl HingeLoss {
637    /// Default value for the `reduction` parameter.
638    pub const DEFAULT_REDUCTION: LossReduction = LossReduction::None;
639
640    /// Compute the hinge loss.
641    ///
642    /// # Params
643    ///
644    /// - `inputs`: predicted values
645    /// - `targets`: target values, -1 or 1
646    pub fn apply(
647        &self,
648        inputs: impl AsRef<Array>,
649        targets: impl AsRef<Array>,
650    ) -> Result<Array, Exception> {
651        let inputs = inputs.as_ref();
652        let targets = targets.as_ref();
653        let reduction = self.reduction;
654
655        let a = array!(1.0).subtract(inputs.multiply(targets)?)?;
656        let b = array!(0.0);
657        let loss = maximum(a, b)?;
658        reduction.reduce(loss)
659    }
660}
661
662generate_builder! {
663    /// Compute the Huber loss.
664    #[derive(Debug, Clone, Buildable)]
665    #[buildable(root = crate)]
666    #[builder(root = crate)]
667    pub struct HuberLoss {
668        /// The threshold at which to change between L1 and L2 loss. Default to
669        /// [`HuberLoss::DEFAULT_DELTA`]
670        #[builder(optional, default = HuberLoss::DEFAULT_DELTA)]
671        pub delta: f32,
672
673        /// Reduction type. Default to [`HuberLoss::DEFAULT_REDUCTION`]
674        #[builder(optional, default = HuberLoss::DEFAULT_REDUCTION)]
675        pub reduction: LossReduction,
676    }
677}
678
679impl HuberLoss {
680    /// Default value for the `delta` parameter.
681    pub const DEFAULT_DELTA: f32 = 1.0;
682
683    /// Default value for the `reduction` parameter.
684    pub const DEFAULT_REDUCTION: LossReduction = LossReduction::None;
685
686    /// Compute the Huber loss.
687    ///
688    /// # Params
689    ///
690    /// - `inputs`: predicted values
691    /// - `targets`: target values
692    pub fn apply(
693        &self,
694        inputs: impl AsRef<Array>,
695        targets: impl AsRef<Array>,
696    ) -> Result<Array, Exception> {
697        let inputs = inputs.as_ref();
698        let targets = targets.as_ref();
699        let delta = self.delta;
700        let reduction = self.reduction;
701
702        let errors = inputs.subtract(targets)?;
703        let abs_errors = errors.abs()?;
704        let quadratic = minimum(&abs_errors, array!(delta))?;
705        let linear = abs_errors.subtract(&quadratic)?;
706        let loss = array!(0.5)
707            .multiply(square(&quadratic)?)?
708            .add(array!(delta).multiply(linear)?)?;
709        reduction.reduce(loss)
710    }
711}
712
713generate_builder! {
714    /// Computes the log cosh loss between inputs and targets.
715    ///
716    /// Logcosh acts like L2 loss for small errors, ensuring stable gradients,
717    /// and like the L1 loss for large errors, reducing sensitivity to outliers. This
718    /// dual behavior offers a balanced, robust approach for regression tasks.
719    #[derive(Debug, Clone, Buildable)]
720    #[buildable(root = crate)]
721    #[builder(root = crate)]
722    pub struct LogCoshLoss {
723        /// Reduction type. Default to [`LogCoshLoss::DEFAULT_REDUCTION`]
724        #[builder(optional, default = LogCoshLoss::DEFAULT_REDUCTION)]
725        pub reduction: LossReduction,
726    }
727}
728
729impl LogCoshLoss {
730    /// Default value for the `reduction` parameter.
731    pub const DEFAULT_REDUCTION: LossReduction = LossReduction::None;
732
733    /// Computes the log cosh loss between inputs and targets.
734    ///
735    /// # Params
736    ///
737    /// - `inputs`: predicted values
738    /// - `targets`: target values
739    pub fn apply(
740        &self,
741        inputs: impl AsRef<Array>,
742        targets: impl AsRef<Array>,
743    ) -> Result<Array, Exception> {
744        let inputs = inputs.as_ref();
745        let targets = targets.as_ref();
746        let reduction = self.reduction;
747
748        let errors = inputs.subtract(targets)?;
749        let neg_errors = errors.negative()?;
750        let loss = log_add_exp(errors, neg_errors)?.subtract(log(&array!(2.0))?)?;
751        reduction.reduce(loss)
752    }
753}
754
755generate_builder! {
756    /// Computes the cosine similarity loss.
757    #[derive(Debug, Clone, Buildable)]
758    #[buildable(root = crate)]
759    #[builder(root = crate)]
760    pub struct CosineSimilarityLoss {
761        /// Embedding axis. Default to [`CosineSimilarityLoss::DEFAULT_AXIS`]
762        #[builder(optional, default = CosineSimilarityLoss::DEFAULT_AXIS)]
763        pub axis: i32,
764
765        /// minimum value of the denominator used for numerical stability. Default to
766        /// [`CosineSimilarityLoss::DEFAULT_EPS`]
767        #[builder(optional, default = CosineSimilarityLoss::DEFAULT_EPS)]
768        pub eps: f32,
769
770        /// Reduction type. Default to [`CosineSimilarityLoss::DEFAULT_REDUCTION`]
771        #[builder(optional, default = CosineSimilarityLoss::DEFAULT_REDUCTION)]
772        pub reduction: LossReduction,
773    }
774}
775
776impl CosineSimilarityLoss {
777    /// Default value for the `axis` parameter.
778    pub const DEFAULT_AXIS: i32 = -1;
779
780    /// Default value for the `eps` parameter.
781    pub const DEFAULT_EPS: f32 = 1e-8;
782
783    /// Default value for the `reduction` parameter.
784    pub const DEFAULT_REDUCTION: LossReduction = LossReduction::None;
785
786    /// Computes the cosine similarity loss.
787    ///
788    /// # Params
789    ///
790    /// - `x1`: first array
791    /// - `x2`: second array
792    pub fn apply(&self, x1: impl AsRef<Array>, x2: impl AsRef<Array>) -> Result<Array, Exception> {
793        let x1 = x1.as_ref();
794        let x2 = x2.as_ref();
795        let axis = self.axis;
796        let eps = self.eps;
797        let reduction = self.reduction;
798
799        fn l2_loss(a: &Array, axis: i32) -> Result<Array, Exception> {
800            if a.dtype().is_complex() {
801                Ok(sqrt(&sum(&abs(a)?.square()?, &[axis], None)?)?)
802            } else {
803                Ok(sqrt(&sum(&a.square()?, &[axis], None)?)?)
804            }
805        }
806
807        let x1_norm = l2_loss(x1, axis)?;
808        let x2_norm = l2_loss(x2, axis)?;
809
810        let num = sum(&x1.multiply(x2)?, &[axis], None)?;
811        let den = maximum(x1_norm.multiply(x2_norm)?, array!(eps))?;
812        let loss = num.divide(&den)?;
813
814        reduction.reduce(loss)
815    }
816}
817
818generate_builder! {
819    /// Computes the margin ranking loss.
820    #[derive(Debug, Clone, Buildable)]
821    #[buildable(root = crate)]
822    #[builder(root = crate)]
823    pub struct MarginRankingLoss {
824        /// The margin by which the scores should be separated. Default to
825        /// [`MarginRankingLoss::DEFAULT_MARGIN`]
826        #[builder(optional, default = MarginRankingLoss::DEFAULT_MARGIN)]
827        pub margin: f32,
828
829        /// Reduction type. Default to [`MarginRankingLoss::DEFAULT_REDUCTION`]
830        #[builder(optional, default = MarginRankingLoss::DEFAULT_REDUCTION)]
831        pub reduction: LossReduction,
832    }
833}
834
835impl MarginRankingLoss {
836    /// Default value for the `margin` parameter.
837    pub const DEFAULT_MARGIN: f32 = 0.0;
838
839    /// Default value for the `reduction` parameter.
840    pub const DEFAULT_REDUCTION: LossReduction = LossReduction::None;
841
842    /// Computes the margin ranking loss.
843    ///
844    /// # Params
845    ///
846    /// - `inputs1`: Scores for the first input.
847    /// - `inputs2`: Scores for the second input.
848    /// - `targets`: Labels indicating whether samples in `inputs1` should be ranked higher than samples
849    ///   in `inputs2`. Values should be 1 or -1.
850    pub fn apply(
851        &self,
852        inputs1: impl AsRef<Array>,
853        inputs2: impl AsRef<Array>,
854        targets: impl AsRef<Array>,
855    ) -> Result<Array, Exception> {
856        let inputs1 = inputs1.as_ref();
857        let inputs2 = inputs2.as_ref();
858        let targets = targets.as_ref();
859        let margin = self.margin;
860        let reduction = self.reduction;
861
862        check_shape(inputs1, inputs2, "inputs1", "inputs2")?;
863        check_shape(inputs1, targets, "inputs1", "targets")?;
864
865        let margin = array!(margin);
866        let diff = inputs1.subtract(inputs2)?;
867        let loss = maximum(
868            array!(0.0),
869            targets.multiply(diff)?.negative()?.add(margin)?,
870        )?;
871        reduction.reduce(loss)
872    }
873}
874
875#[cfg(test)]
876#[allow(clippy::approx_constant)]
877mod tests {
878    use crate::{array, assert_array_eq, builder::Builder, ops::is_nan};
879    use float_eq::assert_float_eq;
880
881    use super::*;
882
883    // The following unit tests are adapted from the python API at: mlx/python/tests/test_losses.py
884
885    #[test]
886    fn test_cross_entropy() {
887        // No weights, no label smoothing
888        let logits = array!([[0.0, f32::NEG_INFINITY], [f32::NEG_INFINITY, 0.0]]);
889        let indices = array!([0, 1]);
890        let expected = array!([0.0, 0.0]);
891        let loss = CrossEntropy::new()
892            .unwrap()
893            .apply(&logits, indices)
894            .unwrap();
895        assert_array_eq!(loss, expected);
896
897        let probs = array!([[1.0, 0.0], [0.0, 1.0]]);
898        let cross_entropy = CrossEntropyBuilder::new()
899            .reduction(LossReduction::None)
900            .build()
901            .unwrap();
902        let loss = cross_entropy.apply(logits, probs).unwrap();
903        assert!(is_nan(&loss)
904            .unwrap()
905            .all(None, None)
906            .unwrap()
907            .item::<bool>());
908
909        // With weights, no label smoothing
910        let logits = array!([[2.0, -1.0], [-1.0, 2.0]]);
911        let indices = array!([0, 1]);
912        let weights = array!([1.0, 2.0]);
913        let expected = array!([0.04858735, 0.0971747]);
914        let cross_entropy = CrossEntropyBuilder::new()
915            .weights(&weights)
916            .reduction(LossReduction::None)
917            .build()
918            .unwrap();
919        let loss = cross_entropy.apply(&logits, indices).unwrap();
920        assert_array_eq!(loss, expected);
921
922        let probs = array!([[1.0, 0.0], [0.0, 1.0]]);
923        let cross_entropy = CrossEntropyBuilder::new()
924            .weights(&weights)
925            .reduction(LossReduction::None)
926            .build()
927            .unwrap();
928        let loss = cross_entropy.apply(logits, probs).unwrap();
929        assert_array_eq!(loss, expected);
930
931        // No weights, with label smoothing
932        let logits = array!([[2.0, -1.0], [-1.0, 2.0]]);
933        let indices = array!([0, 1]);
934        let expected = array!([0.498587, 0.498587]);
935        let cross_entropy = CrossEntropyBuilder::new()
936            .label_smoothing(0.3)
937            .reduction(LossReduction::None)
938            .build()
939            .unwrap();
940        let loss = cross_entropy.apply(&logits, indices).unwrap();
941        assert_array_eq!(loss, expected);
942
943        let probs = array!([[1.0, 0.0], [0.0, 1.0]]);
944        let cross_entropy = CrossEntropyBuilder::new()
945            .label_smoothing(0.3)
946            .reduction(LossReduction::None)
947            .build()
948            .unwrap();
949        let loss = cross_entropy.apply(logits, probs).unwrap();
950        assert_array_eq!(loss, expected);
951
952        // With weights and label smoothing
953        let logits = array!([[2.0, -1.0], [-1.0, 2.0]]);
954        let indices = array!([0, 1]);
955        let weights = array!([1.0, 2.0]);
956        let expected = array!([0.49858734, 0.9971747]);
957        let cross_entropy = CrossEntropyBuilder::new()
958            .weights(&weights)
959            .label_smoothing(0.3)
960            .reduction(LossReduction::None)
961            .build()
962            .unwrap();
963        let loss = cross_entropy.apply(&logits, indices).unwrap();
964        assert_array_eq!(loss, expected);
965
966        let probs = array!([[1.0, 0.0], [0.0, 1.0]]);
967        let cross_entropy = CrossEntropyBuilder::new()
968            .weights(&weights)
969            .label_smoothing(0.3)
970            .reduction(LossReduction::None)
971            .build()
972            .unwrap();
973        let loss = cross_entropy.apply(logits, probs).unwrap();
974        assert_array_eq!(loss, expected);
975    }
976
977    #[test]
978    fn test_binary_cross_entropy_with_logits_as_inputs() {
979        let logits = array!([0.105361, 0.223144, 1.20397, 0.916291]);
980        let targets = array!([0.0, 0.0, 1.0, 1.0]);
981
982        // Test with reduction 'none'
983        let binary_cross_entropy = BinaryCrossEntropyBuilder::new()
984            .reduction(LossReduction::None)
985            .build()
986            .unwrap();
987        let loss_none = binary_cross_entropy.apply(&logits, &targets).unwrap();
988        let expected_none = array!([0.747215, 0.810930, 0.262365, 0.336472]);
989        assert_array_eq!(loss_none, expected_none);
990
991        // Test with reduction 'mean'
992        let binary_cross_entropy = BinaryCrossEntropyBuilder::new()
993            .reduction(LossReduction::Mean)
994            .build()
995            .unwrap();
996        let loss_mean = binary_cross_entropy.apply(&logits, &targets).unwrap();
997        let expected_mean = expected_none.mean(None, None).unwrap();
998        assert_array_eq!(loss_mean, expected_mean);
999
1000        // Test with reduction 'sum'
1001        let binary_cross_entropy = BinaryCrossEntropyBuilder::new()
1002            .reduction(LossReduction::Sum)
1003            .build()
1004            .unwrap();
1005        let loss = binary_cross_entropy.apply(&logits, &targets).unwrap();
1006        let expected = expected_none.sum(None, None).unwrap();
1007        assert_array_eq!(loss, expected);
1008
1009        // With weights, no label smoothing
1010        let weights = array!([1.0, 2.0, 1.0, 2.0]);
1011        let expected = array!([0.747215, 1.62186, 0.262365, 0.672944]);
1012        let binary_cross_entropy = BinaryCrossEntropyBuilder::new()
1013            .weights(&weights)
1014            .reduction(LossReduction::None)
1015            .build()
1016            .unwrap();
1017        let loss = binary_cross_entropy.apply(&logits, &targets).unwrap();
1018        assert_array_eq!(loss, expected);
1019    }
1020
1021    #[test]
1022    fn test_binary_cross_entropy_with_probs_as_inputs() {
1023        let probs = array!([0.5, 0.6, 0.7, 0.8]);
1024        let targets = array!([0.0, 0.0, 1.0, 1.0]);
1025
1026        // Test with reduction 'none'
1027        let binary_cross_entropy = BinaryCrossEntropyBuilder::new()
1028            .inputs_are_logits(false)
1029            .reduction(LossReduction::None)
1030            .build()
1031            .unwrap();
1032        let loss_none = binary_cross_entropy.apply(&probs, &targets).unwrap();
1033        let expected_none = array!([0.693147, 0.916291, 0.356675, 0.223144]);
1034        assert_array_eq!(loss_none, expected_none);
1035
1036        // Test with reduction 'mean'
1037        let binary_cross_entropy = BinaryCrossEntropyBuilder::new()
1038            .inputs_are_logits(false)
1039            .reduction(LossReduction::Mean)
1040            .build()
1041            .unwrap();
1042        let loss_mean = binary_cross_entropy.apply(&probs, &targets).unwrap();
1043        let expected_mean = expected_none.mean(None, None).unwrap();
1044        assert_array_eq!(loss_mean, expected_mean);
1045
1046        // Test with reduction 'sum'
1047        let binary_cross_entropy = BinaryCrossEntropyBuilder::new()
1048            .inputs_are_logits(false)
1049            .reduction(LossReduction::Sum)
1050            .build()
1051            .unwrap();
1052        let loss = binary_cross_entropy.apply(&probs, &targets).unwrap();
1053        let expected = expected_none.sum(None, None).unwrap();
1054        assert_array_eq!(loss, expected);
1055    }
1056
1057    #[test]
1058    fn test_binary_cross_entropy_with_tiny_probs_as_inputs() {
1059        let tiny_prob = 1e-59;
1060        let probs = array!([0.0, tiny_prob, 1.0 - tiny_prob, 1.0]);
1061        let targets = array!([0.0, 0.0, 1.0, 1.0]);
1062
1063        // Test with reduction 'none'
1064        let binary_cross_entropy = BinaryCrossEntropyBuilder::new()
1065            .inputs_are_logits(false)
1066            .reduction(LossReduction::None)
1067            .build()
1068            .unwrap();
1069        let loss_none = binary_cross_entropy.apply(&probs, &targets).unwrap();
1070        let expected_none = array!([0.0, tiny_prob, tiny_prob, 0.0]);
1071        assert_array_eq!(loss_none, expected_none);
1072
1073        // Test with reduction 'mean'
1074        let binary_cross_entropy = BinaryCrossEntropyBuilder::new()
1075            .inputs_are_logits(false)
1076            .reduction(LossReduction::Mean)
1077            .build()
1078            .unwrap();
1079        let loss_mean = binary_cross_entropy.apply(&probs, &targets).unwrap();
1080        let expected_mean = expected_none.mean(None, None).unwrap();
1081        assert_array_eq!(loss_mean, expected_mean);
1082
1083        // Test with reduction 'sum'
1084        let binary_cross_entropy = BinaryCrossEntropyBuilder::new()
1085            .inputs_are_logits(false)
1086            .reduction(LossReduction::Sum)
1087            .build()
1088            .unwrap();
1089        let loss = binary_cross_entropy.apply(&probs, &targets).unwrap();
1090        let expected = expected_none.sum(None, None).unwrap();
1091        assert_array_eq!(loss, expected);
1092    }
1093
1094    #[test]
1095    fn test_l1_loss() {
1096        let predictions = array!([0.5, 0.2, 0.9, 0.0]);
1097        let targets = array!([0.5, 0.2, 0.9, 0.0]);
1098
1099        let expected_none = array!([0.0, 0.0, 0.0, 0.0]);
1100        let expected_sum = expected_none.sum(None, None).unwrap();
1101        let expected_mean = expected_none.mean(None, None).unwrap();
1102
1103        let l1_loss = L1LossBuilder::new()
1104            .reduction(LossReduction::None)
1105            .build()
1106            .unwrap();
1107        let loss_none = l1_loss.apply(&predictions, &targets).unwrap();
1108        assert_array_eq!(loss_none, expected_none);
1109
1110        let l1_loss = L1LossBuilder::new()
1111            .reduction(LossReduction::Sum)
1112            .build()
1113            .unwrap();
1114        let loss_sum = l1_loss.apply(&predictions, &targets).unwrap();
1115        assert_array_eq!(loss_sum, expected_sum);
1116
1117        let l1_loss = L1LossBuilder::new()
1118            .reduction(LossReduction::Mean)
1119            .build()
1120            .unwrap();
1121        let loss_mean = l1_loss.apply(&predictions, &targets).unwrap();
1122        assert_array_eq!(loss_mean, expected_mean);
1123    }
1124
1125    #[test]
1126    fn test_mse_loss() {
1127        let predictions = array!([0.5, 0.2, 0.9, 0.0]);
1128        let targets = array!([0.7, 0.1, 0.8, 0.2]);
1129
1130        let expected_none = array!([0.04, 0.01, 0.01, 0.04]);
1131        let expected_mean = expected_none.mean(None, None).unwrap();
1132        let expected_sum = expected_none.sum(None, None).unwrap();
1133
1134        let mse_loss = MseLossBuilder::new()
1135            .reduction(LossReduction::None)
1136            .build()
1137            .unwrap();
1138        let loss_none = mse_loss.apply(&predictions, &targets).unwrap();
1139        assert_array_eq!(loss_none, expected_none);
1140
1141        let mse_loss = MseLossBuilder::new()
1142            .reduction(LossReduction::Mean)
1143            .build()
1144            .unwrap();
1145        let loss_mean = mse_loss.apply(&predictions, &targets).unwrap();
1146        assert_array_eq!(loss_mean, expected_mean);
1147
1148        let mse_loss = MseLossBuilder::new()
1149            .reduction(LossReduction::Sum)
1150            .build()
1151            .unwrap();
1152        let loss_sum = mse_loss.apply(&predictions, &targets).unwrap();
1153        assert_array_eq!(loss_sum, expected_sum);
1154    }
1155
1156    #[test]
1157    fn test_smooth_l1_loss() {
1158        let predictions = array!([1.5, 2.5, 0.5, 3.5]);
1159        let targets = array!([1.0, 2.0, 0.5, 2.5]);
1160        let beta = 1.0;
1161
1162        let expected_none = array!([0.125, 0.125, 0.0, 0.5]);
1163        let expected_sum = expected_none.sum(None, None).unwrap();
1164        let expected_mean = expected_none.mean(None, None).unwrap();
1165
1166        let smooth_l1_loss = SmoothL1LossBuilder::new()
1167            .beta(beta)
1168            .reduction(LossReduction::None)
1169            .build()
1170            .unwrap();
1171        let loss_none = smooth_l1_loss.apply(&predictions, &targets).unwrap();
1172        assert_array_eq!(loss_none, expected_none);
1173
1174        let smooth_l1_loss = SmoothL1LossBuilder::new()
1175            .beta(beta)
1176            .reduction(LossReduction::Sum)
1177            .build()
1178            .unwrap();
1179        let loss_sum = smooth_l1_loss.apply(&predictions, &targets).unwrap();
1180        assert_array_eq!(loss_sum, expected_sum);
1181
1182        let smooth_l1_loss = SmoothL1LossBuilder::new()
1183            .beta(beta)
1184            .reduction(LossReduction::Mean)
1185            .build()
1186            .unwrap();
1187        let loss_mean = smooth_l1_loss.apply(&predictions, &targets).unwrap();
1188        assert_array_eq!(loss_mean, expected_mean);
1189    }
1190
1191    #[test]
1192    fn test_smooth_l1_loss_negative_diff() {
1193        let a = array!([1.5, 6.0, 0.5, 2.5]);
1194        let b = array!([1.0, 2.0, 0.5, 3.5]);
1195
1196        let loss = SmoothL1Loss::new();
1197
1198        let ab = loss.apply(&a, &b).unwrap();
1199        let ba = loss.apply(&b, &a).unwrap();
1200        assert_array_eq!(ab, ba);
1201    }
1202
1203    #[test]
1204    fn test_nll_loss() {
1205        let logits = array!([[0.0, f32::NEG_INFINITY], [f32::NEG_INFINITY, 0.0]]);
1206        let targets = array!([0, 1]);
1207
1208        let expected_none = array!([0.0, 0.0]);
1209        let expected_sum = expected_none.sum(None, None).unwrap();
1210        let expected_mean = expected_none.mean(None, None).unwrap();
1211
1212        let nll_loss = NllLossBuilder::new()
1213            .reduction(LossReduction::None)
1214            .build()
1215            .unwrap();
1216        let loss_none = nll_loss.apply(&logits, &targets).unwrap();
1217        assert_array_eq!(loss_none, expected_none);
1218
1219        let nll_loss = NllLossBuilder::new()
1220            .reduction(LossReduction::Mean)
1221            .build()
1222            .unwrap();
1223        let loss_mean = nll_loss.apply(&logits, &targets).unwrap();
1224        assert_array_eq!(loss_mean, expected_mean);
1225
1226        let nll_loss = NllLossBuilder::new()
1227            .reduction(LossReduction::Sum)
1228            .build()
1229            .unwrap();
1230        let loss_sum = nll_loss.apply(&logits, &targets).unwrap();
1231        assert_array_eq!(loss_sum, expected_sum);
1232    }
1233
1234    #[test]
1235    fn test_gaussian_nll_loss() {
1236        let inputs = array!([[0.1, 0.2], [0.3, 0.4]]);
1237        let targets = array!([[0.2, 0.1], [0.1, 0.2]]);
1238        let vars = array!([[0.1, 0.2], [0.3, 0.4]]);
1239
1240        // Test with reduction 'none', full=False
1241        let gaussian_nll_loss = GaussianNllLossBuilder::new()
1242            .full(false)
1243            .reduction(LossReduction::None)
1244            .build()
1245            .unwrap();
1246        let loss_none = gaussian_nll_loss.apply(&inputs, &targets, &vars).unwrap();
1247        let expected_none = array!([[-1.101293, -0.779719], [-0.535320, -0.408145]]);
1248        assert_array_eq!(loss_none, expected_none);
1249
1250        // Test with reduction 'mean', full=False
1251        let gaussian_nll_loss = GaussianNllLossBuilder::new()
1252            .full(false)
1253            .reduction(LossReduction::Mean)
1254            .build()
1255            .unwrap();
1256        let loss_mean = gaussian_nll_loss.apply(&inputs, &targets, &vars).unwrap();
1257        let expected_mean = expected_none.mean(None, None).unwrap();
1258        assert_array_eq!(loss_mean, expected_mean);
1259
1260        // Test with reduction 'sum', full=False
1261        let gaussian_nll_loss = GaussianNllLossBuilder::new()
1262            .full(false)
1263            .reduction(LossReduction::Sum)
1264            .build()
1265            .unwrap();
1266        let loss_sum = gaussian_nll_loss.apply(&inputs, &targets, &vars).unwrap();
1267        let expected_sum = expected_none.sum(None, None).unwrap();
1268        assert_array_eq!(loss_sum, expected_sum);
1269
1270        // Test with reduction='none', full=True
1271        let gaussian_nll_loss = GaussianNllLossBuilder::new()
1272            .full(true)
1273            .reduction(LossReduction::None)
1274            .build()
1275            .unwrap();
1276        let loss_none_full = gaussian_nll_loss.apply(&inputs, &targets, &vars).unwrap();
1277        let expected_none_full = array!([[-0.182354, 0.139220], [0.383619, 0.510793]]);
1278        assert_array_eq!(loss_none_full, expected_none_full);
1279
1280        // Test with reduction='mean', full=True
1281        let gaussian_nll_loss = GaussianNllLossBuilder::new()
1282            .full(true)
1283            .reduction(LossReduction::Mean)
1284            .build()
1285            .unwrap();
1286        let loss_mean_full = gaussian_nll_loss.apply(&inputs, &targets, &vars).unwrap();
1287        let expected_mean_full = expected_none_full.mean(None, None).unwrap();
1288        assert_array_eq!(loss_mean_full, expected_mean_full);
1289
1290        // Test with reduction='sum', full=True
1291        let gaussian_nll_loss = GaussianNllLossBuilder::new()
1292            .full(true)
1293            .reduction(LossReduction::Sum)
1294            .build()
1295            .unwrap();
1296        let loss_sum_full = gaussian_nll_loss.apply(&inputs, &targets, &vars).unwrap();
1297        let expected_sum_full = expected_none_full.sum(None, None).unwrap();
1298        assert_array_eq!(loss_sum_full, expected_sum_full);
1299    }
1300
1301    #[test]
1302    fn test_kl_div_loss() {
1303        let p_logits = array!([[0.5, 0.5], [0.8, 0.2]]).log().unwrap();
1304        let q_logits = array!([[0.5, 0.5], [0.2, 0.8]]).log().unwrap();
1305
1306        // Test with reduction 'none'
1307        let kl_div_loss = KlDivLossBuilder::new()
1308            .reduction(LossReduction::None)
1309            .build()
1310            .unwrap();
1311        let loss_none = kl_div_loss.apply(&p_logits, &q_logits).unwrap();
1312        let expected_none = array!([0.0, 0.831777]);
1313        assert_array_eq!(loss_none, expected_none);
1314
1315        // Test with reduction 'mean'
1316        let kl_div_loss = KlDivLossBuilder::new()
1317            .reduction(LossReduction::Mean)
1318            .build()
1319            .unwrap();
1320        let loss_mean = kl_div_loss.apply(&p_logits, &q_logits).unwrap();
1321        let expected_mean = expected_none.mean(None, None).unwrap();
1322        assert_array_eq!(loss_mean, expected_mean);
1323
1324        // Test with reduction 'sum'
1325        let kl_div_loss = KlDivLossBuilder::new()
1326            .reduction(LossReduction::Sum)
1327            .build()
1328            .unwrap();
1329        let loss_sum = kl_div_loss.apply(&p_logits, &q_logits).unwrap();
1330        let expected_sum = expected_none.sum(None, None).unwrap();
1331        assert_array_eq!(loss_sum, expected_sum);
1332    }
1333
1334    #[test]
1335    fn test_triplet_loss() {
1336        let anchors = array!([[1, 2, 3], [1, 2, 3]]);
1337        let positives = array!([[4, 5, 6], [0, -1, 2]]);
1338        let negatives = array!([[7, 8, 9], [3, 2, 3]]);
1339
1340        // Test with reduction 'none'
1341        let triplet_loss = TripletLossBuilder::new()
1342            .reduction(LossReduction::None)
1343            .build()
1344            .unwrap();
1345        let loss_none = triplet_loss
1346            .apply(&anchors, &positives, &negatives)
1347            .unwrap();
1348        let expected_none = array!([0.0, 2.31662]);
1349        assert_array_eq!(loss_none, expected_none);
1350
1351        // Test with reduction 'mean'
1352        let triplet_loss = TripletLossBuilder::new()
1353            .reduction(LossReduction::Mean)
1354            .build()
1355            .unwrap();
1356        let loss_mean = triplet_loss
1357            .apply(&anchors, &positives, &negatives)
1358            .unwrap();
1359        let expected_mean = expected_none.mean(None, None).unwrap();
1360        assert_array_eq!(loss_mean, expected_mean);
1361
1362        // Test with reduction 'sum'
1363        let triplet_loss = TripletLossBuilder::new()
1364            .reduction(LossReduction::Sum)
1365            .build()
1366            .unwrap();
1367        let loss_sum = triplet_loss
1368            .apply(&anchors, &positives, &negatives)
1369            .unwrap();
1370        let expected_sum = expected_none.sum(None, None).unwrap();
1371        assert_array_eq!(loss_sum, expected_sum);
1372    }
1373
1374    #[test]
1375    fn test_hinge_loss() {
1376        let inputs = array!([[1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]]);
1377        let targets = array!([[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]);
1378        let hinge_loss = HingeLossBuilder::new()
1379            .reduction(LossReduction::Mean)
1380            .build()
1381            .unwrap();
1382        let loss = hinge_loss.apply(&inputs, &targets).unwrap();
1383        assert_eq!(loss.item::<f32>(), 1.0);
1384    }
1385
1386    #[test]
1387    fn test_huber_loss() {
1388        let inputs = array!([[1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]]);
1389        let targets = array!([[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]);
1390        let huber_loss = HuberLossBuilder::new()
1391            .reduction(LossReduction::Mean)
1392            .build()
1393            .unwrap();
1394        let loss = huber_loss.apply(&inputs, &targets).unwrap();
1395        assert_eq!(loss.item::<f32>(), 0.5);
1396    }
1397
1398    #[test]
1399    fn test_log_cosh_loss() {
1400        let inputs = array!([[1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]]);
1401        let targets = array!([[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]);
1402        let log_cosh_loss = LogCoshLossBuilder::new()
1403            .reduction(LossReduction::Mean)
1404            .build()
1405            .unwrap();
1406        let loss = log_cosh_loss.apply(&inputs, &targets).unwrap();
1407        assert_float_eq!(loss.item::<f32>(), 0.433781, abs <= 1e-6);
1408    }
1409
1410    #[test]
1411    fn test_cosine_similarity_loss() {
1412        let embeddings1 = array!([[0.5, 0.5, 0.2, 0.9], [0.1, 0.3, 0.5, 0.5]]);
1413        let embeddings2 = array!([[0.6, 0.4, 0.3, 0.8], [0.2, 0.5, 0.6, 0.4]]);
1414
1415        // Test with reduction 'none'
1416        let cosine_similarity_loss = CosineSimilarityLossBuilder::new()
1417            .reduction(LossReduction::None)
1418            .build()
1419            .unwrap();
1420        let loss_none = cosine_similarity_loss
1421            .apply(&embeddings1, &embeddings2)
1422            .unwrap();
1423        let expected_none = array!([0.985344, 0.961074]);
1424        assert_array_eq!(loss_none, expected_none);
1425
1426        // Test with reduction 'mean'
1427        let cosine_similarity_loss = CosineSimilarityLossBuilder::new()
1428            .reduction(LossReduction::Mean)
1429            .build()
1430            .unwrap();
1431        let loss_mean = cosine_similarity_loss
1432            .apply(&embeddings1, &embeddings2)
1433            .unwrap();
1434        let expected_mean = expected_none.mean(None, None).unwrap();
1435        assert_array_eq!(loss_mean, expected_mean);
1436
1437        // Test with reduction 'sum'
1438        let cosine_similarity_loss = CosineSimilarityLossBuilder::new()
1439            .reduction(LossReduction::Sum)
1440            .build()
1441            .unwrap();
1442        let loss_sum = cosine_similarity_loss
1443            .apply(&embeddings1, &embeddings2)
1444            .unwrap();
1445        let expected_sum = expected_none.sum(None, None).unwrap();
1446        assert_array_eq!(loss_sum, expected_sum);
1447    }
1448
1449    #[test]
1450    fn test_margin_ranking_loss() {
1451        let inputs1 = array!([-0.573409, -0.765166, -0.0638]);
1452        let inputs2 = array!([0.75596, 0.225763, 0.256995]);
1453        let targets = array!([1, 1, -1]);
1454
1455        // Test with no margin
1456        let margin_ranking_loss = MarginRankingLossBuilder::new()
1457            .reduction(LossReduction::None)
1458            .build()
1459            .unwrap();
1460        let loss = margin_ranking_loss
1461            .apply(&inputs1, &inputs2, &targets)
1462            .unwrap();
1463        let expected = array!([1.329369, 0.990929, 0.0]);
1464        assert_array_eq!(loss, expected);
1465
1466        // Test with margin
1467        let margin_ranking_loss = MarginRankingLossBuilder::new()
1468            .margin(0.5)
1469            .reduction(LossReduction::None)
1470            .build()
1471            .unwrap();
1472        let loss = margin_ranking_loss
1473            .apply(&inputs1, &inputs2, &targets)
1474            .unwrap();
1475        let expected = array!([1.829369, 1.490929, 0.179205]);
1476        assert_array_eq!(loss, expected);
1477    }
1478}