mlx_rs::ops

Function conv1d

Source
pub fn conv1d(
    array: impl AsRef<Array>,
    weight: impl AsRef<Array>,
    stride: impl Into<Option<i32>>,
    padding: impl Into<Option<i32>>,
    dilation: impl Into<Option<i32>>,
    groups: impl Into<Option<i32>>,
) -> Result<Array>
Expand description

1D convolution over an input with several channels returning an error if the inputs are invalid.

Only the default groups=1 is currently supported.

ยงParams

  • array: input array of shape &[N, H, C_in]
  • weight: weight array of shape &[C_out, H, C_in]
  • stride: kernel stride. Default to 1 if not specified.
  • padding: input padding. Default to 0 if not specified.
  • dilation: kernel dilation. Default to 1 if not specified.
  • groups: input feature groups. Default to 1 if not specified.