1use crate::error::Result;
2use crate::utils::guard::Guarded;
3use crate::utils::IntoOption;
4use crate::{Array, Stream, StreamOrDevice};
5use mlx_internal_macros::{default_device, generate_macro};
6
7#[generate_macro]
25#[default_device]
26#[allow(clippy::too_many_arguments)]
27pub fn conv_general_device<'a>(
28 array: impl AsRef<Array>,
29 weight: impl AsRef<Array>,
30 #[optional] strides: impl IntoOption<&'a [i32]>,
31 #[optional] padding: impl IntoOption<&'a [i32]>,
32 #[optional] kernel_dilation: impl IntoOption<&'a [i32]>,
33 #[optional] input_dilation: impl IntoOption<&'a [i32]>,
34 #[optional] groups: impl Into<Option<i32>>,
35 #[optional] flip: impl Into<Option<bool>>,
36 #[optional] stream: impl AsRef<Stream>,
37) -> Result<Array> {
38 let strides = strides.into_option().unwrap_or(&[1]);
39 let padding = padding.into_option().unwrap_or(&[0]);
40 let kernel_dilation = kernel_dilation.into_option().unwrap_or(&[1]);
41 let input_dilation = input_dilation.into_option().unwrap_or(&[1]);
42 let groups = groups.into().unwrap_or(1);
43 let flip = flip.into().unwrap_or(false);
44
45 Array::try_from_op(|res| unsafe {
46 mlx_sys::mlx_conv_general(
47 res,
48 array.as_ref().as_ptr(),
49 weight.as_ref().as_ptr(),
50 strides.as_ptr(),
51 strides.len(),
52 padding.as_ptr(),
53 padding.len(),
54 padding.as_ptr(),
55 padding.len(),
56 kernel_dilation.as_ptr(),
57 kernel_dilation.len(),
58 input_dilation.as_ptr(),
59 input_dilation.len(),
60 groups,
61 flip,
62 stream.as_ref().as_ptr(),
63 )
64 })
65}
66
67#[generate_macro]
80#[default_device]
81pub fn conv1d_device(
82 array: impl AsRef<Array>,
83 weight: impl AsRef<Array>,
84 #[optional] stride: impl Into<Option<i32>>,
85 #[optional] padding: impl Into<Option<i32>>,
86 #[optional] dilation: impl Into<Option<i32>>,
87 #[optional] groups: impl Into<Option<i32>>,
88 #[optional] stream: impl AsRef<Stream>,
89) -> Result<Array> {
90 let stride = stride.into().unwrap_or(1);
91 let padding = padding.into().unwrap_or(0);
92 let dilation = dilation.into().unwrap_or(1);
93 let groups = groups.into().unwrap_or(1);
94
95 Array::try_from_op(|res| unsafe {
96 mlx_sys::mlx_conv1d(
97 res,
98 array.as_ref().as_ptr(),
99 weight.as_ref().as_ptr(),
100 stride,
101 padding,
102 dilation,
103 groups,
104 stream.as_ref().as_ptr(),
105 )
106 })
107}
108
109#[generate_macro]
122#[default_device]
123pub fn conv2d_device(
124 array: impl AsRef<Array>,
125 weight: impl AsRef<Array>,
126 #[optional] stride: impl Into<Option<(i32, i32)>>,
127 #[optional] padding: impl Into<Option<(i32, i32)>>,
128 #[optional] dilation: impl Into<Option<(i32, i32)>>,
129 #[optional] groups: impl Into<Option<i32>>,
130 #[optional] stream: impl AsRef<Stream>,
131) -> Result<Array> {
132 let stride = stride.into().unwrap_or((1, 1));
133 let padding = padding.into().unwrap_or((0, 0));
134 let dilation = dilation.into().unwrap_or((1, 1));
135 let groups = groups.into().unwrap_or(1);
136
137 Array::try_from_op(|res| unsafe {
138 mlx_sys::mlx_conv2d(
139 res,
140 array.as_ref().as_ptr(),
141 weight.as_ref().as_ptr(),
142 stride.0,
143 stride.1,
144 padding.0,
145 padding.1,
146 dilation.0,
147 dilation.1,
148 groups,
149 stream.as_ref().as_ptr(),
150 )
151 })
152}
153
154#[generate_macro]
158#[default_device]
159pub fn conv3d_device(
160 array: impl AsRef<Array>,
161 weight: impl AsRef<Array>,
162 #[optional] stride: impl Into<Option<(i32, i32, i32)>>,
163 #[optional] padding: impl Into<Option<(i32, i32, i32)>>,
164 #[optional] dilation: impl Into<Option<(i32, i32, i32)>>,
165 #[optional] groups: impl Into<Option<i32>>,
166 #[optional] stream: impl AsRef<Stream>,
167) -> Result<Array> {
168 let stride = stride.into().unwrap_or((1, 1, 1));
169 let padding = padding.into().unwrap_or((0, 0, 0));
170 let dilation = dilation.into().unwrap_or((1, 1, 1));
171 let groups = groups.into().unwrap_or(1);
172
173 Array::try_from_op(|res| unsafe {
174 mlx_sys::mlx_conv3d(
175 res,
176 array.as_ref().as_ptr(),
177 weight.as_ref().as_ptr(),
178 stride.0,
179 stride.1,
180 stride.2,
181 padding.0,
182 padding.1,
183 padding.2,
184 dilation.0,
185 dilation.1,
186 dilation.2,
187 groups,
188 stream.as_ref().as_ptr(),
189 )
190 })
191}
192
193#[generate_macro]
207#[default_device]
208pub fn conv_transpose1d_device(
209 array: impl AsRef<Array>,
210 weight: impl AsRef<Array>,
211 #[optional] stride: impl Into<Option<i32>>,
212 #[optional] padding: impl Into<Option<i32>>,
213 #[optional] dilation: impl Into<Option<i32>>,
214 #[optional] groups: impl Into<Option<i32>>,
215 #[optional] stream: impl AsRef<Stream>,
216) -> Result<Array> {
217 let stride = stride.into().unwrap_or(1);
218 let padding = padding.into().unwrap_or(0);
219 let dilation = dilation.into().unwrap_or(1);
220 let groups = groups.into().unwrap_or(1);
221
222 Array::try_from_op(|res| unsafe {
223 mlx_sys::mlx_conv_transpose1d(
224 res,
225 array.as_ref().as_ptr(),
226 weight.as_ref().as_ptr(),
227 stride,
228 padding,
229 dilation,
230 groups,
231 stream.as_ref().as_ptr(),
232 )
233 })
234}
235
236#[generate_macro]
251#[default_device]
252pub fn conv_transpose2d_device(
253 array: impl AsRef<Array>,
254 weight: impl AsRef<Array>,
255 #[optional] stride: impl Into<Option<(i32, i32)>>,
256 #[optional] padding: impl Into<Option<(i32, i32)>>,
257 #[optional] dilation: impl Into<Option<(i32, i32)>>,
258 #[optional] groups: impl Into<Option<i32>>,
259 #[optional] stream: impl AsRef<Stream>,
260) -> Result<Array> {
261 let stride = stride.into().unwrap_or((1, 1));
262 let padding = padding.into().unwrap_or((0, 0));
263 let dilation = dilation.into().unwrap_or((1, 1));
264 let groups = groups.into().unwrap_or(1);
265
266 Array::try_from_op(|res| unsafe {
267 mlx_sys::mlx_conv_transpose2d(
268 res,
269 array.as_ref().as_ptr(),
270 weight.as_ref().as_ptr(),
271 stride.0,
272 stride.1,
273 padding.0,
274 padding.1,
275 dilation.0,
276 dilation.1,
277 groups,
278 stream.as_ref().as_ptr(),
279 )
280 })
281}
282
283#[generate_macro]
298#[default_device]
299pub fn conv_transpose3d_device(
300 array: impl AsRef<Array>,
301 weight: impl AsRef<Array>,
302 #[optional] stride: impl Into<Option<(i32, i32, i32)>>,
303 #[optional] padding: impl Into<Option<(i32, i32, i32)>>,
304 #[optional] dilation: impl Into<Option<(i32, i32, i32)>>,
305 #[optional] groups: impl Into<Option<i32>>,
306 #[optional] stream: impl AsRef<Stream>,
307) -> Result<Array> {
308 let stride = stride.into().unwrap_or((1, 1, 1));
309 let padding = padding.into().unwrap_or((0, 0, 0));
310 let dilation = dilation.into().unwrap_or((1, 1, 1));
311 let groups = groups.into().unwrap_or(1);
312
313 Array::try_from_op(|res| unsafe {
314 mlx_sys::mlx_conv_transpose3d(
315 res,
316 array.as_ref().as_ptr(),
317 weight.as_ref().as_ptr(),
318 stride.0,
319 stride.1,
320 stride.2,
321 padding.0,
322 padding.1,
323 padding.2,
324 dilation.0,
325 dilation.1,
326 dilation.2,
327 groups,
328 stream.as_ref().as_ptr(),
329 )
330 })
331}
332
333#[cfg(test)]
334mod tests {
335 use super::*;
336 use pretty_assertions::assert_eq;
337
338 #[test]
339 fn test_conv1d_complex_device() {
340 let input_data = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
342 let input_array = Array::from_slice(&input_data, &[1, 5, 2]);
343
344 let weight_data = [0.5, 0.0, -0.5, 1.0, 0.0, 1.5, 2.0, 0.0, -2.0, 1.5, 0.0, 1.0];
346 let weight_array = Array::from_slice(&weight_data, &[2, 3, 2]);
347
348 let result = conv1d(
349 &input_array,
350 &weight_array,
351 Some(1), Some(0), Some(1), Some(1), )
356 .unwrap();
357
358 let expected_output = [12.0, 8.0, 17.0, 13.0, 22.0, 18.0];
359 assert_eq!(result.shape(), &[1, 3, 2]);
360 assert_eq!(result.as_slice::<f32>(), &expected_output);
361 }
362
363 #[test]
364 fn test_conv_transpose1d() {
365 let input = Array::from_slice(&[1.0, 2.0, 3.0], &[1, 3, 1]);
367 let weights = Array::from_slice(&[1.0, 0.5], &[1, 2, 1]);
369
370 let result = conv_transpose1d(
371 &input,
372 &weights,
373 Some(1), Some(0), Some(1), Some(1), )
378 .unwrap();
379
380 let expected = [1.0, 2.5, 4.0, 1.5];
381 assert_eq!(result.shape(), &[1, 4, 1]);
382 assert_eq!(result.as_slice::<f32>(), &expected);
383 }
384
385 #[test]
386 fn test_conv2d() {
387 let input_data = [1.0, 2.0, 3.0, 4.0];
389 let input_shape = [1, 2, 2, 1]; let input_array = Array::from_slice(&input_data, &input_shape);
391
392 let weight_data = [1.0, 0.0, 0.0, 1.0];
394 let weight_shape = [1, 2, 2, 1]; let weight_array = Array::from_slice(&weight_data, &weight_shape);
396
397 let result = conv2d(
399 &input_array,
400 &weight_array,
401 Some((1, 1)), Some((0, 0)), Some((1, 1)), Some(1), )
406 .unwrap();
407
408 let expected_output = 1.0 * 1.0 + 2.0 * 0.0 + 3.0 * 0.0 + 4.0 * 1.0; assert_eq!(result.as_slice::<f32>(), &[expected_output]);
411 }
412
413 #[test]
414 fn test_conv_transpose2d() {
415 let input = Array::from_slice(&[1.0, 2.0, 3.0, 4.0], &[1, 2, 2, 1]);
417 let weights = Array::from_slice(&[1.0, 0.0, 0.0, 1.0], &[1, 2, 2, 1]);
419
420 let result = conv_transpose2d(
421 &input,
422 &weights,
423 Some((1, 1)), Some((0, 0)), Some((1, 1)), Some(1), )
428 .unwrap();
429
430 let expected = [1.0, 2.0, 0.0, 3.0, 5.0, 2.0, 0.0, 3.0, 4.0];
431 assert_eq!(result.shape(), &[1, 3, 3, 1]);
432 assert_eq!(result.as_slice::<f32>(), &expected);
433 }
434
435 #[test]
436 fn test_conv3d() {
437 let input_data = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
439 let input_shape = [1, 2, 2, 2, 1]; let input_array = Array::from_slice(&input_data, &input_shape);
441
442 let weight_data = [1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0];
444 let weight_shape = [1, 2, 2, 2, 1]; let weight_array = Array::from_slice(&weight_data, &weight_shape);
446
447 let result = conv3d(
449 &input_array,
450 &weight_array,
451 Some((1, 1, 1)), Some((0, 0, 0)), Some((1, 1, 1)), Some(1), )
456 .unwrap();
457
458 let expected_output = 1.0 * 1.0
460 + 2.0 * 0.0
461 + 3.0 * 0.0
462 + 4.0 * 1.0
463 + 5.0 * 0.0
464 + 6.0 * 1.0
465 + 7.0 * 1.0
466 + 8.0 * 0.0; assert_eq!(result.as_slice::<f32>(), &[expected_output]);
468 }
469
470 #[test]
471 fn test_conv_transpose3d() {
472 let input = Array::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], &[1, 2, 2, 2, 1]);
474 let weights =
476 Array::from_slice(&[1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0], &[1, 2, 2, 2, 1]);
477
478 let result = conv_transpose3d(
479 &input,
480 &weights,
481 Some((1, 1, 1)), Some((0, 0, 0)), Some((1, 1, 1)), Some(1), )
486 .unwrap();
487
488 assert_eq!(result.shape(), &[1, 3, 3, 3, 1]);
489 }
490
491 #[test]
492 fn test_conv_wrong_dimensions() {
493 let input_data = [1.0, 2.0, 3.0, 4.0];
494 let input_shape = [1, 2, 2, 1]; let input_array = Array::from_slice(&input_data, &input_shape);
496
497 let weight_data = [1.0, 0.0, 0.0, 1.0];
498 let weight_shape = [1, 2, 2]; let weight_array = Array::from_slice(&weight_data, &weight_shape);
500
501 let result = conv2d(
502 &input_array,
503 &weight_array,
504 Some((1, 1)), Some((0, 0)), Some((1, 1)), Some(1), );
509
510 assert!(result.is_err());
511 }
512
513 #[test]
514 fn test_conv_invalid_group_size() {
515 let input_data = [1.0, 2.0, 3.0, 4.0];
516 let input_shape = [1, 2, 2, 1]; let input_array = Array::from_slice(&input_data, &input_shape);
518
519 let weight_data = [1.0, 0.0, 0.0, 1.0];
520 let weight_shape = [1, 2, 2, 1]; let weight_array = Array::from_slice(&weight_data, &weight_shape);
522
523 let result = conv2d(
524 &input_array,
525 &weight_array,
526 Some((1, 1)), Some((0, 0)), Some((1, 1)), Some(2), );
531
532 assert!(result.is_err());
533 }
534
535 #[test]
536 fn test_conv_non_float() {
537 let input_data = [1, 2, 3, 4];
538 let input_shape = [1, 2, 2, 1]; let input_array = Array::from_slice(&input_data, &input_shape);
540
541 let weight_data = [1, 0, 0, 1];
542 let weight_shape = [1, 2, 2, 1]; let weight_array = Array::from_slice(&weight_data, &weight_shape);
544
545 let result = conv2d(
546 &input_array,
547 &weight_array,
548 Some((1, 1)), Some((0, 0)), Some((1, 1)), Some(1), );
553
554 assert!(result.is_err());
555 }
556}