mlx_rs::ops

Function conv_general

Source
pub fn conv_general<'a>(
    array: impl AsRef<Array>,
    weight: impl AsRef<Array>,
    strides: impl IntoOption<&'a [i32]>,
    padding: impl IntoOption<&'a [i32]>,
    kernel_dilation: impl IntoOption<&'a [i32]>,
    input_dilation: impl IntoOption<&'a [i32]>,
    groups: impl Into<Option<i32>>,
    flip: impl Into<Option<bool>>,
) -> Result<Array>
Expand description

General convolution over an input with several channels returning an error if the inputs are invalid.

  • Only 1d and 2d convolutions are supported at the moment
  • the default groups: 1 is currently supported

ยงParams

  • array: Input array of shape &[N, ..., C_in]
  • weight: Weight array of shape &[C_out, ..., C_in]
  • strides: The kernel strides. All dimensions get the same stride if only one number is specified.
  • padding: The input padding. All dimensions get the same padding if only one number is specified.
  • kernel_dilation: The kernel dilation. All dimensions get the same dilation if only one number is specified.
  • input_dilation: The input dilation. All dimensions get the same dilation if only one number is specified.
  • groups: Input feature groups
  • flip: Flip the order in which the spatial dimensions of the weights are processed. Performs the cross-correlation operator when flip is false and the convolution operator otherwise.