mlx_rs/nn/
convolution_transpose.rs

1use crate::module::{Module, Param};
2use crate::{
3    error::Exception,
4    ops::{conv_transpose1d, conv_transpose2d, conv_transpose3d, zeros},
5    random::uniform,
6    Array,
7};
8use mlx_internal_macros::{Buildable, Builder};
9use mlx_macros::ModuleParameters;
10
11use crate::utils::{SingleOrPair, SingleOrTriple};
12
13/// Builder for the `ConvTranspose1d` module.
14#[derive(Debug, Clone, Builder)]
15#[builder(
16    root = crate,
17    build_with = build_conv_transpose_1d,
18    err = Exception,
19)]
20pub struct ConvTranspose1dBuilder {
21    /// The number of input channels.
22    pub input_channels: i32,
23
24    /// The number of output channels.
25    pub output_channels: i32,
26
27    /// The size of the convolution filters.
28    pub kernel_size: i32,
29
30    /// If `true`, add a learnable bias to the output. Default to [`ConvTranspose1d::DEFAULT_BIAS`] if not
31    /// specified.
32    #[builder(optional, default = ConvTranspose1d::DEFAULT_BIAS)]
33    pub bias: bool,
34
35    /// Padding. Default to [`ConvTranspose1d::DEFAULT_PADDING`] if not specified.
36    #[builder(optional, default = ConvTranspose1d::DEFAULT_PADDING)]
37    pub padding: i32,
38
39    /// Stride. Default to [`ConvTranspose1d::DEFAULT_STRIDE`] if not specified.
40    #[builder(optional, default = ConvTranspose1d::DEFAULT_STRIDE)]
41    pub stride: i32,
42}
43
44fn build_conv_transpose_1d(builder: ConvTranspose1dBuilder) -> Result<ConvTranspose1d, Exception> {
45    let input_channels = builder.input_channels;
46    let output_channels = builder.output_channels;
47    let kernel_size = builder.kernel_size;
48
49    let bias = builder.bias;
50    let padding = builder.padding;
51    let stride = builder.stride;
52
53    let scale = f32::sqrt(1.0f32 / (input_channels * kernel_size) as f32);
54    let weight = uniform::<_, f32>(
55        -scale,
56        scale,
57        &[output_channels, kernel_size, input_channels],
58        None,
59    )?;
60    let bias = if bias {
61        Some(zeros::<f32>(&[output_channels])?)
62    } else {
63        None
64    };
65
66    Ok(ConvTranspose1d {
67        weight: Param::new(weight),
68        bias: Param::new(bias),
69        padding,
70        stride,
71    })
72}
73
74/// Applies a 1-dimensional convolution over the multi-channel input sequence.
75///
76/// The channels are expected to be last i.e. the input shape should be `NLC` where:
77///
78/// - `N` is the batch dimension
79/// - `L` is the sequence length
80/// - `C` is the number of input channels
81#[derive(Debug, Clone, ModuleParameters, Buildable)]
82#[module(root = crate)]
83#[buildable(root = crate)]
84pub struct ConvTranspose1d {
85    /// The weight of the convolution layer.
86    #[param]
87    pub weight: Param<Array>,
88
89    /// The bias of the convolution layer.
90    #[param]
91    pub bias: Param<Option<Array>>,
92
93    /// Padding. Default to 0 if not specified.
94    pub padding: i32,
95
96    /// Stride. Default to 1 if not specified.
97    pub stride: i32,
98}
99
100impl ConvTranspose1d {
101    /// Default value for `bias` if not specified.
102    pub const DEFAULT_BIAS: bool = true;
103
104    /// Default value for `padding` if not specified.
105    pub const DEFAULT_PADDING: i32 = 0;
106
107    /// Default value for `stride` if not specified.
108    pub const DEFAULT_STRIDE: i32 = 1;
109}
110
111impl Module<&Array> for ConvTranspose1d {
112    type Error = Exception;
113    type Output = Array;
114
115    fn forward(&mut self, x: &Array) -> Result<Array, Self::Error> {
116        let mut y = conv_transpose1d(
117            x,
118            self.weight.as_ref(),
119            self.stride,
120            self.padding,
121            None,
122            None,
123        )?;
124        if let Some(bias) = &self.bias.value {
125            y += bias;
126        }
127        Ok(y)
128    }
129
130    fn training_mode(&mut self, _: bool) {}
131}
132
133/// Builder for the `ConvTranspose2d` module.
134#[derive(Debug, Clone, Builder)]
135#[builder(
136    root = crate,
137    build_with = build_conv_transpose_2d,
138    err = Exception,
139)]
140pub struct ConvTranspose2dBuilder {
141    /// The number of input channels.
142    pub input_channels: i32,
143
144    /// The number of output channels.
145    pub output_channels: i32,
146
147    /// The size of the convolution filters.
148    pub kernel_size: SingleOrPair<i32>,
149
150    /// If `true`, add a learnable bias to the output. Default to [`ConvTranspose2d::DEFAULT_BIAS`] if not
151    /// specified.
152    #[builder(optional, default = ConvTranspose2d::DEFAULT_BIAS)]
153    bias: bool,
154
155    /// Padding. Default to [`ConvTranspose2d::DEFAULT_PADDING`] if not specified.
156    #[builder(optional, default = ConvTranspose2d::DEFAULT_PADDING)]
157    padding: SingleOrPair<i32>,
158
159    /// Stride. Default to [`ConvTranspose2d::DEFAULT_STRIDE`] if not specified.
160    #[builder(optional, default = ConvTranspose2d::DEFAULT_STRIDE)]
161    stride: SingleOrPair<i32>,
162}
163
164fn build_conv_transpose_2d(builder: ConvTranspose2dBuilder) -> Result<ConvTranspose2d, Exception> {
165    let input_channels = builder.input_channels;
166    let output_channels = builder.output_channels;
167    let kernel_size: (i32, i32) = builder.kernel_size.into();
168
169    let bias = builder.bias;
170    let padding = builder.padding.into();
171    let stride = builder.stride.into();
172
173    let scale = f32::sqrt(1.0 / (input_channels * kernel_size.0 * kernel_size.1) as f32);
174    let weight = uniform::<_, f32>(
175        -scale,
176        scale,
177        &[
178            output_channels,
179            kernel_size.0,
180            kernel_size.1,
181            input_channels,
182        ],
183        None,
184    )?;
185    let bias = if bias {
186        Some(zeros::<f32>(&[output_channels])?)
187    } else {
188        None
189    };
190
191    Ok(ConvTranspose2d {
192        weight: Param::new(weight),
193        bias: Param::new(bias),
194        padding,
195        stride,
196    })
197}
198
199/// Applies a 2-dimensional convolution over the multi-channel input image.
200///
201/// The channels are expected to be last i.e. the input shape should be `NHWC` where:
202///
203/// - `N` is the batch dimension
204/// - `H` is the input image height
205/// - `W` is the input image width
206/// - `C` is the number of input channels
207#[derive(Debug, Clone, ModuleParameters, Buildable)]
208#[module(root = crate)]
209#[buildable(root = crate)]
210pub struct ConvTranspose2d {
211    /// The weight of the convolution layer.
212    #[param]
213    pub weight: Param<Array>,
214
215    /// The bias of the convolution layer.
216    #[param]
217    pub bias: Param<Option<Array>>,
218
219    /// Padding. Default to `(0, 0)` if not specified.
220    pub padding: (i32, i32),
221
222    /// Stride. Default to `(1, 1)` if not specified.
223    pub stride: (i32, i32),
224}
225
226impl ConvTranspose2d {
227    /// Default value for `bias` if not specified.
228    pub const DEFAULT_BIAS: bool = true;
229
230    /// Default value for `padding` if not specified.
231    pub const DEFAULT_PADDING: SingleOrPair<i32> = SingleOrPair::Pair(0, 0);
232
233    /// Default value for `stride` if not specified.
234    pub const DEFAULT_STRIDE: SingleOrPair<i32> = SingleOrPair::Pair(1, 1);
235}
236
237impl Module<&Array> for ConvTranspose2d {
238    type Error = Exception;
239    type Output = Array;
240
241    fn forward(&mut self, x: &Array) -> Result<Array, Self::Error> {
242        let mut y = conv_transpose2d(
243            x,
244            self.weight.as_ref(),
245            self.stride,
246            self.padding,
247            None,
248            None,
249        )?;
250        if let Some(bias) = &self.bias.value {
251            y += bias;
252        }
253        Ok(y)
254    }
255
256    fn training_mode(&mut self, _: bool) {}
257}
258
259/// Builder for the `ConvTranspose3d` module.
260#[derive(Debug, Clone, Builder)]
261#[builder(
262    root = crate,
263    build_with = build_conv_transpose_3d,
264    err = Exception,
265)]
266pub struct ConvTranspose3dBuilder {
267    /// The number of input channels.
268    pub input_channels: i32,
269
270    /// The number of output channels.
271    pub output_channels: i32,
272
273    /// The size of the convolution filters.
274    pub kernel_size: SingleOrTriple<i32>,
275
276    /// If `true`, add a learnable bias to the output. Default to [`ConvTranspose3d::DEFAULT_BIAS`] if not
277    /// specified.
278    #[builder(optional, default = ConvTranspose3d::DEFAULT_BIAS)]
279    pub bias: bool,
280
281    /// Padding. Default to [`ConvTranspose3d::DEFAULT_PADDING`] if not specified.
282    #[builder(optional, default = ConvTranspose3d::DEFAULT_PADDING)]
283    pub padding: SingleOrTriple<i32>,
284
285    /// Stride. Default to [`ConvTranspose3d::DEFAULT_STRIDE`] if not specified.
286    #[builder(optional, default = ConvTranspose3d::DEFAULT_STRIDE)]
287    pub stride: SingleOrTriple<i32>,
288}
289
290fn build_conv_transpose_3d(builder: ConvTranspose3dBuilder) -> Result<ConvTranspose3d, Exception> {
291    let input_channels = builder.input_channels;
292    let output_channels = builder.output_channels;
293    let kernel_size: (i32, i32, i32) = builder.kernel_size.into();
294
295    let bias = builder.bias;
296    let padding = builder.padding.into();
297    let stride = builder.stride.into();
298
299    let scale =
300        f32::sqrt(1.0 / (input_channels * kernel_size.0 * kernel_size.1 * kernel_size.2) as f32);
301    let weight = uniform::<_, f32>(
302        -scale,
303        scale,
304        &[
305            output_channels,
306            kernel_size.0,
307            kernel_size.1,
308            kernel_size.2,
309            input_channels,
310        ],
311        None,
312    )?;
313    let bias = if bias {
314        Some(zeros::<f32>(&[output_channels])?)
315    } else {
316        None
317    };
318
319    Ok(ConvTranspose3d {
320        weight: Param::new(weight),
321        bias: Param::new(bias),
322        padding,
323        stride,
324    })
325}
326
327/// Applies a 3-dimensional convolution over the multi-channel input image.
328///
329/// The channels are expected to be last i.e. the input shape should be `NHWC` where:
330///
331/// - `N` is the batch dimension
332/// - `H` is the input image height
333/// - `W` is the input image width
334/// - `C` is the number of input channels
335#[derive(Debug, Clone, ModuleParameters, Buildable)]
336#[module(root = crate)]
337#[buildable(root = crate)]
338pub struct ConvTranspose3d {
339    /// The weight of the convolution layer.
340    #[param]
341    pub weight: Param<Array>,
342
343    /// The bias of the convolution layer.
344    #[param]
345    pub bias: Param<Option<Array>>,
346
347    /// Padding. Default to `(0, 0, 0)` if not specified.
348    pub padding: (i32, i32, i32),
349
350    /// Stride. Default to `(1, 1, 1)` if not specified.
351    pub stride: (i32, i32, i32),
352}
353
354impl ConvTranspose3d {
355    /// Default value for `bias` if not specified.
356    pub const DEFAULT_BIAS: bool = true;
357
358    /// Default value for `padding` if not specified.
359    pub const DEFAULT_PADDING: SingleOrTriple<i32> = SingleOrTriple::Triple(0, 0, 0);
360
361    /// Default value for `stride` if not specified.
362    pub const DEFAULT_STRIDE: SingleOrTriple<i32> = SingleOrTriple::Triple(1, 1, 1);
363}
364
365impl Module<&Array> for ConvTranspose3d {
366    type Error = Exception;
367    type Output = Array;
368
369    fn forward(&mut self, x: &Array) -> Result<Array, Self::Error> {
370        let mut y = conv_transpose3d(
371            x,
372            self.weight.as_ref(),
373            self.stride,
374            self.padding,
375            None,
376            None,
377        )?;
378        if let Some(bias) = &self.bias.value {
379            y += bias;
380        }
381        Ok(y)
382    }
383
384    fn training_mode(&mut self, _: bool) {}
385}