1use crate::module::{Module, Param};
2use crate::{
3 error::Exception,
4 ops::{conv_transpose1d, conv_transpose2d, conv_transpose3d, zeros},
5 random::uniform,
6 Array,
7};
8use mlx_internal_macros::{Buildable, Builder};
9use mlx_macros::ModuleParameters;
10
11use crate::utils::{SingleOrPair, SingleOrTriple};
12
13#[derive(Debug, Clone, Builder)]
15#[builder(
16 root = crate,
17 build_with = build_conv_transpose_1d,
18 err = Exception,
19)]
20pub struct ConvTranspose1dBuilder {
21 pub input_channels: i32,
23
24 pub output_channels: i32,
26
27 pub kernel_size: i32,
29
30 #[builder(optional, default = ConvTranspose1d::DEFAULT_BIAS)]
33 pub bias: bool,
34
35 #[builder(optional, default = ConvTranspose1d::DEFAULT_PADDING)]
37 pub padding: i32,
38
39 #[builder(optional, default = ConvTranspose1d::DEFAULT_OUTPUT_PADDING)]
41 pub output_padding: i32,
42
43 #[builder(optional, default = ConvTranspose1d::DEFAULT_STRIDE)]
45 pub stride: i32,
46}
47
48fn build_conv_transpose_1d(builder: ConvTranspose1dBuilder) -> Result<ConvTranspose1d, Exception> {
49 let input_channels = builder.input_channels;
50 let output_channels = builder.output_channels;
51 let kernel_size = builder.kernel_size;
52
53 let bias = builder.bias;
54 let padding = builder.padding;
55 let output_padding = builder.output_padding;
56 let stride = builder.stride;
57
58 let scale = f32::sqrt(1.0f32 / (input_channels * kernel_size) as f32);
59 let weight = uniform::<_, f32>(
60 -scale,
61 scale,
62 &[output_channels, kernel_size, input_channels],
63 None,
64 )?;
65 let bias = if bias {
66 Some(zeros::<f32>(&[output_channels])?)
67 } else {
68 None
69 };
70
71 Ok(ConvTranspose1d {
72 weight: Param::new(weight),
73 bias: Param::new(bias),
74 padding,
75 output_padding,
76 stride,
77 })
78}
79
80#[derive(Debug, Clone, ModuleParameters, Buildable)]
88#[module(root = crate)]
89#[buildable(root = crate)]
90pub struct ConvTranspose1d {
91 #[param]
93 pub weight: Param<Array>,
94
95 #[param]
97 pub bias: Param<Option<Array>>,
98
99 pub padding: i32,
101
102 pub output_padding: i32,
104
105 pub stride: i32,
107}
108
109impl ConvTranspose1d {
110 pub const DEFAULT_BIAS: bool = true;
112
113 pub const DEFAULT_PADDING: i32 = 0;
115
116 pub const DEFAULT_OUTPUT_PADDING: i32 = 0;
118
119 pub const DEFAULT_STRIDE: i32 = 1;
121}
122
123impl Module<&Array> for ConvTranspose1d {
124 type Error = Exception;
125 type Output = Array;
126
127 fn forward(&mut self, x: &Array) -> Result<Array, Self::Error> {
128 let mut y = conv_transpose1d(
129 x,
130 self.weight.as_ref(),
131 self.stride,
132 self.padding,
133 None,
134 self.output_padding,
135 None,
136 )?;
137 if let Some(bias) = &self.bias.value {
138 y += bias;
139 }
140 Ok(y)
141 }
142
143 fn training_mode(&mut self, _: bool) {}
144}
145
146#[derive(Debug, Clone, Builder)]
148#[builder(
149 root = crate,
150 build_with = build_conv_transpose_2d,
151 err = Exception,
152)]
153pub struct ConvTranspose2dBuilder {
154 pub input_channels: i32,
156
157 pub output_channels: i32,
159
160 pub kernel_size: SingleOrPair<i32>,
162
163 #[builder(optional, default = ConvTranspose2d::DEFAULT_BIAS)]
166 bias: bool,
167
168 #[builder(optional, default = ConvTranspose2d::DEFAULT_PADDING)]
170 padding: SingleOrPair<i32>,
171
172 #[builder(optional, default = ConvTranspose2d::DEFAULT_OUTPUT_PADDING)]
174 output_padding: SingleOrPair<i32>,
175
176 #[builder(optional, default = ConvTranspose2d::DEFAULT_STRIDE)]
178 stride: SingleOrPair<i32>,
179}
180
181fn build_conv_transpose_2d(builder: ConvTranspose2dBuilder) -> Result<ConvTranspose2d, Exception> {
182 let input_channels = builder.input_channels;
183 let output_channels = builder.output_channels;
184 let kernel_size: (i32, i32) = builder.kernel_size.into();
185
186 let bias = builder.bias;
187 let padding = builder.padding.into();
188 let output_padding = builder.output_padding.into();
189 let stride = builder.stride.into();
190
191 let scale = f32::sqrt(1.0 / (input_channels * kernel_size.0 * kernel_size.1) as f32);
192 let weight = uniform::<_, f32>(
193 -scale,
194 scale,
195 &[
196 output_channels,
197 kernel_size.0,
198 kernel_size.1,
199 input_channels,
200 ],
201 None,
202 )?;
203 let bias = if bias {
204 Some(zeros::<f32>(&[output_channels])?)
205 } else {
206 None
207 };
208
209 Ok(ConvTranspose2d {
210 weight: Param::new(weight),
211 bias: Param::new(bias),
212 padding,
213 output_padding,
214 stride,
215 })
216}
217
218#[derive(Debug, Clone, ModuleParameters, Buildable)]
227#[module(root = crate)]
228#[buildable(root = crate)]
229pub struct ConvTranspose2d {
230 #[param]
232 pub weight: Param<Array>,
233
234 #[param]
236 pub bias: Param<Option<Array>>,
237
238 pub padding: (i32, i32),
240
241 pub output_padding: (i32, i32),
243
244 pub stride: (i32, i32),
246}
247
248impl ConvTranspose2d {
249 pub const DEFAULT_BIAS: bool = true;
251
252 pub const DEFAULT_PADDING: SingleOrPair<i32> = SingleOrPair::Pair(0, 0);
254
255 pub const DEFAULT_OUTPUT_PADDING: SingleOrPair<i32> = SingleOrPair::Pair(0, 0);
257
258 pub const DEFAULT_STRIDE: SingleOrPair<i32> = SingleOrPair::Pair(1, 1);
260}
261
262impl Module<&Array> for ConvTranspose2d {
263 type Error = Exception;
264 type Output = Array;
265
266 fn forward(&mut self, x: &Array) -> Result<Array, Self::Error> {
267 let mut y = conv_transpose2d(
268 x,
269 self.weight.as_ref(),
270 self.stride,
271 self.padding,
272 None,
273 self.output_padding,
274 None,
275 )?;
276 if let Some(bias) = &self.bias.value {
277 y += bias;
278 }
279 Ok(y)
280 }
281
282 fn training_mode(&mut self, _: bool) {}
283}
284
285#[derive(Debug, Clone, Builder)]
287#[builder(
288 root = crate,
289 build_with = build_conv_transpose_3d,
290 err = Exception,
291)]
292pub struct ConvTranspose3dBuilder {
293 pub input_channels: i32,
295
296 pub output_channels: i32,
298
299 pub kernel_size: SingleOrTriple<i32>,
301
302 #[builder(optional, default = ConvTranspose3d::DEFAULT_BIAS)]
305 pub bias: bool,
306
307 #[builder(optional, default = ConvTranspose3d::DEFAULT_PADDING)]
309 pub padding: SingleOrTriple<i32>,
310
311 #[builder(optional, default = ConvTranspose3d::DEFAULT_OUTPUT_PADDING)]
313 pub output_padding: SingleOrTriple<i32>,
314
315 #[builder(optional, default = ConvTranspose3d::DEFAULT_STRIDE)]
317 pub stride: SingleOrTriple<i32>,
318}
319
320fn build_conv_transpose_3d(builder: ConvTranspose3dBuilder) -> Result<ConvTranspose3d, Exception> {
321 let input_channels = builder.input_channels;
322 let output_channels = builder.output_channels;
323 let kernel_size: (i32, i32, i32) = builder.kernel_size.into();
324
325 let bias = builder.bias;
326 let padding = builder.padding.into();
327 let output_padding = builder.output_padding.into();
328 let stride = builder.stride.into();
329
330 let scale =
331 f32::sqrt(1.0 / (input_channels * kernel_size.0 * kernel_size.1 * kernel_size.2) as f32);
332 let weight = uniform::<_, f32>(
333 -scale,
334 scale,
335 &[
336 output_channels,
337 kernel_size.0,
338 kernel_size.1,
339 kernel_size.2,
340 input_channels,
341 ],
342 None,
343 )?;
344 let bias = if bias {
345 Some(zeros::<f32>(&[output_channels])?)
346 } else {
347 None
348 };
349
350 Ok(ConvTranspose3d {
351 weight: Param::new(weight),
352 bias: Param::new(bias),
353 padding,
354 output_padding,
355 stride,
356 })
357}
358
359#[derive(Debug, Clone, ModuleParameters, Buildable)]
368#[module(root = crate)]
369#[buildable(root = crate)]
370pub struct ConvTranspose3d {
371 #[param]
373 pub weight: Param<Array>,
374
375 #[param]
377 pub bias: Param<Option<Array>>,
378
379 pub padding: (i32, i32, i32),
381
382 pub output_padding: (i32, i32, i32),
384
385 pub stride: (i32, i32, i32),
387}
388
389impl ConvTranspose3d {
390 pub const DEFAULT_BIAS: bool = true;
392
393 pub const DEFAULT_PADDING: SingleOrTriple<i32> = SingleOrTriple::Triple(0, 0, 0);
395
396 pub const DEFAULT_OUTPUT_PADDING: SingleOrTriple<i32> = SingleOrTriple::Triple(0, 0, 0);
398
399 pub const DEFAULT_STRIDE: SingleOrTriple<i32> = SingleOrTriple::Triple(1, 1, 1);
401}
402
403impl Module<&Array> for ConvTranspose3d {
404 type Error = Exception;
405 type Output = Array;
406
407 fn forward(&mut self, x: &Array) -> Result<Array, Self::Error> {
408 let mut y = conv_transpose3d(
409 x,
410 self.weight.as_ref(),
411 self.stride,
412 self.padding,
413 None,
414 self.output_padding,
415 None,
416 )?;
417 if let Some(bias) = &self.bias.value {
418 y += bias;
419 }
420 Ok(y)
421 }
422
423 fn training_mode(&mut self, _: bool) {}
424}