use crate::module::Module;
use crate::Array;
use crate::{array, error::Exception, ops::multiply, random::bernoulli};
use mlx_internal_macros::{Buildable, Builder};
use mlx_macros::ModuleParameters;
use crate::error::DropoutBuildError;
#[derive(Debug, Clone, Builder)]
#[builder(
root = crate,
build_with = build_dropout,
default_infallible,
err = DropoutBuildError,
)]
pub struct DropoutBuilder {
#[builder(optional, default = Dropout::DEFAULT_P)]
p: f32,
}
fn build_dropout(builder: DropoutBuilder) -> Result<Dropout, DropoutBuildError> {
let p = builder.p;
if !(0.0..1.0).contains(&p) {
return Err(DropoutBuildError::InvalidProbability);
}
Ok(Dropout {
one_minus_p: 1.0 - p,
training: Dropout::DEFAULT_TRAINING,
})
}
#[derive(Debug, Clone, ModuleParameters, Buildable)]
#[module(root = crate)]
#[buildable(root = crate)]
pub struct Dropout {
pub one_minus_p: f32,
pub training: bool,
}
impl Dropout {
pub const DEFAULT_P: f32 = 0.5;
pub const DEFAULT_TRAINING: bool = true;
}
impl Module<&Array> for Dropout {
type Error = Exception;
type Output = Array;
fn forward(&mut self, x: &Array) -> Result<Array, Self::Error> {
if self.one_minus_p == 1.0 || !self.training {
return Ok(x.clone());
}
let p1 = array!(self.one_minus_p);
let mask = bernoulli(&p1, x.shape(), None)?;
multiply(multiply(array!(1.0 / self.one_minus_p), mask)?, x).map_err(Into::into)
}
fn training_mode(&mut self, mode: bool) {
self.training = mode;
}
}
#[derive(Debug, Clone, Builder)]
#[builder(
root = crate,
build_with = build_dropout2d,
default_infallible,
err = DropoutBuildError,
)]
pub struct Dropout2dBuilder {
#[builder(optional, default = Dropout2d::DEFAULT_P)]
p: f32,
}
fn build_dropout2d(builder: Dropout2dBuilder) -> Result<Dropout2d, DropoutBuildError> {
let p = builder.p;
if !(0.0..1.0).contains(&p) {
return Err(DropoutBuildError::InvalidProbability);
}
Ok(Dropout2d {
one_minus_p: 1.0 - p,
training: Dropout2d::DEFAULT_TRAINING,
})
}
#[derive(Debug, Clone, ModuleParameters, Buildable)]
#[module(root = crate)]
#[buildable(root = crate)]
pub struct Dropout2d {
pub one_minus_p: f32,
pub training: bool,
}
impl Dropout2d {
pub const DEFAULT_P: f32 = 0.5;
pub const DEFAULT_TRAINING: bool = true;
}
impl Module<&Array> for Dropout2d {
type Error = Exception;
type Output = Array;
fn forward(&mut self, x: &Array) -> Result<Array, Self::Error> {
let ndim = x.ndim();
if ndim != 3 && ndim != 4 {
return Err(Exception::custom("Expecting 3D or 4D input"));
}
if self.one_minus_p == 1.0 || !self.training {
return Ok(x.clone());
}
let mut mask_shape = x.shape().to_vec();
let len = mask_shape.len();
mask_shape[len - 2] = 1;
mask_shape[len - 3] = 1;
let p1 = array!(self.one_minus_p);
let mask = bernoulli(&p1, &mask_shape, None)?;
multiply(multiply(array!(1.0 / self.one_minus_p), mask)?, x).map_err(Into::into)
}
fn training_mode(&mut self, mode: bool) {
self.training = mode;
}
}
#[derive(Debug, Clone, Builder)]
#[builder(
root = crate,
build_with = build_dropout3d,
default_infallible,
err = DropoutBuildError,
)]
pub struct Dropout3dBuilder {
#[builder(optional, default = Dropout3d::DEFAULT_P)]
p: f32,
}
fn build_dropout3d(builder: Dropout3dBuilder) -> Result<Dropout3d, DropoutBuildError> {
let p = builder.p;
if !(0.0..1.0).contains(&p) {
return Err(DropoutBuildError::InvalidProbability);
}
Ok(Dropout3d {
one_minus_p: 1.0 - p,
training: Dropout3d::DEFAULT_TRAINING,
})
}
#[derive(Debug, Clone, ModuleParameters, Buildable)]
#[module(root = crate)]
#[buildable(root = crate)]
pub struct Dropout3d {
pub one_minus_p: f32,
pub training: bool,
}
impl Dropout3d {
pub const DEFAULT_P: f32 = 0.5;
pub const DEFAULT_TRAINING: bool = true;
}
impl Module<&Array> for Dropout3d {
type Error = Exception;
type Output = Array;
fn forward(&mut self, x: &Array) -> Result<Array, Self::Error> {
let ndim = x.ndim();
if ndim != 4 && ndim != 5 {
return Err(Exception::custom("Expecting 4D or 5D input"));
}
if self.one_minus_p == 1.0 || !self.training {
return Ok(x.clone());
}
let mut mask_shape = x.shape().to_vec();
let len = mask_shape.len();
mask_shape[len - 2] = 1;
mask_shape[len - 3] = 1;
mask_shape[len - 4] = 1;
let p1 = array!(self.one_minus_p);
let mask = bernoulli(&p1, &mask_shape, None)?;
multiply(multiply(array!(1.0 / self.one_minus_p), mask)?, x).map_err(Into::into)
}
fn training_mode(&mut self, mode: bool) {
self.training = mode;
}
}
#[cfg(test)]
mod tests {
use crate::random::uniform;
use float_eq::assert_float_eq;
use super::*;
#[test]
fn test_dropout() {
crate::random::seed(959).unwrap();
let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
assert_eq!(a.shape(), &[2, 8, 16]);
assert_eq!(a.dtype(), crate::Dtype::Float32);
assert_float_eq!(
a.mean(None, None).unwrap().item::<f32>(),
0.511_429_2,
abs <= 0.010_228_584
);
assert_float_eq!(
a.sum(None, None).unwrap().item::<f32>(),
130.925_87,
abs <= 2.618_517_4
);
let result = Dropout::new().forward(&a).unwrap();
assert_eq!(result.shape(), &[2, 8, 16]);
assert_eq!(result.dtype(), crate::Dtype::Float32);
assert_float_eq!(
result.mean(None, None).unwrap().item::<f32>(),
0.477_913_62,
abs <= 0.009_558_273
);
assert_float_eq!(
result.sum(None, None).unwrap().item::<f32>(),
122.345_89,
abs <= 2.446_917_8
);
}
#[test]
fn test_dropout2d() {
crate::random::seed(695).unwrap();
let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
assert_eq!(a.shape(), &[2, 8, 16]);
assert_eq!(a.dtype(), crate::Dtype::Float32);
assert_float_eq!(
a.mean(None, None).unwrap().item::<f32>(),
0.457_839_9,
abs <= 0.009_156_798
);
assert_float_eq!(
a.sum(None, None).unwrap().item::<f32>(),
117.207_016,
abs <= 2.344_140_3
);
let result = Dropout2d::new().forward(&a).unwrap();
assert_eq!(result.shape(), &[2, 8, 16]);
assert_eq!(result.dtype(), crate::Dtype::Float32);
assert_float_eq!(
result.mean(None, None).unwrap().item::<f32>(),
0.368_284_34,
abs <= 0.007_365_687
);
assert_float_eq!(
result.sum(None, None).unwrap().item::<f32>(),
94.280_79,
abs <= 1.885_615_8
);
}
#[test]
fn test_dropout3d() {
crate::random::seed(23).unwrap();
let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 8, 4], None).unwrap();
assert_eq!(a.shape(), &[2, 8, 8, 4]);
assert_eq!(a.dtype(), crate::Dtype::Float32);
assert_float_eq!(
a.mean(None, None).unwrap().item::<f32>(),
0.500_606_2,
abs <= 0.010_012_124
);
assert_float_eq!(
a.sum(None, None).unwrap().item::<f32>(),
256.310_36,
abs <= 5.126_207_4
);
let result = Dropout3d::new().forward(&a).unwrap();
assert_eq!(result.shape(), &[2, 8, 8, 4]);
assert_eq!(result.dtype(), crate::Dtype::Float32);
assert_float_eq!(
result.mean(None, None).unwrap().item::<f32>(),
0.237_284_15,
abs <= 0.004_745_683
);
assert_float_eq!(
result.sum(None, None).unwrap().item::<f32>(),
121.489_49,
abs <= 2.429_789_8
);
}
}