1use crate::module::{Module, Param};
2use crate::{
3 error::Exception,
4 ops::{conv1d, conv2d, zeros},
5 random::uniform,
6 Array,
7};
8use mlx_internal_macros::{Buildable, Builder};
9use mlx_macros::ModuleParameters;
10
11use crate::utils::{SingleOrPair, SingleOrTriple};
12
13#[derive(Debug, Clone, Builder)]
15#[builder(
16 root = crate,
17 build_with = build_conv1d,
18 err = Exception,
19)]
20pub struct Conv1dBuilder {
21 pub input_channels: i32,
23
24 pub output_channels: i32,
26
27 pub kernel_size: i32,
29
30 #[builder(optional, default = Conv1d::DEFAULT_BIAS)]
33 pub bias: bool,
34
35 #[builder(optional, default = Conv1d::DEFAULT_STRIDE)]
37 pub stride: i32,
38
39 #[builder(optional, default = Conv1d::DEFAULT_PADDING)]
41 pub padding: i32,
42
43 #[builder(optional, default = Conv1d::DEFAULT_DILATION)]
45 pub dilation: i32,
46
47 #[builder(optional, default = Conv1d::DEFAULT_GROUPS)]
49 pub groups: i32,
50}
51
52fn build_conv1d(builder: Conv1dBuilder) -> Result<Conv1d, Exception> {
53 let input_channels = builder.input_channels;
54 let output_channels = builder.output_channels;
55 let kernel_size = builder.kernel_size;
56 let with_bias = builder.bias;
57
58 let scale = f32::sqrt(1.0f32 / (input_channels * kernel_size) as f32);
59 let weight = uniform::<_, f32>(
60 -scale,
61 scale,
62 &[output_channels, kernel_size, input_channels],
63 None,
64 )?;
65 let bias = if with_bias {
66 Some(zeros::<f32>(&[output_channels])?)
67 } else {
68 None
69 };
70
71 Ok(Conv1d {
72 weight: Param::new(weight),
73 bias: Param::new(bias),
74 stride: builder.stride,
75 padding: builder.padding,
76 dilation: builder.dilation,
77 groups: builder.groups,
78 })
79}
80
81#[derive(Debug, Clone, ModuleParameters, Buildable)]
89#[module(root = crate)]
90#[buildable(root = crate)]
91pub struct Conv1d {
92 #[param]
94 pub weight: Param<Array>,
95
96 #[param]
98 pub bias: Param<Option<Array>>,
99
100 pub stride: i32,
102
103 pub padding: i32,
105
106 pub dilation: i32,
108
109 pub groups: i32,
111}
112
113impl Conv1d {
114 pub const DEFAULT_BIAS: bool = true;
116
117 pub const DEFAULT_STRIDE: i32 = 1;
119
120 pub const DEFAULT_PADDING: i32 = 0;
122
123 pub const DEFAULT_DILATION: i32 = 1;
125
126 pub const DEFAULT_GROUPS: i32 = 1;
128}
129
130impl Module<&Array> for Conv1d {
131 type Error = Exception;
132 type Output = Array;
133
134 fn forward(&mut self, x: &Array) -> Result<Array, Self::Error> {
135 let mut y = conv1d(
136 x,
137 self.weight.as_ref(),
138 self.stride,
139 self.padding,
140 self.dilation,
141 self.groups,
142 )?;
143 if let Some(bias) = &self.bias.value {
144 y += bias;
145 }
146 Ok(y)
147 }
148
149 fn training_mode(&mut self, _: bool) {}
150}
151
152#[derive(Debug, Clone, Builder)]
154#[builder(
155 root = crate,
156 build_with = build_conv2d,
157 err = Exception,
158)]
159pub struct Conv2dBuilder {
160 pub input_channels: i32,
162
163 pub output_channels: i32,
165
166 pub kernel_size: SingleOrPair<i32>,
168
169 #[builder(optional, default = Conv2d::DEFAULT_BIAS)]
172 pub bias: bool,
173
174 #[builder(optional, default = Conv2d::DEFAULT_STRIDE)]
176 pub stride: SingleOrPair<i32>,
177
178 #[builder(optional, default = Conv2d::DEFAULT_PADDING)]
180 pub padding: SingleOrPair<i32>,
181
182 #[builder(optional, default = Conv2d::DEFAULT_DILATION)]
184 pub dilation: SingleOrPair<i32>,
185
186 #[builder(optional, default = Conv2d::DEFAULT_GROUPS)]
188 pub groups: i32,
189}
190
191fn build_conv2d(builder: Conv2dBuilder) -> Result<Conv2d, Exception> {
192 let input_channels = builder.input_channels;
193 let output_channels = builder.output_channels;
194 let kernel_size: (i32, i32) = builder.kernel_size.into();
195 let with_bias = builder.bias;
196 let padding = builder.padding.into();
197 let stride = builder.stride.into();
198 let dilation = builder.dilation.into();
199
200 let scale = f32::sqrt(1.0 / (input_channels * kernel_size.0 * kernel_size.1) as f32);
201 let weight = uniform::<_, f32>(
202 -scale,
203 scale,
204 &[
205 output_channels,
206 kernel_size.0,
207 kernel_size.1,
208 input_channels,
209 ],
210 None,
211 )?;
212 let bias = if with_bias {
213 Some(zeros::<f32>(&[output_channels])?)
214 } else {
215 None
216 };
217
218 Ok(Conv2d {
219 weight: Param::new(weight),
220 bias: Param::new(bias),
221 stride,
222 padding,
223 dilation,
224 groups: builder.groups,
225 })
226}
227
228#[derive(Debug, Clone, ModuleParameters, Buildable)]
237#[module(root = crate)]
238#[buildable(root = crate)]
239pub struct Conv2d {
240 #[param]
242 pub weight: Param<Array>,
243
244 #[param]
246 pub bias: Param<Option<Array>>,
247
248 pub stride: (i32, i32),
250
251 pub padding: (i32, i32),
253
254 pub dilation: (i32, i32),
256
257 pub groups: i32,
259}
260
261impl Conv2d {
262 pub const DEFAULT_BIAS: bool = true;
264
265 pub const DEFAULT_STRIDE: SingleOrPair = SingleOrPair::Pair(1, 1);
267
268 pub const DEFAULT_PADDING: SingleOrPair = SingleOrPair::Pair(0, 0);
270
271 pub const DEFAULT_DILATION: SingleOrPair = SingleOrPair::Pair(1, 1);
273
274 pub const DEFAULT_GROUPS: i32 = 1;
276}
277
278impl Module<&Array> for Conv2d {
279 type Error = Exception;
280 type Output = Array;
281
282 fn forward(&mut self, x: &Array) -> Result<Array, Self::Error> {
283 let mut y = conv2d(
284 x,
285 self.weight.as_ref(),
286 self.stride,
287 self.padding,
288 self.dilation,
289 self.groups,
290 )?;
291 if let Some(bias) = &self.bias.value {
292 y += bias;
293 }
294 Ok(y)
295 }
296
297 fn training_mode(&mut self, _: bool) {}
298}
299
300#[derive(Debug, Clone, Builder)]
302#[builder(
303 root = crate,
304 build_with = build_conv3d,
305 err = Exception,
306)]
307pub struct Conv3dBuilder {
308 pub input_channels: i32,
310
311 pub output_channels: i32,
313
314 pub kernel_size: SingleOrTriple<i32>,
316
317 #[builder(optional, default = Conv3d::DEFAULT_BIAS)]
320 pub bias: bool,
321
322 #[builder(optional, default = Conv3d::DEFAULT_STRIDE)]
324 pub stride: SingleOrTriple<i32>,
325
326 #[builder(optional, default = Conv3d::DEFAULT_PADDING)]
328 pub padding: SingleOrTriple<i32>,
329
330 #[builder(optional, default = Conv3d::DEFAULT_DILATION)]
332 pub dilation: SingleOrTriple<i32>,
333
334 #[builder(optional, default = Conv3d::DEFAULT_GROUPS)]
336 pub groups: i32,
337}
338
339fn build_conv3d(builder: Conv3dBuilder) -> Result<Conv3d, Exception> {
340 let input_channels = builder.input_channels;
341 let output_channels = builder.output_channels;
342 let kernel_size: (i32, i32, i32) = builder.kernel_size.into();
343 let with_bias = builder.bias;
344 let padding = builder.padding.into();
345 let stride = builder.stride.into();
346 let dilation = builder.dilation.into();
347
348 let scale =
349 f32::sqrt(1.0 / (input_channels * kernel_size.0 * kernel_size.1 * kernel_size.2) as f32);
350 let weight = uniform::<_, f32>(
351 -scale,
352 scale,
353 &[
354 output_channels,
355 kernel_size.0,
356 kernel_size.1,
357 kernel_size.2,
358 input_channels,
359 ],
360 None,
361 )?;
362 let bias = if with_bias {
363 Some(zeros::<f32>(&[output_channels])?)
364 } else {
365 None
366 };
367
368 Ok(Conv3d {
369 weight: Param::new(weight),
370 bias: Param::new(bias),
371 stride,
372 padding,
373 dilation,
374 groups: builder.groups,
375 })
376}
377
378#[derive(Debug, Clone, ModuleParameters, Buildable)]
387#[module(root = crate)]
388#[buildable(root = crate)]
389pub struct Conv3d {
390 #[param]
392 pub weight: Param<Array>,
393
394 #[param]
396 pub bias: Param<Option<Array>>,
397
398 pub stride: (i32, i32, i32),
400
401 pub padding: (i32, i32, i32),
403
404 pub dilation: (i32, i32, i32),
406
407 pub groups: i32,
409}
410
411impl Conv3d {
412 pub const DEFAULT_BIAS: bool = true;
414
415 pub const DEFAULT_STRIDE: SingleOrTriple<i32> = SingleOrTriple::Triple(1, 1, 1);
417
418 pub const DEFAULT_PADDING: SingleOrTriple<i32> = SingleOrTriple::Triple(0, 0, 0);
420
421 pub const DEFAULT_DILATION: SingleOrTriple<i32> = SingleOrTriple::Triple(1, 1, 1);
423
424 pub const DEFAULT_GROUPS: i32 = 1;
426}
427
428impl Module<&Array> for Conv3d {
429 type Error = Exception;
430 type Output = Array;
431
432 fn forward(&mut self, x: &Array) -> Result<Array, Self::Error> {
433 let mut y = crate::ops::conv3d(
434 x,
435 self.weight.as_ref(),
436 self.stride,
437 self.padding,
438 self.dilation,
439 self.groups,
440 )?;
441 if let Some(bias) = &self.bias.value {
442 y += bias;
443 }
444 Ok(y)
445 }
446
447 fn training_mode(&mut self, _: bool) {}
448}
449
450#[cfg(test)]
453mod tests {
454 use crate::module::Module;
455 use crate::{random::uniform, Dtype};
456 use float_eq::assert_float_eq;
457
458 use crate::nn::Conv1d;
459
460 #[test]
461 fn test_conv1d() {
462 crate::random::seed(819).unwrap();
463 let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 16], None).unwrap();
464 assert_eq!(a.shape(), &[2, 8, 16]);
465 assert_eq!(a.dtype(), Dtype::Float32);
466 assert_float_eq!(
467 a.mean(None, None).unwrap().item::<f32>(),
468 0.512_987_5,
469 abs <= 0.010_259_75
470 );
471 assert_float_eq!(
472 a.sum(None, None).unwrap().item::<f32>(),
473 131.324_8,
474 abs <= 2.626_496
475 );
476 let result = Conv1d::new(16, 2, 8).unwrap().forward(&a).unwrap();
477 assert_eq!(result.shape(), &[2, 1, 2]);
478 assert_eq!(result.dtype(), Dtype::Float32);
479 assert_float_eq!(
480 result.mean(None, None).unwrap().item::<f32>(),
481 0.264_865_2,
482 abs <= 0.005_297_303_7
483 );
484 assert_float_eq!(
485 result.sum(None, None).unwrap().item::<f32>(),
486 1.059_460_8,
487 abs <= 0.021_189_215
488 );
489 }
490
491 #[test]
492 fn test_conv2d() {
493 crate::random::seed(62).unwrap();
494 let a = uniform::<_, f32>(0.0, 1.0, &[2, 8, 8, 4], None).unwrap();
495 assert_eq!(a.shape(), &[2, 8, 8, 4]);
496 assert_eq!(a.dtype(), Dtype::Float32);
497 assert_float_eq!(
498 a.mean(None, None).unwrap().item::<f32>(),
499 0.522_504_27,
500 abs <= 0.010_450_086
501 );
502 assert_float_eq!(
503 a.sum(None, None).unwrap().item::<f32>(),
504 267.522_2,
505 abs <= 5.350_444
506 );
507 let result = crate::nn::Conv2d::new(4, 2, (8, 8))
508 .unwrap()
509 .forward(&a)
510 .unwrap();
511 assert_eq!(result.shape(), &[2, 1, 1, 2]);
512 assert_eq!(result.dtype(), Dtype::Float32);
513 assert_float_eq!(
514 result.mean(None, None).unwrap().item::<f32>(),
515 -0.279_321_5,
516 abs <= 0.005_586_43
517 );
518 assert_float_eq!(
519 result.sum(None, None).unwrap().item::<f32>(),
520 -1.117_286,
521 abs <= 0.022_345_72
522 );
523 }
524}