1use crate::module::{Module, Param};
2use crate::{
3 error::Exception,
4 ops::{conv_transpose1d, conv_transpose2d, conv_transpose3d, zeros},
5 random::uniform,
6 Array,
7};
8use mlx_internal_macros::{Buildable, Builder};
9use mlx_macros::ModuleParameters;
10
11use crate::utils::{SingleOrPair, SingleOrTriple};
12
13#[derive(Debug, Clone, Builder)]
15#[builder(
16 root = crate,
17 build_with = build_conv_transpose_1d,
18 err = Exception,
19)]
20pub struct ConvTranspose1dBuilder {
21 pub input_channels: i32,
23
24 pub output_channels: i32,
26
27 pub kernel_size: i32,
29
30 #[builder(optional, default = ConvTranspose1d::DEFAULT_BIAS)]
33 pub bias: bool,
34
35 #[builder(optional, default = ConvTranspose1d::DEFAULT_PADDING)]
37 pub padding: i32,
38
39 #[builder(optional, default = ConvTranspose1d::DEFAULT_STRIDE)]
41 pub stride: i32,
42}
43
44fn build_conv_transpose_1d(builder: ConvTranspose1dBuilder) -> Result<ConvTranspose1d, Exception> {
45 let input_channels = builder.input_channels;
46 let output_channels = builder.output_channels;
47 let kernel_size = builder.kernel_size;
48
49 let bias = builder.bias;
50 let padding = builder.padding;
51 let stride = builder.stride;
52
53 let scale = f32::sqrt(1.0f32 / (input_channels * kernel_size) as f32);
54 let weight = uniform::<_, f32>(
55 -scale,
56 scale,
57 &[output_channels, kernel_size, input_channels],
58 None,
59 )?;
60 let bias = if bias {
61 Some(zeros::<f32>(&[output_channels])?)
62 } else {
63 None
64 };
65
66 Ok(ConvTranspose1d {
67 weight: Param::new(weight),
68 bias: Param::new(bias),
69 padding,
70 stride,
71 })
72}
73
74#[derive(Debug, Clone, ModuleParameters, Buildable)]
82#[module(root = crate)]
83#[buildable(root = crate)]
84pub struct ConvTranspose1d {
85 #[param]
87 pub weight: Param<Array>,
88
89 #[param]
91 pub bias: Param<Option<Array>>,
92
93 pub padding: i32,
95
96 pub stride: i32,
98}
99
100impl ConvTranspose1d {
101 pub const DEFAULT_BIAS: bool = true;
103
104 pub const DEFAULT_PADDING: i32 = 0;
106
107 pub const DEFAULT_STRIDE: i32 = 1;
109}
110
111impl Module<&Array> for ConvTranspose1d {
112 type Error = Exception;
113 type Output = Array;
114
115 fn forward(&mut self, x: &Array) -> Result<Array, Self::Error> {
116 let mut y = conv_transpose1d(
117 x,
118 self.weight.as_ref(),
119 self.stride,
120 self.padding,
121 None,
122 None,
123 )?;
124 if let Some(bias) = &self.bias.value {
125 y += bias;
126 }
127 Ok(y)
128 }
129
130 fn training_mode(&mut self, _: bool) {}
131}
132
133#[derive(Debug, Clone, Builder)]
135#[builder(
136 root = crate,
137 build_with = build_conv_transpose_2d,
138 err = Exception,
139)]
140pub struct ConvTranspose2dBuilder {
141 pub input_channels: i32,
143
144 pub output_channels: i32,
146
147 pub kernel_size: SingleOrPair<i32>,
149
150 #[builder(optional, default = ConvTranspose2d::DEFAULT_BIAS)]
153 bias: bool,
154
155 #[builder(optional, default = ConvTranspose2d::DEFAULT_PADDING)]
157 padding: SingleOrPair<i32>,
158
159 #[builder(optional, default = ConvTranspose2d::DEFAULT_STRIDE)]
161 stride: SingleOrPair<i32>,
162}
163
164fn build_conv_transpose_2d(builder: ConvTranspose2dBuilder) -> Result<ConvTranspose2d, Exception> {
165 let input_channels = builder.input_channels;
166 let output_channels = builder.output_channels;
167 let kernel_size: (i32, i32) = builder.kernel_size.into();
168
169 let bias = builder.bias;
170 let padding = builder.padding.into();
171 let stride = builder.stride.into();
172
173 let scale = f32::sqrt(1.0 / (input_channels * kernel_size.0 * kernel_size.1) as f32);
174 let weight = uniform::<_, f32>(
175 -scale,
176 scale,
177 &[
178 output_channels,
179 kernel_size.0,
180 kernel_size.1,
181 input_channels,
182 ],
183 None,
184 )?;
185 let bias = if bias {
186 Some(zeros::<f32>(&[output_channels])?)
187 } else {
188 None
189 };
190
191 Ok(ConvTranspose2d {
192 weight: Param::new(weight),
193 bias: Param::new(bias),
194 padding,
195 stride,
196 })
197}
198
199#[derive(Debug, Clone, ModuleParameters, Buildable)]
208#[module(root = crate)]
209#[buildable(root = crate)]
210pub struct ConvTranspose2d {
211 #[param]
213 pub weight: Param<Array>,
214
215 #[param]
217 pub bias: Param<Option<Array>>,
218
219 pub padding: (i32, i32),
221
222 pub stride: (i32, i32),
224}
225
226impl ConvTranspose2d {
227 pub const DEFAULT_BIAS: bool = true;
229
230 pub const DEFAULT_PADDING: SingleOrPair<i32> = SingleOrPair::Pair(0, 0);
232
233 pub const DEFAULT_STRIDE: SingleOrPair<i32> = SingleOrPair::Pair(1, 1);
235}
236
237impl Module<&Array> for ConvTranspose2d {
238 type Error = Exception;
239 type Output = Array;
240
241 fn forward(&mut self, x: &Array) -> Result<Array, Self::Error> {
242 let mut y = conv_transpose2d(
243 x,
244 self.weight.as_ref(),
245 self.stride,
246 self.padding,
247 None,
248 None,
249 )?;
250 if let Some(bias) = &self.bias.value {
251 y += bias;
252 }
253 Ok(y)
254 }
255
256 fn training_mode(&mut self, _: bool) {}
257}
258
259#[derive(Debug, Clone, Builder)]
261#[builder(
262 root = crate,
263 build_with = build_conv_transpose_3d,
264 err = Exception,
265)]
266pub struct ConvTranspose3dBuilder {
267 pub input_channels: i32,
269
270 pub output_channels: i32,
272
273 pub kernel_size: SingleOrTriple<i32>,
275
276 #[builder(optional, default = ConvTranspose3d::DEFAULT_BIAS)]
279 pub bias: bool,
280
281 #[builder(optional, default = ConvTranspose3d::DEFAULT_PADDING)]
283 pub padding: SingleOrTriple<i32>,
284
285 #[builder(optional, default = ConvTranspose3d::DEFAULT_STRIDE)]
287 pub stride: SingleOrTriple<i32>,
288}
289
290fn build_conv_transpose_3d(builder: ConvTranspose3dBuilder) -> Result<ConvTranspose3d, Exception> {
291 let input_channels = builder.input_channels;
292 let output_channels = builder.output_channels;
293 let kernel_size: (i32, i32, i32) = builder.kernel_size.into();
294
295 let bias = builder.bias;
296 let padding = builder.padding.into();
297 let stride = builder.stride.into();
298
299 let scale =
300 f32::sqrt(1.0 / (input_channels * kernel_size.0 * kernel_size.1 * kernel_size.2) as f32);
301 let weight = uniform::<_, f32>(
302 -scale,
303 scale,
304 &[
305 output_channels,
306 kernel_size.0,
307 kernel_size.1,
308 kernel_size.2,
309 input_channels,
310 ],
311 None,
312 )?;
313 let bias = if bias {
314 Some(zeros::<f32>(&[output_channels])?)
315 } else {
316 None
317 };
318
319 Ok(ConvTranspose3d {
320 weight: Param::new(weight),
321 bias: Param::new(bias),
322 padding,
323 stride,
324 })
325}
326
327#[derive(Debug, Clone, ModuleParameters, Buildable)]
336#[module(root = crate)]
337#[buildable(root = crate)]
338pub struct ConvTranspose3d {
339 #[param]
341 pub weight: Param<Array>,
342
343 #[param]
345 pub bias: Param<Option<Array>>,
346
347 pub padding: (i32, i32, i32),
349
350 pub stride: (i32, i32, i32),
352}
353
354impl ConvTranspose3d {
355 pub const DEFAULT_BIAS: bool = true;
357
358 pub const DEFAULT_PADDING: SingleOrTriple<i32> = SingleOrTriple::Triple(0, 0, 0);
360
361 pub const DEFAULT_STRIDE: SingleOrTriple<i32> = SingleOrTriple::Triple(1, 1, 1);
363}
364
365impl Module<&Array> for ConvTranspose3d {
366 type Error = Exception;
367 type Output = Array;
368
369 fn forward(&mut self, x: &Array) -> Result<Array, Self::Error> {
370 let mut y = conv_transpose3d(
371 x,
372 self.weight.as_ref(),
373 self.stride,
374 self.padding,
375 None,
376 None,
377 )?;
378 if let Some(bias) = &self.bias.value {
379 y += bias;
380 }
381 Ok(y)
382 }
383
384 fn training_mode(&mut self, _: bool) {}
385}