mlx_rs/nn/
convolution_transpose.rs

1use crate::module::{Module, Param};
2use crate::{
3    error::Exception,
4    ops::{conv_transpose1d, conv_transpose2d, conv_transpose3d, zeros},
5    random::uniform,
6    Array,
7};
8use mlx_internal_macros::{Buildable, Builder};
9use mlx_macros::ModuleParameters;
10
11use crate::utils::{SingleOrPair, SingleOrTriple};
12
13/// Builder for the `ConvTranspose1d` module.
14#[derive(Debug, Clone, Builder)]
15#[builder(
16    root = crate,
17    build_with = build_conv_transpose_1d,
18    err = Exception,
19)]
20pub struct ConvTranspose1dBuilder {
21    /// The number of input channels.
22    pub input_channels: i32,
23
24    /// The number of output channels.
25    pub output_channels: i32,
26
27    /// The size of the convolution filters.
28    pub kernel_size: i32,
29
30    /// If `true`, add a learnable bias to the output. Default to [`ConvTranspose1d::DEFAULT_BIAS`] if not
31    /// specified.
32    #[builder(optional, default = ConvTranspose1d::DEFAULT_BIAS)]
33    pub bias: bool,
34
35    /// Padding. Default to [`ConvTranspose1d::DEFAULT_PADDING`] if not specified.
36    #[builder(optional, default = ConvTranspose1d::DEFAULT_PADDING)]
37    pub padding: i32,
38
39    /// Output padding. Default to [`ConvTranspose1d::DEFAULT_OUTPUT_PADDING`] if not specified.
40    #[builder(optional, default = ConvTranspose1d::DEFAULT_OUTPUT_PADDING)]
41    pub output_padding: i32,
42
43    /// Stride. Default to [`ConvTranspose1d::DEFAULT_STRIDE`] if not specified.
44    #[builder(optional, default = ConvTranspose1d::DEFAULT_STRIDE)]
45    pub stride: i32,
46}
47
48fn build_conv_transpose_1d(builder: ConvTranspose1dBuilder) -> Result<ConvTranspose1d, Exception> {
49    let input_channels = builder.input_channels;
50    let output_channels = builder.output_channels;
51    let kernel_size = builder.kernel_size;
52
53    let bias = builder.bias;
54    let padding = builder.padding;
55    let output_padding = builder.output_padding;
56    let stride = builder.stride;
57
58    let scale = f32::sqrt(1.0f32 / (input_channels * kernel_size) as f32);
59    let weight = uniform::<_, f32>(
60        -scale,
61        scale,
62        &[output_channels, kernel_size, input_channels],
63        None,
64    )?;
65    let bias = if bias {
66        Some(zeros::<f32>(&[output_channels])?)
67    } else {
68        None
69    };
70
71    Ok(ConvTranspose1d {
72        weight: Param::new(weight),
73        bias: Param::new(bias),
74        padding,
75        output_padding,
76        stride,
77    })
78}
79
80/// Applies a 1-dimensional convolution over the multi-channel input sequence.
81///
82/// The channels are expected to be last i.e. the input shape should be `NLC` where:
83///
84/// - `N` is the batch dimension
85/// - `L` is the sequence length
86/// - `C` is the number of input channels
87#[derive(Debug, Clone, ModuleParameters, Buildable)]
88#[module(root = crate)]
89#[buildable(root = crate)]
90pub struct ConvTranspose1d {
91    /// The weight of the convolution layer.
92    #[param]
93    pub weight: Param<Array>,
94
95    /// The bias of the convolution layer.
96    #[param]
97    pub bias: Param<Option<Array>>,
98
99    /// Padding. Default to 0 if not specified.
100    pub padding: i32,
101
102    /// Output padding. Default to 0 if not specified.
103    pub output_padding: i32,
104
105    /// Stride. Default to 1 if not specified.
106    pub stride: i32,
107}
108
109impl ConvTranspose1d {
110    /// Default value for `bias` if not specified.
111    pub const DEFAULT_BIAS: bool = true;
112
113    /// Default value for `padding` if not specified.
114    pub const DEFAULT_PADDING: i32 = 0;
115
116    /// Default value for `output_padding` if not specified.
117    pub const DEFAULT_OUTPUT_PADDING: i32 = 0;
118
119    /// Default value for `stride` if not specified.
120    pub const DEFAULT_STRIDE: i32 = 1;
121}
122
123impl Module<&Array> for ConvTranspose1d {
124    type Error = Exception;
125    type Output = Array;
126
127    fn forward(&mut self, x: &Array) -> Result<Array, Self::Error> {
128        let mut y = conv_transpose1d(
129            x,
130            self.weight.as_ref(),
131            self.stride,
132            self.padding,
133            None,
134            self.output_padding,
135            None,
136        )?;
137        if let Some(bias) = &self.bias.value {
138            y += bias;
139        }
140        Ok(y)
141    }
142
143    fn training_mode(&mut self, _: bool) {}
144}
145
146/// Builder for the `ConvTranspose2d` module.
147#[derive(Debug, Clone, Builder)]
148#[builder(
149    root = crate,
150    build_with = build_conv_transpose_2d,
151    err = Exception,
152)]
153pub struct ConvTranspose2dBuilder {
154    /// The number of input channels.
155    pub input_channels: i32,
156
157    /// The number of output channels.
158    pub output_channels: i32,
159
160    /// The size of the convolution filters.
161    pub kernel_size: SingleOrPair<i32>,
162
163    /// If `true`, add a learnable bias to the output. Default to [`ConvTranspose2d::DEFAULT_BIAS`] if not
164    /// specified.
165    #[builder(optional, default = ConvTranspose2d::DEFAULT_BIAS)]
166    bias: bool,
167
168    /// Padding. Default to [`ConvTranspose2d::DEFAULT_PADDING`] if not specified.
169    #[builder(optional, default = ConvTranspose2d::DEFAULT_PADDING)]
170    padding: SingleOrPair<i32>,
171
172    /// Output padding. Default to [`ConvTranspose2d::DEFAULT_OUTPUT_PADDING`] if not specified.
173    #[builder(optional, default = ConvTranspose2d::DEFAULT_OUTPUT_PADDING)]
174    output_padding: SingleOrPair<i32>,
175
176    /// Stride. Default to [`ConvTranspose2d::DEFAULT_STRIDE`] if not specified.
177    #[builder(optional, default = ConvTranspose2d::DEFAULT_STRIDE)]
178    stride: SingleOrPair<i32>,
179}
180
181fn build_conv_transpose_2d(builder: ConvTranspose2dBuilder) -> Result<ConvTranspose2d, Exception> {
182    let input_channels = builder.input_channels;
183    let output_channels = builder.output_channels;
184    let kernel_size: (i32, i32) = builder.kernel_size.into();
185
186    let bias = builder.bias;
187    let padding = builder.padding.into();
188    let output_padding = builder.output_padding.into();
189    let stride = builder.stride.into();
190
191    let scale = f32::sqrt(1.0 / (input_channels * kernel_size.0 * kernel_size.1) as f32);
192    let weight = uniform::<_, f32>(
193        -scale,
194        scale,
195        &[
196            output_channels,
197            kernel_size.0,
198            kernel_size.1,
199            input_channels,
200        ],
201        None,
202    )?;
203    let bias = if bias {
204        Some(zeros::<f32>(&[output_channels])?)
205    } else {
206        None
207    };
208
209    Ok(ConvTranspose2d {
210        weight: Param::new(weight),
211        bias: Param::new(bias),
212        padding,
213        output_padding,
214        stride,
215    })
216}
217
218/// Applies a 2-dimensional convolution over the multi-channel input image.
219///
220/// The channels are expected to be last i.e. the input shape should be `NHWC` where:
221///
222/// - `N` is the batch dimension
223/// - `H` is the input image height
224/// - `W` is the input image width
225/// - `C` is the number of input channels
226#[derive(Debug, Clone, ModuleParameters, Buildable)]
227#[module(root = crate)]
228#[buildable(root = crate)]
229pub struct ConvTranspose2d {
230    /// The weight of the convolution layer.
231    #[param]
232    pub weight: Param<Array>,
233
234    /// The bias of the convolution layer.
235    #[param]
236    pub bias: Param<Option<Array>>,
237
238    /// Padding. Default to `(0, 0)` if not specified.
239    pub padding: (i32, i32),
240
241    /// Output padding. Default to `(0, 0)` if not specified.
242    pub output_padding: (i32, i32),
243
244    /// Stride. Default to `(1, 1)` if not specified.
245    pub stride: (i32, i32),
246}
247
248impl ConvTranspose2d {
249    /// Default value for `bias` if not specified.
250    pub const DEFAULT_BIAS: bool = true;
251
252    /// Default value for `padding` if not specified.
253    pub const DEFAULT_PADDING: SingleOrPair<i32> = SingleOrPair::Pair(0, 0);
254
255    /// Default value for `output_padding` if not specified.
256    pub const DEFAULT_OUTPUT_PADDING: SingleOrPair<i32> = SingleOrPair::Pair(0, 0);
257
258    /// Default value for `stride` if not specified.
259    pub const DEFAULT_STRIDE: SingleOrPair<i32> = SingleOrPair::Pair(1, 1);
260}
261
262impl Module<&Array> for ConvTranspose2d {
263    type Error = Exception;
264    type Output = Array;
265
266    fn forward(&mut self, x: &Array) -> Result<Array, Self::Error> {
267        let mut y = conv_transpose2d(
268            x,
269            self.weight.as_ref(),
270            self.stride,
271            self.padding,
272            None,
273            self.output_padding,
274            None,
275        )?;
276        if let Some(bias) = &self.bias.value {
277            y += bias;
278        }
279        Ok(y)
280    }
281
282    fn training_mode(&mut self, _: bool) {}
283}
284
285/// Builder for the `ConvTranspose3d` module.
286#[derive(Debug, Clone, Builder)]
287#[builder(
288    root = crate,
289    build_with = build_conv_transpose_3d,
290    err = Exception,
291)]
292pub struct ConvTranspose3dBuilder {
293    /// The number of input channels.
294    pub input_channels: i32,
295
296    /// The number of output channels.
297    pub output_channels: i32,
298
299    /// The size of the convolution filters.
300    pub kernel_size: SingleOrTriple<i32>,
301
302    /// If `true`, add a learnable bias to the output. Default to [`ConvTranspose3d::DEFAULT_BIAS`] if not
303    /// specified.
304    #[builder(optional, default = ConvTranspose3d::DEFAULT_BIAS)]
305    pub bias: bool,
306
307    /// Padding. Default to [`ConvTranspose3d::DEFAULT_PADDING`] if not specified.
308    #[builder(optional, default = ConvTranspose3d::DEFAULT_PADDING)]
309    pub padding: SingleOrTriple<i32>,
310
311    /// Output padding. Default to [`ConvTranspose3d::DEFAULT_OUTPUT_PADDING`] if not specified.
312    #[builder(optional, default = ConvTranspose3d::DEFAULT_OUTPUT_PADDING)]
313    pub output_padding: SingleOrTriple<i32>,
314
315    /// Stride. Default to [`ConvTranspose3d::DEFAULT_STRIDE`] if not specified.
316    #[builder(optional, default = ConvTranspose3d::DEFAULT_STRIDE)]
317    pub stride: SingleOrTriple<i32>,
318}
319
320fn build_conv_transpose_3d(builder: ConvTranspose3dBuilder) -> Result<ConvTranspose3d, Exception> {
321    let input_channels = builder.input_channels;
322    let output_channels = builder.output_channels;
323    let kernel_size: (i32, i32, i32) = builder.kernel_size.into();
324
325    let bias = builder.bias;
326    let padding = builder.padding.into();
327    let output_padding = builder.output_padding.into();
328    let stride = builder.stride.into();
329
330    let scale =
331        f32::sqrt(1.0 / (input_channels * kernel_size.0 * kernel_size.1 * kernel_size.2) as f32);
332    let weight = uniform::<_, f32>(
333        -scale,
334        scale,
335        &[
336            output_channels,
337            kernel_size.0,
338            kernel_size.1,
339            kernel_size.2,
340            input_channels,
341        ],
342        None,
343    )?;
344    let bias = if bias {
345        Some(zeros::<f32>(&[output_channels])?)
346    } else {
347        None
348    };
349
350    Ok(ConvTranspose3d {
351        weight: Param::new(weight),
352        bias: Param::new(bias),
353        padding,
354        output_padding,
355        stride,
356    })
357}
358
359/// Applies a 3-dimensional convolution over the multi-channel input image.
360///
361/// The channels are expected to be last i.e. the input shape should be `NHWC` where:
362///
363/// - `N` is the batch dimension
364/// - `H` is the input image height
365/// - `W` is the input image width
366/// - `C` is the number of input channels
367#[derive(Debug, Clone, ModuleParameters, Buildable)]
368#[module(root = crate)]
369#[buildable(root = crate)]
370pub struct ConvTranspose3d {
371    /// The weight of the convolution layer.
372    #[param]
373    pub weight: Param<Array>,
374
375    /// The bias of the convolution layer.
376    #[param]
377    pub bias: Param<Option<Array>>,
378
379    /// Padding. Default to `(0, 0, 0)` if not specified.
380    pub padding: (i32, i32, i32),
381
382    /// Output padding. Default to `(0, 0, 0)` if not specified.
383    pub output_padding: (i32, i32, i32),
384
385    /// Stride. Default to `(1, 1, 1)` if not specified.
386    pub stride: (i32, i32, i32),
387}
388
389impl ConvTranspose3d {
390    /// Default value for `bias` if not specified.
391    pub const DEFAULT_BIAS: bool = true;
392
393    /// Default value for `padding` if not specified.
394    pub const DEFAULT_PADDING: SingleOrTriple<i32> = SingleOrTriple::Triple(0, 0, 0);
395
396    /// Default value for `output_padding` if not specified.
397    pub const DEFAULT_OUTPUT_PADDING: SingleOrTriple<i32> = SingleOrTriple::Triple(0, 0, 0);
398
399    /// Default value for `stride` if not specified.
400    pub const DEFAULT_STRIDE: SingleOrTriple<i32> = SingleOrTriple::Triple(1, 1, 1);
401}
402
403impl Module<&Array> for ConvTranspose3d {
404    type Error = Exception;
405    type Output = Array;
406
407    fn forward(&mut self, x: &Array) -> Result<Array, Self::Error> {
408        let mut y = conv_transpose3d(
409            x,
410            self.weight.as_ref(),
411            self.stride,
412            self.padding,
413            None,
414            self.output_padding,
415            None,
416        )?;
417        if let Some(bias) = &self.bias.value {
418            y += bias;
419        }
420        Ok(y)
421    }
422
423    fn training_mode(&mut self, _: bool) {}
424}