mlx_rs/ops/
convolution.rs

1use crate::error::Result;
2use crate::utils::guard::Guarded;
3use crate::utils::IntoOption;
4use crate::{Array, Stream, StreamOrDevice};
5use mlx_internal_macros::{default_device, generate_macro};
6
7/// General convolution over an input with several channels returning an error if the inputs are invalid.
8///
9/// - Only 1d and 2d convolutions are supported at the moment
10/// - the default `groups: 1` is currently supported
11///
12/// # Params
13///
14/// - array: Input array of shape `&[N, ..., C_in]`
15/// - weight: Weight array of shape `&[C_out, ..., C_in]`
16/// - strides: The kernel strides. All dimensions get the same stride if only one number is specified.
17/// - padding: The input padding. All dimensions get the same padding if only one number is specified.
18/// - kernel_dilation: The kernel dilation. All dimensions get the same dilation if only one number is specified.
19/// - input_dilation: The input dilation. All dimensions get the same dilation if only one number is specified.
20/// - groups: Input feature groups
21/// - flip: Flip the order in which the spatial dimensions of the weights are processed.
22///   Performs the cross-correlation operator when `flip` is `false` and the convolution
23///   operator otherwise.
24#[generate_macro]
25#[default_device]
26#[allow(clippy::too_many_arguments)]
27pub fn conv_general_device<'a>(
28    array: impl AsRef<Array>,
29    weight: impl AsRef<Array>,
30    #[optional] strides: impl IntoOption<&'a [i32]>,
31    #[optional] padding: impl IntoOption<&'a [i32]>,
32    #[optional] kernel_dilation: impl IntoOption<&'a [i32]>,
33    #[optional] input_dilation: impl IntoOption<&'a [i32]>,
34    #[optional] groups: impl Into<Option<i32>>,
35    #[optional] flip: impl Into<Option<bool>>,
36    #[optional] stream: impl AsRef<Stream>,
37) -> Result<Array> {
38    let strides = strides.into_option().unwrap_or(&[1]);
39    let padding = padding.into_option().unwrap_or(&[0]);
40    let kernel_dilation = kernel_dilation.into_option().unwrap_or(&[1]);
41    let input_dilation = input_dilation.into_option().unwrap_or(&[1]);
42    let groups = groups.into().unwrap_or(1);
43    let flip = flip.into().unwrap_or(false);
44
45    Array::try_from_op(|res| unsafe {
46        mlx_sys::mlx_conv_general(
47            res,
48            array.as_ref().as_ptr(),
49            weight.as_ref().as_ptr(),
50            strides.as_ptr(),
51            strides.len(),
52            padding.as_ptr(),
53            padding.len(),
54            padding.as_ptr(),
55            padding.len(),
56            kernel_dilation.as_ptr(),
57            kernel_dilation.len(),
58            input_dilation.as_ptr(),
59            input_dilation.len(),
60            groups,
61            flip,
62            stream.as_ref().as_ptr(),
63        )
64    })
65}
66
67/// 1D convolution over an input with several channels returning an error if the inputs are invalid.
68///
69/// Only the default `groups=1` is currently supported.
70///
71/// # Params
72///
73/// - array: input array of shape `&[N, H, C_in]`
74/// - weight: weight array of shape `&[C_out, H, C_in]`
75/// - stride: kernel stride. Default to 1 if not specified.
76/// - padding: input padding. Default to 0 if not specified.
77/// - dilation: kernel dilation. Default to 1 if not specified.
78/// - groups: input feature groups. Default to 1 if not specified.
79#[generate_macro]
80#[default_device]
81pub fn conv1d_device(
82    array: impl AsRef<Array>,
83    weight: impl AsRef<Array>,
84    #[optional] stride: impl Into<Option<i32>>,
85    #[optional] padding: impl Into<Option<i32>>,
86    #[optional] dilation: impl Into<Option<i32>>,
87    #[optional] groups: impl Into<Option<i32>>,
88    #[optional] stream: impl AsRef<Stream>,
89) -> Result<Array> {
90    let stride = stride.into().unwrap_or(1);
91    let padding = padding.into().unwrap_or(0);
92    let dilation = dilation.into().unwrap_or(1);
93    let groups = groups.into().unwrap_or(1);
94
95    Array::try_from_op(|res| unsafe {
96        mlx_sys::mlx_conv1d(
97            res,
98            array.as_ref().as_ptr(),
99            weight.as_ref().as_ptr(),
100            stride,
101            padding,
102            dilation,
103            groups,
104            stream.as_ref().as_ptr(),
105        )
106    })
107}
108
109/// 2D convolution over an input with several channels returning an error if the inputs are invalid.
110///
111/// Only the default `groups=1` is currently supported.
112///
113/// # Params
114///
115/// - array: input array of shape `[N, H, W, C_in]`
116/// - weight: weight array of shape `[C_out, H, W, C_in]`
117/// - stride: kernel stride. Default to (1, 1) if not specified.
118/// - padding: input padding. Default to (0, 0) if not specified.
119/// - dilation: kernel dilation. Default to (1, 1) if not specified.
120/// - groups: input feature groups. Default to 1 if not specified.
121#[generate_macro]
122#[default_device]
123pub fn conv2d_device(
124    array: impl AsRef<Array>,
125    weight: impl AsRef<Array>,
126    #[optional] stride: impl Into<Option<(i32, i32)>>,
127    #[optional] padding: impl Into<Option<(i32, i32)>>,
128    #[optional] dilation: impl Into<Option<(i32, i32)>>,
129    #[optional] groups: impl Into<Option<i32>>,
130    #[optional] stream: impl AsRef<Stream>,
131) -> Result<Array> {
132    let stride = stride.into().unwrap_or((1, 1));
133    let padding = padding.into().unwrap_or((0, 0));
134    let dilation = dilation.into().unwrap_or((1, 1));
135    let groups = groups.into().unwrap_or(1);
136
137    Array::try_from_op(|res| unsafe {
138        mlx_sys::mlx_conv2d(
139            res,
140            array.as_ref().as_ptr(),
141            weight.as_ref().as_ptr(),
142            stride.0,
143            stride.1,
144            padding.0,
145            padding.1,
146            dilation.0,
147            dilation.1,
148            groups,
149            stream.as_ref().as_ptr(),
150        )
151    })
152}
153
154/// 3D convolution over an input with several channels.
155///
156/// Only the default `groups=1` is currently supported.
157#[generate_macro]
158#[default_device]
159pub fn conv3d_device(
160    array: impl AsRef<Array>,
161    weight: impl AsRef<Array>,
162    #[optional] stride: impl Into<Option<(i32, i32, i32)>>,
163    #[optional] padding: impl Into<Option<(i32, i32, i32)>>,
164    #[optional] dilation: impl Into<Option<(i32, i32, i32)>>,
165    #[optional] groups: impl Into<Option<i32>>,
166    #[optional] stream: impl AsRef<Stream>,
167) -> Result<Array> {
168    let stride = stride.into().unwrap_or((1, 1, 1));
169    let padding = padding.into().unwrap_or((0, 0, 0));
170    let dilation = dilation.into().unwrap_or((1, 1, 1));
171    let groups = groups.into().unwrap_or(1);
172
173    Array::try_from_op(|res| unsafe {
174        mlx_sys::mlx_conv3d(
175            res,
176            array.as_ref().as_ptr(),
177            weight.as_ref().as_ptr(),
178            stride.0,
179            stride.1,
180            stride.2,
181            padding.0,
182            padding.1,
183            padding.2,
184            dilation.0,
185            dilation.1,
186            dilation.2,
187            groups,
188            stream.as_ref().as_ptr(),
189        )
190    })
191}
192
193/// 1D transposed convolution over an input with several channels.
194///
195/// Only the default `groups=1` is currently supported.
196///
197/// # Params
198///
199/// - array: input array of shape `[N, H, C_in]`
200/// - weight: weight array of shape `[C_out, H, C_in]`
201/// - stride: kernel stride. Default to 1 if not specified.
202/// - padding: input padding. Default to 0 if not specified.
203/// - dilation: kernel dilation. Default to 1 if not specified.
204/// - groups: input feature groups. Default to 1 if not specified.
205/// - stream: stream or device to evaluate on.
206#[generate_macro]
207#[default_device]
208pub fn conv_transpose1d_device(
209    array: impl AsRef<Array>,
210    weight: impl AsRef<Array>,
211    #[optional] stride: impl Into<Option<i32>>,
212    #[optional] padding: impl Into<Option<i32>>,
213    #[optional] dilation: impl Into<Option<i32>>,
214    #[optional] groups: impl Into<Option<i32>>,
215    #[optional] stream: impl AsRef<Stream>,
216) -> Result<Array> {
217    let stride = stride.into().unwrap_or(1);
218    let padding = padding.into().unwrap_or(0);
219    let dilation = dilation.into().unwrap_or(1);
220    let groups = groups.into().unwrap_or(1);
221
222    Array::try_from_op(|res| unsafe {
223        mlx_sys::mlx_conv_transpose1d(
224            res,
225            array.as_ref().as_ptr(),
226            weight.as_ref().as_ptr(),
227            stride,
228            padding,
229            dilation,
230            groups,
231            stream.as_ref().as_ptr(),
232        )
233    })
234}
235
236/// 2D transposed convolution over an input with several channels.
237///
238/// Only the default `groups=1` is currently supported.
239///
240/// The numeric parameters may be given as single values:
241///
242/// # Params
243/// - array: input array of shape `[N, H, W, C_in]`
244/// - weight: weight array of shape `[C_out, H, W, C_in]`
245/// - stride: kernel stride. Default to (1, 1) if not specified.
246/// - padding: input padding. Default to (0, 0) if not specified.
247/// - dilation: kernel dilation. Default to (1, 1) if not specified.
248/// - groups: input feature groups. Default to 1 if not specified.
249/// - stream: stream or device to evaluate on.
250#[generate_macro]
251#[default_device]
252pub fn conv_transpose2d_device(
253    array: impl AsRef<Array>,
254    weight: impl AsRef<Array>,
255    #[optional] stride: impl Into<Option<(i32, i32)>>,
256    #[optional] padding: impl Into<Option<(i32, i32)>>,
257    #[optional] dilation: impl Into<Option<(i32, i32)>>,
258    #[optional] groups: impl Into<Option<i32>>,
259    #[optional] stream: impl AsRef<Stream>,
260) -> Result<Array> {
261    let stride = stride.into().unwrap_or((1, 1));
262    let padding = padding.into().unwrap_or((0, 0));
263    let dilation = dilation.into().unwrap_or((1, 1));
264    let groups = groups.into().unwrap_or(1);
265
266    Array::try_from_op(|res| unsafe {
267        mlx_sys::mlx_conv_transpose2d(
268            res,
269            array.as_ref().as_ptr(),
270            weight.as_ref().as_ptr(),
271            stride.0,
272            stride.1,
273            padding.0,
274            padding.1,
275            dilation.0,
276            dilation.1,
277            groups,
278            stream.as_ref().as_ptr(),
279        )
280    })
281}
282
283/// 3D transposed convolution over an input with several channels.
284///
285/// Only the default `groups=1` is currently supported.
286///
287/// The numeric parameters may be given as single values:
288///
289/// # Params
290/// - array: input array of shape `[N, D, H, W, C_in]`
291/// - weight: weight array of shape `[C_out, D, H, W, C_in]`
292/// - stride: kernel stride. Default to (1, 1, 1) if not specified.
293/// - padding: input padding. Default to (0, 0, 0) if not specified.
294/// - dilation: kernel dilation. Default to (1, 1, 1) if not specified.
295/// - groups: input feature groups. Default to 1 if not specified.
296/// - stream: stream or device to evaluate on.
297#[generate_macro]
298#[default_device]
299pub fn conv_transpose3d_device(
300    array: impl AsRef<Array>,
301    weight: impl AsRef<Array>,
302    #[optional] stride: impl Into<Option<(i32, i32, i32)>>,
303    #[optional] padding: impl Into<Option<(i32, i32, i32)>>,
304    #[optional] dilation: impl Into<Option<(i32, i32, i32)>>,
305    #[optional] groups: impl Into<Option<i32>>,
306    #[optional] stream: impl AsRef<Stream>,
307) -> Result<Array> {
308    let stride = stride.into().unwrap_or((1, 1, 1));
309    let padding = padding.into().unwrap_or((0, 0, 0));
310    let dilation = dilation.into().unwrap_or((1, 1, 1));
311    let groups = groups.into().unwrap_or(1);
312
313    Array::try_from_op(|res| unsafe {
314        mlx_sys::mlx_conv_transpose3d(
315            res,
316            array.as_ref().as_ptr(),
317            weight.as_ref().as_ptr(),
318            stride.0,
319            stride.1,
320            stride.2,
321            padding.0,
322            padding.1,
323            padding.2,
324            dilation.0,
325            dilation.1,
326            dilation.2,
327            groups,
328            stream.as_ref().as_ptr(),
329        )
330    })
331}
332
333#[cfg(test)]
334mod tests {
335    use super::*;
336    use pretty_assertions::assert_eq;
337
338    #[test]
339    fn test_conv1d_complex_device() {
340        // Define a 1D input with two channels
341        let input_data = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
342        let input_array = Array::from_slice(&input_data, &[1, 5, 2]);
343
344        // Define a 1D kernel with two input channels and two output channels
345        let weight_data = [0.5, 0.0, -0.5, 1.0, 0.0, 1.5, 2.0, 0.0, -2.0, 1.5, 0.0, 1.0];
346        let weight_array = Array::from_slice(&weight_data, &[2, 3, 2]);
347
348        let result = conv1d(
349            &input_array,
350            &weight_array,
351            Some(1), // stride
352            Some(0), // padding
353            Some(1), // dilation
354            Some(1), // groups
355        )
356        .unwrap();
357
358        let expected_output = [12.0, 8.0, 17.0, 13.0, 22.0, 18.0];
359        assert_eq!(result.shape(), &[1, 3, 2]);
360        assert_eq!(result.as_slice::<f32>(), &expected_output);
361    }
362
363    #[test]
364    fn test_conv_transpose1d() {
365        // Single channel input
366        let input = Array::from_slice(&[1.0, 2.0, 3.0], &[1, 3, 1]);
367        // Single input/output channel kernel
368        let weights = Array::from_slice(&[1.0, 0.5], &[1, 2, 1]);
369
370        let result = conv_transpose1d(
371            &input,
372            &weights,
373            Some(1), // stride
374            Some(0), // padding
375            Some(1), // dilation
376            Some(1), // groups
377        )
378        .unwrap();
379
380        let expected = [1.0, 2.5, 4.0, 1.5];
381        assert_eq!(result.shape(), &[1, 4, 1]);
382        assert_eq!(result.as_slice::<f32>(), &expected);
383    }
384
385    #[test]
386    fn test_conv2d() {
387        // Define a 2x2 input with one channel (grayscale image or similar)
388        let input_data = [1.0, 2.0, 3.0, 4.0];
389        let input_shape = [1, 2, 2, 1]; // [N, H, W, C]
390        let input_array = Array::from_slice(&input_data, &input_shape);
391
392        // Define a 2x2 kernel with one input channel and one output channel
393        let weight_data = [1.0, 0.0, 0.0, 1.0];
394        let weight_shape = [1, 2, 2, 1]; // [C_out, H_k, W_k, C_in]
395        let weight_array = Array::from_slice(&weight_data, &weight_shape);
396
397        // Perform the convolution with no padding and stride of 1
398        let result = conv2d(
399            &input_array,
400            &weight_array,
401            Some((1, 1)), // stride
402            Some((0, 0)), // padding
403            Some((1, 1)), // dilation
404            Some(1),      // groups
405        )
406        .unwrap();
407
408        // Expected result is the convolution of a 2x2 filter over a 2x2 input with valid padding, resulting in a single output value
409        let expected_output = 1.0 * 1.0 + 2.0 * 0.0 + 3.0 * 0.0 + 4.0 * 1.0; // = 1*1 + 4*1 = 5
410        assert_eq!(result.as_slice::<f32>(), &[expected_output]);
411    }
412
413    #[test]
414    fn test_conv_transpose2d() {
415        // 2x2 single channel input
416        let input = Array::from_slice(&[1.0, 2.0, 3.0, 4.0], &[1, 2, 2, 1]);
417        // 2x2 single channel kernel (identity-like)
418        let weights = Array::from_slice(&[1.0, 0.0, 0.0, 1.0], &[1, 2, 2, 1]);
419
420        let result = conv_transpose2d(
421            &input,
422            &weights,
423            Some((1, 1)), // stride
424            Some((0, 0)), // padding
425            Some((1, 1)), // dilation
426            Some(1),      // groups
427        )
428        .unwrap();
429
430        let expected = [1.0, 2.0, 0.0, 3.0, 5.0, 2.0, 0.0, 3.0, 4.0];
431        assert_eq!(result.shape(), &[1, 3, 3, 1]);
432        assert_eq!(result.as_slice::<f32>(), &expected);
433    }
434
435    #[test]
436    fn test_conv3d() {
437        // Define a 2x2x2 input with one channel
438        let input_data = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
439        let input_shape = [1, 2, 2, 2, 1]; // [N, D, H, W, C]
440        let input_array = Array::from_slice(&input_data, &input_shape);
441
442        // Define a 2x2x2 kernel with one input channel and one output channel
443        let weight_data = [1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0];
444        let weight_shape = [1, 2, 2, 2, 1]; // [C_out, D_k, H_k, W_k, C_in]
445        let weight_array = Array::from_slice(&weight_data, &weight_shape);
446
447        // Perform the convolution with no padding and stride of 1
448        let result = conv3d(
449            &input_array,
450            &weight_array,
451            Some((1, 1, 1)), // stride
452            Some((0, 0, 0)), // padding
453            Some((1, 1, 1)), // dilation
454            Some(1),         // groups
455        )
456        .unwrap();
457
458        // Expected result is the convolution of a 2x2x2 filter over a 2x2x2 input with valid padding, resulting in a single output value
459        let expected_output = 1.0 * 1.0
460            + 2.0 * 0.0
461            + 3.0 * 0.0
462            + 4.0 * 1.0
463            + 5.0 * 0.0
464            + 6.0 * 1.0
465            + 7.0 * 1.0
466            + 8.0 * 0.0; // = 1*1 + 4*1 + 6*1 + 7*1 = 18
467        assert_eq!(result.as_slice::<f32>(), &[expected_output]);
468    }
469
470    #[test]
471    fn test_conv_transpose3d() {
472        // 2x2x2 single channel input
473        let input = Array::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], &[1, 2, 2, 2, 1]);
474        // 2x2x2 single channel kernel
475        let weights =
476            Array::from_slice(&[1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0], &[1, 2, 2, 2, 1]);
477
478        let result = conv_transpose3d(
479            &input,
480            &weights,
481            Some((1, 1, 1)), // stride
482            Some((0, 0, 0)), // padding
483            Some((1, 1, 1)), // dilation
484            Some(1),         // groups
485        )
486        .unwrap();
487
488        assert_eq!(result.shape(), &[1, 3, 3, 3, 1]);
489    }
490
491    #[test]
492    fn test_conv_wrong_dimensions() {
493        let input_data = [1.0, 2.0, 3.0, 4.0];
494        let input_shape = [1, 2, 2, 1]; // [N, H, W, C]
495        let input_array = Array::from_slice(&input_data, &input_shape);
496
497        let weight_data = [1.0, 0.0, 0.0, 1.0];
498        let weight_shape = [1, 2, 2]; // [C_out, H_k, W_k]
499        let weight_array = Array::from_slice(&weight_data, &weight_shape);
500
501        let result = conv2d(
502            &input_array,
503            &weight_array,
504            Some((1, 1)), // stride
505            Some((0, 0)), // padding
506            Some((1, 1)), // dilation
507            Some(1),      // groups
508        );
509
510        assert!(result.is_err());
511    }
512
513    #[test]
514    fn test_conv_invalid_group_size() {
515        let input_data = [1.0, 2.0, 3.0, 4.0];
516        let input_shape = [1, 2, 2, 1]; // [N, H, W, C]
517        let input_array = Array::from_slice(&input_data, &input_shape);
518
519        let weight_data = [1.0, 0.0, 0.0, 1.0];
520        let weight_shape = [1, 2, 2, 1]; // [C_out, H_k, W_k, C_in]
521        let weight_array = Array::from_slice(&weight_data, &weight_shape);
522
523        let result = conv2d(
524            &input_array,
525            &weight_array,
526            Some((1, 1)), // stride
527            Some((0, 0)), // padding
528            Some((1, 1)), // dilation
529            Some(2),      // groups
530        );
531
532        assert!(result.is_err());
533    }
534
535    #[test]
536    fn test_conv_non_float() {
537        let input_data = [1, 2, 3, 4];
538        let input_shape = [1, 2, 2, 1]; // [N, H, W, C]
539        let input_array = Array::from_slice(&input_data, &input_shape);
540
541        let weight_data = [1, 0, 0, 1];
542        let weight_shape = [1, 2, 2, 1]; // [C_out, H_k, W_k, C_in]
543        let weight_array = Array::from_slice(&weight_data, &weight_shape);
544
545        let result = conv2d(
546            &input_array,
547            &weight_array,
548            Some((1, 1)), // stride
549            Some((0, 0)), // padding
550            Some((1, 1)), // dilation
551            Some(1),      // groups
552        );
553
554        assert!(result.is_err());
555    }
556}