use crate::{
error::{CrossEntropyBuildError, Exception},
abs, clip, exp, indexing::take_along_axis, log, log_add_exp, log_sum_exp, maximum, minimum,
multiply, power, r#where, sqrt, square, sum,
use mlx_internal_macros::{generate_builder, Buildable};
fn check_shape(
left: &Array,
right: &Array,
left_ident: &str,
right_ident: &str,
) -> Result<(), Exception> {
if left.shape() != right.shape() {
return Err(Exception::custom(format!(
"The shape of the {} ({:?}) does not match the shape of the {} ({:?})",
#[derive(Debug, Clone, Copy)]
pub enum LossReduction {
impl LossReduction {
pub fn reduce(&self, loss: Array) -> Result<Array, Exception> {
match self {
LossReduction::None => Ok(loss),
LossReduction::Sum => Ok(loss.sum(None, None)?),
LossReduction::Mean => Ok(loss.mean(None, None)?),
pub type CrossEntropyBuilderWeights<'a> = &'a Array;
generate_builder! {
#[derive(Debug, Clone, Buildable)]
#[buildable(root = crate)]
root = crate,
build_with = build_cross_entropy,
err = CrossEntropyBuildError
pub struct CrossEntropy<'a> {
#[builder(optional, default = CrossEntropy::DEFAULT_WEIGHTS)]
pub weights: Option<&'a Array>,
#[builder(optional, default = CrossEntropy::DEFAULT_AXIS)]
pub axis: i32,
#[builder(optional, default = CrossEntropy::DEFAULT_LABEL_SMOOTHING)]
pub label_smoothing: f32,
#[builder(optional, default = CrossEntropy::DEFAULT_REDUCTION)]
pub reduction: LossReduction,
fn build_cross_entropy(
builder: CrossEntropyBuilder,
) -> Result<CrossEntropy, CrossEntropyBuildError> {
let axis = builder.axis;
let label_smoothing = builder.label_smoothing;
let reduction = builder.reduction;
if !(0.0..1.0).contains(&label_smoothing) {
return Err(CrossEntropyBuildError::InvalidLabelSmoothingFactor);
Ok(CrossEntropy {
weights: builder.weights,
impl<'a> CrossEntropy<'a> {
pub const DEFAULT_AXIS: i32 = -1;
pub const DEFAULT_LABEL_SMOOTHING: f32 = 0.0;
pub const DEFAULT_REDUCTION: LossReduction = LossReduction::None;
pub const DEFAULT_WEIGHTS: Option<&'a Array> = None;
pub fn apply(
logits: impl AsRef<Array>,
targets: impl AsRef<Array>,
) -> Result<Array, Exception> {
let logits = logits.as_ref();
let targets = targets.as_ref();
let target_as_probs = targets.ndim() == logits.ndim();
let score = if target_as_probs {
sum(&logits.multiply(targets)?, &[self.axis], None)?
} else {
take_along_axis(logits, &targets.expand_dims(&[-1])?, self.axis)?.squeeze(&[-1])?
let log_sum_exp_logits = log_sum_exp(logits, &[self.axis], None)?;
let mut loss = if self.label_smoothing > 0.0 {
let adjusted_score = multiply(array!(1.0 - self.label_smoothing), score)?;
let mean_logits = logits.mean(&[self.axis], None)?;
let smoothed_loss = -multiply(mean_logits, array!(self.label_smoothing))?;
} else {
if let Some(weights) = self.weights {
check_shape(weights, &loss, "weights", "loss")?;
loss = multiply(loss, weights)?;
generate_builder! {
#[derive(Debug, Clone, Buildable)]
#[buildable(root = crate)]
#[builder(root = crate)]
pub struct BinaryCrossEntropy<'a> {
#[builder(optional, default = BinaryCrossEntropy::DEFAULT_WEIGHTS)]
pub weights: Option<&'a Array>,
#[builder(optional, default = BinaryCrossEntropy::DEFAULT_INPUTS_ARE_LOGITS)]
pub inputs_are_logits: bool,
#[builder(optional, default = BinaryCrossEntropy::DEFAULT_REDUCTION)]
pub reduction: LossReduction,
impl<'a> BinaryCrossEntropy<'a> {
pub const DEFAULT_WEIGHTS: Option<&'a Array> = None;
pub const DEFAULT_INPUTS_ARE_LOGITS: bool = true;
pub const DEFAULT_REDUCTION: LossReduction = LossReduction::None;
pub fn apply(
logits: impl AsRef<Array>,
targets: impl AsRef<Array>,
) -> Result<Array, Exception> {
let logits = logits.as_ref();
let targets = targets.as_ref();
let weights = self.weights;
let inputs_are_logits = self.inputs_are_logits;
let reduction = self.reduction;
let mut loss = if inputs_are_logits {
log_add_exp(array!(0.0), logits)?.subtract(targets.multiply(logits)?)?
} else {
let log_inputs_clip = clip(log(logits)?, (-100.0, ()))?;
let log_inputs_inverse_clip = clip(log(&array!(1.0).subtract(logits)?)?, (-100.0, ()))?;
if let Some(weights) = weights {
check_shape(weights, &loss, "weights", "loss")?;
loss = multiply(loss, weights)?;
generate_builder! {
#[derive(Debug, Clone, Buildable)]
#[buildable(root = crate)]
#[builder(root = crate)]
pub struct L1Loss {
#[builder(optional, default = L1Loss::DEFAULT_REDUCTION)]
pub reduction: LossReduction,
impl L1Loss {
pub const DEFAULT_REDUCTION: LossReduction = LossReduction::Mean;
pub fn apply(
predictions: impl AsRef<Array>,
targets: impl AsRef<Array>,
) -> Result<Array, Exception> {
let predictions = predictions.as_ref();
let targets = targets.as_ref();
let reduction = self.reduction;
check_shape(predictions, targets, "predictions", "targets")?;
let loss = predictions.subtract(targets)?.abs()?;
generate_builder! {
#[derive(Debug, Clone, Buildable)]
#[buildable(root = crate)]
#[builder(root = crate)]
pub struct MseLoss {
#[builder(optional, default = MseLoss::DEFAULT_REDUCTION)]
pub reduction: LossReduction,
impl MseLoss {
pub const DEFAULT_REDUCTION: LossReduction = LossReduction::Mean;
pub fn apply(
predictions: impl AsRef<Array>,
targets: impl AsRef<Array>,
) -> Result<Array, Exception> {
let predictions = predictions.as_ref();
let targets = targets.as_ref();
let reduction = self.reduction;
check_shape(predictions, targets, "predictions", "targets")?;
let loss = predictions.subtract(targets)?.square()?;
generate_builder! {
#[derive(Debug, Clone, Buildable)]
#[buildable(root = crate)]
#[builder(root = crate)]
pub struct NllLoss {
#[builder(optional, default = NllLoss::DEFAULT_AXIS)]
pub axis: i32,
#[builder(optional, default = NllLoss::DEFAULT_REDUCTION)]
pub reduction: LossReduction,
impl NllLoss {
pub const DEFAULT_AXIS: i32 = -1;
pub const DEFAULT_REDUCTION: LossReduction = LossReduction::None;
pub fn apply(
inputs: impl AsRef<Array>,
targets: impl AsRef<Array>,
) -> Result<Array, Exception> {
let inputs = inputs.as_ref();
let targets = targets.as_ref();
let axis = self.axis;
let reduction = self.reduction;
let loss = -take_along_axis(inputs, &targets.expand_dims(&[-1])?, axis)?.squeeze(&[-1])?;
generate_builder! {
#[derive(Debug, Clone, Buildable)]
#[buildable(root = crate)]
#[builder(root = crate)]
pub struct GaussianNllLoss {
#[builder(optional, default = GaussianNllLoss::DEFAULT_FULL)]
pub full: bool,
#[builder(optional, default = GaussianNllLoss::DEFAULT_EPS)]
pub eps: f32,
#[builder(optional, default = GaussianNllLoss::DEFAULT_REDUCTION)]
pub reduction: LossReduction,
impl GaussianNllLoss {
pub const DEFAULT_FULL: bool = false;
pub const DEFAULT_EPS: f32 = 1e-6;
pub const DEFAULT_REDUCTION: LossReduction = LossReduction::None;
pub fn apply(
inputs: impl AsRef<Array>,
targets: impl AsRef<Array>,
vars: impl AsRef<Array>,
) -> Result<Array, Exception> {
let inputs = inputs.as_ref();
let targets = targets.as_ref();
let vars = vars.as_ref();
let full = self.full;
let eps = self.eps;
let reduction = self.reduction;
check_shape(inputs, targets, "inputs", "targets")?;
check_shape(inputs, vars, "inputs", "vars")?;
let vars = maximum(vars, array!(eps))?;
let mut loss =
array!(0.5) * (log(&vars)?.add(square(&targets.subtract(inputs)?)?.divide(&vars)?)?);
if full {
let pi = array!(std::f32::consts::PI);
loss = loss.add(array!(0.5).multiply(log(&array!(2.0).multiply(pi)?)?)?)?;
generate_builder! {
#[derive(Debug, Clone, Buildable)]
#[buildable(root = crate)]
#[builder(root = crate)]
pub struct KlDivLoss {
#[builder(optional, default = KlDivLoss::DEFAULT_AXIS)]
pub axis: i32,
#[builder(optional, default = KlDivLoss::DEFAULT_REDUCTION)]
pub reduction: LossReduction,
impl KlDivLoss {
pub const DEFAULT_AXIS: i32 = -1;
pub const DEFAULT_REDUCTION: LossReduction = LossReduction::None;
pub fn apply(
inputs: impl AsRef<Array>,
targets: impl AsRef<Array>,
) -> Result<Array, Exception> {
let inputs = inputs.as_ref();
let targets = targets.as_ref();
let axis = self.axis;
let reduction = self.reduction;
let loss = sum(
generate_builder! {
#[derive(Debug, Clone, Buildable)]
#[buildable(root = crate)]
#[builder(root = crate)]
pub struct SmoothL1Loss {
#[builder(optional, default = SmoothL1Loss::DEFAULT_BETA)]
pub beta: f32,
#[builder(optional, default = SmoothL1Loss::DEFAULT_REDUCTION)]
pub reduction: LossReduction,
impl SmoothL1Loss {
pub const DEFAULT_BETA: f32 = 1.0;
pub const DEFAULT_REDUCTION: LossReduction = LossReduction::Mean;
pub fn apply(
predictions: impl AsRef<Array>,
targets: impl AsRef<Array>,
) -> Result<Array, Exception> {
let predictions = predictions.as_ref();
let targets = targets.as_ref();
let beta = self.beta;
let reduction = self.reduction;
check_shape(predictions, targets, "predictions", "targets")?;
let diff = predictions.subtract(targets)?;
let loss = r#where(
generate_builder! {
#[derive(Debug, Clone, Buildable)]
#[buildable(root = crate)]
#[builder(root = crate)]
pub struct TripletLoss {
#[builder(optional, default = TripletLoss::DEFAULT_AXIS)]
pub axis: i32,
#[builder(optional, default = TripletLoss::DEFAULT_P)]
pub p: f32,
#[builder(optional, default = TripletLoss::DEFAULT_MARGIN)]
pub margin: f32,
#[builder(optional, default = TripletLoss::DEFAULT_EPS)]
pub eps: f32,
#[builder(optional, default = TripletLoss::DEFAULT_REDUCTION)]
pub reduction: LossReduction,
impl TripletLoss {
pub const DEFAULT_AXIS: i32 = -1;
pub const DEFAULT_P: f32 = 2.0;
pub const DEFAULT_MARGIN: f32 = 1.0;
pub const DEFAULT_EPS: f32 = 1e-6;
pub const DEFAULT_REDUCTION: LossReduction = LossReduction::None;
pub fn apply(
anchors: impl AsRef<Array>,
positives: impl AsRef<Array>,
negatives: impl AsRef<Array>,
) -> Result<Array, Exception> {
let anchors = anchors.as_ref();
let positives = positives.as_ref();
let negatives = negatives.as_ref();
let axis = self.axis;
let p = self.p;
let margin = self.margin;
let eps = self.eps;
let reduction = self.reduction;
let eps = array!(eps);
let p = array!(p);
let margin = array!(margin);
let pos = sqrt(
&power(&anchors.subtract(positives)?, &p)?
.sum(&[axis], None)?
let neg = sqrt(
&power(&anchors.subtract(negatives)?, &p)?
.sum(&[axis], None)?
let loss = maximum(pos.subtract(neg)?.add(margin)?, array!(0.0))?;
generate_builder! {
#[derive(Debug, Clone, Buildable)]
#[buildable(root = crate)]
#[builder(root = crate)]
pub struct HingeLoss {
#[builder(optional, default = HingeLoss::DEFAULT_REDUCTION)]
pub reduction: LossReduction,
impl HingeLoss {
pub const DEFAULT_REDUCTION: LossReduction = LossReduction::None;
pub fn apply(
inputs: impl AsRef<Array>,
targets: impl AsRef<Array>,
) -> Result<Array, Exception> {
let inputs = inputs.as_ref();
let targets = targets.as_ref();
let reduction = self.reduction;
let a = array!(1.0).subtract(inputs.multiply(targets)?)?;
let b = array!(0.0);
let loss = maximum(a, b)?;
generate_builder! {
#[derive(Debug, Clone, Buildable)]
#[buildable(root = crate)]
#[builder(root = crate)]
pub struct HuberLoss {
#[builder(optional, default = HuberLoss::DEFAULT_DELTA)]
pub delta: f32,
#[builder(optional, default = HuberLoss::DEFAULT_REDUCTION)]
pub reduction: LossReduction,
impl HuberLoss {
pub const DEFAULT_DELTA: f32 = 1.0;
pub const DEFAULT_REDUCTION: LossReduction = LossReduction::None;
pub fn apply(
inputs: impl AsRef<Array>,
targets: impl AsRef<Array>,
) -> Result<Array, Exception> {
let inputs = inputs.as_ref();
let targets = targets.as_ref();
let delta =;
let reduction = self.reduction;
let errors = inputs.subtract(targets)?;
let abs_errors = errors.abs()?;
let quadratic = minimum(&abs_errors, array!(delta))?;
let linear = abs_errors.subtract(&quadratic)?;
let loss = array!(0.5)
generate_builder! {
#[derive(Debug, Clone, Buildable)]
#[buildable(root = crate)]
#[builder(root = crate)]
pub struct LogCoshLoss {
#[builder(optional, default = LogCoshLoss::DEFAULT_REDUCTION)]
pub reduction: LossReduction,
impl LogCoshLoss {
pub const DEFAULT_REDUCTION: LossReduction = LossReduction::None;
pub fn apply(
inputs: impl AsRef<Array>,
targets: impl AsRef<Array>,
) -> Result<Array, Exception> {
let inputs = inputs.as_ref();
let targets = targets.as_ref();
let reduction = self.reduction;
let errors = inputs.subtract(targets)?;
let neg_errors = errors.negative()?;
let loss = log_add_exp(errors, neg_errors)?.subtract(log(&array!(2.0))?)?;
generate_builder! {
#[derive(Debug, Clone, Buildable)]
#[buildable(root = crate)]
#[builder(root = crate)]
pub struct CosineSimilarityLoss {
#[builder(optional, default = CosineSimilarityLoss::DEFAULT_AXIS)]
pub axis: i32,
#[builder(optional, default = CosineSimilarityLoss::DEFAULT_EPS)]
pub eps: f32,
#[builder(optional, default = CosineSimilarityLoss::DEFAULT_REDUCTION)]
pub reduction: LossReduction,
impl CosineSimilarityLoss {
pub const DEFAULT_AXIS: i32 = -1;
pub const DEFAULT_EPS: f32 = 1e-8;
pub const DEFAULT_REDUCTION: LossReduction = LossReduction::None;
pub fn apply(&self, x1: impl AsRef<Array>, x2: impl AsRef<Array>) -> Result<Array, Exception> {
let x1 = x1.as_ref();
let x2 = x2.as_ref();
let axis = self.axis;
let eps = self.eps;
let reduction = self.reduction;
fn l2_loss(a: &Array, axis: i32) -> Result<Array, Exception> {
if a.dtype().is_complex() {
Ok(sqrt(&sum(&abs(a)?.square()?, &[axis], None)?)?)
} else {
Ok(sqrt(&sum(&a.square()?, &[axis], None)?)?)
let x1_norm = l2_loss(x1, axis)?;
let x2_norm = l2_loss(x2, axis)?;
let num = sum(&x1.multiply(x2)?, &[axis], None)?;
let den = maximum(x1_norm.multiply(x2_norm)?, array!(eps))?;
let loss = num.divide(&den)?;
generate_builder! {
#[derive(Debug, Clone, Buildable)]
#[buildable(root = crate)]
#[builder(root = crate)]
pub struct MarginRankingLoss {
#[builder(optional, default = MarginRankingLoss::DEFAULT_MARGIN)]
pub margin: f32,
#[builder(optional, default = MarginRankingLoss::DEFAULT_REDUCTION)]
pub reduction: LossReduction,
impl MarginRankingLoss {
pub const DEFAULT_MARGIN: f32 = 0.0;
pub const DEFAULT_REDUCTION: LossReduction = LossReduction::None;
pub fn apply(
inputs1: impl AsRef<Array>,
inputs2: impl AsRef<Array>,
targets: impl AsRef<Array>,
) -> Result<Array, Exception> {
let inputs1 = inputs1.as_ref();
let inputs2 = inputs2.as_ref();
let targets = targets.as_ref();
let margin = self.margin;
let reduction = self.reduction;
check_shape(inputs1, inputs2, "inputs1", "inputs2")?;
check_shape(inputs1, targets, "inputs1", "targets")?;
let margin = array!(margin);
let diff = inputs1.subtract(inputs2)?;
let loss = maximum(
mod tests {
use crate::{array, assert_array_eq, builder::Builder, ops::is_nan};
use float_eq::assert_float_eq;
use super::*;
fn test_cross_entropy() {
let logits = array!([[0.0, f32::NEG_INFINITY], [f32::NEG_INFINITY, 0.0]]);
let indices = array!([0, 1]);
let expected = array!([0.0, 0.0]);
let loss = CrossEntropy::new()
.apply(&logits, indices)
assert_array_eq!(loss, expected);
let probs = array!([[1.0, 0.0], [0.0, 1.0]]);
let cross_entropy = CrossEntropyBuilder::new()
let loss = cross_entropy.apply(logits, probs).unwrap();
.all(None, None)
let logits = array!([[2.0, -1.0], [-1.0, 2.0]]);
let indices = array!([0, 1]);
let weights = array!([1.0, 2.0]);
let expected = array!([0.04858735, 0.0971747]);
let cross_entropy = CrossEntropyBuilder::new()
let loss = cross_entropy.apply(&logits, indices).unwrap();
assert_array_eq!(loss, expected);
let probs = array!([[1.0, 0.0], [0.0, 1.0]]);
let cross_entropy = CrossEntropyBuilder::new()
let loss = cross_entropy.apply(logits, probs).unwrap();
assert_array_eq!(loss, expected);
let logits = array!([[2.0, -1.0], [-1.0, 2.0]]);
let indices = array!([0, 1]);
let expected = array!([0.498587, 0.498587]);
let cross_entropy = CrossEntropyBuilder::new()
let loss = cross_entropy.apply(&logits, indices).unwrap();
assert_array_eq!(loss, expected);
let probs = array!([[1.0, 0.0], [0.0, 1.0]]);
let cross_entropy = CrossEntropyBuilder::new()
let loss = cross_entropy.apply(logits, probs).unwrap();
assert_array_eq!(loss, expected);
let logits = array!([[2.0, -1.0], [-1.0, 2.0]]);
let indices = array!([0, 1]);
let weights = array!([1.0, 2.0]);
let expected = array!([0.49858734, 0.9971747]);
let cross_entropy = CrossEntropyBuilder::new()
let loss = cross_entropy.apply(&logits, indices).unwrap();
assert_array_eq!(loss, expected);
let probs = array!([[1.0, 0.0], [0.0, 1.0]]);
let cross_entropy = CrossEntropyBuilder::new()
let loss = cross_entropy.apply(logits, probs).unwrap();
assert_array_eq!(loss, expected);
fn test_binary_cross_entropy_with_logits_as_inputs() {
let logits = array!([0.105361, 0.223144, 1.20397, 0.916291]);
let targets = array!([0.0, 0.0, 1.0, 1.0]);
let binary_cross_entropy = BinaryCrossEntropyBuilder::new()
let loss_none = binary_cross_entropy.apply(&logits, &targets).unwrap();
let expected_none = array!([0.747215, 0.810930, 0.262365, 0.336472]);
assert_array_eq!(loss_none, expected_none);
let binary_cross_entropy = BinaryCrossEntropyBuilder::new()
let loss_mean = binary_cross_entropy.apply(&logits, &targets).unwrap();
let expected_mean = expected_none.mean(None, None).unwrap();
assert_array_eq!(loss_mean, expected_mean);
let binary_cross_entropy = BinaryCrossEntropyBuilder::new()
let loss = binary_cross_entropy.apply(&logits, &targets).unwrap();
let expected = expected_none.sum(None, None).unwrap();
assert_array_eq!(loss, expected);
let weights = array!([1.0, 2.0, 1.0, 2.0]);
let expected = array!([0.747215, 1.62186, 0.262365, 0.672944]);
let binary_cross_entropy = BinaryCrossEntropyBuilder::new()
let loss = binary_cross_entropy.apply(&logits, &targets).unwrap();
assert_array_eq!(loss, expected);
fn test_binary_cross_entropy_with_probs_as_inputs() {
let probs = array!([0.5, 0.6, 0.7, 0.8]);
let targets = array!([0.0, 0.0, 1.0, 1.0]);
let binary_cross_entropy = BinaryCrossEntropyBuilder::new()
let loss_none = binary_cross_entropy.apply(&probs, &targets).unwrap();
let expected_none = array!([0.693147, 0.916291, 0.356675, 0.223144]);
assert_array_eq!(loss_none, expected_none);
let binary_cross_entropy = BinaryCrossEntropyBuilder::new()
let loss_mean = binary_cross_entropy.apply(&probs, &targets).unwrap();
let expected_mean = expected_none.mean(None, None).unwrap();
assert_array_eq!(loss_mean, expected_mean);
let binary_cross_entropy = BinaryCrossEntropyBuilder::new()
let loss = binary_cross_entropy.apply(&probs, &targets).unwrap();
let expected = expected_none.sum(None, None).unwrap();
assert_array_eq!(loss, expected);
fn test_binary_cross_entropy_with_tiny_probs_as_inputs() {
let tiny_prob = 1e-59;
let probs = array!([0.0, tiny_prob, 1.0 - tiny_prob, 1.0]);
let targets = array!([0.0, 0.0, 1.0, 1.0]);
let binary_cross_entropy = BinaryCrossEntropyBuilder::new()
let loss_none = binary_cross_entropy.apply(&probs, &targets).unwrap();
let expected_none = array!([0.0, tiny_prob, tiny_prob, 0.0]);
assert_array_eq!(loss_none, expected_none);
let binary_cross_entropy = BinaryCrossEntropyBuilder::new()
let loss_mean = binary_cross_entropy.apply(&probs, &targets).unwrap();
let expected_mean = expected_none.mean(None, None).unwrap();
assert_array_eq!(loss_mean, expected_mean);
let binary_cross_entropy = BinaryCrossEntropyBuilder::new()
let loss = binary_cross_entropy.apply(&probs, &targets).unwrap();
let expected = expected_none.sum(None, None).unwrap();
assert_array_eq!(loss, expected);
fn test_l1_loss() {
let predictions = array!([0.5, 0.2, 0.9, 0.0]);
let targets = array!([0.5, 0.2, 0.9, 0.0]);
let expected_none = array!([0.0, 0.0, 0.0, 0.0]);
let expected_sum = expected_none.sum(None, None).unwrap();
let expected_mean = expected_none.mean(None, None).unwrap();
let l1_loss = L1LossBuilder::new()
let loss_none = l1_loss.apply(&predictions, &targets).unwrap();
assert_array_eq!(loss_none, expected_none);
let l1_loss = L1LossBuilder::new()
let loss_sum = l1_loss.apply(&predictions, &targets).unwrap();
assert_array_eq!(loss_sum, expected_sum);
let l1_loss = L1LossBuilder::new()
let loss_mean = l1_loss.apply(&predictions, &targets).unwrap();
assert_array_eq!(loss_mean, expected_mean);
fn test_mse_loss() {
let predictions = array!([0.5, 0.2, 0.9, 0.0]);
let targets = array!([0.7, 0.1, 0.8, 0.2]);
let expected_none = array!([0.04, 0.01, 0.01, 0.04]);
let expected_mean = expected_none.mean(None, None).unwrap();
let expected_sum = expected_none.sum(None, None).unwrap();
let mse_loss = MseLossBuilder::new()
let loss_none = mse_loss.apply(&predictions, &targets).unwrap();
assert_array_eq!(loss_none, expected_none);
let mse_loss = MseLossBuilder::new()
let loss_mean = mse_loss.apply(&predictions, &targets).unwrap();
assert_array_eq!(loss_mean, expected_mean);
let mse_loss = MseLossBuilder::new()
let loss_sum = mse_loss.apply(&predictions, &targets).unwrap();
assert_array_eq!(loss_sum, expected_sum);
fn test_smooth_l1_loss() {
let predictions = array!([1.5, 2.5, 0.5, 3.5]);
let targets = array!([1.0, 2.0, 0.5, 2.5]);
let beta = 1.0;
let expected_none = array!([0.125, 0.125, 0.0, 0.5]);
let expected_sum = expected_none.sum(None, None).unwrap();
let expected_mean = expected_none.mean(None, None).unwrap();
let smooth_l1_loss = SmoothL1LossBuilder::new()
let loss_none = smooth_l1_loss.apply(&predictions, &targets).unwrap();
assert_array_eq!(loss_none, expected_none);
let smooth_l1_loss = SmoothL1LossBuilder::new()
let loss_sum = smooth_l1_loss.apply(&predictions, &targets).unwrap();
assert_array_eq!(loss_sum, expected_sum);
let smooth_l1_loss = SmoothL1LossBuilder::new()
let loss_mean = smooth_l1_loss.apply(&predictions, &targets).unwrap();
assert_array_eq!(loss_mean, expected_mean);
fn test_nll_loss() {
let logits = array!([[0.0, f32::NEG_INFINITY], [f32::NEG_INFINITY, 0.0]]);
let targets = array!([0, 1]);
let expected_none = array!([0.0, 0.0]);
let expected_sum = expected_none.sum(None, None).unwrap();
let expected_mean = expected_none.mean(None, None).unwrap();
let nll_loss = NllLossBuilder::new()
let loss_none = nll_loss.apply(&logits, &targets).unwrap();
assert_array_eq!(loss_none, expected_none);
let nll_loss = NllLossBuilder::new()
let loss_mean = nll_loss.apply(&logits, &targets).unwrap();
assert_array_eq!(loss_mean, expected_mean);
let nll_loss = NllLossBuilder::new()
let loss_sum = nll_loss.apply(&logits, &targets).unwrap();
assert_array_eq!(loss_sum, expected_sum);
fn test_gaussian_nll_loss() {
let inputs = array!([[0.1, 0.2], [0.3, 0.4]]);
let targets = array!([[0.2, 0.1], [0.1, 0.2]]);
let vars = array!([[0.1, 0.2], [0.3, 0.4]]);
let gaussian_nll_loss = GaussianNllLossBuilder::new()
let loss_none = gaussian_nll_loss.apply(&inputs, &targets, &vars).unwrap();
let expected_none = array!([[-1.101293, -0.779719], [-0.535320, -0.408145]]);
assert_array_eq!(loss_none, expected_none);
let gaussian_nll_loss = GaussianNllLossBuilder::new()
let loss_mean = gaussian_nll_loss.apply(&inputs, &targets, &vars).unwrap();
let expected_mean = expected_none.mean(None, None).unwrap();
assert_array_eq!(loss_mean, expected_mean);
let gaussian_nll_loss = GaussianNllLossBuilder::new()
let loss_sum = gaussian_nll_loss.apply(&inputs, &targets, &vars).unwrap();
let expected_sum = expected_none.sum(None, None).unwrap();
assert_array_eq!(loss_sum, expected_sum);
let gaussian_nll_loss = GaussianNllLossBuilder::new()
let loss_none_full = gaussian_nll_loss.apply(&inputs, &targets, &vars).unwrap();
let expected_none_full = array!([[-0.182354, 0.139220], [0.383619, 0.510793]]);
assert_array_eq!(loss_none_full, expected_none_full);
let gaussian_nll_loss = GaussianNllLossBuilder::new()
let loss_mean_full = gaussian_nll_loss.apply(&inputs, &targets, &vars).unwrap();
let expected_mean_full = expected_none_full.mean(None, None).unwrap();
assert_array_eq!(loss_mean_full, expected_mean_full);
let gaussian_nll_loss = GaussianNllLossBuilder::new()
let loss_sum_full = gaussian_nll_loss.apply(&inputs, &targets, &vars).unwrap();
let expected_sum_full = expected_none_full.sum(None, None).unwrap();
assert_array_eq!(loss_sum_full, expected_sum_full);
fn test_kl_div_loss() {
let p_logits = array!([[0.5, 0.5], [0.8, 0.2]]).log().unwrap();
let q_logits = array!([[0.5, 0.5], [0.2, 0.8]]).log().unwrap();
let kl_div_loss = KlDivLossBuilder::new()
let loss_none = kl_div_loss.apply(&p_logits, &q_logits).unwrap();
let expected_none = array!([0.0, 0.831777]);
assert_array_eq!(loss_none, expected_none);
let kl_div_loss = KlDivLossBuilder::new()
let loss_mean = kl_div_loss.apply(&p_logits, &q_logits).unwrap();
let expected_mean = expected_none.mean(None, None).unwrap();
assert_array_eq!(loss_mean, expected_mean);
let kl_div_loss = KlDivLossBuilder::new()
let loss_sum = kl_div_loss.apply(&p_logits, &q_logits).unwrap();
let expected_sum = expected_none.sum(None, None).unwrap();
assert_array_eq!(loss_sum, expected_sum);
fn test_triplet_loss() {
let anchors = array!([[1, 2, 3], [1, 2, 3]]);
let positives = array!([[4, 5, 6], [0, -1, 2]]);
let negatives = array!([[7, 8, 9], [3, 2, 3]]);
let triplet_loss = TripletLossBuilder::new()
let loss_none = triplet_loss
.apply(&anchors, &positives, &negatives)
let expected_none = array!([0.0, 2.31662]);
assert_array_eq!(loss_none, expected_none);
let triplet_loss = TripletLossBuilder::new()
let loss_mean = triplet_loss
.apply(&anchors, &positives, &negatives)
let expected_mean = expected_none.mean(None, None).unwrap();
assert_array_eq!(loss_mean, expected_mean);
let triplet_loss = TripletLossBuilder::new()
let loss_sum = triplet_loss
.apply(&anchors, &positives, &negatives)
let expected_sum = expected_none.sum(None, None).unwrap();
assert_array_eq!(loss_sum, expected_sum);
fn test_hinge_loss() {
let inputs = array!([[1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]]);
let targets = array!([[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]);
let hinge_loss = HingeLossBuilder::new()
let loss = hinge_loss.apply(&inputs, &targets).unwrap();
assert_eq!(loss.item::<f32>(), 1.0);
fn test_huber_loss() {
let inputs = array!([[1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]]);
let targets = array!([[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]);
let huber_loss = HuberLossBuilder::new()
let loss = huber_loss.apply(&inputs, &targets).unwrap();
assert_eq!(loss.item::<f32>(), 0.5);
fn test_log_cosh_loss() {
let inputs = array!([[1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 1.0, 1.0]]);
let targets = array!([[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]);
let log_cosh_loss = LogCoshLossBuilder::new()
let loss = log_cosh_loss.apply(&inputs, &targets).unwrap();
assert_float_eq!(loss.item::<f32>(), 0.433781, abs <= 1e-6);
fn test_cosine_similarity_loss() {
let embeddings1 = array!([[0.5, 0.5, 0.2, 0.9], [0.1, 0.3, 0.5, 0.5]]);
let embeddings2 = array!([[0.6, 0.4, 0.3, 0.8], [0.2, 0.5, 0.6, 0.4]]);
let cosine_similarity_loss = CosineSimilarityLossBuilder::new()
let loss_none = cosine_similarity_loss
.apply(&embeddings1, &embeddings2)
let expected_none = array!([0.985344, 0.961074]);
assert_array_eq!(loss_none, expected_none);
let cosine_similarity_loss = CosineSimilarityLossBuilder::new()
let loss_mean = cosine_similarity_loss
.apply(&embeddings1, &embeddings2)
let expected_mean = expected_none.mean(None, None).unwrap();
assert_array_eq!(loss_mean, expected_mean);
let cosine_similarity_loss = CosineSimilarityLossBuilder::new()
let loss_sum = cosine_similarity_loss
.apply(&embeddings1, &embeddings2)
let expected_sum = expected_none.sum(None, None).unwrap();
assert_array_eq!(loss_sum, expected_sum);
fn test_margin_ranking_loss() {
let inputs1 = array!([-0.573409, -0.765166, -0.0638]);
let inputs2 = array!([0.75596, 0.225763, 0.256995]);
let targets = array!([1, 1, -1]);
let margin_ranking_loss = MarginRankingLossBuilder::new()
let loss = margin_ranking_loss
.apply(&inputs1, &inputs2, &targets)
let expected = array!([1.329369, 0.990929, 0.0]);
assert_array_eq!(loss, expected);
let margin_ranking_loss = MarginRankingLossBuilder::new()
let loss = margin_ranking_loss
.apply(&inputs1, &inputs2, &targets)
let expected = array!([1.829369, 1.490929, 0.179205]);
assert_array_eq!(loss, expected);