mlx_rs/nn/
dropout.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
use crate::module::Module;
use crate::Array;
use crate::{array, error::Exception, ops::multiply, random::bernoulli};
use mlx_internal_macros::{Buildable, Builder};
use mlx_macros::ModuleParameters;

use crate::error::DropoutBuildError;

/// Builder for [`Dropout`].
#[derive(Debug, Clone, Builder)]
#[builder(
    root = crate,
    build_with = build_dropout,
    default_infallible,
    err = DropoutBuildError,
)]
pub struct DropoutBuilder {
    /// The probability of zeroing an element.
    #[builder(optional, default = Dropout::DEFAULT_P)]
    p: f32,
}

fn build_dropout(builder: DropoutBuilder) -> Result<Dropout, DropoutBuildError> {
    let p = builder.p;

    if !(0.0..1.0).contains(&p) {
        return Err(DropoutBuildError::InvalidProbability);
    }

    Ok(Dropout {
        one_minus_p: 1.0 - p,
        training: Dropout::DEFAULT_TRAINING,
    })
}

/// Randomly zero a portion of the elements during training.
///
/// The remaining elements are multiplied with `1 / (1-p)` where
/// `p` is the probability of zeroing an element. This is done so the
/// expected value of a given element will remain the same.
#[derive(Debug, Clone, ModuleParameters, Buildable)]
#[module(root = crate)]
#[buildable(root = crate)]
pub struct Dropout {
    /// `1-p`, where `p` is the probability of zeroing an element. `p` is default to
    /// [`Dropout::DEFAULT_P`] if not specified.
    pub one_minus_p: f32,

    /// Whether the layer is in training mode. Default to [`Dropout::DEFAULT_TRAINING`] if not
    /// specified.
    pub training: bool,
}

impl Dropout {
    /// Default value for the probability of zeroing an element.
    pub const DEFAULT_P: f32 = 0.5;

    /// Default value for the training mode.
    pub const DEFAULT_TRAINING: bool = true;
}

impl Module<&Array> for Dropout {
    type Error = Exception;
    type Output = Array;

    fn forward(&mut self, x: &Array) -> Result<Array, Self::Error> {
        if self.one_minus_p == 1.0 || !self.training {
            return Ok(x.clone());
        }

        let p1 = array!(self.one_minus_p);
        let mask = bernoulli(&p1, x.shape(), None)?;
        multiply(multiply(array!(1.0 / self.one_minus_p), mask)?, x).map_err(Into::into)
    }

    fn training_mode(&mut self, mode: bool) {
        self.training = mode;
    }
}

/// Builder for [`Dropout2d`].
#[derive(Debug, Clone, Builder)]
#[builder(
    root = crate,
    build_with = build_dropout2d,
    default_infallible,
    err = DropoutBuildError,
)]
pub struct Dropout2dBuilder {
    /// The probability of zeroing a channel.
    #[builder(optional, default = Dropout2d::DEFAULT_P)]
    p: f32,
}

fn build_dropout2d(builder: Dropout2dBuilder) -> Result<Dropout2d, DropoutBuildError> {
    let p = builder.p;

    if !(0.0..1.0).contains(&p) {
        return Err(DropoutBuildError::InvalidProbability);
    }

    Ok(Dropout2d {
        one_minus_p: 1.0 - p,
        training: Dropout2d::DEFAULT_TRAINING,
    })
}

/// Apply 2D channel-wise dropout during training.
///
/// Randomly zero out entire channels independently with probability `p`.
/// This layer expects the channels to be last, i.e. the input shape should be
/// `NWHC` or `WHC` where:`N` is the batch dimension,`H` is the input
/// image height,`W` is the input image width, and`C` is the number of
/// input channels
///
/// The remaining channels are scaled by `1 / (1-p)` to
/// maintain the expected value of each element. Unlike traditional dropout,
/// which zeros individual entries, this layer zeros entire channels. This is
/// beneficial for early convolution layers where adjacent pixels are
/// correlated. In such case, traditional dropout may not effectively
/// regularize activations. For more details, see [1].
///
/// [1]: Thompson, J., Goroshin, R., Jain, A., LeCun, Y. and Bregler C., 2015.
/// Efficient Object Localization Using Convolutional Networks. CVPR 2015.
#[derive(Debug, Clone, ModuleParameters, Buildable)]
#[module(root = crate)]
#[buildable(root = crate)]
pub struct Dropout2d {
    /// `1-p`, where `p` is the probability of zeroing a channel. `p` is default to
    /// [`Dropout2d::DEFAULT_P`] if not specified.
    pub one_minus_p: f32,

