1use crate::module::Module;
2use crate::Array;
3use crate::{array, error::Exception, ops::multiply, random::bernoulli};
4use mlx_internal_macros::{Buildable, Builder};
5use mlx_macros::ModuleParameters;
6
7use crate::error::DropoutBuildError;
8
9#[derive(Debug, Clone, Builder)]
11#[builder(
12 root = crate,
13 build_with = build_dropout,
14 default_infallible,
15 err = DropoutBuildError,
16)]
17pub struct DropoutBuilder {
18 #[builder(optional, default = Dropout::DEFAULT_P)]
20 p: f32,
21}
22
23fn build_dropout(builder: DropoutBuilder) -> Result<Dropout, DropoutBuildError> {
24 let p = builder.p;
25
26 if !(0.0..1.0).contains(&p) {
27 return Err(DropoutBuildError::InvalidProbability);
28 }
29
30 Ok(Dropout {
31 one_minus_p: 1.0 - p,
32 training: Dropout::DEFAULT_TRAINING,
33 })
34}
35
36#[derive(Debug, Clone, ModuleParameters, Buildable)]
42#[module(root = crate)]
43#[buildable(root = crate)]
44pub struct Dropout {
45 pub one_minus_p: f32,
48
49 pub training: bool,
52}
53
54impl Dropout {
55 pub const DEFAULT_P: f32 = 0.5;
57
58 pub const DEFAULT_TRAINING: bool = true;
60}
61
62impl Module<&Array> for Dropout {
63 type Error = Exception;
64 type Output = Array;
65
66 fn forward(&mut self, x: &Array) -> Result<Array, Self::Error> {
67 if self.one_minus_p == 1.0 || !self.training {
68 return Ok(x.clone());
69 }
70
71 let p1 = array!(self.one_minus_p);
72 let mask = bernoulli(&p1, x.shape(), None)?;
73 multiply(multiply(array!(1.0 / self.one_minus_p), mask)?, x)
74 }
75
76 fn training_mode(&mut self, mode: bool) {
77 self.training = mode;
78 }
79}
80
81#[derive(Debug, Clone, Builder)]
83#[builder(
84 root = crate,
85 build_with = build_dropout2d,
86 default_infallible,
87 err = DropoutBuildError,
88)]
89pub struct Dropout2dBuilder {
90 #[builder(optional, default = Dropout2d::DEFAULT_P)]
92 p: f32,
93}
94
95fn build_dropout2d(builder: Dropout2dBuilder) -> Result<Dropout2d, DropoutBuildError> {
96 let p = builder.p;
97
98 if !(0.0..1.0).contains(&p) {
99 return Err(DropoutBuildError::InvalidProbability);
100 }
101
102 Ok(Dropout2d {
103 one_minus_p: 1.0 - p,
104 training: Dropout2d::DEFAULT_TRAINING,
105 })
106}
107
108#[derive(Debug, Clone, ModuleParameters, Buildable)]
126#[module(root = crate)]
127#[buildable(root = crate)]
128pub struct Dropout2d {
129 pub one_minus_p: f32,
132
133 pub training: bool,
136}
137
138impl Dropout2d {
139 pub const DEFAULT_P: f32 = 0.5;
141
142 pub const DEFAULT_TRAINING: bool = true;
144}
145
146impl Module<&Array> for Dropout2d {
147 type Error = Exception;
148 type Output = Array;
149
150 fn forward(&mut self, x: &Array) -> Result<Array, Self::Error> {
151 let ndim = x.ndim();
152
153 if ndim != 3 && ndim != 4 {
154 return Err(Exception::custom("Expecting 3D or 4D input"));
155 }
156
157 if self.one_minus_p == 1.0 || !self.training {
158 return Ok(x.clone());
159 }
160
161 let mut mask_shape = x.shape().to_vec();
166 let len = mask_shape.len();
167 mask_shape[len - 2] = 1;
168 mask_shape[len - 3] = 1;
169
170 let p1 = array!(self.one_minus_p);
171 let mask = bernoulli(&p1, &mask_shape, None)?;
172
173 multiply(multiply(array!(1.0 / self.one_minus_p), mask)?, x)
174 }
175
176 fn training_mode(&mut self, mode: bool) {
177 self.training = mode;
178 }
179}
180
181#[derive(Debug, Clone, Builder)]
183#[builder(
184 root = crate,
185 build_with = build_dropout3d,
186 default_infallible,
187 err = DropoutBuildError,
188)]
189pub struct Dropout3dBuilder {
190 #[builder(optional, default = Dropout3d::DEFAULT_P)]
192 p: f32,
193}
194
195fn build_dropout3d(builder: Dropout3dBuilder) -> Result<Dropout3d, DropoutBuildError> {
196 let p = builder.p;
197
198 if !(0.0..1.0).contains(&p) {
199 return Err(DropoutBuildError::InvalidProbability);
200 }
201
202 Ok(Dropout3d {
203 one_minus_p: 1.0 - p,
204 training: Dropout3d::DEFAULT_TRAINING,
205 })
206}
207
208#[derive(Debug, Clone, ModuleParameters, Buildable)]
222#[module(root = crate)]
223#[buildable(root = crate)]
224pub struct Dropout3d {
225 pub one_minus_p: f32,
228
229 pub training: bool,
232}
233
234impl Dropout3d {
235 pub const DEFAULT_P: f32 = 0.5;
237
238 pub const DEFAULT_TRAINING: bool = true;
240}
241
242impl Module<&Array> for Dropout3d {
243 type Error = Exception;
244 type Output = Array;
245
246 fn forward(&mut self, x: &Array) -> Result<Array, Self::Error> {
247 let ndim = x.ndim();
248
249 if ndim != 4 && ndim != 5 {
250 return Err(Exception::custom("Expecting 4D or 5D input"));
251 }
252
253 if self.one_minus_p == 1.0 || !self.training {
254 return Ok(x.clone());
255 }
256
257 let mut mask_shape = x.shape().to_vec();
262 let len = mask_shape.len();
263 mask_shape[len - 2] = 1;
264 mask_shape[len - 3] = 1;
265 mask_shape[len - 4] = 1;
266
267 let p1 = array!(self.one_minus_p);
268 let mask = bernoulli(&p1, &mask_shape, None)?;
269
270 multiply(multiply(array!(1.0 / self.one_minus_p), mask)?, x)
271 }
272
273 fn training_mode(&mut self, mode: bool) {
274 self.training = mode;
275 }
276}
277
278#[cfg(test)]
281mod tests {
282 use crate::random::uniform;
283 use float_eq::assert_float_eq;
284
285 use super::*;
286
287 #[test]
288 fn test_dropout() {
289 crate::random::seed(959).unwrap();
290 let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
291 assert_eq!(a.shape(), &[2, 8, 16]);
292 assert_eq!(a.dtype(), crate::Dtype::Float32);
293 assert_float_eq!(
294 a.mean(None, None).unwrap().item::<f32>(),
295 0.511_429_2,
296 abs <= 0.010_228_584
297 );
298 assert_float_eq!(
299 a.sum(None, None).unwrap().item::<f32>(),
300 130.925_87,
301 abs <= 2.618_517_4
302 );
303 let result = Dropout::new().forward(&a).unwrap();
304 assert_eq!(result.shape(), &[2, 8, 16]);
305 assert_eq!(result.dtype(), crate::Dtype::Float32);
306 assert_float_eq!(
307 result.mean(None, None).unwrap().item::<f32>(),
308 0.477_913_62,
309 abs <= 0.009_558_273
310 );
311 assert_float_eq!(
312 result.sum(None, None).unwrap().item::<f32>(),
313 122.345_89,
314 abs <= 2.446_917_8
315 );
316 }
317
318 #[test]
319 fn test_dropout2d() {
320 crate::random::seed(695).unwrap();
321 let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
322 assert_eq!(a.shape(), &[2, 8, 16]);
323 assert_eq!(a.dtype(), crate::Dtype::Float32);
324 assert_float_eq!(
325 a.mean(None, None).unwrap().item::<f32>(),
326 0.457_839_9,
327 abs <= 0.009_156_798
328 );
329 assert_float_eq!(
330 a.sum(None, None).unwrap().item::<f32>(),
331 117.207_016,
332 abs <= 2.344_140_3
333 );
334 let result = Dropout2d::new().forward(&a).unwrap();
335 assert_eq!(result.shape(), &[2, 8, 16]);
336 assert_eq!(result.dtype(), crate::Dtype::Float32);
337 assert_float_eq!(
338 result.mean(None, None).unwrap().item::<f32>(),
339 0.368_284_34,
340 abs <= 0.007_365_687
341 );
342 assert_float_eq!(
343 result.sum(None, None).unwrap().item::<f32>(),
344 94.280_79,
345 abs <= 1.885_615_8
346 );
347 }
348
349 #[test]
350 fn test_dropout3d() {
351 crate::random::seed(23).unwrap();
352 let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 8, 4], None).unwrap();
353 assert_eq!(a.shape(), &[2, 8, 8, 4]);
354 assert_eq!(a.dtype(), crate::Dtype::Float32);
355 assert_float_eq!(
356 a.mean(None, None).unwrap().item::<f32>(),
357 0.500_606_2,
358 abs <= 0.010_012_124
359 );
360 assert_float_eq!(
361 a.sum(None, None).unwrap().item::<f32>(),
362 256.310_36,
363 abs <= 5.126_207_4
364 );
365 let result = Dropout3d::new().forward(&a).unwrap();
366 assert_eq!(result.shape(), &[2, 8, 8, 4]);
367 assert_eq!(result.dtype(), crate::Dtype::Float32);
368 assert_float_eq!(
369 result.mean(None, None).unwrap().item::<f32>(),
370 0.237_284_15,
371 abs <= 0.004_745_683
372 );
373 assert_float_eq!(
374 result.sum(None, None).unwrap().item::<f32>(),
375 121.489_49,
376 abs <= 2.429_789_8
377 );
378 }
379}