mlx_rs/nn/
convolution.rs

1use crate::module::{Module, Param};
2use crate::{
3    error::Exception,
4    ops::{conv1d, conv2d, zeros},
5    random::uniform,
6    Array,
7};
8use mlx_internal_macros::{Buildable, Builder};
9use mlx_macros::ModuleParameters;
10
11use crate::utils::{SingleOrPair, SingleOrTriple};
12
13/// Builder for the `Conv1d` module.
14#[derive(Debug, Clone, Builder)]
15#[builder(
16    root = crate,
17    build_with = build_conv1d,
18    err = Exception,
19)]
20pub struct Conv1dBuilder {
21    /// Number of input channels.
22    pub input_channels: i32,
23
24    /// Number of output channels.
25    pub output_channels: i32,
26
27    /// Size of the convolution filters.
28    pub kernel_size: i32,
29
30    /// If `true`, add a learnable bias to the output. Default to [`Conv1d::DEFAULT_BIAS`] if not
31    /// specified.
32    #[builder(optional, default = Conv1d::DEFAULT_BIAS)]
33    pub bias: bool,
34
35    /// Stride. Default to [`Conv1d::DEFAULT_STRIDE`] if not specified.
36    #[builder(optional, default = Conv1d::DEFAULT_STRIDE)]
37    pub stride: i32,
38
39    /// Padding. Default to [`Conv1d::DEFAULT_PADDING`] if not specified.
40    #[builder(optional, default = Conv1d::DEFAULT_PADDING)]
41    pub padding: i32,
42
43    /// Dilation. Default to [`Conv1d::DEFAULT_DILATION`] if not specified.
44    #[builder(optional, default = Conv1d::DEFAULT_DILATION)]
45    pub dilation: i32,
46
47    /// Groups. Default to [`Conv1d::DEFAULT_GROUPS`] if not specified.
48    #[builder(optional, default = Conv1d::DEFAULT_GROUPS)]
49    pub groups: i32,
50}
51
52fn build_conv1d(builder: Conv1dBuilder) -> Result<Conv1d, Exception> {
53    let input_channels = builder.input_channels;
54    let output_channels = builder.output_channels;
55    let kernel_size = builder.kernel_size;
56    let with_bias = builder.bias;
57
58    let scale = f32::sqrt(1.0f32 / (input_channels * kernel_size) as f32);
59    let weight = uniform::<_, f32>(
60        -scale,
61        scale,
62        &[output_channels, kernel_size, input_channels],
63        None,
64    )?;
65    let bias = if with_bias {
66        Some(zeros::<f32>(&[output_channels])?)
67    } else {
68        None
69    };
70
71    Ok(Conv1d {
72        weight: Param::new(weight),
73        bias: Param::new(bias),
74        stride: builder.stride,
75        padding: builder.padding,
76        dilation: builder.dilation,
77        groups: builder.groups,
78    })
79}
80
81/// Applies a 1-dimensional convolution over the multi-channel input sequence.
82///
83/// The channels are expected to be last i.e. the input shape should be `NLC` where:
84///
85/// - `N` is the batch dimension
86/// - `L` is the sequence length
87/// - `C` is the number of input channels
88#[derive(Debug, Clone, ModuleParameters, Buildable)]
89#[module(root = crate)]
90#[buildable(root = crate)]
91pub struct Conv1d {
92    /// The weight of the convolution layer.
93    #[param]
94    pub weight: Param<Array>,
95
96    /// The bias of the convolution layer.
97    #[param]
98    pub bias: Param<Option<Array>>,
99
100    /// Stride. Default to [`Conv1d::DEFAULT_STRIDE`] if not specified.
101    pub stride: i32,
102
103    /// Padding. Default to [`Conv1d::DEFAULT_PADDING`] if not specified.
104    pub padding: i32,
105
106    /// Dilation. Default to [`Conv1d::DEFAULT_DILATION`] if not specified.
107    pub dilation: i32,
108
109    /// Groups. Default to [`Conv1d::DEFAULT_GROUPS`] if not specified.
110    pub groups: i32,
111}
112
113impl Conv1d {
114    /// Default value for `with_bias` if not specified.
115    pub const DEFAULT_BIAS: bool = true;
116
117    /// Default value for `stride` if not specified.
118    pub const DEFAULT_STRIDE: i32 = 1;
119
120    /// Default value for `padding` if not specified.
121    pub const DEFAULT_PADDING: i32 = 0;
122
123    /// Default value for `dilation` if not specified.
124    pub const DEFAULT_DILATION: i32 = 1;
125
126    /// Default value for `groups` if not specified.
127    pub const DEFAULT_GROUPS: i32 = 1;
128}
129
130impl Module<&Array> for Conv1d {
131    type Error = Exception;
132    type Output = Array;
133
134    fn forward(&mut self, x: &Array) -> Result<Array, Self::Error> {
135        let mut y = conv1d(
136            x,
137            self.weight.as_ref(),
138            self.stride,
139            self.padding,
140            self.dilation,
141            self.groups,
142        )?;
143        if let Some(bias) = &self.bias.value {
144            y += bias;
145        }
146        Ok(y)
147    }
148
149    fn training_mode(&mut self, _: bool) {}
150}
151
152/// Builder for the `Conv2d` module.
153#[derive(Debug, Clone, Builder)]
154#[builder(
155    root = crate,
156    build_with = build_conv2d,
157    err = Exception,
158)]
159pub struct Conv2dBuilder {
160    /// Number of input channels.
161    pub input_channels: i32,
162
163    /// Number of output channels.
164    pub output_channels: i32,
165
166    /// Size of the convolution filters.
167    pub kernel_size: SingleOrPair<i32>,
168
169    /// If `true`, add a learnable bias to the output. Default to [`Conv2d::DEFAULT_BIAS`] if not
170    /// specified.
171    #[builder(optional, default = Conv2d::DEFAULT_BIAS)]
172    pub bias: bool,
173
174    /// Stride. Default to [`Conv2d::DEFAULT_STRIDE`] if not specified.
175    #[builder(optional, default = Conv2d::DEFAULT_STRIDE)]
176    pub stride: SingleOrPair<i32>,
177
178    /// Padding. Default to [`Conv2d::DEFAULT_PADDING`] if not specified.
179    #[builder(optional, default = Conv2d::DEFAULT_PADDING)]
180    pub padding: SingleOrPair<i32>,
181
182    /// Dilation. Default to [`Conv2d::DEFAULT_DILATION`] if not specified.
183    #[builder(optional, default = Conv2d::DEFAULT_DILATION)]
184    pub dilation: SingleOrPair<i32>,
185
186    /// Groups. Default to [`Conv2d::DEFAULT_GROUPS`] if not specified.
187    #[builder(optional, default = Conv2d::DEFAULT_GROUPS)]
188    pub groups: i32,
189}
190
191fn build_conv2d(builder: Conv2dBuilder) -> Result<Conv2d, Exception> {
192    let input_channels = builder.input_channels;
193    let output_channels = builder.output_channels;
194    let kernel_size: (i32, i32) = builder.kernel_size.into();
195    let with_bias = builder.bias;
196    let padding = builder.padding.into();
197    let stride = builder.stride.into();
198    let dilation = builder.dilation.into();
199
200    let scale = f32::sqrt(1.0 / (input_channels * kernel_size.0 * kernel_size.1) as f32);
201    let weight = uniform::<_, f32>(
202        -scale,
203        scale,
204        &[
205            output_channels,
206            kernel_size.0,
207            kernel_size.1,
208            input_channels,
209        ],
210        None,
211    )?;
212    let bias = if with_bias {
213        Some(zeros::<f32>(&[output_channels])?)
214    } else {
215        None
216    };
217
218    Ok(Conv2d {
219        weight: Param::new(weight),
220        bias: Param::new(bias),
221        stride,
222        padding,
223        dilation,
224        groups: builder.groups,
225    })
226}
227
228/// Applies a 2-dimensional convolution over the multi-channel input image.
229///
230/// The channels are expected to be last i.e. the input shape should be `NHWC` where:
231///
232/// - `N` is the batch dimension
233/// - `H` is the input image height
234/// - `W` is the input image width
235/// - `C` is the number of input channels
236#[derive(Debug, Clone, ModuleParameters, Buildable)]
237#[module(root = crate)]
238#[buildable(root = crate)]
239pub struct Conv2d {
240    /// The weight of the convolution layer.
241    #[param]
242    pub weight: Param<Array>,
243
244    /// The bias of the convolution layer.
245    #[param]
246    pub bias: Param<Option<Array>>,
247
248    /// Stride. Default to [`Conv2d::DEFAULT_STRIDE`] if not specified.
249    pub stride: (i32, i32),
250
251    /// Padding. Default to [`Conv2d::DEFAULT_PADDING`] if not specified.
252    pub padding: (i32, i32),
253
254    /// Dilation. Default to [`Conv2d::DEFAULT_DILATION`] if not specified.
255    pub dilation: (i32, i32),
256
257    /// Groups. Default to [`Conv2d::DEFAULT_GROUPS`] if not specified.
258    pub groups: i32,
259}
260
261impl Conv2d {
262    /// Default value for `with_bias` if not specified.
263    pub const DEFAULT_BIAS: bool = true;
264
265    /// Default value for `stride` if not specified.
266    pub const DEFAULT_STRIDE: SingleOrPair = SingleOrPair::Pair(1, 1);
267
268    /// Default value for `padding` if not specified.
269    pub const DEFAULT_PADDING: SingleOrPair = SingleOrPair::Pair(0, 0);
270
271    /// Default value for `dilation` if not specified.
272    pub const DEFAULT_DILATION: SingleOrPair = SingleOrPair::Pair(1, 1);
273
274    /// Default value for `groups` if not specified.
275    pub const DEFAULT_GROUPS: i32 = 1;
276}
277
278impl Module<&Array> for Conv2d {
279    type Error = Exception;
280    type Output = Array;
281
282    fn forward(&mut self, x: &Array) -> Result<Array, Self::Error> {
283        let mut y = conv2d(
284            x,
285            self.weight.as_ref(),
286            self.stride,
287            self.padding,
288            self.dilation,
289            self.groups,
290        )?;
291        if let Some(bias) = &self.bias.value {
292            y += bias;
293        }
294        Ok(y)
295    }
296
297    fn training_mode(&mut self, _: bool) {}
298}
299
300/// Builder for the `Conv3d` module.
301#[derive(Debug, Clone, Builder)]
302#[builder(
303    root = crate,
304    build_with = build_conv3d,
305    err = Exception,
306)]
307pub struct Conv3dBuilder {
308    /// Number of input channels.
309    pub input_channels: i32,
310
311    /// Number of output channels.
312    pub output_channels: i32,
313
314    /// Size of the convolution filters.
315    pub kernel_size: SingleOrTriple<i32>,
316
317    /// If `true`, add a learnable bias to the output. Default to [`Conv3d::DEFAULT_BIAS`] if not
318    /// specified.
319    #[builder(optional, default = Conv3d::DEFAULT_BIAS)]
320    pub bias: bool,
321
322    /// Stride. Default to [`Conv3d::DEFAULT_STRIDE`] if not specified.
323    #[builder(optional, default = Conv3d::DEFAULT_STRIDE)]
324    pub stride: SingleOrTriple<i32>,
325
326    /// Padding. Default to [`Conv3d::DEFAULT_PADDING`] if not specified.
327    #[builder(optional, default = Conv3d::DEFAULT_PADDING)]
328    pub padding: SingleOrTriple<i32>,
329
330    /// Dilation. Default to [`Conv3d::DEFAULT_DILATION`] if not specified.
331    #[builder(optional, default = Conv3d::DEFAULT_DILATION)]
332    pub dilation: SingleOrTriple<i32>,
333
334    /// Groups. Default to [`Conv3d::DEFAULT_GROUPS`] if not specified.
335    #[builder(optional, default = Conv3d::DEFAULT_GROUPS)]
336    pub groups: i32,
337}
338
339fn build_conv3d(builder: Conv3dBuilder) -> Result<Conv3d, Exception> {
340    let input_channels = builder.input_channels;
341    let output_channels = builder.output_channels;
342    let kernel_size: (i32, i32, i32) = builder.kernel_size.into();
343    let with_bias = builder.bias;
344    let padding = builder.padding.into();
345    let stride = builder.stride.into();
346    let dilation = builder.dilation.into();
347
348    let scale =
349        f32::sqrt(1.0 / (input_channels * kernel_size.0 * kernel_size.1 * kernel_size.2) as f32);
350    let weight = uniform::<_, f32>(
351        -scale,
352        scale,
353        &[
354            output_channels,
355            kernel_size.0,
356            kernel_size.1,
357            kernel_size.2,
358            input_channels,
359        ],
360        None,
361    )?;
362    let bias = if with_bias {
363        Some(zeros::<f32>(&[output_channels])?)
364    } else {
365        None
366    };
367
368    Ok(Conv3d {
369        weight: Param::new(weight),
370        bias: Param::new(bias),
371        stride,
372        padding,
373        dilation,
374        groups: builder.groups,
375    })
376}
377
378/// Applies a 3-dimensional convolution over the multi-channel input image.
379///
380/// The channels are expected to be last i.e. the input shape should be `NHWC` where:
381///
382/// - `N` is the batch dimension
383/// - `H` is the input image height
384/// - `W` is the input image width
385/// - `C` is the number of input channels
386#[derive(Debug, Clone, ModuleParameters, Buildable)]
387#[module(root = crate)]
388#[buildable(root = crate)]
389pub struct Conv3d {
390    /// The weight of the convolution layer.
391    #[param]
392    pub weight: Param<Array>,
393
394    /// The bias of the convolution layer.
395    #[param]
396    pub bias: Param<Option<Array>>,
397
398    /// Stride. Default to `(1, 1, 1)` if not specified.
399    pub stride: (i32, i32, i32),
400
401    /// Padding. Default to `(0, 0, 0)` if not specified.
402    pub padding: (i32, i32, i32),
403
404    /// Dilation. Default to `(1, 1, 1)` if not specified.
405    pub dilation: (i32, i32, i32),
406
407    /// Groups. Default to 1 if not specified.
408    pub groups: i32,
409}
410
411impl Conv3d {
412    /// Default value for `with_bias` if not specified.
413    pub const DEFAULT_BIAS: bool = true;
414
415    /// Default value for `stride` if not specified.
416    pub const DEFAULT_STRIDE: SingleOrTriple<i32> = SingleOrTriple::Triple(1, 1, 1);
417
418    /// Default value for `padding` if not specified.
419    pub const DEFAULT_PADDING: SingleOrTriple<i32> = SingleOrTriple::Triple(0, 0, 0);
420
421    /// Default value for `dilation` if not specified.
422    pub const DEFAULT_DILATION: SingleOrTriple<i32> = SingleOrTriple::Triple(1, 1, 1);
423
424    /// Default value for `groups` if not specified.
425    pub const DEFAULT_GROUPS: i32 = 1;
426}
427
428impl Module<&Array> for Conv3d {
429    type Error = Exception;
430    type Output = Array;
431
432    fn forward(&mut self, x: &Array) -> Result<Array, Self::Error> {
433        let mut y = crate::ops::conv3d(
434            x,
435            self.weight.as_ref(),
436            self.stride,
437            self.padding,
438            self.dilation,
439            self.groups,
440        )?;
441        if let Some(bias) = &self.bias.value {
442            y += bias;
443        }
444        Ok(y)
445    }
446
447    fn training_mode(&mut self, _: bool) {}
448}
449
450// The following tests are ported from the swift bindings:
451// mlx-swift/Tests/MLXTests/IntegrationTests.swift
452#[cfg(test)]
453mod tests {
454    use crate::module::Module;
455    use crate::{random::uniform, Dtype};
456    use float_eq::assert_float_eq;
457
458    use crate::nn::Conv1d;
459
460    #[test]
461    fn test_conv1d() {
462        crate::random::seed(819).unwrap();
463        let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
464        assert_eq!(a.shape(), &[2, 8, 16]);
465        assert_eq!(a.dtype(), Dtype::Float32);
466        assert_float_eq!(
467            a.mean(None, None).unwrap().item::<f32>(),
468            0.512_987_5,
469            abs <= 0.010_259_75
470        );
471        assert_float_eq!(
472            a.sum(None, None).unwrap().item::<f32>(),
473            131.324_8,
474            abs <= 2.626_496
475        );
476        let result = Conv1d::new(16, 2, 8).unwrap().forward(&a).unwrap();
477        assert_eq!(result.shape(), &[2, 1, 2]);
478        assert_eq!(result.dtype(), Dtype::Float32);
479        assert_float_eq!(
480            result.mean(None, None).unwrap().item::<f32>(),
481            0.264_865_2,
482            abs <= 0.005_297_303_7
483        );
484        assert_float_eq!(
485            result.sum(None, None).unwrap().item::<f32>(),
486            1.059_460_8,
487            abs <= 0.021_189_215
488        );
489    }
490
491    #[test]
492    fn test_conv2d() {
493        crate::random::seed(62).unwrap();
494        let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 8, 4], None).unwrap();
495        assert_eq!(a.shape(), &[2, 8, 8, 4]);
496        assert_eq!(a.dtype(), Dtype::Float32);
497        assert_float_eq!(
498            a.mean(None, None).unwrap().item::<f32>(),
499            0.522_504_27,
500            abs <= 0.010_450_086
501        );
502        assert_float_eq!(
503            a.sum(None, None).unwrap().item::<f32>(),
504            267.522_2,
505            abs <= 5.350_444
506        );
507        let result = crate::nn::Conv2d::new(4, 2, (8, 8))
508            .unwrap()
509            .forward(&a)
510            .unwrap();
511        assert_eq!(result.shape(), &[2, 1, 1, 2]);
512        assert_eq!(result.dtype(), Dtype::Float32);
513        assert_float_eq!(
514            result.mean(None, None).unwrap().item::<f32>(),
515            -0.279_321_5,
516            abs <= 0.005_586_43
517        );
518        assert_float_eq!(
519            result.sum(None, None).unwrap().item::<f32>(),
520            -1.117_286,
521            abs <= 0.022_345_72
522        );
523    }
524}