mlx_rs/ops/
convolution.rs

1use crate::error::Result;
2use crate::utils::guard::Guarded;
3use crate::utils::IntoOption;
4use crate::{Array, Stream};
5use mlx_internal_macros::{default_device, generate_macro};
6
7/// General convolution over an input with several channels returning an error if the inputs are invalid.
8///
9/// - Only 1d and 2d convolutions are supported at the moment
10/// - the default `groups: 1` is currently supported
11///
12/// # Params
13///
14/// - array: Input array of shape `&[N, ..., C_in]`
15/// - weight: Weight array of shape `&[C_out, ..., C_in]`
16/// - strides: The kernel strides. All dimensions get the same stride if only one number is specified.
17/// - padding: The input padding. All dimensions get the same padding if only one number is specified.
18/// - kernel_dilation: The kernel dilation. All dimensions get the same dilation if only one number is specified.
19/// - input_dilation: The input dilation. All dimensions get the same dilation if only one number is specified.
20/// - groups: Input feature groups
21/// - flip: Flip the order in which the spatial dimensions of the weights are processed.
22///   Performs the cross-correlation operator when `flip` is `false` and the convolution
23///   operator otherwise.
24#[generate_macro]
25#[default_device]
26#[allow(clippy::too_many_arguments)]
27pub fn conv_general_device<'a>(
28    array: impl AsRef<Array>,
29    weight: impl AsRef<Array>,
30    #[optional] strides: impl IntoOption<&'a [i32]>,
31    #[optional] padding: impl IntoOption<&'a [i32]>,
32    #[optional] kernel_dilation: impl IntoOption<&'a [i32]>,
33    #[optional] input_dilation: impl IntoOption<&'a [i32]>,
34    #[optional] groups: impl Into<Option<i32>>,
35    #[optional] flip: impl Into<Option<bool>>,
36    #[optional] stream: impl AsRef<Stream>,
37) -> Result<Array> {
38    let strides = strides.into_option().unwrap_or(&[1]);
39    let padding = padding.into_option().unwrap_or(&[0]);
40    let kernel_dilation = kernel_dilation.into_option().unwrap_or(&[1]);
41    let input_dilation = input_dilation.into_option().unwrap_or(&[1]);
42    let groups = groups.into().unwrap_or(1);
43    let flip = flip.into().unwrap_or(false);
44
45    Array::try_from_op(|res| unsafe {
46        mlx_sys::mlx_conv_general(
47            res,
48            array.as_ref().as_ptr(),
49            weight.as_ref().as_ptr(),
50            strides.as_ptr(),
51            strides.len(),
52            padding.as_ptr(),
53            padding.len(),
54            padding.as_ptr(),
55            padding.len(),
56            kernel_dilation.as_ptr(),
57            kernel_dilation.len(),
58            input_dilation.as_ptr(),
59            input_dilation.len(),
60            groups,
61            flip,
62            stream.as_ref().as_ptr(),
63        )
64    })
65}
66
67/// 1D convolution over an input with several channels returning an error if the inputs are invalid.
68///
69/// Only the default `groups=1` is currently supported.
70///
71/// # Params
72///
73/// - array: input array of shape `&[N, H, C_in]`
74/// - weight: weight array of shape `&[C_out, H, C_in]`
75/// - stride: kernel stride. Default to 1 if not specified.
76/// - padding: input padding. Default to 0 if not specified.
77/// - dilation: kernel dilation. Default to 1 if not specified.
78/// - groups: input feature groups. Default to 1 if not specified.
79#[generate_macro]
80#[default_device]
81pub fn conv1d_device(
82    array: impl AsRef<Array>,
83    weight: impl AsRef<Array>,
84    #[optional] stride: impl Into<Option<i32>>,
85    #[optional] padding: impl Into<Option<i32>>,
86    #[optional] dilation: impl Into<Option<i32>>,
87    #[optional] groups: impl Into<Option<i32>>,
88    #[optional] stream: impl AsRef<Stream>,
89) -> Result<Array> {
90    let stride = stride.into().unwrap_or(1);
91    let padding = padding.into().unwrap_or(0);
92    let dilation = dilation.into().unwrap_or(1);
93    let groups = groups.into().unwrap_or(1);
94
95    Array::try_from_op(|res| unsafe {
96        mlx_sys::mlx_conv1d(
97            res,
98            array.as_ref().as_ptr(),
99            weight.as_ref().as_ptr(),
100            stride,
101            padding,
102            dilation,
103            groups,
104            stream.as_ref().as_ptr(),
105        )
106    })
107}
108
109/// 2D convolution over an input with several channels returning an error if the inputs are invalid.
110///
111/// Only the default `groups=1` is currently supported.
112///
113/// # Params
114///
115/// - array: input array of shape `[N, H, W, C_in]`
116/// - weight: weight array of shape `[C_out, H, W, C_in]`
117/// - stride: kernel stride. Default to (1, 1) if not specified.
118/// - padding: input padding. Default to (0, 0) if not specified.
119/// - dilation: kernel dilation. Default to (1, 1) if not specified.
120/// - groups: input feature groups. Default to 1 if not specified.
121#[generate_macro]
122#[default_device]
123pub fn conv2d_device(
124    array: impl AsRef<Array>,
125    weight: impl AsRef<Array>,
126    #[optional] stride: impl Into<Option<(i32, i32)>>,
127    #[optional] padding: impl Into<Option<(i32, i32)>>,
128    #[optional] dilation: impl Into<Option<(i32, i32)>>,
129    #[optional] groups: impl Into<Option<i32>>,
130    #[optional] stream: impl AsRef<Stream>,
131) -> Result<Array> {
132    let stride = stride.into().unwrap_or((1, 1));
133    let padding = padding.into().unwrap_or((0, 0));
134    let dilation = dilation.into().unwrap_or((1, 1));
135    let groups = groups.into().unwrap_or(1);
136
137    Array::try_from_op(|res| unsafe {
138        mlx_sys::mlx_conv2d(
139            res,
140            array.as_ref().as_ptr(),
141            weight.as_ref().as_ptr(),
142            stride.0,
143            stride.1,
144            padding.0,
145            padding.1,
146            dilation.0,
147            dilation.1,
148            groups,
149            stream.as_ref().as_ptr(),
150        )
151    })
152}
153
154/// 3D convolution over an input with several channels.
155///
156/// Only the default `groups=1` is currently supported.
157#[generate_macro]
158#[default_device]
159pub fn conv3d_device(
160    array: impl AsRef<Array>,
161    weight: impl AsRef<Array>,
162    #[optional] stride: impl Into<Option<(i32, i32, i32)>>,
163    #[optional] padding: impl Into<Option<(i32, i32, i32)>>,
164    #[optional] dilation: impl Into<Option<(i32, i32, i32)>>,
165    #[optional] groups: impl Into<Option<i32>>,
166    #[optional] stream: impl AsRef<Stream>,
167) -> Result<Array> {
168    let stride = stride.into().unwrap_or((1, 1, 1));
169    let padding = padding.into().unwrap_or((0, 0, 0));
170    let dilation = dilation.into().unwrap_or((1, 1, 1));
171    let groups = groups.into().unwrap_or(1);
172
173    Array::try_from_op(|res| unsafe {
174        mlx_sys::mlx_conv3d(
175            res,
176            array.as_ref().as_ptr(),
177            weight.as_ref().as_ptr(),
178            stride.0,
179            stride.1,
180            stride.2,
181            padding.0,
182            padding.1,
183            padding.2,
184            dilation.0,
185            dilation.1,
186            dilation.2,
187            groups,
188            stream.as_ref().as_ptr(),
189        )
190    })
191}
192
193/// 1D transposed convolution over an input with several channels.
194///
195/// Only the default `groups=1` is currently supported.
196///
197/// # Params
198///
199/// - array: input array of shape `[N, H, C_in]`
200/// - weight: weight array of shape `[C_out, H, C_in]`
201/// - stride: kernel stride. Default to 1 if not specified.
202/// - padding: input padding. Default to 0 if not specified.
203/// - dilation: kernel dilation. Default to 1 if not specified.
204/// - groups: input feature groups. Default to 1 if not specified.
205/// - stream: stream or device to evaluate on.
206#[allow(clippy::too_many_arguments)]
207#[generate_macro]
208#[default_device]
209pub fn conv_transpose1d_device(
210    array: impl AsRef<Array>,
211    weight: impl AsRef<Array>,
212    #[optional] stride: impl Into<Option<i32>>,
213    #[optional] padding: impl Into<Option<i32>>,
214    #[optional] dilation: impl Into<Option<i32>>,
215    #[optional] output_padding: impl Into<Option<i32>>,
216    #[optional] groups: impl Into<Option<i32>>,
217    #[optional] stream: impl AsRef<Stream>,
218) -> Result<Array> {
219    let stride = stride.into().unwrap_or(1);
220    let padding = padding.into().unwrap_or(0);
221    let dilation = dilation.into().unwrap_or(1);
222    let output_padding = output_padding.into().unwrap_or(0);
223    let groups = groups.into().unwrap_or(1);
224
225    Array::try_from_op(|res| unsafe {
226        mlx_sys::mlx_conv_transpose1d(
227            res,
228            array.as_ref().as_ptr(),
229            weight.as_ref().as_ptr(),
230            stride,
231            padding,
232            dilation,
233            output_padding,
234            groups,
235            stream.as_ref().as_ptr(),
236        )
237    })
238}
239
240/// 2D transposed convolution over an input with several channels.
241///
242/// Only the default `groups=1` is currently supported.
243///
244/// The numeric parameters may be given as single values:
245///
246/// # Params
247/// - array: input array of shape `[N, H, W, C_in]`
248/// - weight: weight array of shape `[C_out, H, W, C_in]`
249/// - stride: kernel stride. Default to (1, 1) if not specified.
250/// - padding: input padding. Default to (0, 0) if not specified.
251/// - dilation: kernel dilation. Default to (1, 1) if not specified.
252/// - groups: input feature groups. Default to 1 if not specified.
253/// - stream: stream or device to evaluate on.
254#[allow(clippy::too_many_arguments)]
255#[generate_macro]
256#[default_device]
257pub fn conv_transpose2d_device(
258    array: impl AsRef<Array>,
259    weight: impl AsRef<Array>,
260    #[optional] stride: impl Into<Option<(i32, i32)>>,
261    #[optional] padding: impl Into<Option<(i32, i32)>>,
262    #[optional] dilation: impl Into<Option<(i32, i32)>>,
263    #[optional] output_padding: impl Into<Option<(i32, i32)>>,
264    #[optional] groups: impl Into<Option<i32>>,
265    #[optional] stream: impl AsRef<Stream>,
266) -> Result<Array> {
267    let stride = stride.into().unwrap_or((1, 1));
268    let padding = padding.into().unwrap_or((0, 0));
269    let dilation = dilation.into().unwrap_or((1, 1));
270    let output_padding = output_padding.into().unwrap_or((0, 0));
271    let groups = groups.into().unwrap_or(1);
272
273    Array::try_from_op(|res| unsafe {
274        mlx_sys::mlx_conv_transpose2d(
275            res,
276            array.as_ref().as_ptr(),
277            weight.as_ref().as_ptr(),
278            stride.0,
279            stride.1,
280            padding.0,
281            padding.1,
282            dilation.0,
283            dilation.1,
284            output_padding.0,
285            output_padding.1,
286            groups,
287            stream.as_ref().as_ptr(),
288        )
289    })
290}
291
292/// 3D transposed convolution over an input with several channels.
293///
294/// Only the default `groups=1` is currently supported.
295///
296/// The numeric parameters may be given as single values:
297///
298/// # Params
299/// - array: input array of shape `[N, D, H, W, C_in]`
300/// - weight: weight array of shape `[C_out, D, H, W, C_in]`
301/// - stride: kernel stride. Default to (1, 1, 1) if not specified.
302/// - padding: input padding. Default to (0, 0, 0) if not specified.
303/// - dilation: kernel dilation. Default to (1, 1, 1) if not specified.
304/// - groups: input feature groups. Default to 1 if not specified.
305/// - stream: stream or device to evaluate on.
306#[allow(clippy::too_many_arguments)]
307#[generate_macro]
308#[default_device]
309pub fn conv_transpose3d_device(
310    array: impl AsRef<Array>,
311    weight: impl AsRef<Array>,
312    #[optional] stride: impl Into<Option<(i32, i32, i32)>>,
313    #[optional] padding: impl Into<Option<(i32, i32, i32)>>,
314    #[optional] dilation: impl Into<Option<(i32, i32, i32)>>,
315    #[optional] output_padding: impl Into<Option<(i32, i32, i32)>>,
316    #[optional] groups: impl Into<Option<i32>>,
317    #[optional] stream: impl AsRef<Stream>,
318) -> Result<Array> {
319    let stride = stride.into().unwrap_or((1, 1, 1));
320    let padding = padding.into().unwrap_or((0, 0, 0));
321    let dilation = dilation.into().unwrap_or((1, 1, 1));
322    let output_padding = output_padding.into().unwrap_or((0, 0, 0));
323    let groups = groups.into().unwrap_or(1);
324
325    Array::try_from_op(|res| unsafe {
326        mlx_sys::mlx_conv_transpose3d(
327            res,
328            array.as_ref().as_ptr(),
329            weight.as_ref().as_ptr(),
330            stride.0,
331            stride.1,
332            stride.2,
333            padding.0,
334            padding.1,
335            padding.2,
336            dilation.0,
337            dilation.1,
338            dilation.2,
339            output_padding.0,
340            output_padding.1,
341            output_padding.2,
342            groups,
343            stream.as_ref().as_ptr(),
344        )
345    })
346}
347
348#[cfg(test)]
349mod tests {
350    use super::*;
351    use pretty_assertions::assert_eq;
352
353    #[test]
354    fn test_conv1d_complex_device() {
355        // Define a 1D input with two channels
356        let input_data = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
357        let input_array = Array::from_slice(&input_data, &[1, 5, 2]);
358
359        // Define a 1D kernel with two input channels and two output channels
360        let weight_data = [0.5, 0.0, -0.5, 1.0, 0.0, 1.5, 2.0, 0.0, -2.0, 1.5, 0.0, 1.0];
361        let weight_array = Array::from_slice(&weight_data, &[2, 3, 2]);
362
363        let result = conv1d(
364            &input_array,
365            &weight_array,
366            Some(1), // stride
367            Some(0), // padding
368            Some(1), // dilation
369            Some(1), // groups
370        )
371        .unwrap();
372
373        let expected_output = [12.0, 8.0, 17.0, 13.0, 22.0, 18.0];
374        assert_eq!(result.shape(), &[1, 3, 2]);
375        assert_eq!(result.as_slice::<f32>(), &expected_output);
376    }
377
378    #[test]
379    fn test_conv_transpose1d() {
380        // Single channel input
381        let input = Array::from_slice(&[1.0, 2.0, 3.0], &[1, 3, 1]);
382        // Single input/output channel kernel
383        let weights = Array::from_slice(&[1.0, 0.5], &[1, 2, 1]);
384
385        let result = conv_transpose1d(
386            &input,
387            &weights,
388            Some(1), // stride
389            Some(0), // padding
390            Some(1), // dilation
391            None,    // output padding
392            Some(1), // groups
393        )
394        .unwrap();
395
396        let expected = [1.0, 2.5, 4.0, 1.5];
397        assert_eq!(result.shape(), &[1, 4, 1]);
398        assert_eq!(result.as_slice::<f32>(), &expected);
399    }
400
401    #[test]
402    fn test_conv2d() {
403        // Define a 2x2 input with one channel (grayscale image or similar)
404        let input_data = [1.0, 2.0, 3.0, 4.0];
405        let input_shape = [1, 2, 2, 1]; // [N, H, W, C]
406        let input_array = Array::from_slice(&input_data, &input_shape);
407
408        // Define a 2x2 kernel with one input channel and one output channel
409        let weight_data = [1.0, 0.0, 0.0, 1.0];
410        let weight_shape = [1, 2, 2, 1]; // [C_out, H_k, W_k, C_in]
411        let weight_array = Array::from_slice(&weight_data, &weight_shape);
412
413        // Perform the convolution with no padding and stride of 1
414        let result = conv2d(
415            &input_array,
416            &weight_array,
417            Some((1, 1)), // stride
418            Some((0, 0)), // padding
419            Some((1, 1)), // dilation
420            Some(1),      // groups
421        )
422        .unwrap();
423
424        // Expected result is the convolution of a 2x2 filter over a 2x2 input with valid padding, resulting in a single output value
425        let expected_output = 1.0 * 1.0 + 2.0 * 0.0 + 3.0 * 0.0 + 4.0 * 1.0; // = 1*1 + 4*1 = 5
426        assert_eq!(result.as_slice::<f32>(), &[expected_output]);
427    }
428
429    #[test]
430    fn test_conv_transpose2d() {
431        // 2x2 single channel input
432        let input = Array::from_slice(&[1.0, 2.0, 3.0, 4.0], &[1, 2, 2, 1]);
433        // 2x2 single channel kernel (identity-like)
434        let weights = Array::from_slice(&[1.0, 0.0, 0.0, 1.0], &[1, 2, 2, 1]);
435
436        let result = conv_transpose2d(
437            &input,
438            &weights,
439            Some((1, 1)), // stride
440            Some((0, 0)), // padding
441            Some((1, 1)), // dilation
442            None,         // output padding
443            Some(1),      // groups
444        )
445        .unwrap();
446
447        let expected = [1.0, 2.0, 0.0, 3.0, 5.0, 2.0, 0.0, 3.0, 4.0];
448        assert_eq!(result.shape(), &[1, 3, 3, 1]);
449        assert_eq!(result.as_slice::<f32>(), &expected);
450    }
451
452    #[test]
453    fn test_conv3d() {
454        // Define a 2x2x2 input with one channel
455        let input_data = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
456        let input_shape = [1, 2, 2, 2, 1]; // [N, D, H, W, C]
457        let input_array = Array::from_slice(&input_data, &input_shape);
458
459        // Define a 2x2x2 kernel with one input channel and one output channel
460        let weight_data = [1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0];
461        let weight_shape = [1, 2, 2, 2, 1]; // [C_out, D_k, H_k, W_k, C_in]
462        let weight_array = Array::from_slice(&weight_data, &weight_shape);
463
464        // Perform the convolution with no padding and stride of 1
465        let result = conv3d(
466            &input_array,
467            &weight_array,
468            Some((1, 1, 1)), // stride
469            Some((0, 0, 0)), // padding
470            Some((1, 1, 1)), // dilation
471            Some(1),         // groups
472        )
473        .unwrap();
474
475        // Expected result is the convolution of a 2x2x2 filter over a 2x2x2 input with valid padding, resulting in a single output value
476        let expected_output = 1.0 * 1.0
477            + 2.0 * 0.0
478            + 3.0 * 0.0
479            + 4.0 * 1.0
480            + 5.0 * 0.0
481            + 6.0 * 1.0
482            + 7.0 * 1.0
483            + 8.0 * 0.0; // = 1*1 + 4*1 + 6*1 + 7*1 = 18
484        assert_eq!(result.as_slice::<f32>(), &[expected_output]);
485    }
486
487    #[test]
488    fn test_conv_transpose3d() {
489        // 2x2x2 single channel input
490        let input = Array::from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], &[1, 2, 2, 2, 1]);
491        // 2x2x2 single channel kernel
492        let weights =
493            Array::from_slice(&[1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0], &[1, 2, 2, 2, 1]);
494
495        let result = conv_transpose3d(
496            &input,
497            &weights,
498            Some((1, 1, 1)), // stride
499            Some((0, 0, 0)), // padding
500            Some((1, 1, 1)), // dilation
501            None,            // output padding
502            Some(1),         // groups
503        )
504        .unwrap();
505
506        assert_eq!(result.shape(), &[1, 3, 3, 3, 1]);
507    }
508
509    #[test]
510    fn test_conv_wrong_dimensions() {
511        let input_data = [1.0, 2.0, 3.0, 4.0];
512        let input_shape = [1, 2, 2, 1]; // [N, H, W, C]
513        let input_array = Array::from_slice(&input_data, &input_shape);
514
515        let weight_data = [1.0, 0.0, 0.0, 1.0];
516        let weight_shape = [1, 2, 2]; // [C_out, H_k, W_k]
517        let weight_array = Array::from_slice(&weight_data, &weight_shape);
518
519        let result = conv2d(
520            &input_array,
521            &weight_array,
522            Some((1, 1)), // stride
523            Some((0, 0)), // padding
524            Some((1, 1)), // dilation
525            Some(1),      // groups
526        );
527
528        assert!(result.is_err());
529    }
530
531    #[test]
532    fn test_conv_invalid_group_size() {
533        let input_data = [1.0, 2.0, 3.0, 4.0];
534        let input_shape = [1, 2, 2, 1]; // [N, H, W, C]
535        let input_array = Array::from_slice(&input_data, &input_shape);
536
537        let weight_data = [1.0, 0.0, 0.0, 1.0];
538        let weight_shape = [1, 2, 2, 1]; // [C_out, H_k, W_k, C_in]
539        let weight_array = Array::from_slice(&weight_data, &weight_shape);
540
541        let result = conv2d(
542            &input_array,
543            &weight_array,
544            Some((1, 1)), // stride
545            Some((0, 0)), // padding
546            Some((1, 1)), // dilation
547            Some(2),      // groups
548        );
549
550        assert!(result.is_err());
551    }
552
553    #[test]
554    fn test_conv_non_float() {
555        let input_data = [1, 2, 3, 4];
556        let input_shape = [1, 2, 2, 1]; // [N, H, W, C]
557        let input_array = Array::from_slice(&input_data, &input_shape);
558
559        let weight_data = [1, 0, 0, 1];
560        let weight_shape = [1, 2, 2, 1]; // [C_out, H_k, W_k, C_in]
561        let weight_array = Array::from_slice(&weight_data, &weight_shape);
562
563        let result = conv2d(
564            &input_array,
565            &weight_array,
566            Some((1, 1)), // stride
567            Some((0, 0)), // padding
568            Some((1, 1)), // dilation
569            Some(1),      // groups
570        );
571
572        assert!(result.is_err());
573    }
574}