mlx_rs::ops

Function conv2d

Source
pub fn conv2d(
    array: impl AsRef<Array>,
    weight: impl AsRef<Array>,
    stride: impl Into<Option<(i32, i32)>>,
    padding: impl Into<Option<(i32, i32)>>,
    dilation: impl Into<Option<(i32, i32)>>,
    groups: impl Into<Option<i32>>,
) -> Result<Array>
Expand description

2D convolution over an input with several channels returning an error if the inputs are invalid.

Only the default groups=1 is currently supported.

ยงParams

  • array: input array of shape [N, H, W, C_in]
  • weight: weight array of shape [C_out, H, W, C_in]
  • stride: kernel stride. Default to (1, 1) if not specified.
  • padding: input padding. Default to (0, 0) if not specified.
  • dilation: kernel dilation. Default to (1, 1) if not specified.
  • groups: input feature groups. Default to 1 if not specified.