mlx_rs/nn/
dropout.rs

1use crate::module::Module;
2use crate::Array;
3use crate::{array, error::Exception, ops::multiply, random::bernoulli};
4use mlx_internal_macros::{Buildable, Builder};
5use mlx_macros::ModuleParameters;
6
7use crate::error::DropoutBuildError;
8
9/// Builder for [`Dropout`].
10#[derive(Debug, Clone, Builder)]
11#[builder(
12    root = crate,
13    build_with = build_dropout,
14    default_infallible,
15    err = DropoutBuildError,
16)]
17pub struct DropoutBuilder {
18    /// The probability of zeroing an element.
19    #[builder(optional, default = Dropout::DEFAULT_P)]
20    p: f32,
21}
22
23fn build_dropout(builder: DropoutBuilder) -> Result<Dropout, DropoutBuildError> {
24    let p = builder.p;
25
26    if !(0.0..1.0).contains(&p) {
27        return Err(DropoutBuildError::InvalidProbability);
28    }
29
30    Ok(Dropout {
31        one_minus_p: 1.0 - p,
32        training: Dropout::DEFAULT_TRAINING,
33    })
34}
35
36/// Randomly zero a portion of the elements during training.
37///
38/// The remaining elements are multiplied with `1 / (1-p)` where
39/// `p` is the probability of zeroing an element. This is done so the
40/// expected value of a given element will remain the same.
41#[derive(Debug, Clone, ModuleParameters, Buildable)]
42#[module(root = crate)]
43#[buildable(root = crate)]
44pub struct Dropout {
45    /// `1-p`, where `p` is the probability of zeroing an element. `p` is default to
46    /// [`Dropout::DEFAULT_P`] if not specified.
47    pub one_minus_p: f32,
48
49    /// Whether the layer is in training mode. Default to [`Dropout::DEFAULT_TRAINING`] if not
50    /// specified.
51    pub training: bool,
52}
53
54impl Dropout {
55    /// Default value for the probability of zeroing an element.
56    pub const DEFAULT_P: f32 = 0.5;
57
58    /// Default value for the training mode.
59    pub const DEFAULT_TRAINING: bool = true;
60}
61
62impl Module<&Array> for Dropout {
63    type Error = Exception;
64    type Output = Array;
65
66    fn forward(&mut self, x: &Array) -> Result<Array, Self::Error> {
67        if self.one_minus_p == 1.0 || !self.training {
68            return Ok(x.clone());
69        }
70
71        let p1 = array!(self.one_minus_p);
72        let mask = bernoulli(&p1, x.shape(), None)?;
73        multiply(multiply(array!(1.0 / self.one_minus_p), mask)?, x)
74    }
75
76    fn training_mode(&mut self, mode: bool) {
77        self.training = mode;
78    }
79}
80
81/// Builder for [`Dropout2d`].
82#[derive(Debug, Clone, Builder)]
83#[builder(
84    root = crate,
85    build_with = build_dropout2d,
86    default_infallible,
87    err = DropoutBuildError,
88)]
89pub struct Dropout2dBuilder {
90    /// The probability of zeroing a channel.
91    #[builder(optional, default = Dropout2d::DEFAULT_P)]
92    p: f32,
93}
94
95fn build_dropout2d(builder: Dropout2dBuilder) -> Result<Dropout2d, DropoutBuildError> {
96    let p = builder.p;
97
98    if !(0.0..1.0).contains(&p) {
99        return Err(DropoutBuildError::InvalidProbability);
100    }
101
102    Ok(Dropout2d {
103        one_minus_p: 1.0 - p,
104        training: Dropout2d::DEFAULT_TRAINING,
105    })
106}
107
108/// Apply 2D channel-wise dropout during training.
109///
110/// Randomly zero out entire channels independently with probability `p`.
111/// This layer expects the channels to be last, i.e. the input shape should be
112/// `NWHC` or `WHC` where:`N` is the batch dimension,`H` is the input
113/// image height,`W` is the input image width, and`C` is the number of
114/// input channels
115///
116/// The remaining channels are scaled by `1 / (1-p)` to
117/// maintain the expected value of each element. Unlike traditional dropout,
118/// which zeros individual entries, this layer zeros entire channels. This is
119/// beneficial for early convolution layers where adjacent pixels are
120/// correlated. In such case, traditional dropout may not effectively
121/// regularize activations. For more details, see [1].
122///
123/// [1]: Thompson, J., Goroshin, R., Jain, A., LeCun, Y. and Bregler C., 2015.
124/// Efficient Object Localization Using Convolutional Networks. CVPR 2015.
125#[derive(Debug, Clone, ModuleParameters, Buildable)]
126#[module(root = crate)]
127#[buildable(root = crate)]
128pub struct Dropout2d {
129    /// `1-p`, where `p` is the probability of zeroing a channel. `p` is default to
130    /// [`Dropout2d::DEFAULT_P`] if not specified.
131    pub one_minus_p: f32,
132
133    /// Whether the layer is in training mode. Default to [`Dropout2d::DEFAULT_TRAINING`] if not
134    /// specified. Default to [`Dropout2d::DEFAULT_TRAINING`] if not specified.
135    pub training: bool,
136}
137
138impl Dropout2d {
139    /// Default value for the probability of zeroing a channel.
140    pub const DEFAULT_P: f32 = 0.5;
141
142    /// Default value for the training mode.
143    pub const DEFAULT_TRAINING: bool = true;
144}
145
146impl Module<&Array> for Dropout2d {
147    type Error = Exception;
148    type Output = Array;
149
150    fn forward(&mut self, x: &Array) -> Result<Array, Self::Error> {
151        let ndim = x.ndim();
152
153        if ndim != 3 && ndim != 4 {
154            return Err(Exception::custom("Expecting 3D or 4D input"));
155        }
156
157        if self.one_minus_p == 1.0 || !self.training {
158            return Ok(x.clone());
159        }
160
161        // Dropout is applied on the whole channel
162        // 3D input: (1, 1, C)
163        // 4D input: (B, 1, 1, C)
164
165        let mut mask_shape = x.shape().to_vec();
166        let len = mask_shape.len();
167        mask_shape[len - 2] = 1;
168        mask_shape[len - 3] = 1;
169
170        let p1 = array!(self.one_minus_p);
171        let mask = bernoulli(&p1, &mask_shape, None)?;
172
173        multiply(multiply(array!(1.0 / self.one_minus_p), mask)?, x)
174    }
175
176    fn training_mode(&mut self, mode: bool) {
177        self.training = mode;
178    }
179}
180
181/// Builder for [`Dropout3d`].
182#[derive(Debug, Clone, Builder)]
183#[builder(
184    root = crate,
185    build_with = build_dropout3d,
186    default_infallible,
187    err = DropoutBuildError,
188)]
189pub struct Dropout3dBuilder {
190    /// The probability of zeroing a channel.
191    #[builder(optional, default = Dropout3d::DEFAULT_P)]
192    p: f32,
193}
194
195fn build_dropout3d(builder: Dropout3dBuilder) -> Result<Dropout3d, DropoutBuildError> {
196    let p = builder.p;
197
198    if !(0.0..1.0).contains(&p) {
199        return Err(DropoutBuildError::InvalidProbability);
200    }
201
202    Ok(Dropout3d {
203        one_minus_p: 1.0 - p,
204        training: Dropout3d::DEFAULT_TRAINING,
205    })
206}
207
208/// Apply 3D channel-wise dropout during training.
209///
210/// Randomly zero out entire channels independently with probability `p`.
211/// This layer expects the channels to be last, i.e., the input shape should be
212/// `NDHWC` or `DHWC` where: `N` is the batch dimension, `D` is the depth,
213/// `H` is the input image height, `W` is the input image width, and `C` is
214/// the number of input channels.
215///
216/// The remaining channels are scaled by `1 / (1-p)` to
217/// maintain the expected value of each element. Unlike traditional dropout,
218/// which zeros individual entries, this layer zeros entire channels. This is
219/// often beneficial for convolutional layers processing 3D data, like in
220/// medical imaging or video processing.
221#[derive(Debug, Clone, ModuleParameters, Buildable)]
222#[module(root = crate)]
223#[buildable(root = crate)]
224pub struct Dropout3d {
225    /// `1-p`, where `p` is the probability of zeroing a channel. `p` is default to
226    /// [`Dropout3d::DEFAULT_P`] if not specified.
227    pub one_minus_p: f32,
228
229    /// Whether the layer is in training mode. Default to [`Dropout3d::DEFAULT_TRAINING`] if not
230    /// specified.
231    pub training: bool,
232}
233
234impl Dropout3d {
235    /// Default value for the probability of zeroing a channel.
236    pub const DEFAULT_P: f32 = 0.5;
237
238    /// Default value for the training mode.
239    pub const DEFAULT_TRAINING: bool = true;
240}
241
242impl Module<&Array> for Dropout3d {
243    type Error = Exception;
244    type Output = Array;
245
246    fn forward(&mut self, x: &Array) -> Result<Array, Self::Error> {
247        let ndim = x.ndim();
248
249        if ndim != 4 && ndim != 5 {
250            return Err(Exception::custom("Expecting 4D or 5D input"));
251        }
252
253        if self.one_minus_p == 1.0 || !self.training {
254            return Ok(x.clone());
255        }
256
257        // Dropout is applied on the whole channel
258        // 4D input: (1, 1, 1, C)
259        // 5D input: (B, 1, 1, 1, C)
260
261        let mut mask_shape = x.shape().to_vec();
262        let len = mask_shape.len();
263        mask_shape[len - 2] = 1;
264        mask_shape[len - 3] = 1;
265        mask_shape[len - 4] = 1;
266
267        let p1 = array!(self.one_minus_p);
268        let mask = bernoulli(&p1, &mask_shape, None)?;
269
270        multiply(multiply(array!(1.0 / self.one_minus_p), mask)?, x)
271    }
272
273    fn training_mode(&mut self, mode: bool) {
274        self.training = mode;
275    }
276}
277
278// The following tests were ported from the swift binding:
279// mlx-swift/Tests/MLXTests/IntegrationTests.swift
280#[cfg(test)]
281mod tests {
282    use crate::random::uniform;
283    use float_eq::assert_float_eq;
284
285    use super::*;
286
287    #[test]
288    fn test_dropout() {
289        crate::random::seed(959).unwrap();
290        let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
291        assert_eq!(a.shape(), &[2, 8, 16]);
292        assert_eq!(a.dtype(), crate::Dtype::Float32);
293        assert_float_eq!(
294            a.mean(None, None).unwrap().item::<f32>(),
295            0.511_429_2,
296            abs <= 0.010_228_584
297        );
298        assert_float_eq!(
299            a.sum(None, None).unwrap().item::<f32>(),
300            130.925_87,
301            abs <= 2.618_517_4
302        );
303        let result = Dropout::new().forward(&a).unwrap();
304        assert_eq!(result.shape(), &[2, 8, 16]);
305        assert_eq!(result.dtype(), crate::Dtype::Float32);
306        assert_float_eq!(
307            result.mean(None, None).unwrap().item::<f32>(),
308            0.477_913_62,
309            abs <= 0.009_558_273
310        );
311        assert_float_eq!(
312            result.sum(None, None).unwrap().item::<f32>(),
313            122.345_89,
314            abs <= 2.446_917_8
315        );
316    }
317
318    #[test]
319    fn test_dropout2d() {
320        crate::random::seed(695).unwrap();
321        let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
322        assert_eq!(a.shape(), &[2, 8, 16]);
323        assert_eq!(a.dtype(), crate::Dtype::Float32);
324        assert_float_eq!(
325            a.mean(None, None).unwrap().item::<f32>(),
326            0.457_839_9,
327            abs <= 0.009_156_798
328        );
329        assert_float_eq!(
330            a.sum(None, None).unwrap().item::<f32>(),
331            117.207_016,
332            abs <= 2.344_140_3
333        );
334        let result = Dropout2d::new().forward(&a).unwrap();
335        assert_eq!(result.shape(), &[2, 8, 16]);
336        assert_eq!(result.dtype(), crate::Dtype::Float32);
337        assert_float_eq!(
338            result.mean(None, None).unwrap().item::<f32>(),
339            0.368_284_34,
340            abs <= 0.007_365_687
341        );
342        assert_float_eq!(
343            result.sum(None, None).unwrap().item::<f32>(),
344            94.280_79,
345            abs <= 1.885_615_8
346        );
347    }
348
349    #[test]
350    fn test_dropout3d() {
351        crate::random::seed(23).unwrap();
352        let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 8, 4], None).unwrap();
353        assert_eq!(a.shape(), &[2, 8, 8, 4]);
354        assert_eq!(a.dtype(), crate::Dtype::Float32);
355        assert_float_eq!(
356            a.mean(None, None).unwrap().item::<f32>(),
357            0.500_606_2,
358            abs <= 0.010_012_124
359        );
360        assert_float_eq!(
361            a.sum(None, None).unwrap().item::<f32>(),
362            256.310_36,
363            abs <= 5.126_207_4
364        );
365        let result = Dropout3d::new().forward(&a).unwrap();
366        assert_eq!(result.shape(), &[2, 8, 8, 4]);
367        assert_eq!(result.dtype(), crate::Dtype::Float32);
368        assert_float_eq!(
369            result.mean(None, None).unwrap().item::<f32>(),
370            0.237_284_15,
371            abs <= 0.004_745_683
372        );
373        assert_float_eq!(
374            result.sum(None, None).unwrap().item::<f32>(),
375            121.489_49,
376            abs <= 2.429_789_8
377        );
378    }
379}