    /// Whether the layer is in training mode. Default to [`Dropout2d::DEFAULT_TRAINING`] if not
    /// specified. Default to [`Dropout2d::DEFAULT_TRAINING`] if not specified.
    pub training: bool,
}

impl Dropout2d {
    /// Default value for the probability of zeroing a channel.
    pub const DEFAULT_P: f32 = 0.5;

    /// Default value for the training mode.
    pub const DEFAULT_TRAINING: bool = true;
}

impl Module<&Array> for Dropout2d {
    type Error = Exception;
    type Output = Array;

    fn forward(&mut self, x: &Array) -> Result<Array, Self::Error> {
        let ndim = x.ndim();

        if ndim != 3 && ndim != 4 {
            return Err(Exception::custom("Expecting 3D or 4D input"));
        }

        if self.one_minus_p == 1.0 || !self.training {
            return Ok(x.clone());
        }

        // Dropout is applied on the whole channel
        // 3D input: (1, 1, C)
        // 4D input: (B, 1, 1, C)

        let mut mask_shape = x.shape().to_vec();
        let len = mask_shape.len();
        mask_shape[len - 2] = 1;
        mask_shape[len - 3] = 1;

        let p1 = array!(self.one_minus_p);
        let mask = bernoulli(&p1, &mask_shape, None)?;

        multiply(multiply(array!(1.0 / self.one_minus_p), mask)?, x).map_err(Into::into)
    }

    fn training_mode(&mut self, mode: bool) {
        self.training = mode;
    }
}

/// Builder for [`Dropout3d`].
#[derive(Debug, Clone, Builder)]
#[builder(
    root = crate,
    build_with = build_dropout3d,
    default_infallible,
    err = DropoutBuildError,
)]
pub struct Dropout3dBuilder {
    /// The probability of zeroing a channel.
    #[builder(optional, default = Dropout3d::DEFAULT_P)]
    p: f32,
}

fn build_dropout3d(builder: Dropout3dBuilder) -> Result<Dropout3d, DropoutBuildError> {
    let p = builder.p;

    if !(0.0..1.0).contains(&p) {
        return Err(DropoutBuildError::InvalidProbability);
    }

    Ok(Dropout3d {
        one_minus_p: 1.0 - p,
        training: Dropout3d::DEFAULT_TRAINING,
    })
}

/// Apply 3D channel-wise dropout during training.
///
/// Randomly zero out entire channels independently with probability `p`.
/// This layer expects the channels to be last, i.e., the input shape should be
/// `NDHWC` or `DHWC` where: `N` is the batch dimension, `D` is the depth,
/// `H` is the input image height, `W` is the input image width, and `C` is
/// the number of input channels.
///
/// The remaining channels are scaled by `1 / (1-p)` to
/// maintain the expected value of each element. Unlike traditional dropout,
/// which zeros individual entries, this layer zeros entire channels. This is
/// often beneficial for convolutional layers processing 3D data, like in
/// medical imaging or video processing.
#[derive(Debug, Clone, ModuleParameters, Buildable)]
#[module(root = crate)]
#[buildable(root = crate)]
pub struct Dropout3d {
    /// `1-p`, where `p` is the probability of zeroing a channel. `p` is default to
    /// [`Dropout3d::DEFAULT_P`] if not specified.
    pub one_minus_p: f32,

    /// Whether the layer is in training mode. Default to [`Dropout3d::DEFAULT_TRAINING`] if not
    /// specified.
    pub training: bool,
}

impl Dropout3d {
    /// Default value for the probability of zeroing a channel.
    pub const DEFAULT_P: f32 = 0.5;

    /// Default value for the training mode.
    pub const DEFAULT_TRAINING: bool = true;
}

impl Module<&Array> for Dropout3d {
    type Error = Exception;
    type Output = Array;

