1use crate::error::Result;
2use crate::utils::guard::Guarded;
3use crate::utils::IntoOption;
4use crate::{Array, Stream};
5use mlx_internal_macros::{default_device, generate_macro};
6
7#[generate_macro]
25#[default_device]
26#[allow(clippy::too_many_arguments)]
27pub fn conv_general_device<'a>(
28 array: impl AsRef<Array>,
29 weight: impl AsRef<Array>,
30 #[optional] strides: impl IntoOption<&'a [i32]>,
31 #[optional] padding: impl IntoOption<&'a [i32]>,
32 #[optional] kernel_dilation: impl IntoOption<&'a [i32]>,
33 #[optional] input_dilation: impl IntoOption<&'a [i32]>,
34 #[optional] groups: impl Into<Option<i32>>,
35 #[optional] flip: impl Into<Option<bool>>,
36 #[optional] stream: impl AsRef<Stream>,
37) -> Result<Array> {
38 let strides = strides.into_option().unwrap_or(&[1]);
39 let padding = padding.into_option().unwrap_or(&[0]);
40 let kernel_dilation = kernel_dilation.into_option().unwrap_or(&[1]);
41 let input_dilation = input_dilation.into_option().unwrap_or(&[1]);
42 let groups = groups.into().unwrap_or(1);
43 let flip = flip.into().unwrap_or(false);
44
45 Array::try_from_op(|res| unsafe {
46 mlx_sys::mlx_conv_general(
47 res,
48 array.as_ref().as_ptr(),
49 weight.as_ref().as_ptr(),
50 strides.as_ptr(),
51 strides.len(),
52 padding.as_ptr(),
53 padding.len(),
54 padding.as_ptr(),
55 padding.len(),
56 kernel_dilation.as_ptr(),
57 kernel_dilation.len(),
58 input_dilation.as_ptr(),
59 input_dilation.len(),
60 groups,
61 flip,
62 stream.as_ref().as_ptr(),
63 )
64 })
65}
66
67#[generate_macro]
80#[default_device]
81pub fn conv1d_device(
82 array: impl AsRef<Array>,
83 weight: impl AsRef<Array>,
84 #[optional] stride: impl Into<Option<i32>>,
85 #[optional] padding: impl Into<Option<i32>>,
86 #[optional] dilation: impl Into<Option<i32>>,
87 #[optional] groups: impl Into<Option<i32>>,
88 #[optional] stream: impl AsRef<Stream>,
89) -> Result<Array> {
90 let stride = stride.into().unwrap_or(1);
91 let padding = padding.into().unwrap_or(0);
92 let dilation = dilation.into().unwrap_or(1);
93 let groups = groups.into().unwrap_or(1);
94
95 Array::try_from_op(|res| unsafe {
96 mlx_sys::mlx_conv1d(
97 res,
98 array.as_ref().as_ptr(),
99 weight.as_ref().as_ptr(),
100 stride,
101 padding,
102 dilation,
103 groups,
104 stream.as_ref().as_ptr(),
105 )
106 })
107}
108
109#[generate_macro]
122#[default_device]
123pub fn conv2d_device(
124 array: impl AsRef<Array>,
125 weight: impl AsRef<Array>,
126 #[optional] stride: impl Into<Option<(i32, i32)>>,
127 #[optional] padding: impl Into<Option<(i32, i32)>>,
128 #[optional] dilation: impl Into<Option<(i32, i32)>>,
129 #[optional] groups: impl Into<Option<i32>>,
130 #[optional] stream: impl AsRef<Stream>,
131) -> Result<Array> {
132 let stride = stride.into().unwrap_or((1, 1));
133 let padding = padding.into().unwrap_or((0, 0));
134 let dilation = dilation.into().unwrap_or((1, 1));
135 let groups = groups.into().unwrap_or(1);
136
137 Array::try_from_op(|res| unsafe {
138 mlx_sys::mlx_conv2d(
139 res,
140 array.as_ref().as_ptr(),
141 weight.as_ref().as_ptr(),
142 stride.0,
143 stride.1,
144 padding.0,
145 padding.1,
146 dilation.0,
147 dilation.1,
148 groups,
149 stream.as_ref().as_ptr(),
150 )
151 })
152}
153
154#[generate_macro]
158#[default_device]
159pub fn conv3d_device(
160 array: impl AsRef<Array>,
161 weight: impl AsRef<Array>,
162 #[optional] stride: impl Into<Option<(i32, i32, i32)>>,
163 #[optional] padding: impl Into<Option<(i32, i32, i32)>>,
164 #[optional] dilation: impl Into<Option<(i32, i32, i32)>>,
165 #[optional] groups: impl Into<Option<i32>>,
166 #[optional] stream: impl AsRef<Stream>,
167) -> Result<Array> {
168 let stride = stride.into().unwrap_or((1, 1, 1));
169 let padding = padding.into().unwrap_or((0, 0, 0));
170 let dilation = dilation.into().unwrap_or((1, 1, 1));
171 let groups = groups.into().unwrap_or(1);
172
173 Array::try_from_op(|res| unsafe {
174 mlx_sys::mlx_conv3d(
175 res,
176 array.as_ref().as_ptr(),
177 weight.as_ref().as_ptr(),
178 stride.0,
179 stride.1,
180 stride.2,
181 padding.0,
182 padding.1,
183 padding.2,
184 dilation.0,
185 dilation.1,
186 dilation.2,
187 groups,
188 stream.as_ref().as_ptr(),
189 )
190 })
191}
192
193#[allow(clippy::too_many_arguments)]
207#[generate_macro]
208#[default_device]
209pub fn conv_transpose1d_device(
210 array: impl AsRef<Array>,
211 weight: impl AsRef<Array>,
212 #[optional] stride: impl Into<Option<i32>>,
213 #[optional] padding: impl Into<Option<i32>>,
214 #[optional] dilation: impl Into<Option<i32>>,
215 #[optional] output_padding: impl Into<Option<i32>>,
216 #[optional] groups: impl Into<Option<i32>>,
217 #[optional] stream: impl AsRef<Stream>,
218) -> Result<Array> {
219 let stride = stride.into().unwrap_or(1);
220 let padding = padding.into().unwrap_or(0);
221 let dilation = dilation.into().unwrap_or(1);
222 let output_padding = output_padding.into().unwrap_or(0);
223 let groups = groups.into().unwrap_or(1);
224
225 Array::try_from_op(|res| unsafe {
226 mlx_sys::mlx_conv_transpose1d(
227 res,
228 array.as_ref().as_ptr(),
229 weight.as_ref().as_ptr(),
230 stride,
231 padding,
232 dilation,
233 output_padding,
234 groups,
235 stream.as_ref().as_ptr(),
236 )
237 })
238}
239
240#[allow(clippy::too_many_arguments)]
255#[generate_macro]
256#[default_device]
257pub fn conv_transpose2d_device(
258 array: impl AsRef<Array>,
259 weight: impl AsRef<Array>,
260 #[optional] stride: impl Into<Option<(i32, i32)>>,
261 #[optional] padding: impl Into<Option<(i32, i32)>>,
262 #[optional] dilation: impl Into<Option<(i32, i32)>>,
263 #[optional] output_padding: impl Into<Option<(i32, i32)>>,
264 #[optional] groups: impl Into<Option<i32>>,
265 #[optional] stream: impl AsRef<Stream>,
266) -> Result<Array> {
267 let stride = stride.into().unwrap_or((1, 1));
268 let padding = padding.into().unwrap_or((0, 0));
269 let dilation = dilation.into().unwrap_or((1, 1));
270 let output_padding = output_padding.into().unwrap_or((0, 0));
271 let groups = groups.into().unwrap_or(1);
272
273 Array::try_from_op(|res| unsafe {
274 mlx_sys::mlx_conv_transpose2d(
275 res,
276 array.as_ref().as_ptr(),
277 weight.as_ref().as_ptr(),
278 stride.0,
279 stride.1,
280 padding.0,
281 padding.1,
282 dilation.0,
283 dilation.1,
284 output_padding.0,
285 output_padding.1,
286 groups,
287 stream.as_ref().as_ptr(),
288 )
289 })
290}
291
292#[allow(clippy::too_many_arguments)]
307#[generate_macro]
308#[default_device]
309pub fn conv_transpose3d_device(
310 array: impl AsRef<Array>,
311 weight: impl AsRef<Array>,
312 #[optional] stride: impl Into<Option<(i32, i32, i32)>>,
313 #[optional] padding: impl Into<Option<(i32, i32, i32)>>,
314 #[optional] dilation: impl Into<Option<(i32, i32, i32)>>,
315 #[optional] output_padding: impl Into<Option<(i32, i32, i32)>>,
316 #[optional] groups: impl Into<Option<i32>>,
317 #[optional] stream: impl AsRef<Stream>,
318) -> Result<Array> {
319 let stride = stride.into().unwrap_or((1, 1, 1));
320 let padding = padding.into().unwrap_or((0, 0, 0));
321 let dilation = dilation.into().unwrap_or((1, 1, 1));
322 let output_padding = output_padding.into().unwrap_or((0, 0, 0));
323 let groups = groups.into().unwrap_or(1);
324
325 Array::try_from_op(|res| unsafe {
326 mlx_sys::mlx_conv_transpose3d(
327 res,
328 array.as_ref().as_ptr(),
329 weight.as_ref().as_ptr(),
330 stride.0,
331 stride.1,
332 stride.2,
333 padding.0,
334 padding.1,
335 padding.2,
336 dilation.0,
337 dilation.1,
338 dilation.2,
339 output_padding.0,
340 output_padding.1,
341 output_padding.2,
342 groups,
343 stream.as_ref().as_ptr(),
344 )
345 })
346}
347
348#[cfg(test)]
349mod tests {
350 use super::*;
351 use pretty_assertions::assert_eq;
352
353 #[test]
354 fn test_conv1d_complex_device() {
355 let input_data = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
357 let input_array = Array::from_slice(&input_data, &[1, 5, 2]);
358
359 let weight_data = [0.5, 0.0, -0.5, 1.0, 0.0, 1.5, 2.0, 0.0, -2.0, 1.5, 0.0, 1.0];
361 let weight_array = Array::from_slice(&weight_data, &[2, 3, 2]);
362
363 let result = conv1d(
364 &input_array,
365 &weight_array,
366 Some(1), Some(0), Some(1), Some(1), )
371 .unwrap();
372
373 let expected_output = [12.0, 8.0, 17.0, 13.0, 22.0, 18.0];
374 assert_eq!(result.shape(), &[1, 3, 2]);
375 assert_eq!(result.as_slice::<f32>(), &expected_output);
376 }
377
378 #[test]
379 fn test_conv_transpose1d() {
380 let input = Array::from_slice(&[1.0, 2.0, 3.0], &[1, 3, 1]);
382 let weights = Array::from_slice(&[1.0, 0.5], &[1, 2, 1]);
384
385 let result = conv_transpose1d(
386 &input,
387 &weights,
388 Some(1), Some(0), Some(1), None, Some(1), )
394 .unwrap();
395
396 let expected = [1.0, 2.5, 4.0, 1.5];
397 assert_eq!(result.shape(), &[1, 4, 1]);
398 assert_eq!(result.as_slice::<f32>(), &expected);
399 }
400
401 #[test]
402 fn test_conv2d() {
403 let input_data = [1.0, 2.0, 3.0, 4.0];
405 let input_shape = [1, 2, 2, 1]; let input_array = Array::from_slice(&input_data, &input_shape);
407
408 let weight_data = [1.0, 0.0, 0.0, 1.0];
410 let weight_shape = [1, 2, 2, 1]; let weight_array = Array::from_slice(&weight_data, &weight_shape);
412
413 let result = conv2d(
415 &input_array,
416 &weight_array,
417 Some((1, 1)), Some((0, 0)), Some((1, 1)), Some(1), )
422 .unwrap();
423
424 let expected_output = 1.0 * 1.0 + 2.0 * 0.0 + 3.0 * 0.0 + 4.0 * 1.0; assert_eq!(result.as_slice::<f32>(), &[expected_output]);
427 }
428
429 #[test]
430 fn test_conv_transpose2d() {
431 let input = Array::from_slice(&[1.0, 2.0, 3.0, 4.0], &[1, 2, 2, 1]);
433 let weights = Array::from_slice(&[1.0, 0.0, 0.0, 1.0], &[1, 2, 2, 1]);
435
436 let result = conv_transpose2d(
437 &input,
438 &weights,
439 Some((1, 1)), Some((0, 0)), Some((1, 1)), None, Some(1), )
445 .unwrap();
446
447 let expected = [1.0, 2.0, 0.0, 3.0, 5.0, 2.0, 0.0, 3.0, 4.0];
448 assert_eq!(result.shape(), &[1, 3, 3, 1]);
449 assert_eq!(result.as_slice::<f32>(), &expected);
450 }
451
452 #[test]
453 fn test_conv3d() {
454 let input_data = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
456 let input_shape = [1, 2, 2, 2, 1]; let input_array = Array::from_slice(&input_data, &input_shape);
458
459 let weight_data = [1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0];
461 let weight_shape = [1, 2, 2, 2, 1]; let weight_array = Array::from_slice(&weight_data, &weight_shape);
463
464 let result = conv3d(
466 &input_array,
467 &weight_array,
468 Some((1, 1, 1)), Some((0, 0, 0)), Some((1, 1, 1)), Some(1), )
473 .unwrap();
474
475 let expected_output = 1.0 * 1.0
477 + 2.0 * 0.0
478 + 3.0 * 0.0
479 + 4.0 * 1.0
480 + 5.0 * 0.0
481 + 6.0 * 1.0
482 + 7.0 * 1.0
483 + 8.0 * 0.0; assert_eq!(result.as_slice::<f32>(), &[expected_output]);
485 }
486
487 #[test]
488 fn test_conv_transpose3d() {
489 let input = Array::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], &[1, 2, 2, 2, 1]);
491 let weights =
493 Array::from_slice(&[1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0], &[1, 2, 2, 2, 1]);
494
495 let result = conv_transpose3d(
496 &input,
497 &weights,
498 Some((1, 1, 1)), Some((0, 0, 0)), Some((1, 1, 1)), None, Some(1), )
504 .unwrap();
505
506 assert_eq!(result.shape(), &[1, 3, 3, 3, 1]);
507 }
508
509 #[test]
510 fn test_conv_wrong_dimensions() {
511 let input_data = [1.0, 2.0, 3.0, 4.0];
512 let input_shape = [1, 2, 2, 1]; let input_array = Array::from_slice(&input_data, &input_shape);
514
515 let weight_data = [1.0, 0.0, 0.0, 1.0];
516 let weight_shape = [1, 2, 2]; let weight_array = Array::from_slice(&weight_data, &weight_shape);
518
519 let result = conv2d(
520 &input_array,
521 &weight_array,
522 Some((1, 1)), Some((0, 0)), Some((1, 1)), Some(1), );
527
528 assert!(result.is_err());
529 }
530
531 #[test]
532 fn test_conv_invalid_group_size() {
533 let input_data = [1.0, 2.0, 3.0, 4.0];
534 let input_shape = [1, 2, 2, 1]; let input_array = Array::from_slice(&input_data, &input_shape);
536
537 let weight_data = [1.0, 0.0, 0.0, 1.0];
538 let weight_shape = [1, 2, 2, 1]; let weight_array = Array::from_slice(&weight_data, &weight_shape);
540
541 let result = conv2d(
542 &input_array,
543 &weight_array,
544 Some((1, 1)), Some((0, 0)), Some((1, 1)), Some(2), );
549
550 assert!(result.is_err());
551 }
552
553 #[test]
554 fn test_conv_non_float() {
555 let input_data = [1, 2, 3, 4];
556 let input_shape = [1, 2, 2, 1]; let input_array = Array::from_slice(&input_data, &input_shape);
558
559 let weight_data = [1, 0, 0, 1];
560 let weight_shape = [1, 2, 2, 1]; let weight_array = Array::from_slice(&weight_data, &weight_shape);
562
563 let result = conv2d(
564 &input_array,
565 &weight_array,
566 Some((1, 1)), Some((0, 0)), Some((1, 1)), Some(1), );
571
572 assert!(result.is_err());
573 }
574}