    fn forward(&mut self, x: &Array) -> Result<Array, Self::Error> {
        let ndim = x.ndim();

        if ndim != 4 && ndim != 5 {
            return Err(Exception::custom("Expecting 4D or 5D input"));
        }

        if self.one_minus_p == 1.0 || !self.training {
            return Ok(x.clone());
        }

        // Dropout is applied on the whole channel
        // 4D input: (1, 1, 1, C)
        // 5D input: (B, 1, 1, 1, C)

        let mut mask_shape = x.shape().to_vec();
        let len = mask_shape.len();
        mask_shape[len - 2] = 1;
        mask_shape[len - 3] = 1;
        mask_shape[len - 4] = 1;

        let p1 = array!(self.one_minus_p);
        let mask = bernoulli(&p1, &mask_shape, None)?;

        multiply(multiply(array!(1.0 / self.one_minus_p), mask)?, x).map_err(Into::into)
    }

    fn training_mode(&mut self, mode: bool) {
        self.training = mode;
    }
}

// The following tests were ported from the swift binding:
// mlx-swift/Tests/MLXTests/IntegrationTests.swift
#[cfg(test)]
mod tests {
    use crate::random::uniform;
    use float_eq::assert_float_eq;

    use super::*;

    #[test]
    fn test_dropout() {
        crate::random::seed(959).unwrap();
        let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
        assert_eq!(a.shape(), &[2, 8, 16]);
        assert_eq!(a.dtype(), crate::Dtype::Float32);
        assert_float_eq!(
            a.mean(None, None).unwrap().item::<f32>(),
            0.511_429_2,
            abs <= 0.010_228_584
        );
        assert_float_eq!(
            a.sum(None, None).unwrap().item::<f32>(),
            130.925_87,
            abs <= 2.618_517_4
        );
        let result = Dropout::new().forward(&a).unwrap();
        assert_eq!(result.shape(), &[2, 8, 16]);
        assert_eq!(result.dtype(), crate::Dtype::Float32);
        assert_float_eq!(
            result.mean(None, None).unwrap().item::<f32>(),
            0.477_913_62,
            abs <= 0.009_558_273
        );
        assert_float_eq!(
            result.sum(None, None).unwrap().item::<f32>(),
            122.345_89,
            abs <= 2.446_917_8
        );
    }

    #[test]
    fn test_dropout2d() {
        crate::random::seed(695).unwrap();
        let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
        assert_eq!(a.shape(), &[2, 8, 16]);
        assert_eq!(a.dtype(), crate::Dtype::Float32);
        assert_float_eq!(
            a.mean(None, None).unwrap().item::<f32>(),
            0.457_839_9,
            abs <= 0.009_156_798
        );
        assert_float_eq!(
            a.sum(None, None).unwrap().item::<f32>(),
            117.207_016,
            abs <= 2.344_140_3
        );
        let result = Dropout2d::new().forward(&a).unwrap();
        assert_eq!(result.shape(), &[2, 8, 16]);
        assert_eq!(result.dtype(), crate::Dtype::Float32);
        assert_float_eq!(
            result.mean(None, None).unwrap().item::<f32>(),
            0.368_284_34,
            abs <= 0.007_365_687
        );
        assert_float_eq!(
            result.sum(None, None).unwrap().item::<f32>(),
            94.280_79,
            abs <= 1.885_615_8
        );
    }

    #[test]
    fn test_dropout3d() {
        crate::random::seed(23).unwrap();
        let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 8, 4], None).unwrap();
        assert_eq!(a.shape(), &[2, 8, 8, 4]);
        assert_eq!(a.dtype(), crate::Dtype::Float32);
        assert_float_eq!(
            a.mean(None, None).unwrap().item::<f32>(),
            0.500_606_2,
            abs <= 0.010_012_124
        );
        assert_float_eq!(
            a.sum(None, None).unwrap().item::<f32>(),
            256.310_36,
            abs <= 5.126_207_4
        );
        let result = Dropout3d::new().forward(&a).unwrap();
        assert_eq!(result.shape(), &[2, 8, 8, 4]);
        assert_eq!(result.dtype(), crate::Dtype::Float32);
        assert_float_eq!(
            result.mean(None, None).unwrap().item::<f32>(),
            0.237_284_15,
            abs <= 0.004_745_683
        );
        assert_float_eq!(
            result.sum(None, None).unwrap().item::<f32>(),
            121.489_49,
            abs <= 2.429_789_8
        );
    }
}